enact_memory/
embeddings.rs1use async_trait::async_trait;
4
5#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8 fn name(&self) -> &str;
10
11 fn dimensions(&self) -> usize;
13
14 async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
16
17 async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
19 let mut results = self.embed(&[text]).await?;
20 results
21 .pop()
22 .ok_or_else(|| anyhow::anyhow!("Empty embedding result"))
23 }
24}
25
26pub struct NoopEmbedding;
30
31#[async_trait]
32impl EmbeddingProvider for NoopEmbedding {
33 fn name(&self) -> &str {
34 "none"
35 }
36
37 fn dimensions(&self) -> usize {
38 0
39 }
40
41 async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
42 Ok(Vec::new())
43 }
44}
45
46pub struct OpenAiEmbedding {
50 base_url: String,
51 api_key: String,
52 model: String,
53 dims: usize,
54 client: reqwest::Client,
55}
56
57impl OpenAiEmbedding {
58 pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
59 Self {
60 base_url: base_url.trim_end_matches('/').to_string(),
61 api_key: api_key.to_string(),
62 model: model.to_string(),
63 dims,
64 client: reqwest::Client::new(),
65 }
66 }
67
68 fn embeddings_url(&self) -> String {
69 let url = reqwest::Url::parse(&self.base_url).ok();
70 let has_embeddings = url
71 .as_ref()
72 .map(|u| u.path().trim_end_matches('/').ends_with("/embeddings"))
73 .unwrap_or(false);
74
75 if has_embeddings {
76 return self.base_url.clone();
77 }
78
79 let has_path = url
80 .as_ref()
81 .map(|u| {
82 let path = u.path().trim_end_matches('/');
83 !path.is_empty() && path != "/"
84 })
85 .unwrap_or(false);
86
87 if has_path {
88 format!("{}/embeddings", self.base_url)
89 } else {
90 format!("{}/v1/embeddings", self.base_url)
91 }
92 }
93}
94
95#[async_trait]
96impl EmbeddingProvider for OpenAiEmbedding {
97 fn name(&self) -> &str {
98 "openai"
99 }
100
101 fn dimensions(&self) -> usize {
102 self.dims
103 }
104
105 async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
106 if texts.is_empty() {
107 return Ok(Vec::new());
108 }
109
110 let body = serde_json::json!({
111 "model": self.model,
112 "input": texts,
113 });
114
115 let resp = self
116 .client
117 .post(self.embeddings_url())
118 .header("Authorization", format!("Bearer {}", self.api_key))
119 .header("Content-Type", "application/json")
120 .json(&body)
121 .send()
122 .await?;
123
124 if !resp.status().is_success() {
125 let status = resp.status();
126 let text = resp.text().await.unwrap_or_default();
127 anyhow::bail!("Embedding API error {status}: {text}");
128 }
129
130 let json: serde_json::Value = resp.json().await?;
131 let data = json
132 .get("data")
133 .and_then(|d| d.as_array())
134 .ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?;
135
136 let mut embeddings = Vec::with_capacity(data.len());
137 for item in data {
138 let embedding = item
139 .get("embedding")
140 .and_then(|e| e.as_array())
141 .ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?;
142
143 #[allow(clippy::cast_possible_truncation)]
144 let vec: Vec<f32> = embedding
145 .iter()
146 .filter_map(|v| v.as_f64().map(|f| f as f32))
147 .collect();
148
149 embeddings.push(vec);
150 }
151
152 Ok(embeddings)
153 }
154}
155
156pub fn create_embedding_provider(
160 provider: &str,
161 api_key: Option<&str>,
162 model: &str,
163 dims: usize,
164) -> Box<dyn EmbeddingProvider> {
165 match provider {
166 "openai" => {
167 let key = api_key.unwrap_or("");
168 Box::new(OpenAiEmbedding::new(
169 "https://api.openai.com",
170 key,
171 model,
172 dims,
173 ))
174 }
175 name if name.starts_with("custom:") => {
176 let base_url = name.strip_prefix("custom:").unwrap_or("");
177 let key = api_key.unwrap_or("");
178 Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
179 }
180 _ => Box::new(NoopEmbedding),
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn noop_name() {
190 let p = NoopEmbedding;
191 assert_eq!(p.name(), "none");
192 assert_eq!(p.dimensions(), 0);
193 }
194
195 #[tokio::test]
196 async fn noop_embed_returns_empty() {
197 let p = NoopEmbedding;
198 let result = p.embed(&["hello"]).await.unwrap();
199 assert!(result.is_empty());
200 }
201
202 #[test]
203 fn factory_none() {
204 let p = create_embedding_provider("none", None, "model", 1536);
205 assert_eq!(p.name(), "none");
206 }
207
208 #[test]
209 fn factory_openai() {
210 let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
211 assert_eq!(p.name(), "openai");
212 assert_eq!(p.dimensions(), 1536);
213 }
214
215 #[test]
216 fn factory_custom_url() {
217 let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
218 assert_eq!(p.name(), "openai");
219 assert_eq!(p.dimensions(), 768);
220 }
221}