1use anyhow::{Context, Result};
2use std::sync::Arc;
3
4use brainwires_storage::CachedEmbeddingProvider;
5use brainwires_storage::databases::{
6 FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
7};
8
9const TABLE_NAME: &str = "messages";
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MessageMetadata {
14 pub message_id: String,
16 pub conversation_id: String,
18 pub role: String,
20 pub content: String,
22 pub token_count: Option<i32>,
24 pub model_id: Option<String>,
26 pub images: Option<String>, pub created_at: i64,
30 pub expires_at: Option<i64>,
36}
37
38fn table_schema(embedding_dim: usize) -> Vec<FieldDef> {
40 vec![
41 FieldDef::required("vector", FieldType::Vector(embedding_dim)),
42 FieldDef::required("message_id", FieldType::Utf8),
43 FieldDef::required("conversation_id", FieldType::Utf8),
44 FieldDef::required("role", FieldType::Utf8),
45 FieldDef::required("content", FieldType::Utf8),
46 FieldDef::optional("token_count", FieldType::Int32),
47 FieldDef::optional("model_id", FieldType::Utf8),
48 FieldDef::optional("images", FieldType::Utf8),
49 FieldDef::required("created_at", FieldType::Int64),
50 FieldDef::optional("expires_at", FieldType::Int64),
51 ]
52}
53
54pub fn messages_schema(embedding_dim: usize) -> Arc<arrow_schema::Schema> {
56 use arrow_schema::{DataType, Field, Schema};
57
58 Arc::new(Schema::new(vec![
59 Field::new(
60 "vector",
61 DataType::FixedSizeList(
62 Arc::new(Field::new("item", DataType::Float32, true)),
63 embedding_dim as i32,
64 ),
65 false,
66 ),
67 Field::new("message_id", DataType::Utf8, false),
68 Field::new("conversation_id", DataType::Utf8, false),
69 Field::new("role", DataType::Utf8, false),
70 Field::new("content", DataType::Utf8, false),
71 Field::new("token_count", DataType::Int32, true),
72 Field::new("model_id", DataType::Utf8, true),
73 Field::new("images", DataType::Utf8, true),
74 Field::new("created_at", DataType::Int64, false),
75 Field::new("expires_at", DataType::Int64, true),
76 ]))
77}
78
79fn to_record(m: &MessageMetadata, embedding: Vec<f32>) -> Record {
80 vec![
81 ("vector".into(), FieldValue::Vector(embedding)),
82 (
83 "message_id".into(),
84 FieldValue::Utf8(Some(m.message_id.clone())),
85 ),
86 (
87 "conversation_id".into(),
88 FieldValue::Utf8(Some(m.conversation_id.clone())),
89 ),
90 ("role".into(), FieldValue::Utf8(Some(m.role.clone()))),
91 ("content".into(), FieldValue::Utf8(Some(m.content.clone()))),
92 ("token_count".into(), FieldValue::Int32(m.token_count)),
93 ("model_id".into(), FieldValue::Utf8(m.model_id.clone())),
94 ("images".into(), FieldValue::Utf8(m.images.clone())),
95 ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
96 ("expires_at".into(), FieldValue::Int64(m.expires_at)),
97 ]
98}
99
100fn from_record(r: &Record) -> Result<MessageMetadata> {
101 Ok(MessageMetadata {
102 message_id: record_get(r, "message_id")
103 .and_then(|v| v.as_str())
104 .context("missing message_id")?
105 .to_string(),
106 conversation_id: record_get(r, "conversation_id")
107 .and_then(|v| v.as_str())
108 .context("missing conversation_id")?
109 .to_string(),
110 role: record_get(r, "role")
111 .and_then(|v| v.as_str())
112 .context("missing role")?
113 .to_string(),
114 content: record_get(r, "content")
115 .and_then(|v| v.as_str())
116 .context("missing content")?
117 .to_string(),
118 token_count: record_get(r, "token_count").and_then(|v| v.as_i32()),
119 model_id: record_get(r, "model_id")
120 .and_then(|v| v.as_str())
121 .map(String::from),
122 images: record_get(r, "images")
123 .and_then(|v| v.as_str())
124 .map(String::from),
125 created_at: record_get(r, "created_at")
126 .and_then(|v| v.as_i64())
127 .context("missing created_at")?,
128 expires_at: record_get(r, "expires_at").and_then(|v| v.as_i64()),
129 })
130}
131
132pub struct MessageStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
134 backend: Arc<B>,
135 embeddings: Arc<CachedEmbeddingProvider>,
136}
137
138impl<B: StorageBackend> MessageStore<B> {
139 pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
141 Self {
142 backend,
143 embeddings,
144 }
145 }
146
147 pub async fn ensure_table(&self) -> Result<()> {
149 self.backend
150 .ensure_table(TABLE_NAME, &table_schema(self.embeddings.dimension()))
151 .await
152 }
153
154 pub async fn add(&self, message: MessageMetadata) -> Result<()> {
156 let embedding = self.embeddings.embed(&message.content)?;
158 let record = to_record(&message, embedding);
159
160 self.backend
161 .insert(TABLE_NAME, vec![record])
162 .await
163 .context("Failed to add message")?;
164
165 Ok(())
166 }
167
168 pub async fn add_batch(&self, messages: Vec<MessageMetadata>) -> Result<()> {
170 if messages.is_empty() {
171 return Ok(());
172 }
173
174 let contents: Vec<String> = messages.iter().map(|m| m.content.clone()).collect();
176 let embeddings = self.embeddings.embed_batch(&contents)?;
177
178 let records: Vec<Record> = messages
179 .iter()
180 .zip(embeddings.into_iter())
181 .map(|(m, emb)| to_record(m, emb))
182 .collect();
183
184 self.backend
185 .insert(TABLE_NAME, records)
186 .await
187 .context("Failed to add messages")?;
188
189 Ok(())
190 }
191
192 pub async fn get(&self, message_id: &str) -> Result<Option<MessageMetadata>> {
194 let filter = Filter::Eq(
195 "message_id".into(),
196 FieldValue::Utf8(Some(message_id.to_string())),
197 );
198 let records = self
199 .backend
200 .query(TABLE_NAME, Some(&filter), Some(1))
201 .await?;
202
203 match records.first() {
204 Some(r) => Ok(Some(from_record(r)?)),
205 None => Ok(None),
206 }
207 }
208
209 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageMetadata>> {
211 let filter = Filter::Eq(
212 "conversation_id".into(),
213 FieldValue::Utf8(Some(conversation_id.to_string())),
214 );
215 let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
216
217 records.iter().map(from_record).collect()
218 }
219
220 pub async fn search(
222 &self,
223 query: &str,
224 limit: usize,
225 min_score: f32,
226 ) -> Result<Vec<(MessageMetadata, f32)>> {
227 self.search_with_filter(query, limit, min_score, None).await
228 }
229
230 pub async fn search_conversation(
232 &self,
233 conversation_id: &str,
234 query: &str,
235 limit: usize,
236 min_score: f32,
237 ) -> Result<Vec<(MessageMetadata, f32)>> {
238 let filter = Filter::Eq(
239 "conversation_id".into(),
240 FieldValue::Utf8(Some(conversation_id.to_string())),
241 );
242 self.search_with_filter(query, limit, min_score, Some(filter))
243 .await
244 }
245
246 async fn search_with_filter(
248 &self,
249 query: &str,
250 limit: usize,
251 min_score: f32,
252 filter: Option<Filter>,
253 ) -> Result<Vec<(MessageMetadata, f32)>> {
254 let query_embedding = self.embeddings.embed_cached(query)?;
256
257 let scored = self
258 .backend
259 .vector_search(
260 TABLE_NAME,
261 "vector",
262 query_embedding,
263 limit,
264 filter.as_ref(),
265 )
266 .await?;
267
268 let mut messages_with_scores = Vec::new();
269
270 for sr in scored {
271 if sr.score >= min_score {
272 let message = from_record(&sr.record)?;
273 messages_with_scores.push((message, sr.score));
274 }
275 }
276
277 Ok(messages_with_scores)
278 }
279
280 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
282 let filter = Filter::Eq(
283 "conversation_id".into(),
284 FieldValue::Utf8(Some(conversation_id.to_string())),
285 );
286 self.backend.delete(TABLE_NAME, &filter).await?;
287 Ok(())
288 }
289
290 pub async fn delete(&self, message_id: &str) -> Result<()> {
292 let filter = Filter::Eq(
293 "message_id".into(),
294 FieldValue::Utf8(Some(message_id.to_string())),
295 );
296 self.backend.delete(TABLE_NAME, &filter).await?;
297 Ok(())
298 }
299
300 pub async fn delete_expired(&self) -> Result<usize> {
308 use chrono::Utc;
309 let now = Utc::now().timestamp();
310
311 let filter = Filter::And(vec![
312 Filter::NotNull("expires_at".into()),
313 Filter::Lte("expires_at".into(), FieldValue::Int64(Some(now))),
314 ]);
315
316 let count = self.backend.count(TABLE_NAME, Some(&filter)).await?;
317 if count > 0 {
318 self.backend.delete(TABLE_NAME, &filter).await?;
319 }
320 Ok(count)
321 }
322}