Skip to main content

enact_memory/
embeddings.rs

1//! Embedding providers for semantic memory search
2
3use async_trait::async_trait;
4
5/// Trait for embedding providers — convert text to vectors
6#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8    /// Provider name
9    fn name(&self) -> &str;
10
11    /// Embedding dimensions
12    fn dimensions(&self) -> usize;
13
14    /// Embed a batch of texts into vectors
15    async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
16
17    /// Embed a single text
18    async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
19        let mut results = self.embed(&[text]).await?;
20        results
21            .pop()
22            .ok_or_else(|| anyhow::anyhow!("Empty embedding result"))
23    }
24}
25
26// ── Noop provider (keyword-only fallback) ────────────────────
27
28/// No-op embedding provider that returns empty vectors
29pub struct NoopEmbedding;
30
31#[async_trait]
32impl EmbeddingProvider for NoopEmbedding {
33    fn name(&self) -> &str {
34        "none"
35    }
36
37    fn dimensions(&self) -> usize {
38        0
39    }
40
41    async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
42        Ok(Vec::new())
43    }
44}
45
46// ── OpenAI-compatible embedding provider ─────────────────────
47
48/// OpenAI-compatible embedding provider
49pub struct OpenAiEmbedding {
50    base_url: String,
51    api_key: String,
52    model: String,
53    dims: usize,
54    client: reqwest::Client,
55}
56
57impl OpenAiEmbedding {
58    pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
59        Self {
60            base_url: base_url.trim_end_matches('/').to_string(),
61            api_key: api_key.to_string(),
62            model: model.to_string(),
63            dims,
64            client: reqwest::Client::new(),
65        }
66    }
67
68    fn embeddings_url(&self) -> String {
69        let url = reqwest::Url::parse(&self.base_url).ok();
70        let has_embeddings = url
71            .as_ref()
72            .map(|u| u.path().trim_end_matches('/').ends_with("/embeddings"))
73            .unwrap_or(false);
74
75        if has_embeddings {
76            return self.base_url.clone();
77        }
78
79        let has_path = url
80            .as_ref()
81            .map(|u| {
82                let path = u.path().trim_end_matches('/');
83                !path.is_empty() && path != "/"
84            })
85            .unwrap_or(false);
86
87        if has_path {
88            format!("{}/embeddings", self.base_url)
89        } else {
90            format!("{}/v1/embeddings", self.base_url)
91        }
92    }
93}
94
95#[async_trait]
96impl EmbeddingProvider for OpenAiEmbedding {
97    fn name(&self) -> &str {
98        "openai"
99    }
100
101    fn dimensions(&self) -> usize {
102        self.dims
103    }
104
105    async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
106        if texts.is_empty() {
107            return Ok(Vec::new());
108        }
109
110        let body = serde_json::json!({
111            "model": self.model,
112            "input": texts,
113        });
114
115        let resp = self
116            .client
117            .post(self.embeddings_url())
118            .header("Authorization", format!("Bearer {}", self.api_key))
119            .header("Content-Type", "application/json")
120            .json(&body)
121            .send()
122            .await?;
123
124        if !resp.status().is_success() {
125            let status = resp.status();
126            let text = resp.text().await.unwrap_or_default();
127            anyhow::bail!("Embedding API error {status}: {text}");
128        }
129
130        let json: serde_json::Value = resp.json().await?;
131        let data = json
132            .get("data")
133            .and_then(|d| d.as_array())
134            .ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?;
135
136        let mut embeddings = Vec::with_capacity(data.len());
137        for item in data {
138            let embedding = item
139                .get("embedding")
140                .and_then(|e| e.as_array())
141                .ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?;
142
143            #[allow(clippy::cast_possible_truncation)]
144            let vec: Vec<f32> = embedding
145                .iter()
146                .filter_map(|v| v.as_f64().map(|f| f as f32))
147                .collect();
148
149            embeddings.push(vec);
150        }
151
152        Ok(embeddings)
153    }
154}
155
156// ── Factory ──────────────────────────────────────────────────
157
158/// Create an embedding provider
159pub fn create_embedding_provider(
160    provider: &str,
161    api_key: Option<&str>,
162    model: &str,
163    dims: usize,
164) -> Box<dyn EmbeddingProvider> {
165    match provider {
166        "openai" => {
167            let key = api_key.unwrap_or("");
168            Box::new(OpenAiEmbedding::new(
169                "https://api.openai.com",
170                key,
171                model,
172                dims,
173            ))
174        }
175        name if name.starts_with("custom:") => {
176            let base_url = name.strip_prefix("custom:").unwrap_or("");
177            let key = api_key.unwrap_or("");
178            Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
179        }
180        _ => Box::new(NoopEmbedding),
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn noop_name() {
190        let p = NoopEmbedding;
191        assert_eq!(p.name(), "none");
192        assert_eq!(p.dimensions(), 0);
193    }
194
195    #[tokio::test]
196    async fn noop_embed_returns_empty() {
197        let p = NoopEmbedding;
198        let result = p.embed(&["hello"]).await.unwrap();
199        assert!(result.is_empty());
200    }
201
202    #[test]
203    fn factory_none() {
204        let p = create_embedding_provider("none", None, "model", 1536);
205        assert_eq!(p.name(), "none");
206    }
207
208    #[test]
209    fn factory_openai() {
210        let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
211        assert_eq!(p.name(), "openai");
212        assert_eq!(p.dimensions(), 1536);
213    }
214
215    #[test]
216    fn factory_custom_url() {
217        let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
218        assert_eq!(p.name(), "openai");
219        assert_eq!(p.dimensions(), 768);
220    }
221}