Skip to main content

engram/embedding/
queue.rs

1//! Async embedding queue with batch processing (RML-873)
2//!
3//! Embeddings are computed in the background to avoid blocking writes.
4//! The queue supports batching for efficient API usage.
5
6use async_channel::{bounded, Receiver, Sender};
7use chrono::Utc;
8use parking_lot::Mutex;
9use rusqlite::{params, Connection};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::time::interval;
13
14use super::{create_embedder, Embedder};
15use crate::error::{EngramError, Result};
16use crate::types::{EmbeddingConfig, EmbeddingState, EmbeddingStatus, MemoryId};
17
18/// Message for the embedding queue
19#[derive(Debug)]
20pub struct EmbeddingRequest {
21    pub memory_id: MemoryId,
22    pub content: String,
23}
24
25/// Embedding queue for async processing
26pub struct EmbeddingQueue {
27    sender: Sender<EmbeddingRequest>,
28    receiver: Receiver<EmbeddingRequest>,
29    batch_size: usize,
30}
31
32impl EmbeddingQueue {
33    /// Create a new embedding queue
34    pub fn new(batch_size: usize) -> Self {
35        let (sender, receiver) = bounded(10000); // Buffer up to 10k requests
36        Self {
37            sender,
38            receiver,
39            batch_size,
40        }
41    }
42
43    /// Queue a memory for embedding
44    pub async fn queue(&self, memory_id: MemoryId, content: String) -> Result<()> {
45        self.sender
46            .send(EmbeddingRequest { memory_id, content })
47            .await
48            .map_err(|e| EngramError::Embedding(format!("Queue send error: {}", e)))?;
49        Ok(())
50    }
51
52    /// Queue a memory (blocking version for sync contexts)
53    pub fn queue_blocking(&self, memory_id: MemoryId, content: String) -> Result<()> {
54        self.sender
55            .send_blocking(EmbeddingRequest { memory_id, content })
56            .map_err(|e| EngramError::Embedding(format!("Queue send error: {}", e)))?;
57        Ok(())
58    }
59
60    /// Get queue length
61    pub fn len(&self) -> usize {
62        self.receiver.len()
63    }
64
65    /// Check if queue is empty
66    pub fn is_empty(&self) -> bool {
67        self.receiver.is_empty()
68    }
69
70    /// Get receiver for worker
71    pub fn receiver(&self) -> Receiver<EmbeddingRequest> {
72        self.receiver.clone()
73    }
74}
75
76impl Clone for EmbeddingQueue {
77    fn clone(&self) -> Self {
78        Self {
79            sender: self.sender.clone(),
80            receiver: self.receiver.clone(),
81            batch_size: self.batch_size,
82        }
83    }
84}
85
86/// Background worker for processing embeddings
87pub struct EmbeddingWorker {
88    embedder: Arc<dyn Embedder>,
89    queue: EmbeddingQueue,
90    conn: Arc<Mutex<Connection>>,
91    batch_size: usize,
92    batch_timeout: Duration,
93}
94
95impl EmbeddingWorker {
96    /// Create a new embedding worker
97    pub fn new(
98        config: EmbeddingConfig,
99        queue: EmbeddingQueue,
100        conn: Arc<Mutex<Connection>>,
101    ) -> Result<Self> {
102        let embedder = create_embedder(&config)?;
103        let batch_size = config.batch_size;
104
105        Ok(Self {
106            embedder,
107            queue,
108            conn,
109            batch_size,
110            batch_timeout: Duration::from_secs(5),
111        })
112    }
113
114    /// Run the worker (call in a spawned task)
115    pub async fn run(&self) {
116        let receiver = self.queue.receiver();
117        let mut batch: Vec<EmbeddingRequest> = Vec::with_capacity(self.batch_size);
118        let mut batch_timer = interval(self.batch_timeout);
119
120        loop {
121            tokio::select! {
122                // Receive new request
123                Ok(request) = receiver.recv() => {
124                    batch.push(request);
125
126                    // Process if batch is full
127                    if batch.len() >= self.batch_size {
128                        self.process_batch(&mut batch).await;
129                    }
130                }
131
132                // Process on timeout even if batch isn't full
133                _ = batch_timer.tick() => {
134                    if !batch.is_empty() {
135                        self.process_batch(&mut batch).await;
136                    }
137                }
138            }
139        }
140    }
141
142    /// Process a batch of embedding requests
143    async fn process_batch(&self, batch: &mut Vec<EmbeddingRequest>) {
144        if batch.is_empty() {
145            return;
146        }
147
148        let memory_ids: Vec<MemoryId> = batch.iter().map(|r| r.memory_id).collect();
149        let contents: Vec<&str> = batch.iter().map(|r| r.content.as_str()).collect();
150
151        // Mark as processing
152        {
153            let conn = self.conn.lock();
154            let now = Utc::now().to_rfc3339();
155            for &id in &memory_ids {
156                let _ = conn.execute(
157                    "UPDATE embedding_queue SET status = 'processing', started_at = ? WHERE memory_id = ?",
158                    params![now, id],
159                );
160            }
161        }
162
163        // Generate embeddings
164        match self.embedder.embed_batch(&contents) {
165            Ok(embeddings) => {
166                let conn = self.conn.lock();
167                let now = Utc::now().to_rfc3339();
168                let model = self.embedder.model_name();
169                let dimensions = self.embedder.dimensions();
170
171                for (id, embedding) in memory_ids.iter().zip(embeddings.iter()) {
172                    // Serialize embedding to bytes
173                    let embedding_bytes: Vec<u8> =
174                        embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
175
176                    // Store embedding
177                    let _ = conn.execute(
178                        "INSERT OR REPLACE INTO embeddings (memory_id, embedding, model, dimensions, created_at)
179                         VALUES (?, ?, ?, ?, ?)",
180                        params![id, embedding_bytes, model, dimensions, now],
181                    );
182
183                    // Update memory
184                    let _ = conn.execute(
185                        "UPDATE memories SET has_embedding = 1 WHERE id = ?",
186                        params![id],
187                    );
188
189                    // Mark as complete
190                    let _ = conn.execute(
191                        "UPDATE embedding_queue SET status = 'complete', completed_at = ? WHERE memory_id = ?",
192                        params![now, id],
193                    );
194                }
195
196                tracing::info!("Processed {} embeddings", memory_ids.len());
197            }
198            Err(e) => {
199                let conn = self.conn.lock();
200                let error_time = Utc::now().to_rfc3339();
201                let error_msg = e.to_string();
202                let _ = error_time; // suppress unused warning
203
204                for &id in &memory_ids {
205                    let _ = conn.execute(
206                        "UPDATE embedding_queue SET status = 'failed', error = ?, retry_count = retry_count + 1 WHERE memory_id = ?",
207                        params![error_msg, id],
208                    );
209                }
210
211                tracing::error!("Embedding batch failed: {}", e);
212            }
213        }
214
215        batch.clear();
216    }
217}
218
219/// Get embedding status for a memory
220pub fn get_embedding_status(conn: &Connection, memory_id: MemoryId) -> Result<EmbeddingStatus> {
221    let row = conn.query_row(
222        "SELECT status, queued_at, completed_at, error FROM embedding_queue WHERE memory_id = ?",
223        params![memory_id],
224        |row| {
225            let status_str: String = row.get(0)?;
226            let queued_at: Option<String> = row.get(1)?;
227            let completed_at: Option<String> = row.get(2)?;
228            let error: Option<String> = row.get(3)?;
229
230            let status = match status_str.as_str() {
231                "pending" => EmbeddingState::Pending,
232                "processing" => EmbeddingState::Processing,
233                "complete" => EmbeddingState::Complete,
234                "failed" => EmbeddingState::Failed,
235                _ => EmbeddingState::Pending,
236            };
237
238            Ok(EmbeddingStatus {
239                memory_id,
240                status,
241                queued_at: queued_at.and_then(|s| {
242                    chrono::DateTime::parse_from_rfc3339(&s)
243                        .map(|dt| dt.with_timezone(&Utc))
244                        .ok()
245                }),
246                completed_at: completed_at.and_then(|s| {
247                    chrono::DateTime::parse_from_rfc3339(&s)
248                        .map(|dt| dt.with_timezone(&Utc))
249                        .ok()
250                }),
251                error,
252            })
253        },
254    );
255
256    match row {
257        Ok(status) => Ok(status),
258        Err(rusqlite::Error::QueryReturnedNoRows) => {
259            // Check if memory has embedding
260            let has_embedding: bool = conn
261                .query_row(
262                    "SELECT has_embedding FROM memories WHERE id = ?",
263                    params![memory_id],
264                    |row| row.get(0),
265                )
266                .unwrap_or(false);
267
268            Ok(EmbeddingStatus {
269                memory_id,
270                status: if has_embedding {
271                    EmbeddingState::Complete
272                } else {
273                    EmbeddingState::Pending
274                },
275                queued_at: None,
276                completed_at: None,
277                error: None,
278            })
279        }
280        Err(e) => Err(EngramError::Database(e)),
281    }
282}
283
284/// Get embedding for a memory
285pub fn get_embedding(conn: &Connection, memory_id: MemoryId) -> Result<Option<Vec<f32>>> {
286    let row = conn.query_row(
287        "SELECT embedding, dimensions FROM embeddings WHERE memory_id = ?",
288        params![memory_id],
289        |row| {
290            let bytes: Vec<u8> = row.get(0)?;
291            let dimensions: usize = row.get(1)?;
292            Ok((bytes, dimensions))
293        },
294    );
295
296    match row {
297        Ok((bytes, dimensions)) => {
298            let expected_len = dimensions.checked_mul(4).ok_or_else(|| {
299                EngramError::InvalidInput("Embedding dimensions too large".to_string())
300            })?;
301            if bytes.len() != expected_len {
302                return Err(EngramError::InvalidInput(format!(
303                    "Embedding byte length {} does not match dimensions {}",
304                    bytes.len(),
305                    dimensions
306                )));
307            }
308
309            // Deserialize from bytes
310            let mut embedding = Vec::with_capacity(dimensions);
311            for chunk in bytes.chunks_exact(4) {
312                let arr: [u8; 4] = chunk.try_into().unwrap();
313                embedding.push(f32::from_le_bytes(arr));
314            }
315            Ok(Some(embedding))
316        }
317        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
318        Err(e) => Err(EngramError::Database(e)),
319    }
320}
321
322/// Retry failed embeddings
323#[allow(dead_code)]
324pub fn retry_failed_embeddings(conn: &Connection, max_retries: i32) -> Result<Vec<MemoryId>> {
325    let mut stmt = conn.prepare(
326        "SELECT eq.memory_id, m.content FROM embedding_queue eq
327         JOIN memories m ON eq.memory_id = m.id
328         WHERE eq.status = 'failed' AND eq.retry_count < ?",
329    )?;
330
331    let failed: Vec<(MemoryId, String)> = stmt
332        .query_map([max_retries], |row| Ok((row.get(0)?, row.get(1)?)))?
333        .filter_map(|r| r.ok())
334        .collect();
335
336    let ids: Vec<MemoryId> = failed.iter().map(|(id, _)| *id).collect();
337
338    // Reset status to pending
339    for &id in &ids {
340        conn.execute(
341            "UPDATE embedding_queue SET status = 'pending', error = NULL WHERE memory_id = ?",
342            params![id],
343        )?;
344    }
345
346    Ok(ids)
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::storage::queries::create_memory;
353    use crate::storage::Storage;
354    use crate::types::{CreateMemoryInput, MemoryType};
355
356    #[tokio::test]
357    async fn test_embedding_queue() {
358        let queue = EmbeddingQueue::new(10);
359
360        queue.queue(1, "Hello world".to_string()).await.unwrap();
361        queue.queue(2, "Test content".to_string()).await.unwrap();
362
363        assert_eq!(queue.len(), 2);
364    }
365
366    #[test]
367    fn test_get_embedding_length_mismatch() {
368        let storage = Storage::open_in_memory().unwrap();
369
370        storage
371            .with_connection(|conn| {
372                let memory = create_memory(
373                    conn,
374                    &CreateMemoryInput {
375                        content: "Test embedding".to_string(),
376                        memory_type: MemoryType::Note,
377                        tags: vec![],
378                        metadata: std::collections::HashMap::new(),
379                        importance: None,
380                        scope: Default::default(),
381                        workspace: None,
382                        tier: Default::default(),
383                        defer_embedding: true,
384                        ttl_seconds: None,
385                        dedup_mode: Default::default(),
386                        dedup_threshold: None,
387                        event_time: None,
388                        event_duration_seconds: None,
389                        trigger_pattern: None,
390                        summary_of_id: None,
391                    },
392                )?;
393
394                // Insert embedding with incorrect byte length (dimensions=2 => expected 8 bytes)
395                conn.execute(
396                    "INSERT INTO embeddings (memory_id, embedding, model, dimensions, created_at)
397                     VALUES (?, ?, ?, ?, datetime('now'))",
398                    params![memory.id, vec![0u8; 4], "test", 2],
399                )?;
400
401                match get_embedding(conn, memory.id) {
402                    Err(EngramError::InvalidInput(_)) => Ok(()),
403                    Err(e) => Err(e),
404                    Ok(_) => Err(EngramError::Internal(
405                        "Expected embedding length mismatch error".to_string(),
406                    )),
407                }
408            })
409            .unwrap();
410    }
411}