use crate::error::Result;
use crate::thinktool::modules::{
ThinkToolContext, ThinkToolModule, ThinkToolModuleConfig, ThinkToolOutput,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Example {
pub id: String,
pub query: String,
pub reasoning: String,
pub answer: String,
pub embedding: Vec<f32>,
pub category: Option<String>,
pub difficulty: f32,
pub quality_score: f32,
pub metadata: HashMap<String, String>,
}
impl Example {
pub fn new(
query: impl Into<String>,
reasoning: impl Into<String>,
answer: impl Into<String>,
embedding: Vec<f32>,
) -> Self {
let query = query.into();
let id = format!("ex_{:x}", hash_string(&query));
Self {
id,
query,
reasoning: reasoning.into(),
answer: answer.into(),
embedding,
category: None,
difficulty: 0.5,
quality_score: 1.0,
metadata: HashMap::new(),
}
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = id.into();
self
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_difficulty(mut self, difficulty: f32) -> Self {
self.difficulty = difficulty.clamp(0.0, 1.0);
self
}
pub fn with_quality(mut self, quality: f32) -> Self {
self.quality_score = quality.clamp(0.0, 1.0);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn to_prompt_text(&self) -> String {
format!(
"Q: {}\n\nThinking: {}\n\nA: {}",
self.query, self.reasoning, self.answer
)
}
pub fn to_prompt_with_template(&self, template: &str) -> String {
template
.replace("{query}", &self.query)
.replace("{reasoning}", &self.reasoning)
.replace("{answer}", &self.answer)
.replace("{category}", self.category.as_deref().unwrap_or("general"))
}
}
fn hash_string(s: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExampleDatabase {
examples: Vec<Example>,
dimension: usize,
category_index: HashMap<String, Vec<usize>>,
min_quality: f32,
}
impl ExampleDatabase {
pub fn new(dimension: usize) -> Self {
Self {
examples: Vec::new(),
dimension,
category_index: HashMap::new(),
min_quality: 0.0,
}
}
pub fn with_min_quality(mut self, min_quality: f32) -> Self {
self.min_quality = min_quality.clamp(0.0, 1.0);
self
}
pub fn add_example(&mut self, example: Example) -> Result<()> {
if example.embedding.len() != self.dimension {
return Err(crate::error::Error::validation(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
example.embedding.len()
)));
}
let idx = self.examples.len();
if let Some(ref cat) = example.category {
self.category_index
.entry(cat.clone())
.or_default()
.push(idx);
}
self.examples.push(example);
Ok(())
}
pub fn add_examples(&mut self, examples: Vec<Example>) -> Result<()> {
for example in examples {
self.add_example(example)?;
}
Ok(())
}
pub fn get_by_id(&self, id: &str) -> Option<&Example> {
self.examples.iter().find(|e| e.id == id)
}
pub fn get_by_category(&self, category: &str) -> Vec<&Example> {
self.category_index
.get(category)
.map(|indices| indices.iter().map(|&i| &self.examples[i]).collect())
.unwrap_or_default()
}
pub fn len(&self) -> usize {
self.examples.len()
}
pub fn is_empty(&self) -> bool {
self.examples.is_empty()
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn categories(&self) -> Vec<&str> {
self.category_index.keys().map(|s| s.as_str()).collect()
}
pub fn find_similar(
&self,
query_embedding: &[f32],
k: usize,
category_filter: Option<&str>,
) -> Vec<SimilarExample> {
if query_embedding.len() != self.dimension {
return Vec::new();
}
let candidates: Vec<usize> = match category_filter {
Some(cat) => self.category_index.get(cat).cloned().unwrap_or_default(),
None => (0..self.examples.len()).collect(),
};
let mut scored: Vec<SimilarExample> = candidates
.into_iter()
.map(|idx| {
let example = &self.examples[idx];
let similarity = cosine_similarity(query_embedding, &example.embedding);
SimilarExample {
example: example.clone(),
similarity,
}
})
.filter(|se| se.example.quality_score >= self.min_quality)
.collect();
scored.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(k);
scored
}
pub fn find_similar_diverse(
&self,
query_embedding: &[f32],
k: usize,
diversity_weight: f32,
category_filter: Option<&str>,
) -> Vec<SimilarExample> {
if query_embedding.len() != self.dimension || k == 0 {
return Vec::new();
}
let diversity_weight = diversity_weight.clamp(0.0, 1.0);
let candidates: Vec<(usize, f32)> = {
let indices: Vec<usize> = match category_filter {
Some(cat) => self.category_index.get(cat).cloned().unwrap_or_default(),
None => (0..self.examples.len()).collect(),
};
indices
.into_iter()
.filter(|&idx| self.examples[idx].quality_score >= self.min_quality)
.map(|idx| {
let sim = cosine_similarity(query_embedding, &self.examples[idx].embedding);
(idx, sim)
})
.collect()
};
if candidates.is_empty() {
return Vec::new();
}
let mut selected: Vec<usize> = Vec::with_capacity(k);
let mut remaining: Vec<(usize, f32)> = candidates;
remaining.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if let Some((idx, _)) = remaining.first() {
selected.push(*idx);
remaining.remove(0);
}
while selected.len() < k && !remaining.is_empty() {
let mut best_score = f32::NEG_INFINITY;
let mut best_idx = 0;
for (i, &(idx, sim)) in remaining.iter().enumerate() {
let max_sim_selected = selected
.iter()
.map(|&sel_idx| {
cosine_similarity(
&self.examples[idx].embedding,
&self.examples[sel_idx].embedding,
)
})
.fold(f32::NEG_INFINITY, f32::max);
let mmr_score =
(1.0 - diversity_weight) * sim - diversity_weight * max_sim_selected;
if mmr_score > best_score {
best_score = mmr_score;
best_idx = i;
}
}
let (idx, _) = remaining.remove(best_idx);
selected.push(idx);
}
selected
.into_iter()
.map(|idx| {
let example = &self.examples[idx];
SimilarExample {
example: example.clone(),
similarity: cosine_similarity(query_embedding, &example.embedding),
}
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarExample {
pub example: Example,
pub similarity: f32,
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a.sqrt()) * (norm_b.sqrt());
if denom > 0.0 {
dot / denom
} else {
0.0
}
}
pub fn normalize_vector(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ESCoTConfig {
pub k: usize,
pub min_similarity: f32,
pub diversity_weight: f32,
pub category_filter: Option<String>,
pub example_template: String,
pub prompt_template: String,
pub include_similarity_scores: bool,
pub max_example_tokens: usize,
}
impl Default for ESCoTConfig {
fn default() -> Self {
Self {
k: 3,
min_similarity: 0.0,
diversity_weight: 0.3,
category_filter: None,
example_template: "Question: {query}\n\nLet me think step by step:\n{reasoning}\n\nAnswer: {answer}".to_string(),
prompt_template: "Here are some examples of step-by-step reasoning:\n\n{examples}\n\n---\n\nNow solve this problem:\n\nQuestion: {query}\n\nLet me think step by step:".to_string(),
include_similarity_scores: true,
max_example_tokens: 0,
}
}
}
impl ESCoTConfig {
pub fn with_k(k: usize) -> Self {
Self {
k,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct ESCoT {
config: ThinkToolModuleConfig,
escot_config: ESCoTConfig,
database: ExampleDatabase,
}
impl ESCoT {
pub fn new(database: ExampleDatabase) -> Self {
Self {
config: ThinkToolModuleConfig::new(
"ESCoT",
"1.0.0",
"Example Selection Chain-of-Thought - embedding-based few-shot example selection",
),
escot_config: ESCoTConfig::default(),
database,
}
}
pub fn builder() -> ESCoTBuilder {
ESCoTBuilder::new()
}
pub fn escot_config(&self) -> &ESCoTConfig {
&self.escot_config
}
pub fn database(&self) -> &ExampleDatabase {
&self.database
}
pub fn select_examples(&self, query_embedding: &[f32]) -> Vec<SimilarExample> {
let mut selected = if self.escot_config.diversity_weight > 0.0 {
self.database.find_similar_diverse(
query_embedding,
self.escot_config.k,
self.escot_config.diversity_weight,
self.escot_config.category_filter.as_deref(),
)
} else {
self.database.find_similar(
query_embedding,
self.escot_config.k,
self.escot_config.category_filter.as_deref(),
)
};
selected.retain(|se| se.similarity >= self.escot_config.min_similarity);
selected
}
pub fn build_prompt(&self, query: &str, examples: &[SimilarExample]) -> String {
let formatted_examples: Vec<String> = examples
.iter()
.map(|se| {
se.example
.to_prompt_with_template(&self.escot_config.example_template)
})
.collect();
let examples_text = formatted_examples.join("\n\n---\n\n");
self.escot_config
.prompt_template
.replace("{examples}", &examples_text)
.replace("{query}", query)
}
pub fn execute_with_embedding(
&self,
query: &str,
query_embedding: &[f32],
) -> Result<ESCoTResult> {
let selected = self.select_examples(query_embedding);
let prompt = self.build_prompt(query, &selected);
let confidence = if selected.is_empty() {
0.5 } else {
let avg_sim: f32 =
selected.iter().map(|se| se.similarity).sum::<f32>() / selected.len() as f32;
0.5 + (avg_sim * 0.4) };
Ok(ESCoTResult {
prompt,
selected_examples: selected,
confidence,
k: self.escot_config.k,
})
}
}
pub struct ESCoTBuilder {
escot_config: ESCoTConfig,
database: Option<ExampleDatabase>,
module_config: Option<ThinkToolModuleConfig>,
}
impl ESCoTBuilder {
pub fn new() -> Self {
Self {
escot_config: ESCoTConfig::default(),
database: None,
module_config: None,
}
}
pub fn with_database(mut self, database: ExampleDatabase) -> Self {
self.database = Some(database);
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.escot_config.k = k;
self
}
pub fn with_min_similarity(mut self, threshold: f32) -> Self {
self.escot_config.min_similarity = threshold.clamp(0.0, 1.0);
self
}
pub fn with_diversity(mut self, weight: f32) -> Self {
self.escot_config.diversity_weight = weight.clamp(0.0, 1.0);
self
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.escot_config.category_filter = Some(category.into());
self
}
pub fn with_example_template(mut self, template: impl Into<String>) -> Self {
self.escot_config.example_template = template.into();
self
}
pub fn with_prompt_template(mut self, template: impl Into<String>) -> Self {
self.escot_config.prompt_template = template.into();
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.escot_config.max_example_tokens = max_tokens;
self
}
pub fn with_config(mut self, config: ESCoTConfig) -> Self {
self.escot_config = config;
self
}
pub fn build(self) -> ESCoT {
let database = self.database.unwrap_or_else(|| ExampleDatabase::new(1536));
ESCoT {
config: self.module_config.unwrap_or_else(|| {
ThinkToolModuleConfig::new(
"ESCoT",
"1.0.0",
"Example Selection Chain-of-Thought - embedding-based few-shot example selection",
)
}),
escot_config: self.escot_config,
database,
}
}
}
impl Default for ESCoTBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ESCoTResult {
pub prompt: String,
pub selected_examples: Vec<SimilarExample>,
pub confidence: f32,
pub k: usize,
}
impl ESCoTResult {
pub fn num_examples(&self) -> usize {
self.selected_examples.len()
}
pub fn has_examples(&self) -> bool {
!self.selected_examples.is_empty()
}
pub fn avg_similarity(&self) -> f32 {
if self.selected_examples.is_empty() {
0.0
} else {
self.selected_examples
.iter()
.map(|se| se.similarity)
.sum::<f32>()
/ self.selected_examples.len() as f32
}
}
pub fn min_similarity(&self) -> f32 {
self.selected_examples
.iter()
.map(|se| se.similarity)
.fold(f32::INFINITY, f32::min)
}
}
impl ThinkToolModule for ESCoT {
fn config(&self) -> &ThinkToolModuleConfig {
&self.config
}
fn execute(&self, context: &ThinkToolContext) -> Result<ThinkToolOutput> {
let mock_embedding = generate_mock_embedding(&context.query, self.database.dimension());
let result = self.execute_with_embedding(&context.query, &mock_embedding)?;
let output = serde_json::json!({
"prompt": result.prompt,
"num_examples": result.num_examples(),
"avg_similarity": result.avg_similarity(),
"k_requested": result.k,
"selected_examples": result.selected_examples.iter().map(|se| {
serde_json::json!({
"id": se.example.id,
"query": se.example.query,
"similarity": se.similarity,
"category": se.example.category,
})
}).collect::<Vec<_>>(),
});
Ok(ThinkToolOutput::new(
"ESCoT",
result.confidence as f64,
output,
))
}
}
fn generate_mock_embedding(query: &str, dimension: usize) -> Vec<f32> {
use std::hash::{Hash, Hasher};
let mut embedding = vec![0.0f32; dimension];
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
let seed = hasher.finish();
let mut state = seed;
for item in embedding.iter_mut().take(dimension) {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*item = ((state >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
}
normalize_vector(&mut embedding);
embedding
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_examples(dimension: usize) -> Vec<Example> {
vec![
Example::new(
"What is 2+2?",
"Let me add these numbers. 2 plus 2 equals 4.",
"4",
generate_mock_embedding("What is 2+2?", dimension),
).with_category("math"),
Example::new(
"What is 3+3?",
"Adding 3 and 3 together. 3 plus 3 equals 6.",
"6",
generate_mock_embedding("What is 3+3?", dimension),
).with_category("math"),
Example::new(
"What is the capital of France?",
"France is a European country. Its capital city is Paris.",
"Paris",
generate_mock_embedding("What is the capital of France?", dimension),
).with_category("geography"),
Example::new(
"Is it valid: All A are B, X is A, therefore X is B?",
"This is a categorical syllogism. The premises establish that A is a subset of B, and X belongs to A. Therefore X must also belong to B. This is valid.",
"Yes, valid (Barbara syllogism)",
generate_mock_embedding("Is it valid: All A are B, X is A, therefore X is B?", dimension),
).with_category("logic"),
Example::new(
"What is 10 divided by 2?",
"Division: 10 ÷ 2. How many times does 2 go into 10? 5 times.",
"5",
generate_mock_embedding("What is 10 divided by 2?", dimension),
).with_category("math"),
]
}
#[test]
fn test_example_creation() {
let example = Example::new(
"Test query",
"Test reasoning",
"Test answer",
vec![0.1, 0.2, 0.3],
);
assert!(example.id.starts_with("ex_"));
assert_eq!(example.query, "Test query");
assert_eq!(example.reasoning, "Test reasoning");
assert_eq!(example.answer, "Test answer");
assert_eq!(example.difficulty, 0.5);
assert_eq!(example.quality_score, 1.0);
}
#[test]
fn test_example_with_metadata() {
let example = Example::new("Q", "R", "A", vec![0.1])
.with_category("math")
.with_difficulty(0.8)
.with_quality(0.95)
.with_metadata("source", "textbook");
assert_eq!(example.category, Some("math".to_string()));
assert!((example.difficulty - 0.8).abs() < 0.001);
assert!((example.quality_score - 0.95).abs() < 0.001);
assert_eq!(
example.metadata.get("source"),
Some(&"textbook".to_string())
);
}
#[test]
fn test_example_to_prompt() {
let example = Example::new("What is 2+2?", "Adding numbers: 2+2=4", "4", vec![0.1]);
let prompt = example.to_prompt_text();
assert!(prompt.contains("What is 2+2?"));
assert!(prompt.contains("Adding numbers: 2+2=4"));
assert!(prompt.contains("A: 4"));
}
#[test]
fn test_database_creation() {
let db = ExampleDatabase::new(128);
assert_eq!(db.dimension(), 128);
assert!(db.is_empty());
assert_eq!(db.len(), 0);
}
#[test]
fn test_database_add_example() {
let mut db = ExampleDatabase::new(3);
let example = Example::new("Q", "R", "A", vec![0.1, 0.2, 0.3]).with_category("test");
db.add_example(example).unwrap();
assert_eq!(db.len(), 1);
assert!(!db.is_empty());
assert!(db.categories().contains(&"test"));
}
#[test]
fn test_database_dimension_mismatch() {
let mut db = ExampleDatabase::new(3);
let example = Example::new("Q", "R", "A", vec![0.1, 0.2]);
let result = db.add_example(example);
assert!(result.is_err());
}
#[test]
fn test_find_similar() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
let examples = create_test_examples(dimension);
db.add_examples(examples).unwrap();
let query_embedding = generate_mock_embedding("What is 5+5?", dimension);
let results = db.find_similar(&query_embedding, 2, None);
assert_eq!(results.len(), 2);
assert!(results[0].similarity >= results[1].similarity);
}
#[test]
fn test_find_similar_with_category_filter() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
let examples = create_test_examples(dimension);
db.add_examples(examples).unwrap();
let query_embedding = generate_mock_embedding("arithmetic question", dimension);
let results = db.find_similar(&query_embedding, 10, Some("math"));
assert_eq!(results.len(), 3);
for result in &results {
assert_eq!(result.example.category, Some("math".to_string()));
}
}
#[test]
fn test_find_similar_diverse() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
let examples = create_test_examples(dimension);
db.add_examples(examples).unwrap();
let query_embedding = generate_mock_embedding("question", dimension);
let diverse_results = db.find_similar_diverse(&query_embedding, 3, 0.7, None);
assert_eq!(diverse_results.len(), 3);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let c = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 0.001);
assert!((cosine_similarity(&a, &c) - 1.0).abs() < 0.001);
assert_eq!(cosine_similarity(&a, &[1.0, 0.0]), 0.0);
}
#[test]
fn test_normalize_vector() {
let mut v = vec![3.0, 4.0];
normalize_vector(&mut v);
assert!((v[0] - 0.6).abs() < 0.001);
assert!((v[1] - 0.8).abs() < 0.001);
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((mag - 1.0).abs() < 0.001);
}
#[test]
fn test_escot_config_default() {
let config = ESCoTConfig::default();
assert_eq!(config.k, 3);
assert_eq!(config.min_similarity, 0.0);
assert!((config.diversity_weight - 0.3).abs() < 0.001);
assert!(config.category_filter.is_none());
}
#[test]
fn test_escot_builder() {
let dimension = 64;
let db = ExampleDatabase::new(dimension);
let escot = ESCoT::builder()
.with_database(db)
.with_k(5)
.with_min_similarity(0.5)
.with_diversity(0.4)
.with_category("math")
.build();
assert_eq!(escot.escot_config().k, 5);
assert!((escot.escot_config().min_similarity - 0.5).abs() < 0.001);
assert!((escot.escot_config().diversity_weight - 0.4).abs() < 0.001);
assert_eq!(
escot.escot_config().category_filter,
Some("math".to_string())
);
}
#[test]
fn test_escot_select_examples() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let escot = ESCoT::builder().with_database(db).with_k(2).build();
let query_embedding = generate_mock_embedding("What is 4+4?", dimension);
let selected = escot.select_examples(&query_embedding);
assert_eq!(selected.len(), 2);
}
#[test]
fn test_escot_build_prompt() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let escot = ESCoT::builder().with_database(db).with_k(2).build();
let query_embedding = generate_mock_embedding("What is 4+4?", dimension);
let selected = escot.select_examples(&query_embedding);
let prompt = escot.build_prompt("What is 4+4?", &selected);
assert!(prompt.contains("What is 4+4?"));
assert!(prompt.contains("Here are some examples"));
assert!(prompt.contains("Let me think step by step"));
}
#[test]
fn test_escot_execute_with_embedding() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let escot = ESCoT::builder().with_database(db).with_k(3).build();
let query = "What is 7+7?";
let query_embedding = generate_mock_embedding(query, dimension);
let result = escot
.execute_with_embedding(query, &query_embedding)
.unwrap();
assert!(result.has_examples());
assert!(result.num_examples() <= 3);
assert!(result.confidence > 0.0);
assert!(!result.prompt.is_empty());
}
#[test]
fn test_escot_result_metrics() {
let result = ESCoTResult {
prompt: "test".to_string(),
selected_examples: vec![
SimilarExample {
example: Example::new("Q1", "R1", "A1", vec![0.1]),
similarity: 0.9,
},
SimilarExample {
example: Example::new("Q2", "R2", "A2", vec![0.2]),
similarity: 0.7,
},
],
confidence: 0.8,
k: 3,
};
assert_eq!(result.num_examples(), 2);
assert!(result.has_examples());
assert!((result.avg_similarity() - 0.8).abs() < 0.001);
assert!((result.min_similarity() - 0.7).abs() < 0.001);
}
#[test]
fn test_escot_execute_thinktool() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let escot = ESCoT::builder().with_database(db).with_k(2).build();
let context = ThinkToolContext::new("What is 8+8?");
let output = escot.execute(&context).unwrap();
assert_eq!(output.module, "ESCoT");
assert!(output.confidence > 0.0);
assert!(output.get("prompt").is_some());
assert!(output.get("num_examples").is_some());
}
#[test]
fn test_escot_module_name() {
let escot = ESCoT::builder().build();
assert_eq!(escot.name(), "ESCoT");
assert_eq!(escot.version(), "1.0.0");
}
#[test]
fn test_escot_min_similarity_filter() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let escot = ESCoT::builder()
.with_database(db)
.with_k(10)
.with_min_similarity(0.99) .build();
let query_embedding = generate_mock_embedding("random unrelated query", dimension);
let selected = escot.select_examples(&query_embedding);
for se in &selected {
assert!(se.similarity >= 0.99);
}
}
#[test]
fn test_escot_empty_database() {
let dimension = 64;
let db = ExampleDatabase::new(dimension);
let escot = ESCoT::builder().with_database(db).with_k(3).build();
let query_embedding = generate_mock_embedding("test query", dimension);
let result = escot
.execute_with_embedding("test query", &query_embedding)
.unwrap();
assert!(!result.has_examples());
assert_eq!(result.num_examples(), 0);
assert!((result.confidence - 0.5).abs() < 0.001); }
#[test]
fn test_database_get_by_id() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
let example = Example::new("Test", "R", "A", generate_mock_embedding("Test", dimension))
.with_id("custom_id_123");
db.add_example(example).unwrap();
let found = db.get_by_id("custom_id_123");
assert!(found.is_some());
assert_eq!(found.unwrap().query, "Test");
let not_found = db.get_by_id("nonexistent");
assert!(not_found.is_none());
}
#[test]
fn test_database_get_by_category() {
let dimension = 64;
let mut db = ExampleDatabase::new(dimension);
db.add_examples(create_test_examples(dimension)).unwrap();
let math_examples = db.get_by_category("math");
assert_eq!(math_examples.len(), 3);
let logic_examples = db.get_by_category("logic");
assert_eq!(logic_examples.len(), 1);
let nonexistent = db.get_by_category("nonexistent");
assert!(nonexistent.is_empty());
}
#[test]
fn test_generate_mock_embedding_deterministic() {
let dim = 64;
let query = "test query";
let emb1 = generate_mock_embedding(query, dim);
let emb2 = generate_mock_embedding(query, dim);
assert_eq!(emb1, emb2);
let emb3 = generate_mock_embedding("different query", dim);
assert_ne!(emb1, emb3);
}
#[test]
fn test_generate_mock_embedding_normalized() {
let emb = generate_mock_embedding("test", 128);
let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
}
}