1#![allow(missing_docs)]
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use serde::Deserialize;
9
10use crate::auth::TenantScope;
11use crate::error::Error;
12
13use super::{Memory, MemoryEntry};
14
15#[allow(clippy::type_complexity)]
17pub trait EmbeddingProvider: Send + Sync {
18 fn embed(
19 &self,
20 texts: &[&str],
21 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>;
22
23 fn dimension(&self) -> usize;
24}
25
26pub struct NoopEmbedding;
29
30impl EmbeddingProvider for NoopEmbedding {
31 fn embed(
32 &self,
33 texts: &[&str],
34 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
35 let len = texts.len();
36 Box::pin(async move { Ok(vec![vec![]; len]) })
37 }
38
39 fn dimension(&self) -> usize {
40 0
41 }
42}
43
44pub struct OpenAiEmbedding {
49 client: reqwest::Client,
50 api_key: String,
51 model: String,
52 base_url: String,
53 dimension: usize,
54}
55
56impl OpenAiEmbedding {
57 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
65 let model = model.into();
66 let dimension = match model.as_str() {
67 "text-embedding-3-small" => 1536,
68 "text-embedding-3-large" => 3072,
69 "text-embedding-ada-002" => 1536,
70 _ => 1536, };
72 let client = reqwest::Client::builder()
73 .redirect(reqwest::redirect::Policy::none())
74 .https_only(true)
75 .no_proxy()
76 .connect_timeout(std::time::Duration::from_secs(10))
77 .timeout(std::time::Duration::from_secs(60))
78 .build()
79 .expect("failed to build hardened HTTPS client for OpenAiEmbedding");
80 Self {
81 client,
82 api_key: api_key.into(),
83 model,
84 base_url: "https://api.openai.com".into(),
85 dimension,
86 }
87 }
88
89 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
96 self.base_url = base_url.into();
97 self
98 }
99
100 pub fn with_dimension(mut self, dimension: usize) -> Self {
101 self.dimension = dimension;
102 self
103 }
104}
105
106#[derive(Deserialize)]
107struct EmbeddingResponse {
108 data: Vec<EmbeddingData>,
109}
110
111#[derive(Deserialize)]
112struct EmbeddingData {
113 embedding: Vec<f32>,
114}
115
116impl EmbeddingProvider for OpenAiEmbedding {
117 fn embed(
118 &self,
119 texts: &[&str],
120 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
121 let input: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
122 Box::pin(async move {
123 if input.is_empty() {
124 return Ok(vec![]);
125 }
126
127 let body = serde_json::json!({
128 "model": self.model,
129 "input": input,
130 });
131
132 let resp = self
133 .client
134 .post(format!("{}/v1/embeddings", self.base_url))
135 .header("Authorization", format!("Bearer {}", self.api_key))
136 .header("Content-Type", "application/json")
137 .json(&body)
138 .send()
139 .await
140 .map_err(|e| Error::Memory(format!("embedding request failed: {e}")))?;
141
142 if !resp.status().is_success() {
143 let status = resp.status();
144 let text = resp.text().await.unwrap_or_else(|_| "unknown error".into());
145 return Err(Error::Memory(format!(
146 "embedding API returned {status}: {text}"
147 )));
148 }
149
150 let response: EmbeddingResponse = resp
151 .json()
152 .await
153 .map_err(|e| Error::Memory(format!("failed to parse embedding response: {e}")))?;
154
155 Ok(response.data.into_iter().map(|d| d.embedding).collect())
156 })
157 }
158
159 fn dimension(&self) -> usize {
160 self.dimension
161 }
162}
163
164pub struct EmbeddingMemory {
170 inner: Arc<dyn Memory>,
171 embedder: Arc<dyn EmbeddingProvider>,
172}
173
174impl EmbeddingMemory {
175 pub fn new(inner: Arc<dyn Memory>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
176 Self { inner, embedder }
177 }
178}
179
180impl Memory for EmbeddingMemory {
181 fn store(
182 &self,
183 scope: &TenantScope,
184 entry: MemoryEntry,
185 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
186 let scope = scope.clone();
187 Box::pin(async move {
188 let mut entry = entry;
189 if entry.embedding.is_none() && self.embedder.dimension() > 0 {
191 match self.embedder.embed(&[&entry.content]).await {
192 Ok(mut embeddings) if !embeddings.is_empty() => {
193 let emb = embeddings.swap_remove(0);
194 if !emb.is_empty() {
195 entry.embedding = Some(emb);
196 }
197 }
198 Ok(_) => {} Err(e) => {
200 tracing::warn!("failed to generate embedding for memory {}: {e}", entry.id);
202 }
203 }
204 }
205 self.inner.store(&scope, entry).await
206 })
207 }
208
209 fn recall(
210 &self,
211 scope: &TenantScope,
212 query: super::MemoryQuery,
213 ) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryEntry>, Error>> + Send + '_>> {
214 let scope = scope.clone();
215 Box::pin(async move {
216 let mut query = query;
217 if query.query_embedding.is_none()
220 && query.text.is_some()
221 && self.embedder.dimension() > 0
222 {
223 let text = query.text.as_deref().unwrap_or_default();
224 match self.embedder.embed(&[text]).await {
225 Ok(mut embeddings) if !embeddings.is_empty() => {
226 let emb = embeddings.swap_remove(0);
227 if !emb.is_empty() {
228 query.query_embedding = Some(emb);
229 }
230 }
231 Ok(_) => {}
232 Err(e) => {
233 tracing::warn!("failed to generate query embedding: {e}");
235 }
236 }
237 }
238 self.inner.recall(&scope, query).await
239 })
240 }
241
242 fn update(
243 &self,
244 scope: &TenantScope,
245 id: &str,
246 content: String,
247 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
248 let scope = scope.clone();
249 let id = id.to_string();
250 Box::pin(async move { self.inner.update(&scope, &id, content).await })
251 }
252
253 fn forget(
254 &self,
255 scope: &TenantScope,
256 id: &str,
257 ) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>> {
258 let scope = scope.clone();
259 let id = id.to_string();
260 Box::pin(async move { self.inner.forget(&scope, &id).await })
261 }
262
263 fn add_link(
264 &self,
265 scope: &TenantScope,
266 id: &str,
267 related_id: &str,
268 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
269 let scope = scope.clone();
270 let id = id.to_string();
271 let related_id = related_id.to_string();
272 Box::pin(async move { self.inner.add_link(&scope, &id, &related_id).await })
273 }
274
275 fn prune(
276 &self,
277 scope: &TenantScope,
278 min_strength: f64,
279 min_age: chrono::Duration,
280 agent_prefix: Option<&str>,
281 ) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
282 let scope = scope.clone();
283 let agent_prefix = agent_prefix.map(String::from);
284 Box::pin(async move {
285 self.inner
286 .prune(&scope, min_strength, min_age, agent_prefix.as_deref())
287 .await
288 })
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::memory::in_memory::InMemoryStore;
296 use crate::memory::{Confidentiality, MemoryEntry, MemoryQuery, MemoryType};
297 use chrono::Utc;
298
299 fn test_scope() -> TenantScope {
300 TenantScope::default()
301 }
302
303 fn make_entry(id: &str, content: &str) -> MemoryEntry {
304 MemoryEntry {
305 id: id.into(),
306 agent: "test".into(),
307 content: content.into(),
308 category: "fact".into(),
309 tags: vec![],
310 created_at: Utc::now(),
311 last_accessed: Utc::now(),
312 access_count: 0,
313 importance: 5,
314 memory_type: MemoryType::default(),
315 keywords: vec![],
316 summary: None,
317 strength: 1.0,
318 related_ids: vec![],
319 source_ids: vec![],
320 embedding: None,
321 confidentiality: Confidentiality::default(),
322 author_user_id: None,
323 author_tenant_id: None,
324 }
325 }
326
327 #[test]
328 fn noop_embedding_returns_empty() {
329 let noop = NoopEmbedding;
330 assert_eq!(noop.dimension(), 0);
331 let rt = tokio::runtime::Builder::new_current_thread()
332 .build()
333 .unwrap();
334 let result = rt.block_on(noop.embed(&["hello", "world"])).unwrap();
335 assert_eq!(result.len(), 2);
336 assert!(result[0].is_empty());
337 assert!(result[1].is_empty());
338 }
339
340 #[test]
341 fn embedding_provider_is_object_safe() {
342 fn _accepts_dyn(_p: &dyn EmbeddingProvider) {}
343 }
344
345 #[test]
346 fn embedding_memory_is_send_sync() {
347 fn assert_send_sync<T: Send + Sync>() {}
348 assert_send_sync::<EmbeddingMemory>();
349 }
350
351 #[tokio::test]
352 async fn noop_embedding_skips_embedding_on_store() {
353 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
354 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
355 let em = EmbeddingMemory::new(store.clone(), embedder);
356
357 em.store(&test_scope(), make_entry("m1", "test content"))
358 .await
359 .unwrap();
360
361 let results = store
362 .recall(
363 &test_scope(),
364 MemoryQuery {
365 limit: 10,
366 ..Default::default()
367 },
368 )
369 .await
370 .unwrap();
371 assert_eq!(results.len(), 1);
372 assert!(results[0].embedding.is_none());
373 }
374
375 struct FakeEmbedding;
377
378 impl EmbeddingProvider for FakeEmbedding {
379 fn embed(
380 &self,
381 texts: &[&str],
382 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
383 let results: Vec<Vec<f32>> = texts
384 .iter()
385 .map(|t| {
386 let bytes = t.as_bytes();
388 vec![
389 bytes.first().copied().unwrap_or(0) as f32 / 255.0,
390 bytes.get(1).copied().unwrap_or(0) as f32 / 255.0,
391 bytes.get(2).copied().unwrap_or(0) as f32 / 255.0,
392 ]
393 })
394 .collect();
395 Box::pin(async move { Ok(results) })
396 }
397
398 fn dimension(&self) -> usize {
399 3
400 }
401 }
402
403 #[tokio::test]
404 async fn embedding_memory_generates_embedding_on_store() {
405 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
406 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(FakeEmbedding);
407 let em = EmbeddingMemory::new(store.clone(), embedder);
408
409 em.store(&test_scope(), make_entry("m1", "hello"))
410 .await
411 .unwrap();
412
413 let results = store
414 .recall(
415 &test_scope(),
416 MemoryQuery {
417 limit: 10,
418 ..Default::default()
419 },
420 )
421 .await
422 .unwrap();
423 assert_eq!(results.len(), 1);
424 let emb = results[0]
425 .embedding
426 .as_ref()
427 .expect("embedding should be set");
428 assert_eq!(emb.len(), 3);
429 }
430
431 #[tokio::test]
432 async fn embedding_memory_preserves_existing_embedding() {
433 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
434 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(FakeEmbedding);
435 let em = EmbeddingMemory::new(store.clone(), embedder);
436
437 let mut entry = make_entry("m1", "hello");
438 entry.embedding = Some(vec![9.0, 8.0, 7.0]);
439 em.store(&test_scope(), entry).await.unwrap();
440
441 let results = store
442 .recall(
443 &test_scope(),
444 MemoryQuery {
445 limit: 10,
446 ..Default::default()
447 },
448 )
449 .await
450 .unwrap();
451 let emb = results[0].embedding.as_ref().unwrap();
452 assert!((emb[0] - 9.0).abs() < f32::EPSILON);
454 }
455
456 #[tokio::test]
457 async fn embedding_memory_delegates_recall() {
458 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
459 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
460 let em = EmbeddingMemory::new(store.clone(), embedder);
461
462 store
463 .store(&test_scope(), make_entry("m1", "test"))
464 .await
465 .unwrap();
466 let results = em
467 .recall(
468 &test_scope(),
469 MemoryQuery {
470 limit: 10,
471 ..Default::default()
472 },
473 )
474 .await
475 .unwrap();
476 assert_eq!(results.len(), 1);
477 assert_eq!(results[0].id, "m1");
478 }
479
480 #[tokio::test]
481 async fn embedding_memory_generates_query_embedding_on_recall() {
482 use std::sync::atomic::{AtomicBool, Ordering};
485
486 struct TrackingEmbedding {
488 called: Arc<AtomicBool>,
489 }
490
491 impl EmbeddingProvider for TrackingEmbedding {
492 fn embed(
493 &self,
494 _texts: &[&str],
495 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>
496 {
497 self.called.store(true, Ordering::SeqCst);
498 Box::pin(async { Ok(vec![vec![0.5, 0.5, 0.5]]) })
499 }
500
501 fn dimension(&self) -> usize {
502 3
503 }
504 }
505
506 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
507 let called = Arc::new(AtomicBool::new(false));
508 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(TrackingEmbedding {
509 called: called.clone(),
510 });
511 let em = EmbeddingMemory::new(store.clone(), embedder);
512
513 store
514 .store(&test_scope(), make_entry("m1", "hello world"))
515 .await
516 .unwrap();
517
518 let _results = em
520 .recall(
521 &test_scope(),
522 MemoryQuery {
523 text: Some("hello".into()),
524 limit: 10,
525 ..Default::default()
526 },
527 )
528 .await
529 .unwrap();
530
531 assert!(
532 called.load(Ordering::SeqCst),
533 "embed() should have been called for query text"
534 );
535 }
536
537 #[tokio::test]
538 async fn embedding_memory_skips_query_embedding_without_text() {
539 use std::sync::atomic::{AtomicBool, Ordering};
540
541 struct TrackingEmbedding {
542 called: Arc<AtomicBool>,
543 }
544
545 impl EmbeddingProvider for TrackingEmbedding {
546 fn embed(
547 &self,
548 _texts: &[&str],
549 ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>
550 {
551 self.called.store(true, Ordering::SeqCst);
552 Box::pin(async { Ok(vec![vec![0.5, 0.5, 0.5]]) })
553 }
554
555 fn dimension(&self) -> usize {
556 3
557 }
558 }
559
560 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
561 let called = Arc::new(AtomicBool::new(false));
562 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(TrackingEmbedding {
563 called: called.clone(),
564 });
565 let em = EmbeddingMemory::new(store.clone(), embedder);
566
567 store
568 .store(&test_scope(), make_entry("m1", "hello world"))
569 .await
570 .unwrap();
571
572 let _results = em
574 .recall(
575 &test_scope(),
576 MemoryQuery {
577 limit: 10,
578 ..Default::default()
579 },
580 )
581 .await
582 .unwrap();
583
584 assert!(
585 !called.load(Ordering::SeqCst),
586 "embed() should NOT be called when no text query"
587 );
588 }
589
590 #[tokio::test]
591 async fn embedding_memory_delegates_forget() {
592 let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
593 let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
594 let em = EmbeddingMemory::new(store.clone(), embedder);
595
596 store
597 .store(&test_scope(), make_entry("m1", "test"))
598 .await
599 .unwrap();
600 let removed = em.forget(&test_scope(), "m1").await.unwrap();
601 assert!(removed);
602
603 let results = store
604 .recall(
605 &test_scope(),
606 MemoryQuery {
607 limit: 10,
608 ..Default::default()
609 },
610 )
611 .await
612 .unwrap();
613 assert!(results.is_empty());
614 }
615}