use crate::buffer::Buffer;
use crate::error::Result;
use crate::intern::Sym;
use crate::predict::FieldRange;
use smallvec::SmallVec;
use std::future::Future;
pub struct ExampleSet<'a> {
buffer: &'a Buffer,
examples: &'a [ExampleMeta],
}
#[derive(Clone, Copy, Debug)]
pub struct ExampleMeta {
pub input_ranges: [(Sym, FieldRange); 4],
pub input_count: u8,
pub output_ranges: [(Sym, FieldRange); 2],
pub output_count: u8,
}
impl ExampleMeta {
pub const fn empty() -> Self {
Self {
input_ranges: [(Sym::EMPTY, FieldRange::new(0, 0)); 4],
input_count: 0,
output_ranges: [(Sym::EMPTY, FieldRange::new(0, 0)); 2],
output_count: 0,
}
}
pub fn inputs(&self) -> impl Iterator<Item = (Sym, FieldRange)> + '_ {
self.input_ranges[..self.input_count as usize]
.iter()
.copied()
}
pub fn outputs(&self) -> impl Iterator<Item = (Sym, FieldRange)> + '_ {
self.output_ranges[..self.output_count as usize]
.iter()
.copied()
}
}
impl<'a> ExampleSet<'a> {
pub const fn new(buffer: &'a Buffer, examples: &'a [ExampleMeta]) -> Self {
Self { buffer, examples }
}
pub fn from_buffer(buffer: Buffer, count: usize) -> ExampleSet<'static> {
let buffer_ref: &'static Buffer = Box::leak(Box::new(buffer));
let examples: &'static [ExampleMeta] =
Box::leak(vec![ExampleMeta::empty(); count].into_boxed_slice());
ExampleSet {
buffer: buffer_ref,
examples,
}
}
#[inline]
pub const fn len(&self) -> usize {
self.examples.len()
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.examples.is_empty()
}
#[inline]
pub const fn buffer(&self) -> &'a Buffer {
self.buffer
}
#[inline]
pub fn get(&self, idx: usize) -> Option<&ExampleMeta> {
self.examples.get(idx)
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = ExampleView<'a>> + '_ {
self.examples.iter().map(|meta| ExampleView {
buffer: self.buffer,
meta,
})
}
pub fn get_input(&self, idx: usize, sym: Sym) -> Option<&'a str> {
let meta = self.examples.get(idx)?;
for (s, fr) in meta.inputs() {
if s == sym {
let bytes = &self.buffer.as_slice()[fr.as_range()];
return std::str::from_utf8(bytes).ok();
}
}
None
}
pub fn get_output(&self, idx: usize, sym: Sym) -> Option<&'a str> {
let meta = self.examples.get(idx)?;
for (s, fr) in meta.outputs() {
if s == sym {
let bytes = &self.buffer.as_slice()[fr.as_range()];
return std::str::from_utf8(bytes).ok();
}
}
None
}
}
#[derive(Clone, Copy)]
pub struct ExampleView<'a> {
buffer: &'a Buffer,
meta: &'a ExampleMeta,
}
impl<'a> ExampleView<'a> {
pub fn input_text(&self) -> crate::str_view::StrView<'a> {
if let Some((_, fr)) = self.meta.inputs().next() {
let bytes = &self.buffer.as_slice()[fr.as_range()];
if let Ok(s) = std::str::from_utf8(bytes) {
return crate::str_view::StrView::new(s);
}
}
crate::str_view::StrView::new("")
}
pub fn get_input(&self, sym: Sym) -> Option<&'a str> {
for (s, fr) in self.meta.inputs() {
if s == sym {
let bytes = &self.buffer.as_slice()[fr.as_range()];
return std::str::from_utf8(bytes).ok();
}
}
None
}
pub fn get_output(&self, sym: Sym) -> Option<&'a str> {
for (s, fr) in self.meta.outputs() {
if s == sym {
let bytes = &self.buffer.as_slice()[fr.as_range()];
return std::str::from_utf8(bytes).ok();
}
}
None
}
pub fn inputs(&self) -> impl Iterator<Item = (Sym, &'a str)> + '_ {
self.meta.inputs().filter_map(|(sym, fr)| {
let bytes = &self.buffer.as_slice()[fr.as_range()];
std::str::from_utf8(bytes).ok().map(|s| (sym, s))
})
}
pub fn outputs(&self) -> impl Iterator<Item = (Sym, &'a str)> + '_ {
self.meta.outputs().filter_map(|(sym, fr)| {
let bytes = &self.buffer.as_slice()[fr.as_range()];
std::str::from_utf8(bytes).ok().map(|s| (sym, s))
})
}
}
pub trait Optimizer: Send + Sync {
type Output<'a>
where
Self: 'a;
type OptimizeFut<'a>: Future<Output = Result<Self::Output<'a>>> + Send + 'a
where
Self: 'a;
fn optimize<'a>(&'a self, trainset: &'a ExampleSet<'a>) -> Self::OptimizeFut<'a>;
fn name(&self) -> &'static str;
fn id(&self) -> Sym {
crate::intern::sym(self.name())
}
}
#[derive(Debug, Clone, Copy)]
pub struct OptimizerConfig {
pub max_iterations: u16,
pub batch_size: u16,
pub seed: u64,
pub metric_threshold: f32,
pub max_demos: u8,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
max_iterations: 10,
batch_size: 50,
seed: 42,
metric_threshold: 0.5,
max_demos: 4,
}
}
}
impl OptimizerConfig {
pub const fn new() -> Self {
Self {
max_iterations: 10,
batch_size: 50,
seed: 42,
metric_threshold: 0.5,
max_demos: 4,
}
}
pub const fn with_max_iterations(mut self, n: u16) -> Self {
self.max_iterations = n;
self
}
pub const fn with_batch_size(mut self, n: u16) -> Self {
self.batch_size = n;
self
}
pub const fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub const fn with_metric_threshold(mut self, threshold: f32) -> Self {
self.metric_threshold = threshold;
self
}
pub const fn with_max_demos(mut self, n: u8) -> Self {
self.max_demos = n;
self
}
}
#[derive(Clone, Copy)]
pub struct Rng(u64);
impl Rng {
pub const fn new(seed: u64) -> Self {
Self(seed)
}
#[inline]
pub fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_mul(1664525).wrapping_add(1013904223);
self.0
}
#[inline]
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() as f64) / (u64::MAX as f64)
}
#[inline]
pub fn next_usize(&mut self, max: usize) -> usize {
((self.next_f64() * max as f64) as usize).min(max.saturating_sub(1))
}
pub fn shuffle<T>(&mut self, slice: &mut [T]) {
for i in (1..slice.len()).rev() {
let j = self.next_usize(i + 1);
slice.swap(i, j);
}
}
#[inline]
pub fn gen_range(&mut self, min: u64, max: u64) -> u64 {
if max <= min {
return min;
}
min + (self.next_u64() % (max - min))
}
#[inline]
pub fn gen_float(&mut self) -> f32 {
self.next_f64() as f32
}
}
#[derive(Clone, Debug)]
pub struct OptimizationResult {
pub demo_indices: SmallVec<[u32; 8]>,
pub score: f64,
pub iterations: u16,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = OptimizerConfig::default();
assert_eq!(config.max_iterations, 10);
assert_eq!(config.batch_size, 50);
}
#[test]
fn test_config_builder() {
let config = OptimizerConfig::new()
.with_max_iterations(20)
.with_batch_size(100)
.with_seed(123);
assert_eq!(config.max_iterations, 20);
assert_eq!(config.batch_size, 100);
assert_eq!(config.seed, 123);
}
#[test]
fn test_rng() {
let mut rng = Rng::new(42);
let a = rng.next_u64();
let b = rng.next_u64();
assert_ne!(a, b);
}
#[test]
fn test_rng_deterministic() {
let mut rng1 = Rng::new(42);
let mut rng2 = Rng::new(42);
for _ in 0..10 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
#[test]
fn test_example_set_empty() {
static BUFFER: Buffer = Buffer::Static(b"");
let set = ExampleSet::new(&BUFFER, &[]);
assert!(set.is_empty());
assert_eq!(set.len(), 0);
}
}