Skip to main content

brainwires_storage/stores/
message_store.rs

1use anyhow::{Context, Result};
2use std::sync::Arc;
3
4use crate::databases::{
5    FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
6};
7use crate::embeddings::EmbeddingProvider;
8
9const TABLE_NAME: &str = "messages";
10
11/// Metadata for a message
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MessageMetadata {
14    /// Unique message identifier.
15    pub message_id: String,
16    /// Conversation this message belongs to.
17    pub conversation_id: String,
18    /// Message role (e.g., "user", "assistant").
19    pub role: String,
20    /// Message text content.
21    pub content: String,
22    /// Token count estimate.
23    pub token_count: Option<i32>,
24    /// Model that generated this message.
25    pub model_id: Option<String>,
26    /// Image references as JSON array string.
27    pub images: Option<String>, // JSON array as string
28    /// Creation timestamp (Unix seconds).
29    pub created_at: i64,
30    /// Optional Unix timestamp after which this entry should be evicted.
31    ///
32    /// `None` means no expiry (the entry persists indefinitely).  Use
33    /// [`MessageStore::delete_expired`] to perform bulk eviction, or call
34    /// [`TieredMemory::evict_expired`](crate::TieredMemory::evict_expired) for tier-aware cleanup.
35    pub expires_at: Option<i64>,
36}
37
38/// Return the backend-agnostic table schema for messages.
39fn table_schema(embedding_dim: usize) -> Vec<FieldDef> {
40    vec![
41        FieldDef::required("vector", FieldType::Vector(embedding_dim)),
42        FieldDef::required("message_id", FieldType::Utf8),
43        FieldDef::required("conversation_id", FieldType::Utf8),
44        FieldDef::required("role", FieldType::Utf8),
45        FieldDef::required("content", FieldType::Utf8),
46        FieldDef::optional("token_count", FieldType::Int32),
47        FieldDef::optional("model_id", FieldType::Utf8),
48        FieldDef::optional("images", FieldType::Utf8),
49        FieldDef::required("created_at", FieldType::Int64),
50        FieldDef::optional("expires_at", FieldType::Int64),
51    ]
52}
53
54/// Arrow `Schema` for the messages table (LanceDatabase compatibility).
55#[cfg(feature = "native")]
56pub fn messages_schema(embedding_dim: usize) -> Arc<arrow_schema::Schema> {
57    use arrow_schema::{DataType, Field, Schema};
58
59    Arc::new(Schema::new(vec![
60        Field::new(
61            "vector",
62            DataType::FixedSizeList(
63                Arc::new(Field::new("item", DataType::Float32, true)),
64                embedding_dim as i32,
65            ),
66            false,
67        ),
68        Field::new("message_id", DataType::Utf8, false),
69        Field::new("conversation_id", DataType::Utf8, false),
70        Field::new("role", DataType::Utf8, false),
71        Field::new("content", DataType::Utf8, false),
72        Field::new("token_count", DataType::Int32, true),
73        Field::new("model_id", DataType::Utf8, true),
74        Field::new("images", DataType::Utf8, true),
75        Field::new("created_at", DataType::Int64, false),
76        Field::new("expires_at", DataType::Int64, true),
77    ]))
78}
79
80fn to_record(m: &MessageMetadata, embedding: Vec<f32>) -> Record {
81    vec![
82        ("vector".into(), FieldValue::Vector(embedding)),
83        (
84            "message_id".into(),
85            FieldValue::Utf8(Some(m.message_id.clone())),
86        ),
87        (
88            "conversation_id".into(),
89            FieldValue::Utf8(Some(m.conversation_id.clone())),
90        ),
91        ("role".into(), FieldValue::Utf8(Some(m.role.clone()))),
92        ("content".into(), FieldValue::Utf8(Some(m.content.clone()))),
93        ("token_count".into(), FieldValue::Int32(m.token_count)),
94        ("model_id".into(), FieldValue::Utf8(m.model_id.clone())),
95        ("images".into(), FieldValue::Utf8(m.images.clone())),
96        ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
97        ("expires_at".into(), FieldValue::Int64(m.expires_at)),
98    ]
99}
100
101fn from_record(r: &Record) -> Result<MessageMetadata> {
102    Ok(MessageMetadata {
103        message_id: record_get(r, "message_id")
104            .and_then(|v| v.as_str())
105            .context("missing message_id")?
106            .to_string(),
107        conversation_id: record_get(r, "conversation_id")
108            .and_then(|v| v.as_str())
109            .context("missing conversation_id")?
110            .to_string(),
111        role: record_get(r, "role")
112            .and_then(|v| v.as_str())
113            .context("missing role")?
114            .to_string(),
115        content: record_get(r, "content")
116            .and_then(|v| v.as_str())
117            .context("missing content")?
118            .to_string(),
119        token_count: record_get(r, "token_count").and_then(|v| v.as_i32()),
120        model_id: record_get(r, "model_id")
121            .and_then(|v| v.as_str())
122            .map(String::from),
123        images: record_get(r, "images")
124            .and_then(|v| v.as_str())
125            .map(String::from),
126        created_at: record_get(r, "created_at")
127            .and_then(|v| v.as_i64())
128            .context("missing created_at")?,
129        expires_at: record_get(r, "expires_at").and_then(|v| v.as_i64()),
130    })
131}
132
133/// Store for managing messages with semantic search
134pub struct MessageStore<B: StorageBackend = crate::databases::lance::LanceDatabase> {
135    backend: Arc<B>,
136    embeddings: Arc<EmbeddingProvider>,
137}
138
139impl<B: StorageBackend> MessageStore<B> {
140    /// Create a new message store
141    pub fn new(backend: Arc<B>, embeddings: Arc<EmbeddingProvider>) -> Self {
142        Self {
143            backend,
144            embeddings,
145        }
146    }
147
148    /// Ensure the underlying table exists.
149    pub async fn ensure_table(&self) -> Result<()> {
150        self.backend
151            .ensure_table(TABLE_NAME, &table_schema(self.embeddings.dimension()))
152            .await
153    }
154
155    /// Add a message to the store
156    pub async fn add(&self, message: MessageMetadata) -> Result<()> {
157        // Generate embedding for the content
158        let embedding = self.embeddings.embed(&message.content)?;
159        let record = to_record(&message, embedding);
160
161        self.backend
162            .insert(TABLE_NAME, vec![record])
163            .await
164            .context("Failed to add message")?;
165
166        Ok(())
167    }
168
169    /// Add multiple messages in batch
170    pub async fn add_batch(&self, messages: Vec<MessageMetadata>) -> Result<()> {
171        if messages.is_empty() {
172            return Ok(());
173        }
174
175        // Generate embeddings for all messages
176        let contents: Vec<String> = messages.iter().map(|m| m.content.clone()).collect();
177        let embeddings = self.embeddings.embed_batch(&contents)?;
178
179        let records: Vec<Record> = messages
180            .iter()
181            .zip(embeddings.into_iter())
182            .map(|(m, emb)| to_record(m, emb))
183            .collect();
184
185        self.backend
186            .insert(TABLE_NAME, records)
187            .await
188            .context("Failed to add messages")?;
189
190        Ok(())
191    }
192
193    /// Get a single message by ID
194    pub async fn get(&self, message_id: &str) -> Result<Option<MessageMetadata>> {
195        let filter = Filter::Eq(
196            "message_id".into(),
197            FieldValue::Utf8(Some(message_id.to_string())),
198        );
199        let records = self
200            .backend
201            .query(TABLE_NAME, Some(&filter), Some(1))
202            .await?;
203
204        match records.first() {
205            Some(r) => Ok(Some(from_record(r)?)),
206            None => Ok(None),
207        }
208    }
209
210    /// Get messages for a conversation
211    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageMetadata>> {
212        let filter = Filter::Eq(
213            "conversation_id".into(),
214            FieldValue::Utf8(Some(conversation_id.to_string())),
215        );
216        let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
217
218        records.iter().map(from_record).collect()
219    }
220
221    /// Search messages by semantic similarity
222    pub async fn search(
223        &self,
224        query: &str,
225        limit: usize,
226        min_score: f32,
227    ) -> Result<Vec<(MessageMetadata, f32)>> {
228        self.search_with_filter(query, limit, min_score, None).await
229    }
230
231    /// Search messages within a specific conversation by semantic similarity
232    pub async fn search_conversation(
233        &self,
234        conversation_id: &str,
235        query: &str,
236        limit: usize,
237        min_score: f32,
238    ) -> Result<Vec<(MessageMetadata, f32)>> {
239        let filter = Filter::Eq(
240            "conversation_id".into(),
241            FieldValue::Utf8(Some(conversation_id.to_string())),
242        );
243        self.search_with_filter(query, limit, min_score, Some(filter))
244            .await
245    }
246
247    /// Search messages with optional filter by semantic similarity
248    async fn search_with_filter(
249        &self,
250        query: &str,
251        limit: usize,
252        min_score: f32,
253        filter: Option<Filter>,
254    ) -> Result<Vec<(MessageMetadata, f32)>> {
255        // Generate query embedding (use cached version for repeated queries)
256        let query_embedding = self.embeddings.embed_cached(query)?;
257
258        let scored = self
259            .backend
260            .vector_search(
261                TABLE_NAME,
262                "vector",
263                query_embedding,
264                limit,
265                filter.as_ref(),
266            )
267            .await?;
268
269        let mut messages_with_scores = Vec::new();
270
271        for sr in scored {
272            if sr.score >= min_score {
273                let message = from_record(&sr.record)?;
274                messages_with_scores.push((message, sr.score));
275            }
276        }
277
278        Ok(messages_with_scores)
279    }
280
281    /// Delete all messages for a conversation
282    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
283        let filter = Filter::Eq(
284            "conversation_id".into(),
285            FieldValue::Utf8(Some(conversation_id.to_string())),
286        );
287        self.backend.delete(TABLE_NAME, &filter).await?;
288        Ok(())
289    }
290
291    /// Delete a specific message
292    pub async fn delete(&self, message_id: &str) -> Result<()> {
293        let filter = Filter::Eq(
294            "message_id".into(),
295            FieldValue::Utf8(Some(message_id.to_string())),
296        );
297        self.backend.delete(TABLE_NAME, &filter).await?;
298        Ok(())
299    }
300
301    /// Delete all messages whose `expires_at` timestamp is in the past.
302    ///
303    /// Returns the number of rows deleted.  Rows with `expires_at = NULL`
304    /// (no TTL) are never touched.
305    ///
306    /// Call this at agent run completion or on a periodic background schedule
307    /// to enforce session-tier TTL policies.
308    pub async fn delete_expired(&self) -> Result<usize> {
309        use chrono::Utc;
310        let now = Utc::now().timestamp();
311
312        let filter = Filter::And(vec![
313            Filter::NotNull("expires_at".into()),
314            Filter::Lte("expires_at".into(), FieldValue::Int64(Some(now))),
315        ]);
316
317        let count = self.backend.count(TABLE_NAME, Some(&filter)).await?;
318        if count > 0 {
319            self.backend.delete(TABLE_NAME, &filter).await?;
320        }
321        Ok(count)
322    }
323}