anamnesis_core/
embedding.rs1use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26
27use crate::error::Result;
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(transparent)]
35pub struct ModelId(pub String);
36
37impl ModelId {
38 pub fn new(provider: &str, model: &str, version: u32) -> Self {
40 Self(format!("{provider}:{model}:{version}"))
41 }
42
43 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47}
48
49impl std::fmt::Display for ModelId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.write_str(&self.0)
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum EmbeddingTask {
59 Query,
61 Document,
63}
64
65#[async_trait]
67pub trait EmbeddingProvider: Send + Sync {
68 fn model_id(&self) -> ModelId;
70
71 fn dim(&self) -> u16;
73
74 async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
76 let mut out = self.embed_batch(&[text], EmbeddingTask::Query).await?;
77 out.pop()
78 .ok_or_else(|| crate::error::Error::Other("provider returned no vector".into()))
79 }
80
81 async fn embed_batch(&self, texts: &[&str], task: EmbeddingTask) -> Result<Vec<Vec<f32>>>;
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn model_id_format_is_stable() {
93 let id = ModelId::new("local", "multilingual-e5-small", 1);
94 assert_eq!(id.as_str(), "local:multilingual-e5-small:1");
95 }
96
97 #[test]
98 fn model_id_roundtrips_through_json() {
99 let id = ModelId::new("local", "bge-m3", 1);
100 let s = serde_json::to_string(&id).unwrap();
101 assert_eq!(s, "\"local:bge-m3:1\"");
103 let back: ModelId = serde_json::from_str(&s).unwrap();
104 assert_eq!(back, id);
105 }
106
107 struct FakeProvider {
110 id: ModelId,
111 dim: u16,
112 }
113
114 #[async_trait]
115 impl EmbeddingProvider for FakeProvider {
116 fn model_id(&self) -> ModelId {
117 self.id.clone()
118 }
119 fn dim(&self) -> u16 {
120 self.dim
121 }
122 async fn embed_batch(&self, texts: &[&str], _task: EmbeddingTask) -> Result<Vec<Vec<f32>>> {
123 Ok(texts
125 .iter()
126 .map(|t| {
127 let v = (t.len() as f32) / 100.0;
128 vec![v; self.dim as usize]
129 })
130 .collect())
131 }
132 }
133
134 #[tokio::test]
135 async fn default_embed_query_forwards_to_batch() {
136 let p = FakeProvider {
137 id: ModelId::new("test", "fake", 1),
138 dim: 4,
139 };
140 let v = p.embed_query("hello world").await.unwrap();
141 assert_eq!(v.len(), 4);
142 assert!((v[0] - 0.11).abs() < f32::EPSILON);
143 }
144
145 #[tokio::test]
146 async fn batch_returns_one_vector_per_input() {
147 let p = FakeProvider {
148 id: ModelId::new("test", "fake", 1),
149 dim: 4,
150 };
151 let v = p
152 .embed_batch(&["a", "bb", "ccc"], EmbeddingTask::Document)
153 .await
154 .unwrap();
155 assert_eq!(v.len(), 3);
156 assert!(v.iter().all(|row| row.len() == 4));
157 }
158
159 #[tokio::test]
160 async fn embed_query_propagates_empty_provider_result() {
161 struct Empty;
162 #[async_trait]
163 impl EmbeddingProvider for Empty {
164 fn model_id(&self) -> ModelId {
165 ModelId::new("test", "empty", 1)
166 }
167 fn dim(&self) -> u16 {
168 4
169 }
170 async fn embed_batch(
171 &self,
172 _texts: &[&str],
173 _task: EmbeddingTask,
174 ) -> Result<Vec<Vec<f32>>> {
175 Ok(vec![])
176 }
177 }
178 let err = Empty.embed_query("x").await.unwrap_err();
179 assert!(format!("{err}").contains("no vector"));
180 }
181}