use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EmbedderDefinition {
Precomputed,
CandleBert {
model: String,
},
CandleClip {
model: String,
},
Openai {
model: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_precomputed_serde_roundtrip() {
let json = r#"{"type": "precomputed"}"#;
let def: EmbedderDefinition = serde_json::from_str(json).unwrap();
assert!(matches!(def, EmbedderDefinition::Precomputed));
let serialized = serde_json::to_string(&def).unwrap();
let _roundtrip: EmbedderDefinition = serde_json::from_str(&serialized).unwrap();
}
#[test]
fn test_candle_bert_serde_roundtrip() {
let json = r#"{"type": "candle_bert", "model": "sentence-transformers/all-MiniLM-L6-v2"}"#;
let def: EmbedderDefinition = serde_json::from_str(json).unwrap();
if let EmbedderDefinition::CandleBert { model } = &def {
assert_eq!(model, "sentence-transformers/all-MiniLM-L6-v2");
} else {
panic!("Expected CandleBert");
}
let serialized = serde_json::to_string(&def).unwrap();
let _roundtrip: EmbedderDefinition = serde_json::from_str(&serialized).unwrap();
}
#[test]
fn test_candle_clip_serde_roundtrip() {
let json = r#"{"type": "candle_clip", "model": "openai/clip-vit-base-patch32"}"#;
let def: EmbedderDefinition = serde_json::from_str(json).unwrap();
assert!(matches!(def, EmbedderDefinition::CandleClip { .. }));
}
#[test]
fn test_openai_serde_roundtrip() {
let json = r#"{"type": "openai", "model": "text-embedding-3-small"}"#;
let def: EmbedderDefinition = serde_json::from_str(json).unwrap();
if let EmbedderDefinition::Openai { model } = &def {
assert_eq!(model, "text-embedding-3-small");
} else {
panic!("Expected Openai");
}
}
}