1use async_channel::{bounded, Receiver, Sender};
7use chrono::Utc;
8use parking_lot::Mutex;
9use rusqlite::{params, Connection};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::time::interval;
13
14use super::{create_embedder, Embedder};
15use crate::error::{EngramError, Result};
16use crate::types::{EmbeddingConfig, EmbeddingState, EmbeddingStatus, MemoryId};
17
18#[derive(Debug)]
20pub struct EmbeddingRequest {
21 pub memory_id: MemoryId,
22 pub content: String,
23}
24
25pub struct EmbeddingQueue {
27 sender: Sender<EmbeddingRequest>,
28 receiver: Receiver<EmbeddingRequest>,
29 batch_size: usize,
30}
31
32impl EmbeddingQueue {
33 pub fn new(batch_size: usize) -> Self {
35 let (sender, receiver) = bounded(10000); Self {
37 sender,
38 receiver,
39 batch_size,
40 }
41 }
42
43 pub async fn queue(&self, memory_id: MemoryId, content: String) -> Result<()> {
45 self.sender
46 .send(EmbeddingRequest { memory_id, content })
47 .await
48 .map_err(|e| EngramError::Embedding(format!("Queue send error: {}", e)))?;
49 Ok(())
50 }
51
52 pub fn queue_blocking(&self, memory_id: MemoryId, content: String) -> Result<()> {
54 self.sender
55 .send_blocking(EmbeddingRequest { memory_id, content })
56 .map_err(|e| EngramError::Embedding(format!("Queue send error: {}", e)))?;
57 Ok(())
58 }
59
60 pub fn len(&self) -> usize {
62 self.receiver.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.receiver.is_empty()
68 }
69
70 pub fn receiver(&self) -> Receiver<EmbeddingRequest> {
72 self.receiver.clone()
73 }
74}
75
76impl Clone for EmbeddingQueue {
77 fn clone(&self) -> Self {
78 Self {
79 sender: self.sender.clone(),
80 receiver: self.receiver.clone(),
81 batch_size: self.batch_size,
82 }
83 }
84}
85
86pub struct EmbeddingWorker {
88 embedder: Arc<dyn Embedder>,
89 queue: EmbeddingQueue,
90 conn: Arc<Mutex<Connection>>,
91 batch_size: usize,
92 batch_timeout: Duration,
93}
94
95impl EmbeddingWorker {
96 pub fn new(
98 config: EmbeddingConfig,
99 queue: EmbeddingQueue,
100 conn: Arc<Mutex<Connection>>,
101 ) -> Result<Self> {
102 let embedder = create_embedder(&config)?;
103 let batch_size = config.batch_size;
104
105 Ok(Self {
106 embedder,
107 queue,
108 conn,
109 batch_size,
110 batch_timeout: Duration::from_secs(5),
111 })
112 }
113
114 pub async fn run(&self) {
116 let receiver = self.queue.receiver();
117 let mut batch: Vec<EmbeddingRequest> = Vec::with_capacity(self.batch_size);
118 let mut batch_timer = interval(self.batch_timeout);
119
120 loop {
121 tokio::select! {
122 Ok(request) = receiver.recv() => {
124 batch.push(request);
125
126 if batch.len() >= self.batch_size {
128 self.process_batch(&mut batch).await;
129 }
130 }
131
132 _ = batch_timer.tick() => {
134 if !batch.is_empty() {
135 self.process_batch(&mut batch).await;
136 }
137 }
138 }
139 }
140 }
141
142 async fn process_batch(&self, batch: &mut Vec<EmbeddingRequest>) {
144 if batch.is_empty() {
145 return;
146 }
147
148 let memory_ids: Vec<MemoryId> = batch.iter().map(|r| r.memory_id).collect();
149 let contents: Vec<&str> = batch.iter().map(|r| r.content.as_str()).collect();
150
151 {
153 let conn = self.conn.lock();
154 let now = Utc::now().to_rfc3339();
155 for &id in &memory_ids {
156 let _ = conn.execute(
157 "UPDATE embedding_queue SET status = 'processing', started_at = ? WHERE memory_id = ?",
158 params![now, id],
159 );
160 }
161 }
162
163 match self.embedder.embed_batch(&contents) {
165 Ok(embeddings) => {
166 let conn = self.conn.lock();
167 let now = Utc::now().to_rfc3339();
168 let model = self.embedder.model_name();
169 let dimensions = self.embedder.dimensions();
170
171 for (id, embedding) in memory_ids.iter().zip(embeddings.iter()) {
172 let embedding_bytes: Vec<u8> =
174 embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
175
176 let _ = conn.execute(
178 "INSERT OR REPLACE INTO embeddings (memory_id, embedding, model, dimensions, created_at)
179 VALUES (?, ?, ?, ?, ?)",
180 params![id, embedding_bytes, model, dimensions, now],
181 );
182
183 let _ = conn.execute(
185 "UPDATE memories SET has_embedding = 1 WHERE id = ?",
186 params![id],
187 );
188
189 let _ = conn.execute(
191 "UPDATE embedding_queue SET status = 'complete', completed_at = ? WHERE memory_id = ?",
192 params![now, id],
193 );
194 }
195
196 tracing::info!("Processed {} embeddings", memory_ids.len());
197 }
198 Err(e) => {
199 let conn = self.conn.lock();
200 let error_time = Utc::now().to_rfc3339();
201 let error_msg = e.to_string();
202 let _ = error_time; for &id in &memory_ids {
205 let _ = conn.execute(
206 "UPDATE embedding_queue SET status = 'failed', error = ?, retry_count = retry_count + 1 WHERE memory_id = ?",
207 params![error_msg, id],
208 );
209 }
210
211 tracing::error!("Embedding batch failed: {}", e);
212 }
213 }
214
215 batch.clear();
216 }
217}
218
219pub fn get_embedding_status(conn: &Connection, memory_id: MemoryId) -> Result<EmbeddingStatus> {
221 let row = conn.query_row(
222 "SELECT status, queued_at, completed_at, error FROM embedding_queue WHERE memory_id = ?",
223 params![memory_id],
224 |row| {
225 let status_str: String = row.get(0)?;
226 let queued_at: Option<String> = row.get(1)?;
227 let completed_at: Option<String> = row.get(2)?;
228 let error: Option<String> = row.get(3)?;
229
230 let status = match status_str.as_str() {
231 "pending" => EmbeddingState::Pending,
232 "processing" => EmbeddingState::Processing,
233 "complete" => EmbeddingState::Complete,
234 "failed" => EmbeddingState::Failed,
235 _ => EmbeddingState::Pending,
236 };
237
238 Ok(EmbeddingStatus {
239 memory_id,
240 status,
241 queued_at: queued_at.and_then(|s| {
242 chrono::DateTime::parse_from_rfc3339(&s)
243 .map(|dt| dt.with_timezone(&Utc))
244 .ok()
245 }),
246 completed_at: completed_at.and_then(|s| {
247 chrono::DateTime::parse_from_rfc3339(&s)
248 .map(|dt| dt.with_timezone(&Utc))
249 .ok()
250 }),
251 error,
252 })
253 },
254 );
255
256 match row {
257 Ok(status) => Ok(status),
258 Err(rusqlite::Error::QueryReturnedNoRows) => {
259 let has_embedding: bool = conn
261 .query_row(
262 "SELECT has_embedding FROM memories WHERE id = ?",
263 params![memory_id],
264 |row| row.get(0),
265 )
266 .unwrap_or(false);
267
268 Ok(EmbeddingStatus {
269 memory_id,
270 status: if has_embedding {
271 EmbeddingState::Complete
272 } else {
273 EmbeddingState::Pending
274 },
275 queued_at: None,
276 completed_at: None,
277 error: None,
278 })
279 }
280 Err(e) => Err(EngramError::Database(e)),
281 }
282}
283
284pub fn get_embedding(conn: &Connection, memory_id: MemoryId) -> Result<Option<Vec<f32>>> {
286 let row = conn.query_row(
287 "SELECT embedding, dimensions FROM embeddings WHERE memory_id = ?",
288 params![memory_id],
289 |row| {
290 let bytes: Vec<u8> = row.get(0)?;
291 let dimensions: usize = row.get(1)?;
292 Ok((bytes, dimensions))
293 },
294 );
295
296 match row {
297 Ok((bytes, dimensions)) => {
298 let expected_len = dimensions.checked_mul(4).ok_or_else(|| {
299 EngramError::InvalidInput("Embedding dimensions too large".to_string())
300 })?;
301 if bytes.len() != expected_len {
302 return Err(EngramError::InvalidInput(format!(
303 "Embedding byte length {} does not match dimensions {}",
304 bytes.len(),
305 dimensions
306 )));
307 }
308
309 let mut embedding = Vec::with_capacity(dimensions);
311 for chunk in bytes.chunks_exact(4) {
312 let arr: [u8; 4] = chunk.try_into().unwrap();
313 embedding.push(f32::from_le_bytes(arr));
314 }
315 Ok(Some(embedding))
316 }
317 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
318 Err(e) => Err(EngramError::Database(e)),
319 }
320}
321
322#[allow(dead_code)]
324pub fn retry_failed_embeddings(conn: &Connection, max_retries: i32) -> Result<Vec<MemoryId>> {
325 let mut stmt = conn.prepare(
326 "SELECT eq.memory_id, m.content FROM embedding_queue eq
327 JOIN memories m ON eq.memory_id = m.id
328 WHERE eq.status = 'failed' AND eq.retry_count < ?",
329 )?;
330
331 let failed: Vec<(MemoryId, String)> = stmt
332 .query_map([max_retries], |row| Ok((row.get(0)?, row.get(1)?)))?
333 .filter_map(|r| r.ok())
334 .collect();
335
336 let ids: Vec<MemoryId> = failed.iter().map(|(id, _)| *id).collect();
337
338 for &id in &ids {
340 conn.execute(
341 "UPDATE embedding_queue SET status = 'pending', error = NULL WHERE memory_id = ?",
342 params![id],
343 )?;
344 }
345
346 Ok(ids)
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::storage::queries::create_memory;
353 use crate::storage::Storage;
354 use crate::types::{CreateMemoryInput, MemoryType};
355
356 #[tokio::test]
357 async fn test_embedding_queue() {
358 let queue = EmbeddingQueue::new(10);
359
360 queue.queue(1, "Hello world".to_string()).await.unwrap();
361 queue.queue(2, "Test content".to_string()).await.unwrap();
362
363 assert_eq!(queue.len(), 2);
364 }
365
366 #[test]
367 fn test_get_embedding_length_mismatch() {
368 let storage = Storage::open_in_memory().unwrap();
369
370 storage
371 .with_connection(|conn| {
372 let memory = create_memory(
373 conn,
374 &CreateMemoryInput {
375 content: "Test embedding".to_string(),
376 memory_type: MemoryType::Note,
377 tags: vec![],
378 metadata: std::collections::HashMap::new(),
379 importance: None,
380 scope: Default::default(),
381 workspace: None,
382 tier: Default::default(),
383 defer_embedding: true,
384 ttl_seconds: None,
385 dedup_mode: Default::default(),
386 dedup_threshold: None,
387 event_time: None,
388 event_duration_seconds: None,
389 trigger_pattern: None,
390 summary_of_id: None,
391 },
392 )?;
393
394 conn.execute(
396 "INSERT INTO embeddings (memory_id, embedding, model, dimensions, created_at)
397 VALUES (?, ?, ?, ?, datetime('now'))",
398 params![memory.id, vec![0u8; 4], "test", 2],
399 )?;
400
401 match get_embedding(conn, memory.id) {
402 Err(EngramError::InvalidInput(_)) => Ok(()),
403 Err(e) => Err(e),
404 Ok(_) => Err(EngramError::Internal(
405 "Expected embedding length mismatch error".to_string(),
406 )),
407 }
408 })
409 .unwrap();
410 }
411}