omega_agentdb/
lib.rs

1//! AgentDB wrapper for ExoGenesis Omega
2//! Provides ReasoningBank, Reflexion, Causal, and Skill storage
3//!
4//! This is an in-memory implementation that mimics AgentDB's functionality
5//! for the ExoGenesis Omega cognitive architecture.
6//!
7//! ## Features
8//! - HNSW index for fast approximate nearest neighbor search
9//! - SIMD-accelerated distance computations
10//! - Self-learning GNN index with adaptive navigation
11//! - RuVector integration for advanced vector operations
12//!
13//! ## RuVector Integration
14//! - ruvector-core: HNSW with SIMD acceleration
15//! - ruvector-gnn: Self-learning graph neural networks
16//! - ruvector-graph: Cypher-like graph queries
17
18mod hnsw;
19pub mod gnn_index;
20pub mod ruvector_integration;
21pub mod simd_ops;
22
23pub use gnn_index::{GNNConfig, GNNIndex, GNNNode, GNNSearchResult, GNNStats};
24pub use ruvector_integration::{
25    GNNLayer, GNNStats as RuVectorGNNStats, GraphEdge, GraphQueryResult,
26    RuVectorConfig, RuVectorError, RuVectorIndex, RuVectorResult, SimdLevel, VectorEntry,
27};
28pub use simd_ops::DistanceMetric;
29
30use serde::{Deserialize, Serialize};
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use chrono::{DateTime, Utc};
34use hnsw::{HnswIndex, HnswConfig, VectorPoint};
35
36pub type VectorId = String;
37pub type ReflexionId = String;
38pub type SkillId = String;
39pub type Embedding = Vec<f32>;
40
41/// Represents a single reflexion episode capturing agent learning
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ReflexionEpisode {
44    pub id: Option<ReflexionId>,
45    pub session_id: String,
46    pub task: String,
47    pub input: serde_json::Value,
48    pub output: serde_json::Value,
49    pub reward: f64,
50    pub success: bool,
51    pub critique: String,
52    pub latency_ms: u64,
53    pub tokens: u64,
54    pub timestamp: DateTime<Utc>,
55    pub embedding: Option<Embedding>,
56}
57
58/// Represents a causal relationship between actions and outcomes
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CausalEdge {
61    pub cause: String,
62    pub effect: String,
63    pub uplift: f64,
64    pub confidence: f64,
65    pub sample_size: u64,
66    pub first_observed: DateTime<Utc>,
67    pub last_observed: DateTime<Utc>,
68}
69
70/// Represents a learned skill with semantic embedding
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Skill {
73    pub id: Option<SkillId>,
74    pub name: String,
75    pub description: String,
76    pub embedding: Embedding,
77    pub usage_count: u64,
78    pub success_rate: f64,
79    pub created_at: DateTime<Utc>,
80}
81
82/// Result of a vector similarity search
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct VectorResult {
85    pub id: VectorId,
86    pub similarity: f64,
87    pub metadata: serde_json::Value,
88}
89
90/// Configuration for AgentDB instance
91#[derive(Debug, Clone)]
92pub struct AgentDBConfig {
93    pub dimension: usize,
94    pub hnsw_m: usize,
95    pub hnsw_ef: usize,
96    pub cache_size: usize,
97}
98
99impl Default for AgentDBConfig {
100    fn default() -> Self {
101        Self {
102            dimension: 4096,
103            hnsw_m: 32,
104            hnsw_ef: 100,
105            cache_size: 100_000,
106        }
107    }
108}
109
110/// Main AgentDB interface providing vector storage, reflexion, causal, and skill management
111pub struct AgentDB {
112    config: AgentDBConfig,
113    vector_index: Arc<RwLock<HnswIndex>>,
114    episodes: Arc<RwLock<Vec<ReflexionEpisode>>>,
115    causal_edges: Arc<RwLock<Vec<CausalEdge>>>,
116    skills: Arc<RwLock<Vec<Skill>>>,
117}
118
119impl AgentDB {
120    /// Creates a new AgentDB instance with the given configuration
121    pub async fn new(config: AgentDBConfig) -> Result<Self, AgentDBError> {
122        let hnsw_config = HnswConfig {
123            ef_construction: config.hnsw_ef,
124            ef_search: config.hnsw_ef,
125            m: config.hnsw_m,
126        };
127
128        Ok(Self {
129            config,
130            vector_index: Arc::new(RwLock::new(HnswIndex::new(hnsw_config))),
131            episodes: Arc::new(RwLock::new(Vec::new())),
132            causal_edges: Arc::new(RwLock::new(Vec::new())),
133            skills: Arc::new(RwLock::new(Vec::new())),
134        })
135    }
136
137    // ==================== Vector Operations ====================
138
139    /// Stores a vector embedding with associated metadata
140    pub async fn vector_store(
141        &self,
142        embedding: Embedding,
143        metadata: serde_json::Value,
144    ) -> Result<VectorId, AgentDBError> {
145        if embedding.len() != self.config.dimension {
146            return Err(AgentDBError::StorageError(format!(
147                "Embedding dimension {} does not match configured dimension {}",
148                embedding.len(),
149                self.config.dimension
150            )));
151        }
152
153        let id = uuid::Uuid::new_v4().to_string();
154
155        let point = VectorPoint {
156            id: id.clone(),
157            embedding,
158            metadata,
159        };
160
161        self.vector_index.write().await.insert(point);
162        Ok(id)
163    }
164
165    /// Searches for the k most similar vectors using HNSW index
166    pub async fn vector_search(
167        &self,
168        query: &Embedding,
169        k: usize,
170    ) -> Result<Vec<VectorResult>, AgentDBError> {
171        if query.len() != self.config.dimension {
172            return Err(AgentDBError::QueryError(format!(
173                "Query dimension {} does not match configured dimension {}",
174                query.len(),
175                self.config.dimension
176            )));
177        }
178
179        let results = self.vector_index.write().await.search(query, k);
180
181        Ok(results.into_iter().map(|r| VectorResult {
182            id: r.id,
183            similarity: r.similarity as f64,
184            metadata: r.metadata,
185        }).collect())
186    }
187
188    /// Retrieves a specific vector by ID
189    pub async fn vector_get(&self, id: &str) -> Result<(Embedding, serde_json::Value), AgentDBError> {
190        let index = self.vector_index.read().await;
191        let point = index
192            .get(id)
193            .ok_or_else(|| AgentDBError::NotFound(format!("Vector {} not found", id)))?;
194
195        Ok((point.embedding.clone(), point.metadata.clone()))
196    }
197
198    /// Deletes a vector by ID
199    pub async fn vector_delete(&self, id: &str) -> Result<(), AgentDBError> {
200        let mut index = self.vector_index.write().await;
201        if index.remove(id) {
202            Ok(())
203        } else {
204            Err(AgentDBError::NotFound(format!("Vector {} not found", id)))
205        }
206    }
207
208    // ==================== Reflexion Operations ====================
209
210    /// Stores a reflexion episode for learning from experience
211    pub async fn reflexion_store(
212        &self,
213        mut episode: ReflexionEpisode,
214    ) -> Result<ReflexionId, AgentDBError> {
215        let id = uuid::Uuid::new_v4().to_string();
216        episode.id = Some(id.clone());
217
218        let mut episodes = self.episodes.write().await;
219        episodes.push(episode);
220
221        Ok(id)
222    }
223
224    /// Retrieves similar reflexion episodes for a given task
225    pub async fn reflexion_retrieve(
226        &self,
227        task: &str,
228        limit: usize,
229    ) -> Result<Vec<ReflexionEpisode>, AgentDBError> {
230        let episodes = self.episodes.read().await;
231
232        // Simple substring matching for task similarity
233        let mut matching: Vec<ReflexionEpisode> = episodes
234            .iter()
235            .filter(|ep| {
236                ep.task.to_lowercase().contains(&task.to_lowercase())
237                    || task.to_lowercase().contains(&ep.task.to_lowercase())
238            })
239            .cloned()
240            .collect();
241
242        // Sort by timestamp descending (most recent first)
243        matching.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
244        matching.truncate(limit);
245
246        Ok(matching)
247    }
248
249    /// Retrieves reflexion episodes by session ID
250    pub async fn reflexion_by_session(
251        &self,
252        session_id: &str,
253    ) -> Result<Vec<ReflexionEpisode>, AgentDBError> {
254        let episodes = self.episodes.read().await;
255        let matching: Vec<ReflexionEpisode> = episodes
256            .iter()
257            .filter(|ep| ep.session_id == session_id)
258            .cloned()
259            .collect();
260
261        Ok(matching)
262    }
263
264    /// Analyzes reflexion episodes to calculate success metrics
265    pub async fn reflexion_analyze(&self, task_prefix: &str) -> Result<ReflexionStats, AgentDBError> {
266        let episodes = self.episodes.read().await;
267        let matching: Vec<&ReflexionEpisode> = episodes
268            .iter()
269            .filter(|ep| ep.task.starts_with(task_prefix))
270            .collect();
271
272        if matching.is_empty() {
273            return Ok(ReflexionStats::default());
274        }
275
276        let total = matching.len();
277        let successful = matching.iter().filter(|ep| ep.success).count();
278        let avg_reward = matching.iter().map(|ep| ep.reward).sum::<f64>() / total as f64;
279        let avg_latency = matching.iter().map(|ep| ep.latency_ms).sum::<u64>() / total as u64;
280        let avg_tokens = matching.iter().map(|ep| ep.tokens).sum::<u64>() / total as u64;
281
282        Ok(ReflexionStats {
283            total_episodes: total,
284            successful_episodes: successful,
285            success_rate: successful as f64 / total as f64,
286            avg_reward,
287            avg_latency_ms: avg_latency,
288            avg_tokens,
289        })
290    }
291
292    // ==================== Causal Operations ====================
293
294    /// Adds or updates a causal edge
295    pub async fn causal_add_edge(&self, edge: CausalEdge) -> Result<(), AgentDBError> {
296        let mut edges = self.causal_edges.write().await;
297
298        // Check if edge already exists
299        if let Some(existing) = edges.iter_mut().find(|e| e.cause == edge.cause && e.effect == edge.effect) {
300            // Update existing edge with new observations
301            existing.uplift = (existing.uplift * existing.sample_size as f64
302                + edge.uplift * edge.sample_size as f64)
303                / (existing.sample_size + edge.sample_size) as f64;
304            existing.confidence = edge.confidence.max(existing.confidence);
305            existing.sample_size += edge.sample_size;
306            existing.last_observed = edge.last_observed;
307        } else {
308            // Add new edge
309            edges.push(edge);
310        }
311
312        Ok(())
313    }
314
315    /// Queries effects caused by a specific cause
316    pub async fn causal_query_effects(&self, cause: &str) -> Result<Vec<CausalEdge>, AgentDBError> {
317        let edges = self.causal_edges.read().await;
318        let mut matching: Vec<CausalEdge> = edges
319            .iter()
320            .filter(|e| e.cause == cause)
321            .cloned()
322            .collect();
323
324        // Sort by uplift descending
325        matching.sort_by(|a, b| b.uplift.partial_cmp(&a.uplift).unwrap());
326
327        Ok(matching)
328    }
329
330    /// Queries causes that lead to a specific effect
331    pub async fn causal_query_causes(&self, effect: &str) -> Result<Vec<CausalEdge>, AgentDBError> {
332        let edges = self.causal_edges.read().await;
333        let mut matching: Vec<CausalEdge> = edges
334            .iter()
335            .filter(|e| e.effect == effect)
336            .cloned()
337            .collect();
338
339        // Sort by uplift descending
340        matching.sort_by(|a, b| b.uplift.partial_cmp(&a.uplift).unwrap());
341
342        Ok(matching)
343    }
344
345    /// Finds causal paths between a cause and effect
346    pub async fn causal_find_path(
347        &self,
348        start: &str,
349        end: &str,
350        max_depth: usize,
351    ) -> Result<Vec<Vec<String>>, AgentDBError> {
352        let edges = self.causal_edges.read().await;
353        let mut paths: Vec<Vec<String>> = Vec::new();
354        let mut current_path: Vec<String> = vec![start.to_string()];
355
356        Self::dfs_causal_path(&edges, start, end, &mut current_path, &mut paths, max_depth);
357
358        Ok(paths)
359    }
360
361    // Helper for DFS path finding
362    fn dfs_causal_path(
363        edges: &[CausalEdge],
364        current: &str,
365        target: &str,
366        path: &mut Vec<String>,
367        paths: &mut Vec<Vec<String>>,
368        max_depth: usize,
369    ) {
370        if path.len() > max_depth {
371            return;
372        }
373
374        if current == target {
375            paths.push(path.clone());
376            return;
377        }
378
379        for edge in edges.iter().filter(|e| e.cause == current) {
380            if !path.contains(&edge.effect) {
381                path.push(edge.effect.clone());
382                Self::dfs_causal_path(edges, &edge.effect, target, path, paths, max_depth);
383                path.pop();
384            }
385        }
386    }
387
388    // ==================== Skill Operations ====================
389
390    /// Creates a new skill with embedding
391    pub async fn skill_create(&self, mut skill: Skill) -> Result<SkillId, AgentDBError> {
392        if skill.embedding.len() != self.config.dimension {
393            return Err(AgentDBError::StorageError(format!(
394                "Skill embedding dimension {} does not match configured dimension {}",
395                skill.embedding.len(),
396                self.config.dimension
397            )));
398        }
399
400        let id = uuid::Uuid::new_v4().to_string();
401        skill.id = Some(id.clone());
402
403        let mut skills = self.skills.write().await;
404        skills.push(skill);
405
406        Ok(id)
407    }
408
409    /// Searches for skills using semantic similarity
410    pub async fn skill_search(&self, query: &str, limit: usize) -> Result<Vec<Skill>, AgentDBError> {
411        let skills = self.skills.read().await;
412
413        // Simple text matching (in production, would use query embedding)
414        let query_lower = query.to_lowercase();
415        let mut matching: Vec<Skill> = skills
416            .iter()
417            .filter(|s| {
418                s.name.to_lowercase().contains(&query_lower)
419                    || s.description.to_lowercase().contains(&query_lower)
420            })
421            .cloned()
422            .collect();
423
424        // Sort by usage count and success rate
425        matching.sort_by(|a, b| {
426            let score_a = a.usage_count as f64 * a.success_rate;
427            let score_b = b.usage_count as f64 * b.success_rate;
428            score_b.partial_cmp(&score_a).unwrap()
429        });
430
431        matching.truncate(limit);
432        Ok(matching)
433    }
434
435    /// Searches for skills using embedding similarity
436    pub async fn skill_search_by_embedding(
437        &self,
438        query_embedding: &Embedding,
439        limit: usize,
440    ) -> Result<Vec<(Skill, f64)>, AgentDBError> {
441        if query_embedding.len() != self.config.dimension {
442            return Err(AgentDBError::QueryError(format!(
443                "Query embedding dimension {} does not match configured dimension {}",
444                query_embedding.len(),
445                self.config.dimension
446            )));
447        }
448
449        let skills = self.skills.read().await;
450        let mut results: Vec<(Skill, f64)> = skills
451            .iter()
452            .map(|skill| {
453                let similarity = cosine_similarity(query_embedding, &skill.embedding);
454                (skill.clone(), similarity)
455            })
456            .collect();
457
458        // Sort by similarity descending
459        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
460        results.truncate(limit);
461
462        Ok(results)
463    }
464
465    /// Updates skill usage statistics
466    pub async fn skill_update_stats(
467        &self,
468        skill_id: &str,
469        success: bool,
470    ) -> Result<(), AgentDBError> {
471        let mut skills = self.skills.write().await;
472        let skill = skills
473            .iter_mut()
474            .find(|s| s.id.as_ref() == Some(&skill_id.to_string()))
475            .ok_or_else(|| AgentDBError::NotFound(format!("Skill {} not found", skill_id)))?;
476
477        skill.usage_count += 1;
478        let new_successes = (skill.success_rate * (skill.usage_count - 1) as f64)
479            + if success { 1.0 } else { 0.0 };
480        skill.success_rate = new_successes / skill.usage_count as f64;
481
482        Ok(())
483    }
484
485    /// Gets a skill by ID
486    pub async fn skill_get(&self, skill_id: &str) -> Result<Skill, AgentDBError> {
487        let skills = self.skills.read().await;
488        skills
489            .iter()
490            .find(|s| s.id.as_ref() == Some(&skill_id.to_string()))
491            .cloned()
492            .ok_or_else(|| AgentDBError::NotFound(format!("Skill {} not found", skill_id)))
493    }
494
495    // ==================== Utility Methods ====================
496
497    /// Returns statistics about the database
498    pub async fn stats(&self) -> AgentDBStats {
499        let vector_index = self.vector_index.read().await;
500        let episodes = self.episodes.read().await;
501        let edges = self.causal_edges.read().await;
502        let skills = self.skills.read().await;
503
504        AgentDBStats {
505            vector_count: vector_index.len(),
506            episode_count: episodes.len(),
507            causal_edge_count: edges.len(),
508            skill_count: skills.len(),
509        }
510    }
511
512    /// Clears all data from the database
513    pub async fn clear(&self) -> Result<(), AgentDBError> {
514        let hnsw_config = HnswConfig {
515            ef_construction: self.config.hnsw_ef,
516            ef_search: self.config.hnsw_ef,
517            m: self.config.hnsw_m,
518        };
519
520        let mut vector_index = self.vector_index.write().await;
521        let mut episodes = self.episodes.write().await;
522        let mut edges = self.causal_edges.write().await;
523        let mut skills = self.skills.write().await;
524
525        *vector_index = HnswIndex::new(hnsw_config);
526        episodes.clear();
527        edges.clear();
528        skills.clear();
529
530        Ok(())
531    }
532}
533
534/// Statistics from reflexion analysis
535#[derive(Debug, Clone, Default, Serialize, Deserialize)]
536pub struct ReflexionStats {
537    pub total_episodes: usize,
538    pub successful_episodes: usize,
539    pub success_rate: f64,
540    pub avg_reward: f64,
541    pub avg_latency_ms: u64,
542    pub avg_tokens: u64,
543}
544
545/// Overall database statistics
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct AgentDBStats {
548    pub vector_count: usize,
549    pub episode_count: usize,
550    pub causal_edge_count: usize,
551    pub skill_count: usize,
552}
553
554/// Computes cosine similarity between two vectors using SIMD optimization
555fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
556    use simsimd::SpatialSimilarity;
557
558    if a.len() != b.len() {
559        return 0.0;
560    }
561
562    // SIMD-optimized: SimSIMD returns DISTANCE, convert to SIMILARITY
563    // distance: 0 = identical, 1 = orthogonal, 2 = opposite
564    // similarity: 1 = identical, 0 = orthogonal, -1 = opposite
565    match f32::cosine(a, b) {
566        Some(distance) => 1.0 - distance,
567        None => 0.0,
568    }
569}
570
571/// Error types for AgentDB operations
572#[derive(Debug, thiserror::Error)]
573pub enum AgentDBError {
574    #[error("Storage error: {0}")]
575    StorageError(String),
576    #[error("Query error: {0}")]
577    QueryError(String),
578    #[error("Not found: {0}")]
579    NotFound(String),
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[tokio::test]
587    async fn test_vector_operations() {
588        let db = AgentDB::new(AgentDBConfig {
589            dimension: 128,
590            ..Default::default()
591        })
592        .await
593        .unwrap();
594
595        // Create test embedding
596        let embedding: Embedding = (0..128).map(|i| (i as f32) / 128.0).collect();
597        let metadata = serde_json::json!({"test": "data"});
598
599        // Store vector
600        let id = db.vector_store(embedding.clone(), metadata.clone()).await.unwrap();
601
602        // Retrieve vector
603        let (retrieved_emb, retrieved_meta) = db.vector_get(&id).await.unwrap();
604        assert_eq!(retrieved_emb.len(), 128);
605        assert_eq!(retrieved_meta, metadata);
606
607        // Search for similar vectors
608        let results = db.vector_search(&embedding, 1).await.unwrap();
609        assert_eq!(results.len(), 1);
610        assert!(results[0].similarity > 0.99);
611    }
612
613    #[tokio::test]
614    async fn test_reflexion_operations() {
615        let db = AgentDB::new(AgentDBConfig::default()).await.unwrap();
616
617        let episode = ReflexionEpisode {
618            id: None,
619            session_id: "session-1".to_string(),
620            task: "solve math problem".to_string(),
621            input: serde_json::json!({"problem": "2+2"}),
622            output: serde_json::json!({"answer": 4}),
623            reward: 1.0,
624            success: true,
625            critique: "Correct answer".to_string(),
626            latency_ms: 100,
627            tokens: 50,
628            timestamp: Utc::now(),
629            embedding: None,
630        };
631
632        let id = db.reflexion_store(episode).await.unwrap();
633        assert!(!id.is_empty());
634
635        let retrieved = db.reflexion_retrieve("math", 10).await.unwrap();
636        assert_eq!(retrieved.len(), 1);
637        assert_eq!(retrieved[0].task, "solve math problem");
638    }
639
640    #[tokio::test]
641    async fn test_causal_operations() {
642        let db = AgentDB::new(AgentDBConfig::default()).await.unwrap();
643
644        let edge = CausalEdge {
645            cause: "use_cache".to_string(),
646            effect: "faster_response".to_string(),
647            uplift: 0.5,
648            confidence: 0.95,
649            sample_size: 100,
650            first_observed: Utc::now(),
651            last_observed: Utc::now(),
652        };
653
654        db.causal_add_edge(edge).await.unwrap();
655
656        let effects = db.causal_query_effects("use_cache").await.unwrap();
657        assert_eq!(effects.len(), 1);
658        assert_eq!(effects[0].effect, "faster_response");
659    }
660
661    #[tokio::test]
662    async fn test_skill_operations() {
663        let db = AgentDB::new(AgentDBConfig {
664            dimension: 64,
665            ..Default::default()
666        })
667        .await
668        .unwrap();
669
670        let embedding: Embedding = (0..64).map(|i| (i as f32) / 64.0).collect();
671        let skill = Skill {
672            id: None,
673            name: "code_generation".to_string(),
674            description: "Generate Python code from natural language".to_string(),
675            embedding,
676            usage_count: 0,
677            success_rate: 0.0,
678            created_at: Utc::now(),
679        };
680
681        let id = db.skill_create(skill).await.unwrap();
682        assert!(!id.is_empty());
683
684        let results = db.skill_search("code", 10).await.unwrap();
685        assert_eq!(results.len(), 1);
686        assert_eq!(results[0].name, "code_generation");
687    }
688
689    #[test]
690    fn test_cosine_similarity() {
691        let a = vec![1.0, 0.0, 0.0];
692        let b = vec![1.0, 0.0, 0.0];
693        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
694
695        let c = vec![1.0, 0.0, 0.0];
696        let d = vec![0.0, 1.0, 0.0];
697        assert!(cosine_similarity(&c, &d).abs() < 0.001);
698
699        let e = vec![1.0, 1.0, 0.0];
700        let f = vec![1.0, 1.0, 0.0];
701        assert!((cosine_similarity(&e, &f) - 1.0).abs() < 0.001);
702    }
703
704    #[tokio::test]
705    async fn test_hnsw_vector_operations() {
706        let db = AgentDB::new(AgentDBConfig {
707            dimension: 128,
708            ..Default::default()
709        })
710        .await
711        .unwrap();
712
713        // Store test vectors
714        let emb1: Embedding = (0..128).map(|i| i as f32 / 128.0).collect();
715        let emb2: Embedding = (0..128).map(|i| (128 - i) as f32 / 128.0).collect();
716
717        let id1 = db.vector_store(emb1.clone(), serde_json::json!({"name": "v1"})).await.unwrap();
718        let id2 = db.vector_store(emb2.clone(), serde_json::json!({"name": "v2"})).await.unwrap();
719
720        // Retrieve specific vector
721        let (retrieved, meta) = db.vector_get(&id2).await.unwrap();
722        assert_eq!(retrieved.len(), 128);
723        assert_eq!(meta["name"], "v2");
724
725        // Delete a vector
726        db.vector_delete(&id2).await.unwrap();
727        assert!(db.vector_get(&id2).await.is_err());
728
729        // Stats should reflect deletion
730        let stats = db.stats().await;
731        assert_eq!(stats.vector_count, 1);
732    }
733
734    #[tokio::test]
735    async fn test_hnsw_large_dataset() {
736        let db = AgentDB::new(AgentDBConfig {
737            dimension: 64,
738            hnsw_m: 16,
739            hnsw_ef: 100,
740            ..Default::default()
741        })
742        .await
743        .unwrap();
744
745        // Insert 100 vectors
746        for i in 0..100 {
747            let embedding: Embedding = (0..64).map(|j| ((i * j) as f32) / 1000.0).collect();
748            db.vector_store(embedding, serde_json::json!({"index": i})).await.unwrap();
749        }
750
751        // Search for similar to vector 50
752        let query: Embedding = (0..64).map(|j| ((50 * j) as f32) / 1000.0).collect();
753        let results = db.vector_search(&query, 10).await.unwrap();
754
755        // HNSW is approximate - just verify we get meaningful results
756        assert!(!results.is_empty());
757        assert!(results.len() <= 10);
758
759        // Results should have reasonable similarity
760        assert!(results[0].similarity > 0.5, "Top result should have >50% similarity");
761
762        // Verify stats
763        let stats = db.stats().await;
764        assert_eq!(stats.vector_count, 100);
765    }
766
767    #[tokio::test]
768    async fn test_hnsw_empty_search() {
769        let db = AgentDB::new(AgentDBConfig {
770            dimension: 32,
771            ..Default::default()
772        })
773        .await
774        .unwrap();
775
776        let query: Embedding = vec![0.1; 32];
777        let results = db.vector_search(&query, 10).await.unwrap();
778        assert!(results.is_empty());
779    }
780
781    #[tokio::test]
782    async fn test_hnsw_stats() {
783        let db = AgentDB::new(AgentDBConfig {
784            dimension: 16,
785            ..Default::default()
786        })
787        .await
788        .unwrap();
789
790        let stats = db.stats().await;
791        assert_eq!(stats.vector_count, 0);
792
793        for i in 0..5 {
794            let emb: Embedding = vec![i as f32; 16];
795            db.vector_store(emb, serde_json::json!({})).await.unwrap();
796        }
797
798        let stats = db.stats().await;
799        assert_eq!(stats.vector_count, 5);
800
801        db.clear().await.unwrap();
802        let stats = db.stats().await;
803        assert_eq!(stats.vector_count, 0);
804    }
805
806    #[tokio::test]
807    async fn test_hnsw_dimension_validation() {
808        let db = AgentDB::new(AgentDBConfig {
809            dimension: 64,
810            ..Default::default()
811        })
812        .await
813        .unwrap();
814
815        // Try to store wrong dimension
816        let wrong_emb: Embedding = vec![1.0; 32];
817        let result = db.vector_store(wrong_emb, serde_json::json!({})).await;
818        assert!(result.is_err());
819
820        // Try to search with wrong dimension
821        let wrong_query: Embedding = vec![1.0; 32];
822        let result = db.vector_search(&wrong_query, 5).await;
823        assert!(result.is_err());
824    }
825}