use crate::error::Result;
use crate::optimizer::{ExampleSet, OptimizationResult, Optimizer, OptimizerConfig, Rng};
use smallvec::SmallVec;
#[derive(Clone, Copy, Default)]
pub enum SelectionStrategy {
#[default]
First,
Last,
Random,
Stratified,
}
#[derive(Clone, Copy)]
pub struct LabeledConfig {
pub base: OptimizerConfig,
pub k: u8,
pub strategy: SelectionStrategy,
pub seed: u64,
}
impl Default for LabeledConfig {
fn default() -> Self {
Self {
base: OptimizerConfig::default(),
k: 5,
strategy: SelectionStrategy::First,
seed: 42,
}
}
}
impl LabeledConfig {
pub const fn new() -> Self {
Self {
base: OptimizerConfig::new(),
k: 5,
strategy: SelectionStrategy::First,
seed: 42,
}
}
pub const fn with_k(mut self, k: u8) -> Self {
self.k = k;
self
}
pub const fn with_strategy(mut self, strategy: SelectionStrategy) -> Self {
self.strategy = strategy;
self
}
pub const fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
#[derive(Clone, Copy)]
pub struct LabeledFewShot {
config: LabeledConfig,
}
impl LabeledFewShot {
pub const fn new(config: LabeledConfig) -> Self {
Self { config }
}
pub const fn default() -> Self {
Self::new(LabeledConfig::new())
}
pub const fn config(&self) -> &LabeledConfig {
&self.config
}
pub fn select<'a>(&self, trainset: &ExampleSet<'a>) -> SmallVec<[u32; 8]> {
let n = (self.config.k as usize).min(trainset.len());
match self.config.strategy {
SelectionStrategy::First => (0..n as u32).collect(),
SelectionStrategy::Last => {
let start = trainset.len().saturating_sub(n);
(start as u32..trainset.len() as u32).collect()
}
SelectionStrategy::Random => {
let mut rng = Rng::new(self.config.seed);
let mut indices: SmallVec<[u32; 8]> = SmallVec::new();
let mut available: Vec<u32> = (0..trainset.len() as u32).collect();
for _ in 0..n {
if available.is_empty() {
break;
}
let idx = rng.gen_range(0, available.len() as u64) as usize;
indices.push(available.swap_remove(idx));
}
indices
}
SelectionStrategy::Stratified => {
let mut rng = Rng::new(self.config.seed);
let mut indices: SmallVec<[u32; 8]> = SmallVec::new();
if trainset.len() <= n {
return (0..trainset.len() as u32).collect();
}
let step = trainset.len() as f64 / n as f64;
for i in 0..n {
let base = (i as f64 * step) as u32;
let jitter = rng.gen_range(0, (step as u64).max(1)) as u32;
let idx = (base + jitter).min(trainset.len() as u32 - 1);
if !indices.contains(&idx) {
indices.push(idx);
}
}
while indices.len() < n {
let idx = rng.gen_range(0, trainset.len() as u64) as u32;
if !indices.contains(&idx) {
indices.push(idx);
}
}
indices
}
}
}
}
impl Optimizer for LabeledFewShot {
type Output<'a> = OptimizationResult;
type OptimizeFut<'a> = std::future::Ready<Result<OptimizationResult>>;
fn optimize<'a>(&'a self, trainset: &'a ExampleSet<'a>) -> Self::OptimizeFut<'a> {
let indices = self.select(trainset);
std::future::ready(Ok(OptimizationResult {
demo_indices: indices,
score: 0.0,
iterations: 0,
}))
}
fn name(&self) -> &'static str {
"LabeledFewShot"
}
}
pub struct LabeledFewShotBuilder {
config: LabeledConfig,
}
impl LabeledFewShotBuilder {
pub fn new() -> Self {
Self {
config: LabeledConfig::default(),
}
}
pub fn k(mut self, k: u8) -> Self {
self.config.k = k;
self
}
pub fn first(mut self) -> Self {
self.config.strategy = SelectionStrategy::First;
self
}
pub fn last(mut self) -> Self {
self.config.strategy = SelectionStrategy::Last;
self
}
pub fn random(mut self, seed: u64) -> Self {
self.config.strategy = SelectionStrategy::Random;
self.config.seed = seed;
self
}
pub fn stratified(mut self, seed: u64) -> Self {
self.config.strategy = SelectionStrategy::Stratified;
self.config.seed = seed;
self
}
pub fn build(self) -> LabeledFewShot {
LabeledFewShot::new(self.config)
}
}
impl Default for LabeledFewShotBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_labeled_creation() {
let labeled = LabeledFewShot::default();
assert_eq!(labeled.name(), "LabeledFewShot");
assert_eq!(labeled.config().k, 5);
}
#[test]
fn test_labeled_config() {
let config = LabeledConfig::new()
.with_k(10)
.with_strategy(SelectionStrategy::Random)
.with_seed(123);
assert_eq!(config.k, 10);
assert_eq!(config.seed, 123);
}
#[test]
fn test_builder() {
let labeled = LabeledFewShotBuilder::new().k(7).random(999).build();
assert_eq!(labeled.config().k, 7);
assert_eq!(labeled.config().seed, 999);
assert!(matches!(
labeled.config().strategy,
SelectionStrategy::Random
));
}
#[test]
fn test_first_strategy() {
let labeled = LabeledFewShot::new(
LabeledConfig::new()
.with_k(3)
.with_strategy(SelectionStrategy::First),
);
let buffer = crate::buffer::Buffer::Static(
b"input1\noutput1\ninput2\noutput2\ninput3\noutput3\ninput4\noutput4\ninput5\noutput5",
);
let trainset = ExampleSet::from_buffer(buffer, 5);
let indices = labeled.select(&trainset);
assert_eq!(indices.len(), 3);
assert_eq!(indices[0], 0);
assert_eq!(indices[1], 1);
assert_eq!(indices[2], 2);
}
#[test]
fn test_last_strategy() {
let labeled = LabeledFewShot::new(
LabeledConfig::new()
.with_k(3)
.with_strategy(SelectionStrategy::Last),
);
let buffer = crate::buffer::Buffer::Static(
b"input1\noutput1\ninput2\noutput2\ninput3\noutput3\ninput4\noutput4\ninput5\noutput5",
);
let trainset = ExampleSet::from_buffer(buffer, 5);
let indices = labeled.select(&trainset);
assert_eq!(indices.len(), 3);
assert_eq!(indices[0], 2);
assert_eq!(indices[1], 3);
assert_eq!(indices[2], 4);
}
#[test]
fn test_random_strategy() {
let labeled = LabeledFewShot::new(
LabeledConfig::new()
.with_k(3)
.with_strategy(SelectionStrategy::Random)
.with_seed(42),
);
let buffer = crate::buffer::Buffer::Static(
b"input1\noutput1\ninput2\noutput2\ninput3\noutput3\ninput4\noutput4\ninput5\noutput5",
);
let trainset = ExampleSet::from_buffer(buffer, 5);
let indices = labeled.select(&trainset);
assert_eq!(indices.len(), 3);
for idx in &indices {
assert!(*idx < 5);
}
let indices2 = labeled.select(&trainset);
assert_eq!(indices, indices2);
}
}