1mod hnsw;
19pub mod gnn_index;
20pub mod ruvector_integration;
21pub mod simd_ops;
22
23pub use gnn_index::{GNNConfig, GNNIndex, GNNNode, GNNSearchResult, GNNStats};
24pub use ruvector_integration::{
25 GNNLayer, GNNStats as RuVectorGNNStats, GraphEdge, GraphQueryResult,
26 RuVectorConfig, RuVectorError, RuVectorIndex, RuVectorResult, SimdLevel, VectorEntry,
27};
28pub use simd_ops::DistanceMetric;
29
30use serde::{Deserialize, Serialize};
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use chrono::{DateTime, Utc};
34use hnsw::{HnswIndex, HnswConfig, VectorPoint};
35
36pub type VectorId = String;
37pub type ReflexionId = String;
38pub type SkillId = String;
39pub type Embedding = Vec<f32>;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ReflexionEpisode {
44 pub id: Option<ReflexionId>,
45 pub session_id: String,
46 pub task: String,
47 pub input: serde_json::Value,
48 pub output: serde_json::Value,
49 pub reward: f64,
50 pub success: bool,
51 pub critique: String,
52 pub latency_ms: u64,
53 pub tokens: u64,
54 pub timestamp: DateTime<Utc>,
55 pub embedding: Option<Embedding>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CausalEdge {
61 pub cause: String,
62 pub effect: String,
63 pub uplift: f64,
64 pub confidence: f64,
65 pub sample_size: u64,
66 pub first_observed: DateTime<Utc>,
67 pub last_observed: DateTime<Utc>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Skill {
73 pub id: Option<SkillId>,
74 pub name: String,
75 pub description: String,
76 pub embedding: Embedding,
77 pub usage_count: u64,
78 pub success_rate: f64,
79 pub created_at: DateTime<Utc>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct VectorResult {
85 pub id: VectorId,
86 pub similarity: f64,
87 pub metadata: serde_json::Value,
88}
89
90#[derive(Debug, Clone)]
92pub struct AgentDBConfig {
93 pub dimension: usize,
94 pub hnsw_m: usize,
95 pub hnsw_ef: usize,
96 pub cache_size: usize,
97}
98
99impl Default for AgentDBConfig {
100 fn default() -> Self {
101 Self {
102 dimension: 4096,
103 hnsw_m: 32,
104 hnsw_ef: 100,
105 cache_size: 100_000,
106 }
107 }
108}
109
110pub struct AgentDB {
112 config: AgentDBConfig,
113 vector_index: Arc<RwLock<HnswIndex>>,
114 episodes: Arc<RwLock<Vec<ReflexionEpisode>>>,
115 causal_edges: Arc<RwLock<Vec<CausalEdge>>>,
116 skills: Arc<RwLock<Vec<Skill>>>,
117}
118
119impl AgentDB {
120 pub async fn new(config: AgentDBConfig) -> Result<Self, AgentDBError> {
122 let hnsw_config = HnswConfig {
123 ef_construction: config.hnsw_ef,
124 ef_search: config.hnsw_ef,
125 m: config.hnsw_m,
126 };
127
128 Ok(Self {
129 config,
130 vector_index: Arc::new(RwLock::new(HnswIndex::new(hnsw_config))),
131 episodes: Arc::new(RwLock::new(Vec::new())),
132 causal_edges: Arc::new(RwLock::new(Vec::new())),
133 skills: Arc::new(RwLock::new(Vec::new())),
134 })
135 }
136
137 pub async fn vector_store(
141 &self,
142 embedding: Embedding,
143 metadata: serde_json::Value,
144 ) -> Result<VectorId, AgentDBError> {
145 if embedding.len() != self.config.dimension {
146 return Err(AgentDBError::StorageError(format!(
147 "Embedding dimension {} does not match configured dimension {}",
148 embedding.len(),
149 self.config.dimension
150 )));
151 }
152
153 let id = uuid::Uuid::new_v4().to_string();
154
155 let point = VectorPoint {
156 id: id.clone(),
157 embedding,
158 metadata,
159 };
160
161 self.vector_index.write().await.insert(point);
162 Ok(id)
163 }
164
165 pub async fn vector_search(
167 &self,
168 query: &Embedding,
169 k: usize,
170 ) -> Result<Vec<VectorResult>, AgentDBError> {
171 if query.len() != self.config.dimension {
172 return Err(AgentDBError::QueryError(format!(
173 "Query dimension {} does not match configured dimension {}",
174 query.len(),
175 self.config.dimension
176 )));
177 }
178
179 let results = self.vector_index.write().await.search(query, k);
180
181 Ok(results.into_iter().map(|r| VectorResult {
182 id: r.id,
183 similarity: r.similarity as f64,
184 metadata: r.metadata,
185 }).collect())
186 }
187
188 pub async fn vector_get(&self, id: &str) -> Result<(Embedding, serde_json::Value), AgentDBError> {
190 let index = self.vector_index.read().await;
191 let point = index
192 .get(id)
193 .ok_or_else(|| AgentDBError::NotFound(format!("Vector {} not found", id)))?;
194
195 Ok((point.embedding.clone(), point.metadata.clone()))
196 }
197
198 pub async fn vector_delete(&self, id: &str) -> Result<(), AgentDBError> {
200 let mut index = self.vector_index.write().await;
201 if index.remove(id) {
202 Ok(())
203 } else {
204 Err(AgentDBError::NotFound(format!("Vector {} not found", id)))
205 }
206 }
207
208 pub async fn reflexion_store(
212 &self,
213 mut episode: ReflexionEpisode,
214 ) -> Result<ReflexionId, AgentDBError> {
215 let id = uuid::Uuid::new_v4().to_string();
216 episode.id = Some(id.clone());
217
218 let mut episodes = self.episodes.write().await;
219 episodes.push(episode);
220
221 Ok(id)
222 }
223
224 pub async fn reflexion_retrieve(
226 &self,
227 task: &str,
228 limit: usize,
229 ) -> Result<Vec<ReflexionEpisode>, AgentDBError> {
230 let episodes = self.episodes.read().await;
231
232 let mut matching: Vec<ReflexionEpisode> = episodes
234 .iter()
235 .filter(|ep| {
236 ep.task.to_lowercase().contains(&task.to_lowercase())
237 || task.to_lowercase().contains(&ep.task.to_lowercase())
238 })
239 .cloned()
240 .collect();
241
242 matching.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
244 matching.truncate(limit);
245
246 Ok(matching)
247 }
248
249 pub async fn reflexion_by_session(
251 &self,
252 session_id: &str,
253 ) -> Result<Vec<ReflexionEpisode>, AgentDBError> {
254 let episodes = self.episodes.read().await;
255 let matching: Vec<ReflexionEpisode> = episodes
256 .iter()
257 .filter(|ep| ep.session_id == session_id)
258 .cloned()
259 .collect();
260
261 Ok(matching)
262 }
263
264 pub async fn reflexion_analyze(&self, task_prefix: &str) -> Result<ReflexionStats, AgentDBError> {
266 let episodes = self.episodes.read().await;
267 let matching: Vec<&ReflexionEpisode> = episodes
268 .iter()
269 .filter(|ep| ep.task.starts_with(task_prefix))
270 .collect();
271
272 if matching.is_empty() {
273 return Ok(ReflexionStats::default());
274 }
275
276 let total = matching.len();
277 let successful = matching.iter().filter(|ep| ep.success).count();
278 let avg_reward = matching.iter().map(|ep| ep.reward).sum::<f64>() / total as f64;
279 let avg_latency = matching.iter().map(|ep| ep.latency_ms).sum::<u64>() / total as u64;
280 let avg_tokens = matching.iter().map(|ep| ep.tokens).sum::<u64>() / total as u64;
281
282 Ok(ReflexionStats {
283 total_episodes: total,
284 successful_episodes: successful,
285 success_rate: successful as f64 / total as f64,
286 avg_reward,
287 avg_latency_ms: avg_latency,
288 avg_tokens,
289 })
290 }
291
292 pub async fn causal_add_edge(&self, edge: CausalEdge) -> Result<(), AgentDBError> {
296 let mut edges = self.causal_edges.write().await;
297
298 if let Some(existing) = edges.iter_mut().find(|e| e.cause == edge.cause && e.effect == edge.effect) {
300 existing.uplift = (existing.uplift * existing.sample_size as f64
302 + edge.uplift * edge.sample_size as f64)
303 / (existing.sample_size + edge.sample_size) as f64;
304 existing.confidence = edge.confidence.max(existing.confidence);
305 existing.sample_size += edge.sample_size;
306 existing.last_observed = edge.last_observed;
307 } else {
308 edges.push(edge);
310 }
311
312 Ok(())
313 }
314
315 pub async fn causal_query_effects(&self, cause: &str) -> Result<Vec<CausalEdge>, AgentDBError> {
317 let edges = self.causal_edges.read().await;
318 let mut matching: Vec<CausalEdge> = edges
319 .iter()
320 .filter(|e| e.cause == cause)
321 .cloned()
322 .collect();
323
324 matching.sort_by(|a, b| b.uplift.partial_cmp(&a.uplift).unwrap());
326
327 Ok(matching)
328 }
329
330 pub async fn causal_query_causes(&self, effect: &str) -> Result<Vec<CausalEdge>, AgentDBError> {
332 let edges = self.causal_edges.read().await;
333 let mut matching: Vec<CausalEdge> = edges
334 .iter()
335 .filter(|e| e.effect == effect)
336 .cloned()
337 .collect();
338
339 matching.sort_by(|a, b| b.uplift.partial_cmp(&a.uplift).unwrap());
341
342 Ok(matching)
343 }
344
345 pub async fn causal_find_path(
347 &self,
348 start: &str,
349 end: &str,
350 max_depth: usize,
351 ) -> Result<Vec<Vec<String>>, AgentDBError> {
352 let edges = self.causal_edges.read().await;
353 let mut paths: Vec<Vec<String>> = Vec::new();
354 let mut current_path: Vec<String> = vec![start.to_string()];
355
356 Self::dfs_causal_path(&edges, start, end, &mut current_path, &mut paths, max_depth);
357
358 Ok(paths)
359 }
360
361 fn dfs_causal_path(
363 edges: &[CausalEdge],
364 current: &str,
365 target: &str,
366 path: &mut Vec<String>,
367 paths: &mut Vec<Vec<String>>,
368 max_depth: usize,
369 ) {
370 if path.len() > max_depth {
371 return;
372 }
373
374 if current == target {
375 paths.push(path.clone());
376 return;
377 }
378
379 for edge in edges.iter().filter(|e| e.cause == current) {
380 if !path.contains(&edge.effect) {
381 path.push(edge.effect.clone());
382 Self::dfs_causal_path(edges, &edge.effect, target, path, paths, max_depth);
383 path.pop();
384 }
385 }
386 }
387
388 pub async fn skill_create(&self, mut skill: Skill) -> Result<SkillId, AgentDBError> {
392 if skill.embedding.len() != self.config.dimension {
393 return Err(AgentDBError::StorageError(format!(
394 "Skill embedding dimension {} does not match configured dimension {}",
395 skill.embedding.len(),
396 self.config.dimension
397 )));
398 }
399
400 let id = uuid::Uuid::new_v4().to_string();
401 skill.id = Some(id.clone());
402
403 let mut skills = self.skills.write().await;
404 skills.push(skill);
405
406 Ok(id)
407 }
408
409 pub async fn skill_search(&self, query: &str, limit: usize) -> Result<Vec<Skill>, AgentDBError> {
411 let skills = self.skills.read().await;
412
413 let query_lower = query.to_lowercase();
415 let mut matching: Vec<Skill> = skills
416 .iter()
417 .filter(|s| {
418 s.name.to_lowercase().contains(&query_lower)
419 || s.description.to_lowercase().contains(&query_lower)
420 })
421 .cloned()
422 .collect();
423
424 matching.sort_by(|a, b| {
426 let score_a = a.usage_count as f64 * a.success_rate;
427 let score_b = b.usage_count as f64 * b.success_rate;
428 score_b.partial_cmp(&score_a).unwrap()
429 });
430
431 matching.truncate(limit);
432 Ok(matching)
433 }
434
435 pub async fn skill_search_by_embedding(
437 &self,
438 query_embedding: &Embedding,
439 limit: usize,
440 ) -> Result<Vec<(Skill, f64)>, AgentDBError> {
441 if query_embedding.len() != self.config.dimension {
442 return Err(AgentDBError::QueryError(format!(
443 "Query embedding dimension {} does not match configured dimension {}",
444 query_embedding.len(),
445 self.config.dimension
446 )));
447 }
448
449 let skills = self.skills.read().await;
450 let mut results: Vec<(Skill, f64)> = skills
451 .iter()
452 .map(|skill| {
453 let similarity = cosine_similarity(query_embedding, &skill.embedding);
454 (skill.clone(), similarity)
455 })
456 .collect();
457
458 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
460 results.truncate(limit);
461
462 Ok(results)
463 }
464
465 pub async fn skill_update_stats(
467 &self,
468 skill_id: &str,
469 success: bool,
470 ) -> Result<(), AgentDBError> {
471 let mut skills = self.skills.write().await;
472 let skill = skills
473 .iter_mut()
474 .find(|s| s.id.as_ref() == Some(&skill_id.to_string()))
475 .ok_or_else(|| AgentDBError::NotFound(format!("Skill {} not found", skill_id)))?;
476
477 skill.usage_count += 1;
478 let new_successes = (skill.success_rate * (skill.usage_count - 1) as f64)
479 + if success { 1.0 } else { 0.0 };
480 skill.success_rate = new_successes / skill.usage_count as f64;
481
482 Ok(())
483 }
484
485 pub async fn skill_get(&self, skill_id: &str) -> Result<Skill, AgentDBError> {
487 let skills = self.skills.read().await;
488 skills
489 .iter()
490 .find(|s| s.id.as_ref() == Some(&skill_id.to_string()))
491 .cloned()
492 .ok_or_else(|| AgentDBError::NotFound(format!("Skill {} not found", skill_id)))
493 }
494
495 pub async fn stats(&self) -> AgentDBStats {
499 let vector_index = self.vector_index.read().await;
500 let episodes = self.episodes.read().await;
501 let edges = self.causal_edges.read().await;
502 let skills = self.skills.read().await;
503
504 AgentDBStats {
505 vector_count: vector_index.len(),
506 episode_count: episodes.len(),
507 causal_edge_count: edges.len(),
508 skill_count: skills.len(),
509 }
510 }
511
512 pub async fn clear(&self) -> Result<(), AgentDBError> {
514 let hnsw_config = HnswConfig {
515 ef_construction: self.config.hnsw_ef,
516 ef_search: self.config.hnsw_ef,
517 m: self.config.hnsw_m,
518 };
519
520 let mut vector_index = self.vector_index.write().await;
521 let mut episodes = self.episodes.write().await;
522 let mut edges = self.causal_edges.write().await;
523 let mut skills = self.skills.write().await;
524
525 *vector_index = HnswIndex::new(hnsw_config);
526 episodes.clear();
527 edges.clear();
528 skills.clear();
529
530 Ok(())
531 }
532}
533
534#[derive(Debug, Clone, Default, Serialize, Deserialize)]
536pub struct ReflexionStats {
537 pub total_episodes: usize,
538 pub successful_episodes: usize,
539 pub success_rate: f64,
540 pub avg_reward: f64,
541 pub avg_latency_ms: u64,
542 pub avg_tokens: u64,
543}
544
545#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct AgentDBStats {
548 pub vector_count: usize,
549 pub episode_count: usize,
550 pub causal_edge_count: usize,
551 pub skill_count: usize,
552}
553
554fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
556 use simsimd::SpatialSimilarity;
557
558 if a.len() != b.len() {
559 return 0.0;
560 }
561
562 match f32::cosine(a, b) {
566 Some(distance) => 1.0 - distance,
567 None => 0.0,
568 }
569}
570
571#[derive(Debug, thiserror::Error)]
573pub enum AgentDBError {
574 #[error("Storage error: {0}")]
575 StorageError(String),
576 #[error("Query error: {0}")]
577 QueryError(String),
578 #[error("Not found: {0}")]
579 NotFound(String),
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[tokio::test]
587 async fn test_vector_operations() {
588 let db = AgentDB::new(AgentDBConfig {
589 dimension: 128,
590 ..Default::default()
591 })
592 .await
593 .unwrap();
594
595 let embedding: Embedding = (0..128).map(|i| (i as f32) / 128.0).collect();
597 let metadata = serde_json::json!({"test": "data"});
598
599 let id = db.vector_store(embedding.clone(), metadata.clone()).await.unwrap();
601
602 let (retrieved_emb, retrieved_meta) = db.vector_get(&id).await.unwrap();
604 assert_eq!(retrieved_emb.len(), 128);
605 assert_eq!(retrieved_meta, metadata);
606
607 let results = db.vector_search(&embedding, 1).await.unwrap();
609 assert_eq!(results.len(), 1);
610 assert!(results[0].similarity > 0.99);
611 }
612
613 #[tokio::test]
614 async fn test_reflexion_operations() {
615 let db = AgentDB::new(AgentDBConfig::default()).await.unwrap();
616
617 let episode = ReflexionEpisode {
618 id: None,
619 session_id: "session-1".to_string(),
620 task: "solve math problem".to_string(),
621 input: serde_json::json!({"problem": "2+2"}),
622 output: serde_json::json!({"answer": 4}),
623 reward: 1.0,
624 success: true,
625 critique: "Correct answer".to_string(),
626 latency_ms: 100,
627 tokens: 50,
628 timestamp: Utc::now(),
629 embedding: None,
630 };
631
632 let id = db.reflexion_store(episode).await.unwrap();
633 assert!(!id.is_empty());
634
635 let retrieved = db.reflexion_retrieve("math", 10).await.unwrap();
636 assert_eq!(retrieved.len(), 1);
637 assert_eq!(retrieved[0].task, "solve math problem");
638 }
639
640 #[tokio::test]
641 async fn test_causal_operations() {
642 let db = AgentDB::new(AgentDBConfig::default()).await.unwrap();
643
644 let edge = CausalEdge {
645 cause: "use_cache".to_string(),
646 effect: "faster_response".to_string(),
647 uplift: 0.5,
648 confidence: 0.95,
649 sample_size: 100,
650 first_observed: Utc::now(),
651 last_observed: Utc::now(),
652 };
653
654 db.causal_add_edge(edge).await.unwrap();
655
656 let effects = db.causal_query_effects("use_cache").await.unwrap();
657 assert_eq!(effects.len(), 1);
658 assert_eq!(effects[0].effect, "faster_response");
659 }
660
661 #[tokio::test]
662 async fn test_skill_operations() {
663 let db = AgentDB::new(AgentDBConfig {
664 dimension: 64,
665 ..Default::default()
666 })
667 .await
668 .unwrap();
669
670 let embedding: Embedding = (0..64).map(|i| (i as f32) / 64.0).collect();
671 let skill = Skill {
672 id: None,
673 name: "code_generation".to_string(),
674 description: "Generate Python code from natural language".to_string(),
675 embedding,
676 usage_count: 0,
677 success_rate: 0.0,
678 created_at: Utc::now(),
679 };
680
681 let id = db.skill_create(skill).await.unwrap();
682 assert!(!id.is_empty());
683
684 let results = db.skill_search("code", 10).await.unwrap();
685 assert_eq!(results.len(), 1);
686 assert_eq!(results[0].name, "code_generation");
687 }
688
689 #[test]
690 fn test_cosine_similarity() {
691 let a = vec![1.0, 0.0, 0.0];
692 let b = vec![1.0, 0.0, 0.0];
693 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
694
695 let c = vec![1.0, 0.0, 0.0];
696 let d = vec![0.0, 1.0, 0.0];
697 assert!(cosine_similarity(&c, &d).abs() < 0.001);
698
699 let e = vec![1.0, 1.0, 0.0];
700 let f = vec![1.0, 1.0, 0.0];
701 assert!((cosine_similarity(&e, &f) - 1.0).abs() < 0.001);
702 }
703
704 #[tokio::test]
705 async fn test_hnsw_vector_operations() {
706 let db = AgentDB::new(AgentDBConfig {
707 dimension: 128,
708 ..Default::default()
709 })
710 .await
711 .unwrap();
712
713 let emb1: Embedding = (0..128).map(|i| i as f32 / 128.0).collect();
715 let emb2: Embedding = (0..128).map(|i| (128 - i) as f32 / 128.0).collect();
716
717 let id1 = db.vector_store(emb1.clone(), serde_json::json!({"name": "v1"})).await.unwrap();
718 let id2 = db.vector_store(emb2.clone(), serde_json::json!({"name": "v2"})).await.unwrap();
719
720 let (retrieved, meta) = db.vector_get(&id2).await.unwrap();
722 assert_eq!(retrieved.len(), 128);
723 assert_eq!(meta["name"], "v2");
724
725 db.vector_delete(&id2).await.unwrap();
727 assert!(db.vector_get(&id2).await.is_err());
728
729 let stats = db.stats().await;
731 assert_eq!(stats.vector_count, 1);
732 }
733
734 #[tokio::test]
735 async fn test_hnsw_large_dataset() {
736 let db = AgentDB::new(AgentDBConfig {
737 dimension: 64,
738 hnsw_m: 16,
739 hnsw_ef: 100,
740 ..Default::default()
741 })
742 .await
743 .unwrap();
744
745 for i in 0..100 {
747 let embedding: Embedding = (0..64).map(|j| ((i * j) as f32) / 1000.0).collect();
748 db.vector_store(embedding, serde_json::json!({"index": i})).await.unwrap();
749 }
750
751 let query: Embedding = (0..64).map(|j| ((50 * j) as f32) / 1000.0).collect();
753 let results = db.vector_search(&query, 10).await.unwrap();
754
755 assert!(!results.is_empty());
757 assert!(results.len() <= 10);
758
759 assert!(results[0].similarity > 0.5, "Top result should have >50% similarity");
761
762 let stats = db.stats().await;
764 assert_eq!(stats.vector_count, 100);
765 }
766
767 #[tokio::test]
768 async fn test_hnsw_empty_search() {
769 let db = AgentDB::new(AgentDBConfig {
770 dimension: 32,
771 ..Default::default()
772 })
773 .await
774 .unwrap();
775
776 let query: Embedding = vec![0.1; 32];
777 let results = db.vector_search(&query, 10).await.unwrap();
778 assert!(results.is_empty());
779 }
780
781 #[tokio::test]
782 async fn test_hnsw_stats() {
783 let db = AgentDB::new(AgentDBConfig {
784 dimension: 16,
785 ..Default::default()
786 })
787 .await
788 .unwrap();
789
790 let stats = db.stats().await;
791 assert_eq!(stats.vector_count, 0);
792
793 for i in 0..5 {
794 let emb: Embedding = vec![i as f32; 16];
795 db.vector_store(emb, serde_json::json!({})).await.unwrap();
796 }
797
798 let stats = db.stats().await;
799 assert_eq!(stats.vector_count, 5);
800
801 db.clear().await.unwrap();
802 let stats = db.stats().await;
803 assert_eq!(stats.vector_count, 0);
804 }
805
806 #[tokio::test]
807 async fn test_hnsw_dimension_validation() {
808 let db = AgentDB::new(AgentDBConfig {
809 dimension: 64,
810 ..Default::default()
811 })
812 .await
813 .unwrap();
814
815 let wrong_emb: Embedding = vec![1.0; 32];
817 let result = db.vector_store(wrong_emb, serde_json::json!({})).await;
818 assert!(result.is_err());
819
820 let wrong_query: Embedding = vec![1.0; 32];
822 let result = db.vector_search(&wrong_query, 5).await;
823 assert!(result.is_err());
824 }
825}