Skip to main content

heartbit_core/memory/
embedding.rs

1//! Embedding providers for semantic memory retrieval.
2
3#![allow(missing_docs)]
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use serde::Deserialize;
9
10use crate::auth::TenantScope;
11use crate::error::Error;
12
13use super::{Memory, MemoryEntry};
14
15/// Trait for generating text embeddings.
16#[allow(clippy::type_complexity)]
17pub trait EmbeddingProvider: Send + Sync {
18    fn embed(
19        &self,
20        texts: &[&str],
21    ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>;
22
23    fn dimension(&self) -> usize;
24}
25
26/// No-op embedding provider — returns empty results.
27/// Used when no embedding API is configured (graceful degradation).
28pub struct NoopEmbedding;
29
30impl EmbeddingProvider for NoopEmbedding {
31    fn embed(
32        &self,
33        texts: &[&str],
34    ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
35        let len = texts.len();
36        Box::pin(async move { Ok(vec![vec![]; len]) })
37    }
38
39    fn dimension(&self) -> usize {
40        0
41    }
42}
43
44/// OpenAI-compatible embedding provider.
45///
46/// Calls `POST /v1/embeddings` with the configured model.
47/// Works with OpenAI API and compatible endpoints.
48pub struct OpenAiEmbedding {
49    client: reqwest::Client,
50    api_key: String,
51    model: String,
52    base_url: String,
53    dimension: usize,
54}
55
56impl OpenAiEmbedding {
57    /// Create an `OpenAiEmbedding` provider.
58    ///
59    /// SECURITY (F-MEM-4): the HTTP client is hardened with
60    /// `redirect::Policy::none()`, `https_only(true)`,
61    /// `connect_timeout(10s)`, `timeout(60s)`, and `.no_proxy()`. Without
62    /// these, a slow-loris embedding endpoint wedges every `memory_store`,
63    /// and a redirect to a non-HTTPS host would leak the Bearer API key.
64    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
65        let model = model.into();
66        let dimension = match model.as_str() {
67            "text-embedding-3-small" => 1536,
68            "text-embedding-3-large" => 3072,
69            "text-embedding-ada-002" => 1536,
70            _ => 1536, // default
71        };
72        let client = reqwest::Client::builder()
73            .redirect(reqwest::redirect::Policy::none())
74            .https_only(true)
75            .no_proxy()
76            .connect_timeout(std::time::Duration::from_secs(10))
77            .timeout(std::time::Duration::from_secs(60))
78            .build()
79            .expect("failed to build hardened HTTPS client for OpenAiEmbedding");
80        Self {
81            client,
82            api_key: api_key.into(),
83            model,
84            base_url: "https://api.openai.com".into(),
85            dimension,
86        }
87    }
88
89    /// Override the base URL.
90    ///
91    /// SECURITY (F-MEM-4): the URL must be HTTPS — `https_only(true)` is
92    /// already set on the client, so a plaintext URL here will fail at
93    /// request time. For local non-secret endpoints (Ollama embeddings,
94    /// vLLM), build a separate provider with a custom client.
95    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
96        self.base_url = base_url.into();
97        self
98    }
99
100    pub fn with_dimension(mut self, dimension: usize) -> Self {
101        self.dimension = dimension;
102        self
103    }
104}
105
106#[derive(Deserialize)]
107struct EmbeddingResponse {
108    data: Vec<EmbeddingData>,
109}
110
111#[derive(Deserialize)]
112struct EmbeddingData {
113    embedding: Vec<f32>,
114}
115
116impl EmbeddingProvider for OpenAiEmbedding {
117    fn embed(
118        &self,
119        texts: &[&str],
120    ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
121        let input: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
122        Box::pin(async move {
123            if input.is_empty() {
124                return Ok(vec![]);
125            }
126
127            let body = serde_json::json!({
128                "model": self.model,
129                "input": input,
130            });
131
132            let resp = self
133                .client
134                .post(format!("{}/v1/embeddings", self.base_url))
135                .header("Authorization", format!("Bearer {}", self.api_key))
136                .header("Content-Type", "application/json")
137                .json(&body)
138                .send()
139                .await
140                .map_err(|e| Error::Memory(format!("embedding request failed: {e}")))?;
141
142            if !resp.status().is_success() {
143                let status = resp.status();
144                let text = resp.text().await.unwrap_or_else(|_| "unknown error".into());
145                return Err(Error::Memory(format!(
146                    "embedding API returned {status}: {text}"
147                )));
148            }
149
150            let response: EmbeddingResponse = resp
151                .json()
152                .await
153                .map_err(|e| Error::Memory(format!("failed to parse embedding response: {e}")))?;
154
155            Ok(response.data.into_iter().map(|d| d.embedding).collect())
156        })
157    }
158
159    fn dimension(&self) -> usize {
160        self.dimension
161    }
162}
163
164/// Decorator that generates embeddings on store and passes through to inner Memory.
165///
166/// When storing a `MemoryEntry` without an embedding, this wrapper generates
167/// one via the configured `EmbeddingProvider` before delegating to the inner store.
168/// All other operations pass through unchanged.
169pub struct EmbeddingMemory {
170    inner: Arc<dyn Memory>,
171    embedder: Arc<dyn EmbeddingProvider>,
172}
173
174impl EmbeddingMemory {
175    pub fn new(inner: Arc<dyn Memory>, embedder: Arc<dyn EmbeddingProvider>) -> Self {
176        Self { inner, embedder }
177    }
178}
179
180impl Memory for EmbeddingMemory {
181    fn store(
182        &self,
183        scope: &TenantScope,
184        entry: MemoryEntry,
185    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
186        let scope = scope.clone();
187        Box::pin(async move {
188            let mut entry = entry;
189            // Only generate embedding if not already present and embedder is real (dimension > 0)
190            if entry.embedding.is_none() && self.embedder.dimension() > 0 {
191                match self.embedder.embed(&[&entry.content]).await {
192                    Ok(mut embeddings) if !embeddings.is_empty() => {
193                        let emb = embeddings.swap_remove(0);
194                        if !emb.is_empty() {
195                            entry.embedding = Some(emb);
196                        }
197                    }
198                    Ok(_) => {} // empty result, skip
199                    Err(e) => {
200                        // Log but don't fail — embedding is optional
201                        tracing::warn!("failed to generate embedding for memory {}: {e}", entry.id);
202                    }
203                }
204            }
205            self.inner.store(&scope, entry).await
206        })
207    }
208
209    fn recall(
210        &self,
211        scope: &TenantScope,
212        query: super::MemoryQuery,
213    ) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryEntry>, Error>> + Send + '_>> {
214        let scope = scope.clone();
215        Box::pin(async move {
216            let mut query = query;
217            // Generate query embedding for hybrid retrieval when text is present
218            // and no embedding was already provided.
219            if query.query_embedding.is_none()
220                && query.text.is_some()
221                && self.embedder.dimension() > 0
222            {
223                let text = query.text.as_deref().unwrap_or_default();
224                match self.embedder.embed(&[text]).await {
225                    Ok(mut embeddings) if !embeddings.is_empty() => {
226                        let emb = embeddings.swap_remove(0);
227                        if !emb.is_empty() {
228                            query.query_embedding = Some(emb);
229                        }
230                    }
231                    Ok(_) => {}
232                    Err(e) => {
233                        // Log but don't fail — fall back to BM25-only
234                        tracing::warn!("failed to generate query embedding: {e}");
235                    }
236                }
237            }
238            self.inner.recall(&scope, query).await
239        })
240    }
241
242    fn update(
243        &self,
244        scope: &TenantScope,
245        id: &str,
246        content: String,
247    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
248        let scope = scope.clone();
249        let id = id.to_string();
250        Box::pin(async move { self.inner.update(&scope, &id, content).await })
251    }
252
253    fn forget(
254        &self,
255        scope: &TenantScope,
256        id: &str,
257    ) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>> {
258        let scope = scope.clone();
259        let id = id.to_string();
260        Box::pin(async move { self.inner.forget(&scope, &id).await })
261    }
262
263    fn add_link(
264        &self,
265        scope: &TenantScope,
266        id: &str,
267        related_id: &str,
268    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
269        let scope = scope.clone();
270        let id = id.to_string();
271        let related_id = related_id.to_string();
272        Box::pin(async move { self.inner.add_link(&scope, &id, &related_id).await })
273    }
274
275    fn prune(
276        &self,
277        scope: &TenantScope,
278        min_strength: f64,
279        min_age: chrono::Duration,
280        agent_prefix: Option<&str>,
281    ) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
282        let scope = scope.clone();
283        let agent_prefix = agent_prefix.map(String::from);
284        Box::pin(async move {
285            self.inner
286                .prune(&scope, min_strength, min_age, agent_prefix.as_deref())
287                .await
288        })
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::memory::in_memory::InMemoryStore;
296    use crate::memory::{Confidentiality, MemoryEntry, MemoryQuery, MemoryType};
297    use chrono::Utc;
298
299    fn test_scope() -> TenantScope {
300        TenantScope::default()
301    }
302
303    fn make_entry(id: &str, content: &str) -> MemoryEntry {
304        MemoryEntry {
305            id: id.into(),
306            agent: "test".into(),
307            content: content.into(),
308            category: "fact".into(),
309            tags: vec![],
310            created_at: Utc::now(),
311            last_accessed: Utc::now(),
312            access_count: 0,
313            importance: 5,
314            memory_type: MemoryType::default(),
315            keywords: vec![],
316            summary: None,
317            strength: 1.0,
318            related_ids: vec![],
319            source_ids: vec![],
320            embedding: None,
321            confidentiality: Confidentiality::default(),
322            author_user_id: None,
323            author_tenant_id: None,
324        }
325    }
326
327    #[test]
328    fn noop_embedding_returns_empty() {
329        let noop = NoopEmbedding;
330        assert_eq!(noop.dimension(), 0);
331        let rt = tokio::runtime::Builder::new_current_thread()
332            .build()
333            .unwrap();
334        let result = rt.block_on(noop.embed(&["hello", "world"])).unwrap();
335        assert_eq!(result.len(), 2);
336        assert!(result[0].is_empty());
337        assert!(result[1].is_empty());
338    }
339
340    #[test]
341    fn embedding_provider_is_object_safe() {
342        fn _accepts_dyn(_p: &dyn EmbeddingProvider) {}
343    }
344
345    #[test]
346    fn embedding_memory_is_send_sync() {
347        fn assert_send_sync<T: Send + Sync>() {}
348        assert_send_sync::<EmbeddingMemory>();
349    }
350
351    #[tokio::test]
352    async fn noop_embedding_skips_embedding_on_store() {
353        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
354        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
355        let em = EmbeddingMemory::new(store.clone(), embedder);
356
357        em.store(&test_scope(), make_entry("m1", "test content"))
358            .await
359            .unwrap();
360
361        let results = store
362            .recall(
363                &test_scope(),
364                MemoryQuery {
365                    limit: 10,
366                    ..Default::default()
367                },
368            )
369            .await
370            .unwrap();
371        assert_eq!(results.len(), 1);
372        assert!(results[0].embedding.is_none());
373    }
374
375    /// Fake embedding provider for testing that returns deterministic vectors.
376    struct FakeEmbedding;
377
378    impl EmbeddingProvider for FakeEmbedding {
379        fn embed(
380            &self,
381            texts: &[&str],
382        ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>> {
383            let results: Vec<Vec<f32>> = texts
384                .iter()
385                .map(|t| {
386                    // Simple deterministic embedding: first 4 bytes as f32 values
387                    let bytes = t.as_bytes();
388                    vec![
389                        bytes.first().copied().unwrap_or(0) as f32 / 255.0,
390                        bytes.get(1).copied().unwrap_or(0) as f32 / 255.0,
391                        bytes.get(2).copied().unwrap_or(0) as f32 / 255.0,
392                    ]
393                })
394                .collect();
395            Box::pin(async move { Ok(results) })
396        }
397
398        fn dimension(&self) -> usize {
399            3
400        }
401    }
402
403    #[tokio::test]
404    async fn embedding_memory_generates_embedding_on_store() {
405        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
406        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(FakeEmbedding);
407        let em = EmbeddingMemory::new(store.clone(), embedder);
408
409        em.store(&test_scope(), make_entry("m1", "hello"))
410            .await
411            .unwrap();
412
413        let results = store
414            .recall(
415                &test_scope(),
416                MemoryQuery {
417                    limit: 10,
418                    ..Default::default()
419                },
420            )
421            .await
422            .unwrap();
423        assert_eq!(results.len(), 1);
424        let emb = results[0]
425            .embedding
426            .as_ref()
427            .expect("embedding should be set");
428        assert_eq!(emb.len(), 3);
429    }
430
431    #[tokio::test]
432    async fn embedding_memory_preserves_existing_embedding() {
433        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
434        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(FakeEmbedding);
435        let em = EmbeddingMemory::new(store.clone(), embedder);
436
437        let mut entry = make_entry("m1", "hello");
438        entry.embedding = Some(vec![9.0, 8.0, 7.0]);
439        em.store(&test_scope(), entry).await.unwrap();
440
441        let results = store
442            .recall(
443                &test_scope(),
444                MemoryQuery {
445                    limit: 10,
446                    ..Default::default()
447                },
448            )
449            .await
450            .unwrap();
451        let emb = results[0].embedding.as_ref().unwrap();
452        // Should keep original, not overwrite with FakeEmbedding output
453        assert!((emb[0] - 9.0).abs() < f32::EPSILON);
454    }
455
456    #[tokio::test]
457    async fn embedding_memory_delegates_recall() {
458        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
459        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
460        let em = EmbeddingMemory::new(store.clone(), embedder);
461
462        store
463            .store(&test_scope(), make_entry("m1", "test"))
464            .await
465            .unwrap();
466        let results = em
467            .recall(
468                &test_scope(),
469                MemoryQuery {
470                    limit: 10,
471                    ..Default::default()
472                },
473            )
474            .await
475            .unwrap();
476        assert_eq!(results.len(), 1);
477        assert_eq!(results[0].id, "m1");
478    }
479
480    #[tokio::test]
481    async fn embedding_memory_generates_query_embedding_on_recall() {
482        // When EmbeddingMemory wraps a store and query has text,
483        // it should generate a query embedding for hybrid retrieval.
484        use std::sync::atomic::{AtomicBool, Ordering};
485
486        // Tracking embedding provider that records whether embed() was called
487        struct TrackingEmbedding {
488            called: Arc<AtomicBool>,
489        }
490
491        impl EmbeddingProvider for TrackingEmbedding {
492            fn embed(
493                &self,
494                _texts: &[&str],
495            ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>
496            {
497                self.called.store(true, Ordering::SeqCst);
498                Box::pin(async { Ok(vec![vec![0.5, 0.5, 0.5]]) })
499            }
500
501            fn dimension(&self) -> usize {
502                3
503            }
504        }
505
506        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
507        let called = Arc::new(AtomicBool::new(false));
508        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(TrackingEmbedding {
509            called: called.clone(),
510        });
511        let em = EmbeddingMemory::new(store.clone(), embedder);
512
513        store
514            .store(&test_scope(), make_entry("m1", "hello world"))
515            .await
516            .unwrap();
517
518        // Recall with text query should trigger embedding generation
519        let _results = em
520            .recall(
521                &test_scope(),
522                MemoryQuery {
523                    text: Some("hello".into()),
524                    limit: 10,
525                    ..Default::default()
526                },
527            )
528            .await
529            .unwrap();
530
531        assert!(
532            called.load(Ordering::SeqCst),
533            "embed() should have been called for query text"
534        );
535    }
536
537    #[tokio::test]
538    async fn embedding_memory_skips_query_embedding_without_text() {
539        use std::sync::atomic::{AtomicBool, Ordering};
540
541        struct TrackingEmbedding {
542            called: Arc<AtomicBool>,
543        }
544
545        impl EmbeddingProvider for TrackingEmbedding {
546            fn embed(
547                &self,
548                _texts: &[&str],
549            ) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, Error>> + Send + '_>>
550            {
551                self.called.store(true, Ordering::SeqCst);
552                Box::pin(async { Ok(vec![vec![0.5, 0.5, 0.5]]) })
553            }
554
555            fn dimension(&self) -> usize {
556                3
557            }
558        }
559
560        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
561        let called = Arc::new(AtomicBool::new(false));
562        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(TrackingEmbedding {
563            called: called.clone(),
564        });
565        let em = EmbeddingMemory::new(store.clone(), embedder);
566
567        store
568            .store(&test_scope(), make_entry("m1", "hello world"))
569            .await
570            .unwrap();
571
572        // Recall WITHOUT text query should NOT generate embedding
573        let _results = em
574            .recall(
575                &test_scope(),
576                MemoryQuery {
577                    limit: 10,
578                    ..Default::default()
579                },
580            )
581            .await
582            .unwrap();
583
584        assert!(
585            !called.load(Ordering::SeqCst),
586            "embed() should NOT be called when no text query"
587        );
588    }
589
590    #[tokio::test]
591    async fn embedding_memory_delegates_forget() {
592        let store: Arc<dyn Memory> = Arc::new(InMemoryStore::new());
593        let embedder: Arc<dyn EmbeddingProvider> = Arc::new(NoopEmbedding);
594        let em = EmbeddingMemory::new(store.clone(), embedder);
595
596        store
597            .store(&test_scope(), make_entry("m1", "test"))
598            .await
599            .unwrap();
600        let removed = em.forget(&test_scope(), "m1").await.unwrap();
601        assert!(removed);
602
603        let results = store
604            .recall(
605                &test_scope(),
606                MemoryQuery {
607                    limit: 10,
608                    ..Default::default()
609                },
610            )
611            .await
612            .unwrap();
613        assert!(results.is_empty());
614    }
615}