use serde::{Deserialize, Serialize};
use super::provider_config::ProviderConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: ProviderConfig,
pub similarity_threshold: f32,
pub batch_size: usize,
pub cache_embeddings: bool,
pub timeout_seconds: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
provider: ProviderConfig::default(),
similarity_threshold: 0.7,
batch_size: 32,
cache_embeddings: true,
timeout_seconds: 30,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_embedding_config() {
let config = EmbeddingConfig::default();
match &config.provider {
ProviderConfig::Local(local_config) => {
assert_eq!(
local_config.model_name,
"sentence-transformers/all-MiniLM-L6-v2"
);
assert_eq!(local_config.embedding_dimension, 384);
}
_ => panic!("Expected Local provider in default config"),
}
assert_eq!(config.similarity_threshold, 0.7);
assert_eq!(config.batch_size, 32);
assert!(config.cache_embeddings);
assert_eq!(config.timeout_seconds, 30);
}
#[test]
fn test_embedding_config_with_openai() {
let config = EmbeddingConfig {
provider: ProviderConfig::openai_3_small(),
similarity_threshold: 0.8,
batch_size: 64,
cache_embeddings: false,
timeout_seconds: 60,
};
assert_eq!(config.similarity_threshold, 0.8);
assert_eq!(config.batch_size, 64);
assert!(!config.cache_embeddings);
assert_eq!(config.timeout_seconds, 60);
}
}