use crate::error::Result;
use crate::optimizer::{ExampleSet, OptimizationResult, Optimizer, OptimizerConfig};
use crate::str_view::StrView;
use smallvec::SmallVec;
use std::future::Future;
use std::pin::Pin;
pub trait Embedder: Send + Sync {
fn embed<'a>(
&'a self,
text: StrView<'a>,
output: &'a mut Vec<f32>,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
fn dimension(&self) -> usize;
}
#[derive(Clone, Copy)]
pub struct KNNConfig {
pub base: OptimizerConfig,
pub k: u8,
pub weighted: bool,
}
impl Default for KNNConfig {
fn default() -> Self {
Self {
base: OptimizerConfig::default(),
k: 3,
weighted: true,
}
}
}
impl KNNConfig {
pub const fn new() -> Self {
Self {
base: OptimizerConfig::new(),
k: 3,
weighted: true,
}
}
pub const fn with_k(mut self, k: u8) -> Self {
self.k = k;
self
}
pub const fn with_weighted(mut self, weighted: bool) -> Self {
self.weighted = weighted;
self
}
}
pub struct EmbeddingIndex {
embeddings: Vec<f32>,
dim: usize,
count: usize,
norms: Vec<f32>,
}
impl EmbeddingIndex {
pub fn new(dim: usize) -> Self {
Self {
embeddings: Vec::new(),
dim,
count: 0,
norms: Vec::new(),
}
}
pub fn add(&mut self, embedding: &[f32]) {
debug_assert_eq!(embedding.len(), self.dim);
self.embeddings.extend_from_slice(embedding);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
self.norms.push(norm);
self.count += 1;
}
#[inline]
pub fn get(&self, idx: usize) -> &[f32] {
let start = idx * self.dim;
&self.embeddings[start..start + self.dim]
}
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn find_nearest(&self, query: &[f32], k: usize) -> SmallVec<[(u32, f32); 8]> {
debug_assert_eq!(query.len(), self.dim);
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
if query_norm == 0.0 {
return SmallVec::new();
}
let mut similarities: Vec<(u32, f32)> = (0..self.count)
.map(|i| {
let emb = self.get(i);
let emb_norm = self.norms[i];
if emb_norm == 0.0 {
return (i as u32, 0.0);
}
let dot: f32 = query.iter().zip(emb.iter()).map(|(a, b)| a * b).sum();
let similarity = dot / (query_norm * emb_norm);
(i as u32, similarity)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities.truncate(k);
similarities.into_iter().collect()
}
#[cfg(target_arch = "x86_64")]
pub fn find_nearest_simd(&self, query: &[f32], k: usize) -> SmallVec<[(u32, f32); 8]> {
self.find_nearest(query, k)
}
}
#[derive(Clone)]
pub struct KNNFewShot<E: Embedder> {
config: KNNConfig,
embedder: E,
}
impl<E: Embedder> KNNFewShot<E> {
pub fn new(config: KNNConfig, embedder: E) -> Self {
Self { config, embedder }
}
pub fn config(&self) -> &KNNConfig {
&self.config
}
pub fn embedder(&self) -> &E {
&self.embedder
}
pub async fn build_index<'a>(
&self,
trainset: &ExampleSet<'a>,
buffer: &mut Vec<f32>,
) -> Result<EmbeddingIndex> {
let dim = self.embedder.dimension();
let mut index = EmbeddingIndex::new(dim);
buffer.clear();
buffer.resize(dim, 0.0);
for example in trainset.iter() {
let text = example.input_text();
self.embedder.embed(text, buffer).await?;
index.add(buffer);
}
Ok(index)
}
pub async fn select_demos<'a>(
&self,
query: StrView<'a>,
index: &EmbeddingIndex,
buffer: &mut Vec<f32>,
) -> Result<SmallVec<[u32; 8]>> {
let dim = self.embedder.dimension();
buffer.clear();
buffer.resize(dim, 0.0);
self.embedder.embed(query, buffer).await?;
let neighbors = index.find_nearest(buffer, self.config.k as usize);
Ok(neighbors.iter().map(|(idx, _)| *idx).collect())
}
}
pub struct KNNSelector {
config: KNNConfig,
index: EmbeddingIndex,
}
impl KNNSelector {
pub fn new(config: KNNConfig, index: EmbeddingIndex) -> Self {
Self { config, index }
}
pub fn select(&self, query_embedding: &[f32]) -> SmallVec<[u32; 8]> {
let neighbors = self
.index
.find_nearest(query_embedding, self.config.k as usize);
neighbors.iter().map(|(idx, _)| *idx).collect()
}
pub fn index(&self) -> &EmbeddingIndex {
&self.index
}
}
impl<E: Embedder> Optimizer for KNNFewShot<E> {
type Output<'a>
= OptimizationResult
where
E: 'a;
type OptimizeFut<'a>
= std::future::Ready<Result<OptimizationResult>>
where
E: 'a;
fn optimize<'a>(&'a self, trainset: &'a ExampleSet<'a>) -> Self::OptimizeFut<'a> {
let n = (self.config.k as usize).min(trainset.len());
let indices: SmallVec<[u32; 8]> = (0..n as u32).collect();
std::future::ready(Ok(OptimizationResult {
demo_indices: indices,
score: 0.0,
iterations: 0,
}))
}
fn name(&self) -> &'static str {
"KNNFewShot"
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockEmbedder {
dim: usize,
}
impl Embedder for MockEmbedder {
fn embed<'a>(
&'a self,
text: StrView<'a>,
output: &'a mut Vec<f32>,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
output.clear();
output.resize(self.dim, 0.0);
let bytes = text.as_str().as_bytes();
for (i, &b) in bytes.iter().enumerate() {
output[i % self.dim] += (b as f32) / 255.0;
}
let norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in output.iter_mut() {
*x /= norm;
}
}
Ok(())
})
}
fn dimension(&self) -> usize {
self.dim
}
}
#[test]
fn test_knn_creation() {
let embedder = MockEmbedder { dim: 64 };
let knn = KNNFewShot::new(KNNConfig::default(), embedder);
assert_eq!(knn.name(), "KNNFewShot");
assert_eq!(knn.config().k, 3);
}
#[test]
fn test_knn_config() {
let config = KNNConfig::new().with_k(5).with_weighted(false);
assert_eq!(config.k, 5);
assert!(!config.weighted);
}
#[test]
fn test_embedding_index() {
let mut index = EmbeddingIndex::new(4);
index.add(&[1.0, 0.0, 0.0, 0.0]);
index.add(&[0.0, 1.0, 0.0, 0.0]);
index.add(&[0.7, 0.7, 0.0, 0.0]);
assert_eq!(index.len(), 3);
let neighbors = index.find_nearest(&[0.9, 0.1, 0.0, 0.0], 2);
assert_eq!(neighbors.len(), 2);
assert_eq!(neighbors[0].0, 0);
}
#[test]
fn test_cosine_similarity() {
let mut index = EmbeddingIndex::new(3);
index.add(&[1.0, 0.0, 0.0]);
index.add(&[0.0, 1.0, 0.0]);
index.add(&[0.0, 0.0, 1.0]);
let neighbors = index.find_nearest(&[1.0, 0.0, 0.0], 3);
assert!(neighbors[0].1 > 0.99); assert!(neighbors[1].1.abs() < 0.01); }
}