1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{Context, Result};
5use arrow_array::{
6 Array, BooleanArray, FixedSizeListArray, Float32Array, Int64Array, RecordBatch,
7 RecordBatchIterator, StringArray,
8};
9use arrow_schema::{DataType, Field, Schema};
10use chrono::Utc;
11use futures::TryStreamExt;
12use lancedb::query::{ExecutableQuery, QueryBase};
13use tracing;
14
15use crate::knowledge::bks_pks::{
16 BehavioralKnowledgeCache, PersonalFactCollector, PersonalKnowledgeCache,
17};
18use brainwires_storage::{EmbeddingProvider, LanceClient};
19
20use crate::knowledge::fact_extractor;
21use crate::knowledge::thought::{Thought, ThoughtCategory, ThoughtSource};
22use crate::knowledge::types::*;
23
24pub struct BrainClient {
26 lance: Arc<LanceClient>,
27 embeddings: Arc<EmbeddingProvider>,
28 pks_cache: PersonalKnowledgeCache,
29 bks_cache: BehavioralKnowledgeCache,
30 fact_collector: PersonalFactCollector,
31}
32
33const THOUGHTS_TABLE: &str = "thoughts";
34
35impl BrainClient {
36 pub async fn new() -> Result<Self> {
42 let base = dirs::home_dir()
43 .context("Cannot determine home directory")?
44 .join(".brainwires");
45
46 std::fs::create_dir_all(&base)?;
47
48 let lance_path = base.join("brain");
49 let pks_path = base.join("pks.db");
50 let bks_path = base.join("bks.db");
51
52 Self::with_paths(
53 lance_path
54 .to_str()
55 .context("lance path is not valid UTF-8")?,
56 pks_path.to_str().context("pks path is not valid UTF-8")?,
57 bks_path.to_str().context("bks path is not valid UTF-8")?,
58 )
59 .await
60 }
61
62 pub async fn with_paths(lance_path: &str, pks_path: &str, bks_path: &str) -> Result<Self> {
64 let embeddings = Arc::new(EmbeddingProvider::new()?);
65 let lance = Arc::new(LanceClient::new(lance_path).await?);
66
67 Self::ensure_thoughts_table(&lance, embeddings.dimension()).await?;
69
70 let pks_cache = PersonalKnowledgeCache::new(pks_path, 1000)?;
71 let bks_cache = BehavioralKnowledgeCache::new(bks_path, 1000)?;
72 let fact_collector = PersonalFactCollector::default();
73
74 Ok(Self {
75 lance,
76 embeddings,
77 pks_cache,
78 bks_cache,
79 fact_collector,
80 })
81 }
82
83 async fn ensure_thoughts_table(lance: &LanceClient, dim: usize) -> Result<()> {
86 let conn = lance.connection();
87 let tables = conn.table_names().execute().await?;
88 if tables.contains(&THOUGHTS_TABLE.to_string()) {
89 return Ok(());
90 }
91
92 let schema = Self::thoughts_schema(dim);
93 let empty = RecordBatch::new_empty(schema.clone());
94 let batches = RecordBatchIterator::new(vec![Ok(empty)], schema);
95
96 conn.create_table(THOUGHTS_TABLE, Box::new(batches))
97 .execute()
98 .await
99 .context("Failed to create thoughts table")?;
100
101 tracing::info!("Created thoughts LanceDB table");
102 Ok(())
103 }
104
105 fn thoughts_schema(dim: usize) -> Arc<Schema> {
106 Arc::new(Schema::new(vec![
107 Field::new(
108 "vector",
109 DataType::FixedSizeList(
110 Arc::new(Field::new("item", DataType::Float32, true)),
111 dim as i32,
112 ),
113 false,
114 ),
115 Field::new("id", DataType::Utf8, false),
116 Field::new("content", DataType::Utf8, false),
117 Field::new("category", DataType::Utf8, false),
118 Field::new("tags", DataType::Utf8, false), Field::new("source", DataType::Utf8, false),
120 Field::new("importance", DataType::Float32, false),
121 Field::new("created_at", DataType::Int64, false),
122 Field::new("updated_at", DataType::Int64, false),
123 Field::new("deleted", DataType::Boolean, false),
124 ]))
125 }
126
127 fn thoughts_table(
128 &self,
129 ) -> impl std::future::Future<Output = Result<lancedb::Table>> + Send + '_ {
130 let conn = self.lance.connection().clone();
131 async move {
132 conn.open_table(THOUGHTS_TABLE)
133 .execute()
134 .await
135 .context("Failed to open thoughts table")
136 }
137 }
138
139 pub async fn capture_thought(
143 &mut self,
144 req: CaptureThoughtRequest,
145 ) -> Result<CaptureThoughtResponse> {
146 let category = match &req.category {
148 Some(c) => ThoughtCategory::parse(c),
149 None => fact_extractor::detect_category(&req.content),
150 };
151
152 let mut auto_tags = fact_extractor::extract_tags(&req.content);
153 if let Some(ref user_tags) = req.tags {
154 for t in user_tags {
155 let lower = t.to_lowercase();
156 if !auto_tags.contains(&lower) {
157 auto_tags.push(lower);
158 }
159 }
160 }
161
162 let source = req
163 .source
164 .as_deref()
165 .map(ThoughtSource::parse)
166 .unwrap_or(ThoughtSource::ManualCapture);
167
168 let thought = Thought::new(req.content.clone())
169 .with_category(category)
170 .with_tags(auto_tags.clone())
171 .with_source(source)
172 .with_importance(req.importance.unwrap_or(0.5));
173
174 let embedding = self.embeddings.embed(&thought.content)?;
176
177 let batch = self.thought_to_batch(&thought, &embedding)?;
179 let table = self.thoughts_table().await?;
180 let schema = batch.schema();
181 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
182 table
183 .add(Box::new(batches))
184 .execute()
185 .await
186 .context("Failed to store thought")?;
187
188 let facts = self.fact_collector.process_message(&req.content);
190 let facts_count = facts.len();
191 for fact in facts {
192 if let Err(e) = self.pks_cache.upsert_fact(fact) {
193 tracing::warn!("Failed to upsert PKS fact: {}", e);
194 }
195 }
196
197 tracing::info!(
198 id = %thought.id,
199 category = %category,
200 facts = facts_count,
201 "Captured thought"
202 );
203
204 Ok(CaptureThoughtResponse {
205 id: thought.id,
206 category: category.to_string(),
207 tags: auto_tags,
208 importance: thought.importance,
209 facts_extracted: facts_count,
210 })
211 }
212
213 pub async fn search_memory(&self, req: SearchMemoryRequest) -> Result<SearchMemoryResponse> {
217 let search_thoughts = req
218 .sources
219 .as_ref()
220 .is_none_or(|s| s.iter().any(|x| x == "thoughts"));
221 let search_facts = req
222 .sources
223 .as_ref()
224 .is_none_or(|s| s.iter().any(|x| x == "facts"));
225
226 let mut results = Vec::new();
227
228 if search_thoughts {
230 let query_embedding = self.embeddings.embed_cached(&req.query)?;
231 let table = self.thoughts_table().await?;
232
233 let mut search = table
234 .vector_search(query_embedding)
235 .context("Failed to create vector search")?;
236
237 search = search.only_if("deleted = false");
239
240 if let Some(ref cat) = req.category {
242 let cat_str = ThoughtCategory::parse(cat).as_str().to_string();
243 search = search.only_if(format!("category = '{}'", cat_str));
244 }
245
246 let stream = search.limit(req.limit).execute().await?;
247 let batches: Vec<RecordBatch> = stream.try_collect().await?;
248
249 for batch in &batches {
250 let distances = batch
251 .column_by_name("_distance")
252 .context("Missing _distance column")?
253 .as_any()
254 .downcast_ref::<Float32Array>()
255 .context("Invalid _distance type")?;
256
257 let thoughts = self.batch_to_thoughts(std::slice::from_ref(batch))?;
258
259 for (i, thought) in thoughts.into_iter().enumerate() {
260 let distance = distances.value(i);
261 let score = 1.0 / (1.0 + distance);
262 if score >= req.min_score {
263 results.push(MemorySearchResult {
264 content: thought.content,
265 score,
266 source: "thoughts".into(),
267 thought_id: Some(thought.id),
268 category: Some(thought.category.to_string()),
269 tags: Some(thought.tags),
270 created_at: Some(thought.created_at),
271 });
272 }
273 }
274 }
275 }
276
277 if search_facts {
279 let pks_results = self.pks_cache.search_facts(&req.query);
280 for fact in pks_results {
281 let score = 0.7; if score >= req.min_score {
283 results.push(MemorySearchResult {
284 content: format!("{}: {}", fact.key, fact.value),
285 score,
286 source: "facts".into(),
287 thought_id: None,
288 category: Some(format!("{:?}", fact.category)),
289 tags: None,
290 created_at: Some(fact.created_at),
291 });
292 }
293 }
294 }
295
296 results.sort_by(|a, b| {
298 b.score
299 .partial_cmp(&a.score)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 });
302 results.truncate(req.limit);
303
304 let total = results.len();
305 Ok(SearchMemoryResponse { results, total })
306 }
307
308 pub async fn list_recent(&self, req: ListRecentRequest) -> Result<ListRecentResponse> {
312 let since_ts = match &req.since {
313 Some(s) => chrono::DateTime::parse_from_rfc3339(s)
314 .map(|dt| dt.timestamp())
315 .unwrap_or_else(|_| Utc::now().timestamp() - 7 * 86400),
316 None => Utc::now().timestamp() - 7 * 86400,
317 };
318
319 let table = self.thoughts_table().await?;
320
321 let mut filter = format!("deleted = false AND created_at >= {}", since_ts);
322 if let Some(ref cat) = req.category {
323 let cat_str = ThoughtCategory::parse(cat).as_str().to_string();
324 filter.push_str(&format!(" AND category = '{}'", cat_str));
325 }
326
327 let stream = table
328 .query()
329 .only_if(filter)
330 .limit(req.limit)
331 .execute()
332 .await?;
333
334 let batches: Vec<RecordBatch> = stream.try_collect().await?;
335 let mut thoughts = self.batch_to_thoughts(&batches)?;
336 thoughts.sort_by(|a, b| b.created_at.cmp(&a.created_at));
337 thoughts.truncate(req.limit);
338
339 let total = thoughts.len();
340 let summaries = thoughts
341 .into_iter()
342 .map(|t| ThoughtSummary {
343 id: t.id,
344 content: t.content,
345 category: t.category.to_string(),
346 tags: t.tags,
347 importance: t.importance,
348 created_at: t.created_at,
349 })
350 .collect();
351
352 Ok(ListRecentResponse {
353 thoughts: summaries,
354 total,
355 })
356 }
357
358 pub async fn get_thought(&self, id: &str) -> Result<Option<GetThoughtResponse>> {
362 let table = self.thoughts_table().await?;
363 let filter = format!("id = '{}' AND deleted = false", id);
364 let stream = table.query().only_if(filter).limit(1).execute().await?;
365 let batches: Vec<RecordBatch> = stream.try_collect().await?;
366 let thoughts = self.batch_to_thoughts(&batches)?;
367
368 Ok(thoughts.into_iter().next().map(|t| GetThoughtResponse {
369 id: t.id,
370 content: t.content,
371 category: t.category.to_string(),
372 tags: t.tags,
373 source: t.source.to_string(),
374 importance: t.importance,
375 created_at: t.created_at,
376 updated_at: t.updated_at,
377 }))
378 }
379
380 pub fn search_knowledge(&self, req: SearchKnowledgeRequest) -> Result<SearchKnowledgeResponse> {
384 let search_pks = req
385 .source
386 .as_ref()
387 .is_none_or(|s| s == "all" || s == "personal");
388 let search_bks = req
389 .source
390 .as_ref()
391 .is_none_or(|s| s == "all" || s == "behavioral");
392
393 let mut results = Vec::new();
394
395 if search_pks {
396 let pks_results = self.pks_cache.search_facts(&req.query);
397 for fact in pks_results {
398 if fact.confidence >= req.min_confidence {
399 results.push(KnowledgeResult {
400 source: "personal".into(),
401 category: format!("{:?}", fact.category),
402 key: fact.key.clone(),
403 value: fact.value.clone(),
404 confidence: fact.confidence,
405 context: fact.context.clone(),
406 });
407 }
408 }
409 }
410
411 if search_bks {
412 let bks_results = self
413 .bks_cache
414 .get_matching_truths_with_scores(&req.query, req.min_confidence, req.limit)
415 .unwrap_or_default();
416 for (truth, score) in bks_results {
417 results.push(KnowledgeResult {
418 source: "behavioral".into(),
419 category: format!("{:?}", truth.category),
420 key: truth.context_pattern.clone(),
421 value: truth.rule.clone(),
422 confidence: score,
423 context: Some(truth.rationale.clone()),
424 });
425 }
426 }
427
428 results.sort_by(|a, b| {
429 b.confidence
430 .partial_cmp(&a.confidence)
431 .unwrap_or(std::cmp::Ordering::Equal)
432 });
433 results.truncate(req.limit);
434
435 let total = results.len();
436 Ok(SearchKnowledgeResponse { results, total })
437 }
438
439 pub async fn memory_stats(&self) -> Result<MemoryStatsResponse> {
443 let now = Utc::now().timestamp();
444 let one_day = 86_400i64;
445
446 let table = self.thoughts_table().await?;
448 let stream = table.query().only_if("deleted = false").execute().await?;
449 let batches: Vec<RecordBatch> = stream.try_collect().await?;
450 let all_thoughts = self.batch_to_thoughts(&batches)?;
451
452 let total = all_thoughts.len();
453 let mut by_category: HashMap<String, usize> = HashMap::new();
454 let mut tag_counts: HashMap<String, usize> = HashMap::new();
455 let mut recent_24h = 0usize;
456 let mut recent_7d = 0usize;
457 let mut recent_30d = 0usize;
458
459 for t in &all_thoughts {
460 *by_category.entry(t.category.to_string()).or_insert(0) += 1;
461 for tag in &t.tags {
462 *tag_counts.entry(tag.clone()).or_insert(0) += 1;
463 }
464 let age = now - t.created_at;
465 if age <= one_day {
466 recent_24h += 1;
467 }
468 if age <= 7 * one_day {
469 recent_7d += 1;
470 }
471 if age <= 30 * one_day {
472 recent_30d += 1;
473 }
474 }
475
476 let mut top_tags: Vec<(String, usize)> = tag_counts.into_iter().collect();
477 top_tags.sort_by(|a, b| b.1.cmp(&a.1));
478 top_tags.truncate(10);
479
480 let pks_stats_raw = self.pks_cache.stats();
482 let pks_by_cat: HashMap<String, u32> = pks_stats_raw
483 .by_category
484 .into_iter()
485 .map(|(k, v)| (format!("{:?}", k), v))
486 .collect();
487
488 let bks_stats_raw = self.bks_cache.stats();
490 let bks_by_cat: HashMap<String, u32> = bks_stats_raw
491 .by_category
492 .into_iter()
493 .map(|(k, v)| (format!("{:?}", k), v))
494 .collect();
495
496 Ok(MemoryStatsResponse {
497 thoughts: ThoughtStats {
498 total,
499 by_category,
500 recent_24h,
501 recent_7d,
502 recent_30d,
503 top_tags,
504 },
505 pks: PksStats {
506 total_facts: pks_stats_raw.total_facts,
507 by_category: pks_by_cat,
508 avg_confidence: pks_stats_raw.avg_confidence,
509 },
510 bks: BksStats {
511 total_truths: bks_stats_raw.total_truths,
512 by_category: bks_by_cat,
513 },
514 })
515 }
516
517 pub async fn delete_thought(&self, id: &str) -> Result<DeleteThoughtResponse> {
521 let table = self.thoughts_table().await?;
522
523 let filter = format!("id = '{}' AND deleted = false", id);
525 let count = table.count_rows(Some(filter.clone())).await?;
526 if count == 0 {
527 return Ok(DeleteThoughtResponse {
528 deleted: false,
529 id: id.to_string(),
530 });
531 }
532
533 let delete_filter = format!("id = '{}'", id);
536 table.delete(&delete_filter).await?;
537
538 tracing::info!(id = id, "Deleted thought");
539 Ok(DeleteThoughtResponse {
540 deleted: true,
541 id: id.to_string(),
542 })
543 }
544
545 fn thought_to_batch(&self, thought: &Thought, embedding: &[f32]) -> Result<RecordBatch> {
548 let dim = self.embeddings.dimension();
549 let schema = Self::thoughts_schema(dim);
550
551 let embedding_array = Float32Array::from(embedding.to_vec());
552 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
553 let vectors =
554 FixedSizeListArray::new(vector_field, dim as i32, Arc::new(embedding_array), None);
555
556 let ids = StringArray::from(vec![thought.id.as_str()]);
557 let contents = StringArray::from(vec![thought.content.as_str()]);
558 let categories = StringArray::from(vec![thought.category.as_str()]);
559 let tags_json = serde_json::to_string(&thought.tags).unwrap_or_else(|_| "[]".into());
560 let tags = StringArray::from(vec![tags_json.as_str()]);
561 let sources = StringArray::from(vec![thought.source.as_str()]);
562 let importances = Float32Array::from(vec![thought.importance]);
563 let created_ats = Int64Array::from(vec![thought.created_at]);
564 let updated_ats = Int64Array::from(vec![thought.updated_at]);
565 let deleteds = BooleanArray::from(vec![thought.deleted]);
566
567 RecordBatch::try_new(
568 schema,
569 vec![
570 Arc::new(vectors),
571 Arc::new(ids),
572 Arc::new(contents),
573 Arc::new(categories),
574 Arc::new(tags),
575 Arc::new(sources),
576 Arc::new(importances),
577 Arc::new(created_ats),
578 Arc::new(updated_ats),
579 Arc::new(deleteds),
580 ],
581 )
582 .context("Failed to create thought record batch")
583 }
584
585 fn batch_to_thoughts(&self, batches: &[RecordBatch]) -> Result<Vec<Thought>> {
586 let mut result = Vec::new();
587
588 for batch in batches {
589 let ids = batch
590 .column_by_name("id")
591 .context("Missing id column")?
592 .as_any()
593 .downcast_ref::<StringArray>()
594 .context("Invalid id type")?;
595 let contents = batch
596 .column_by_name("content")
597 .context("Missing content column")?
598 .as_any()
599 .downcast_ref::<StringArray>()
600 .context("Invalid content type")?;
601 let categories = batch
602 .column_by_name("category")
603 .context("Missing category column")?
604 .as_any()
605 .downcast_ref::<StringArray>()
606 .context("Invalid category type")?;
607 let tags_col = batch
608 .column_by_name("tags")
609 .context("Missing tags column")?
610 .as_any()
611 .downcast_ref::<StringArray>()
612 .context("Invalid tags type")?;
613 let sources = batch
614 .column_by_name("source")
615 .context("Missing source column")?
616 .as_any()
617 .downcast_ref::<StringArray>()
618 .context("Invalid source type")?;
619 let importances = batch
620 .column_by_name("importance")
621 .context("Missing importance column")?
622 .as_any()
623 .downcast_ref::<Float32Array>()
624 .context("Invalid importance type")?;
625 let created_ats = batch
626 .column_by_name("created_at")
627 .context("Missing created_at column")?
628 .as_any()
629 .downcast_ref::<Int64Array>()
630 .context("Invalid created_at type")?;
631 let updated_ats = batch
632 .column_by_name("updated_at")
633 .context("Missing updated_at column")?
634 .as_any()
635 .downcast_ref::<Int64Array>()
636 .context("Invalid updated_at type")?;
637 let deleteds = batch
638 .column_by_name("deleted")
639 .context("Missing deleted column")?
640 .as_any()
641 .downcast_ref::<BooleanArray>()
642 .context("Invalid deleted type")?;
643
644 for i in 0..batch.num_rows() {
645 let tags_str = tags_col.value(i);
646 let tags: Vec<String> = serde_json::from_str(tags_str).unwrap_or_default();
647
648 result.push(Thought {
649 id: ids.value(i).to_string(),
650 content: contents.value(i).to_string(),
651 category: ThoughtCategory::parse(categories.value(i)),
652 tags,
653 source: ThoughtSource::parse(sources.value(i)),
654 importance: importances.value(i),
655 created_at: created_ats.value(i),
656 updated_at: updated_ats.value(i),
657 deleted: deleteds.value(i),
658 });
659 }
660 }
661
662 Ok(result)
663 }
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669 use tempfile::TempDir;
670
671 async fn setup() -> (TempDir, BrainClient) {
672 let temp = TempDir::new().unwrap();
673 let lance_path = temp.path().join("brain.lance");
674 let pks_path = temp.path().join("pks.db");
675 let bks_path = temp.path().join("bks.db");
676
677 let client = BrainClient::with_paths(
678 lance_path.to_str().unwrap(),
679 pks_path.to_str().unwrap(),
680 bks_path.to_str().unwrap(),
681 )
682 .await
683 .unwrap();
684
685 (temp, client)
686 }
687
688 #[tokio::test]
689 async fn test_capture_and_get() {
690 let (_temp, mut client) = setup().await;
691
692 let resp = client
693 .capture_thought(CaptureThoughtRequest {
694 content: "Decided to use PostgreSQL for auth service".into(),
695 category: None,
696 tags: Some(vec!["db".into()]),
697 importance: Some(0.8),
698 source: None,
699 })
700 .await
701 .unwrap();
702
703 assert_eq!(resp.category, "decision");
704 assert!(resp.tags.contains(&"db".to_string()));
705
706 let thought = client.get_thought(&resp.id).await.unwrap();
707 assert!(thought.is_some());
708 let t = thought.unwrap();
709 assert_eq!(t.category, "decision");
710 }
711
712 #[tokio::test]
713 async fn test_search_memory() {
714 let (_temp, mut client) = setup().await;
715
716 client
717 .capture_thought(CaptureThoughtRequest {
718 content: "Rust is great for systems programming".into(),
719 category: Some("insight".into()),
720 tags: None,
721 importance: None,
722 source: None,
723 })
724 .await
725 .unwrap();
726
727 let results = client
728 .search_memory(SearchMemoryRequest {
729 query: "programming languages".into(),
730 limit: 10,
731 min_score: 0.0,
732 category: None,
733 sources: None,
734 })
735 .await
736 .unwrap();
737
738 assert!(!results.results.is_empty());
739 }
740
741 #[tokio::test]
742 async fn test_delete_thought() {
743 let (_temp, mut client) = setup().await;
744
745 let resp = client
746 .capture_thought(CaptureThoughtRequest {
747 content: "Something to delete".into(),
748 category: None,
749 tags: None,
750 importance: None,
751 source: None,
752 })
753 .await
754 .unwrap();
755
756 let del = client.delete_thought(&resp.id).await.unwrap();
757 assert!(del.deleted);
758
759 let thought = client.get_thought(&resp.id).await.unwrap();
760 assert!(thought.is_none());
761 }
762
763 #[tokio::test]
764 async fn test_memory_stats() {
765 let (_temp, client) = setup().await;
766 let stats = client.memory_stats().await.unwrap();
767 assert_eq!(stats.thoughts.total, 0);
768 }
769}