Skip to main content

brainwires_stores/memory/
message_store.rs

1use anyhow::{Context, Result};
2use std::sync::Arc;
3
4use brainwires_storage::CachedEmbeddingProvider;
5use brainwires_storage::databases::{
6    FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
7};
8
9const TABLE_NAME: &str = "messages";
10
11/// Metadata for a message
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MessageMetadata {
14    /// Unique message identifier.
15    pub message_id: String,
16    /// Conversation this message belongs to.
17    pub conversation_id: String,
18    /// Message role (e.g., "user", "assistant").
19    pub role: String,
20    /// Message text content.
21    pub content: String,
22    /// Token count estimate.
23    pub token_count: Option<i32>,
24    /// Model that generated this message.
25    pub model_id: Option<String>,
26    /// Image references as JSON array string.
27    pub images: Option<String>, // JSON array as string
28    /// Creation timestamp (Unix seconds).
29    pub created_at: i64,
30    /// Optional Unix timestamp after which this entry should be evicted.
31    ///
32    /// `None` means no expiry (the entry persists indefinitely).  Use
33    /// [`MessageStore::delete_expired`] to perform bulk eviction, or call
34    /// `TieredMemory::evict_expired` (in `brainwires-memory`) for tier-aware cleanup.
35    pub expires_at: Option<i64>,
36}
37
38/// Return the backend-agnostic table schema for messages.
39fn table_schema(embedding_dim: usize) -> Vec<FieldDef> {
40    vec![
41        FieldDef::required("vector", FieldType::Vector(embedding_dim)),
42        FieldDef::required("message_id", FieldType::Utf8),
43        FieldDef::required("conversation_id", FieldType::Utf8),
44        FieldDef::required("role", FieldType::Utf8),
45        FieldDef::required("content", FieldType::Utf8),
46        FieldDef::optional("token_count", FieldType::Int32),
47        FieldDef::optional("model_id", FieldType::Utf8),
48        FieldDef::optional("images", FieldType::Utf8),
49        FieldDef::required("created_at", FieldType::Int64),
50        FieldDef::optional("expires_at", FieldType::Int64),
51    ]
52}
53
54/// Arrow `Schema` for the messages table (LanceDatabase compatibility).
55pub fn messages_schema(embedding_dim: usize) -> Arc<arrow_schema::Schema> {
56    use arrow_schema::{DataType, Field, Schema};
57
58    Arc::new(Schema::new(vec![
59        Field::new(
60            "vector",
61            DataType::FixedSizeList(
62                Arc::new(Field::new("item", DataType::Float32, true)),
63                embedding_dim as i32,
64            ),
65            false,
66        ),
67        Field::new("message_id", DataType::Utf8, false),
68        Field::new("conversation_id", DataType::Utf8, false),
69        Field::new("role", DataType::Utf8, false),
70        Field::new("content", DataType::Utf8, false),
71        Field::new("token_count", DataType::Int32, true),
72        Field::new("model_id", DataType::Utf8, true),
73        Field::new("images", DataType::Utf8, true),
74        Field::new("created_at", DataType::Int64, false),
75        Field::new("expires_at", DataType::Int64, true),
76    ]))
77}
78
79fn to_record(m: &MessageMetadata, embedding: Vec<f32>) -> Record {
80    vec![
81        ("vector".into(), FieldValue::Vector(embedding)),
82        (
83            "message_id".into(),
84            FieldValue::Utf8(Some(m.message_id.clone())),
85        ),
86        (
87            "conversation_id".into(),
88            FieldValue::Utf8(Some(m.conversation_id.clone())),
89        ),
90        ("role".into(), FieldValue::Utf8(Some(m.role.clone()))),
91        ("content".into(), FieldValue::Utf8(Some(m.content.clone()))),
92        ("token_count".into(), FieldValue::Int32(m.token_count)),
93        ("model_id".into(), FieldValue::Utf8(m.model_id.clone())),
94        ("images".into(), FieldValue::Utf8(m.images.clone())),
95        ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
96        ("expires_at".into(), FieldValue::Int64(m.expires_at)),
97    ]
98}
99
100fn from_record(r: &Record) -> Result<MessageMetadata> {
101    Ok(MessageMetadata {
102        message_id: record_get(r, "message_id")
103            .and_then(|v| v.as_str())
104            .context("missing message_id")?
105            .to_string(),
106        conversation_id: record_get(r, "conversation_id")
107            .and_then(|v| v.as_str())
108            .context("missing conversation_id")?
109            .to_string(),
110        role: record_get(r, "role")
111            .and_then(|v| v.as_str())
112            .context("missing role")?
113            .to_string(),
114        content: record_get(r, "content")
115            .and_then(|v| v.as_str())
116            .context("missing content")?
117            .to_string(),
118        token_count: record_get(r, "token_count").and_then(|v| v.as_i32()),
119        model_id: record_get(r, "model_id")
120            .and_then(|v| v.as_str())
121            .map(String::from),
122        images: record_get(r, "images")
123            .and_then(|v| v.as_str())
124            .map(String::from),
125        created_at: record_get(r, "created_at")
126            .and_then(|v| v.as_i64())
127            .context("missing created_at")?,
128        expires_at: record_get(r, "expires_at").and_then(|v| v.as_i64()),
129    })
130}
131
132/// Store for managing messages with semantic search
133pub struct MessageStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
134    backend: Arc<B>,
135    embeddings: Arc<CachedEmbeddingProvider>,
136}
137
138impl<B: StorageBackend> MessageStore<B> {
139    /// Create a new message store
140    pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
141        Self {
142            backend,
143            embeddings,
144        }
145    }
146
147    /// Ensure the underlying table exists.
148    pub async fn ensure_table(&self) -> Result<()> {
149        self.backend
150            .ensure_table(TABLE_NAME, &table_schema(self.embeddings.dimension()))
151            .await
152    }
153
154    /// Add a message to the store
155    pub async fn add(&self, message: MessageMetadata) -> Result<()> {
156        // Generate embedding for the content
157        let embedding = self.embeddings.embed(&message.content)?;
158        let record = to_record(&message, embedding);
159
160        self.backend
161            .insert(TABLE_NAME, vec![record])
162            .await
163            .context("Failed to add message")?;
164
165        Ok(())
166    }
167
168    /// Add multiple messages in batch
169    pub async fn add_batch(&self, messages: Vec<MessageMetadata>) -> Result<()> {
170        if messages.is_empty() {
171            return Ok(());
172        }
173
174        // Generate embeddings for all messages
175        let contents: Vec<String> = messages.iter().map(|m| m.content.clone()).collect();
176        let embeddings = self.embeddings.embed_batch(&contents)?;
177
178        let records: Vec<Record> = messages
179            .iter()
180            .zip(embeddings.into_iter())
181            .map(|(m, emb)| to_record(m, emb))
182            .collect();
183
184        self.backend
185            .insert(TABLE_NAME, records)
186            .await
187            .context("Failed to add messages")?;
188
189        Ok(())
190    }
191
192    /// Get a single message by ID
193    pub async fn get(&self, message_id: &str) -> Result<Option<MessageMetadata>> {
194        let filter = Filter::Eq(
195            "message_id".into(),
196            FieldValue::Utf8(Some(message_id.to_string())),
197        );
198        let records = self
199            .backend
200            .query(TABLE_NAME, Some(&filter), Some(1))
201            .await?;
202
203        match records.first() {
204            Some(r) => Ok(Some(from_record(r)?)),
205            None => Ok(None),
206        }
207    }
208
209    /// Get messages for a conversation
210    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageMetadata>> {
211        let filter = Filter::Eq(
212            "conversation_id".into(),
213            FieldValue::Utf8(Some(conversation_id.to_string())),
214        );
215        let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
216
217        records.iter().map(from_record).collect()
218    }
219
220    /// Search messages by semantic similarity
221    pub async fn search(
222        &self,
223        query: &str,
224        limit: usize,
225        min_score: f32,
226    ) -> Result<Vec<(MessageMetadata, f32)>> {
227        self.search_with_filter(query, limit, min_score, None).await
228    }
229
230    /// Search messages within a specific conversation by semantic similarity
231    pub async fn search_conversation(
232        &self,
233        conversation_id: &str,
234        query: &str,
235        limit: usize,
236        min_score: f32,
237    ) -> Result<Vec<(MessageMetadata, f32)>> {
238        let filter = Filter::Eq(
239            "conversation_id".into(),
240            FieldValue::Utf8(Some(conversation_id.to_string())),
241        );
242        self.search_with_filter(query, limit, min_score, Some(filter))
243            .await
244    }
245
246    /// Search messages with optional filter by semantic similarity
247    async fn search_with_filter(
248        &self,
249        query: &str,
250        limit: usize,
251        min_score: f32,
252        filter: Option<Filter>,
253    ) -> Result<Vec<(MessageMetadata, f32)>> {
254        // Generate query embedding (use cached version for repeated queries)
255        let query_embedding = self.embeddings.embed_cached(query)?;
256
257        let scored = self
258            .backend
259            .vector_search(
260                TABLE_NAME,
261                "vector",
262                query_embedding,
263                limit,
264                filter.as_ref(),
265            )
266            .await?;
267
268        let mut messages_with_scores = Vec::new();
269
270        for sr in scored {
271            if sr.score >= min_score {
272                let message = from_record(&sr.record)?;
273                messages_with_scores.push((message, sr.score));
274            }
275        }
276
277        Ok(messages_with_scores)
278    }
279
280    /// Delete all messages for a conversation
281    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
282        let filter = Filter::Eq(
283            "conversation_id".into(),
284            FieldValue::Utf8(Some(conversation_id.to_string())),
285        );
286        self.backend.delete(TABLE_NAME, &filter).await?;
287        Ok(())
288    }
289
290    /// Delete a specific message
291    pub async fn delete(&self, message_id: &str) -> Result<()> {
292        let filter = Filter::Eq(
293            "message_id".into(),
294            FieldValue::Utf8(Some(message_id.to_string())),
295        );
296        self.backend.delete(TABLE_NAME, &filter).await?;
297        Ok(())
298    }
299
300    /// Delete all messages whose `expires_at` timestamp is in the past.
301    ///
302    /// Returns the number of rows deleted.  Rows with `expires_at = NULL`
303    /// (no TTL) are never touched.
304    ///
305    /// Call this at agent run completion or on a periodic background schedule
306    /// to enforce session-tier TTL policies.
307    pub async fn delete_expired(&self) -> Result<usize> {
308        use chrono::Utc;
309        let now = Utc::now().timestamp();
310
311        let filter = Filter::And(vec![
312            Filter::NotNull("expires_at".into()),
313            Filter::Lte("expires_at".into(), FieldValue::Int64(Some(now))),
314        ]);
315
316        let count = self.backend.count(TABLE_NAME, Some(&filter)).await?;
317        if count > 0 {
318            self.backend.delete(TABLE_NAME, &filter).await?;
319        }
320        Ok(count)
321    }
322}