1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use thiserror::Error;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RuVectorConfig {
18 pub dimension: usize,
20 pub hnsw_m: usize,
22 pub ef_construction: usize,
24 pub ef_search: usize,
26 pub gnn_enabled: bool,
28 pub gnn_learning_rate: f64,
30 pub graph_enabled: bool,
32 pub simd_level: SimdLevel,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum SimdLevel {
39 None,
41 SSE42,
43 AVX2,
45 AVX512,
47 Auto,
49}
50
51impl Default for SimdLevel {
52 fn default() -> Self {
53 Self::Auto
54 }
55}
56
57impl Default for RuVectorConfig {
58 fn default() -> Self {
59 Self {
60 dimension: 1024,
61 hnsw_m: 32,
62 ef_construction: 200,
63 ef_search: 100,
64 gnn_enabled: true,
65 gnn_learning_rate: 0.001,
66 graph_enabled: true,
67 simd_level: SimdLevel::Auto,
68 }
69 }
70}
71
72#[derive(Debug, Error)]
74pub enum RuVectorError {
75 #[error("Index error: {0}")]
76 IndexError(String),
77
78 #[error("Query error: {0}")]
79 QueryError(String),
80
81 #[error("GNN error: {0}")]
82 GNNError(String),
83
84 #[error("Graph error: {0}")]
85 GraphError(String),
86
87 #[error("Dimension mismatch: expected {expected}, got {got}")]
88 DimensionMismatch { expected: usize, got: usize },
89
90 #[error("Not found: {0}")]
91 NotFound(String),
92}
93
94pub type Result<T> = std::result::Result<T, RuVectorError>;
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct VectorEntry {
99 pub id: String,
101 pub embedding: Vec<f32>,
103 pub metadata: serde_json::Value,
105 pub connections: Vec<String>,
107 pub timestamp: u64,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RuVectorResult {
114 pub id: String,
116 pub similarity: f64,
118 pub distance: f64,
120 pub metadata: serde_json::Value,
122 pub gnn_score: Option<f64>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct GraphQueryResult {
129 pub nodes: Vec<String>,
131 pub path: Option<Vec<String>>,
133 pub relationships: Vec<String>,
135 pub weight: f64,
137}
138
139#[derive(Debug, Clone)]
141pub struct GNNLayer {
142 input_dim: usize,
144 output_dim: usize,
146 weights: Vec<Vec<f64>>,
148 bias: Vec<f64>,
150 learning_rate: f64,
152 update_count: u64,
154}
155
156impl GNNLayer {
157 pub fn new(input_dim: usize, output_dim: usize, learning_rate: f64) -> Self {
159 let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
161 let weights: Vec<Vec<f64>> = (0..output_dim)
162 .map(|_| {
163 (0..input_dim)
164 .map(|i| (i as f64 * 0.1).sin() * scale)
165 .collect()
166 })
167 .collect();
168
169 let bias = vec![0.0; output_dim];
170
171 Self {
172 input_dim,
173 output_dim,
174 weights,
175 bias,
176 learning_rate,
177 update_count: 0,
178 }
179 }
180
181 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
183 let mut output = self.bias.clone();
184 for (i, w_row) in self.weights.iter().enumerate() {
185 for (j, &w) in w_row.iter().enumerate() {
186 if j < input.len() {
187 output[i] += w * input[j];
188 }
189 }
190 output[i] = output[i].max(0.0);
192 }
193 output
194 }
195
196 pub fn update(&mut self, input: &[f64], target: &[f64]) {
198 let output = self.forward(input);
200
201 for (i, w_row) in self.weights.iter_mut().enumerate() {
202 if i < target.len() {
203 let error = target[i] - output[i];
204 for (j, w) in w_row.iter_mut().enumerate() {
205 if j < input.len() {
206 *w += self.learning_rate * error * input[j];
207 }
208 }
209 self.bias[i] += self.learning_rate * error;
210 }
211 }
212
213 self.update_count += 1;
214 }
215
216 pub fn update_count(&self) -> u64 {
218 self.update_count
219 }
220
221 pub fn input_dim(&self) -> usize {
223 self.input_dim
224 }
225
226 pub fn output_dim(&self) -> usize {
228 self.output_dim
229 }
230}
231
232pub struct RuVectorIndex {
234 config: RuVectorConfig,
236 vectors: HashMap<String, VectorEntry>,
238 hnsw_graph: Vec<Vec<(String, f32)>>,
240 gnn_layers: Vec<GNNLayer>,
242 graph_edges: HashMap<(String, String), GraphEdge>,
244 count: usize,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct GraphEdge {
251 pub source: String,
253 pub target: String,
255 pub relationship: String,
257 pub weight: f64,
259 pub properties: serde_json::Value,
261}
262
263impl RuVectorIndex {
264 pub fn new(config: RuVectorConfig) -> Self {
266 let mut gnn_layers = Vec::new();
267 if config.gnn_enabled {
268 gnn_layers.push(GNNLayer::new(config.dimension, 128, config.gnn_learning_rate));
270 gnn_layers.push(GNNLayer::new(128, 64, config.gnn_learning_rate));
271 }
272
273 Self {
274 config,
275 vectors: HashMap::new(),
276 hnsw_graph: Vec::new(),
277 gnn_layers,
278 graph_edges: HashMap::new(),
279 count: 0,
280 }
281 }
282
283 pub fn insert(&mut self, entry: VectorEntry) -> Result<()> {
285 if entry.embedding.len() != self.config.dimension {
286 return Err(RuVectorError::DimensionMismatch {
287 expected: self.config.dimension,
288 got: entry.embedding.len(),
289 });
290 }
291
292 let id = entry.id.clone();
293 self.vectors.insert(id.clone(), entry);
294 self.count += 1;
295
296 self.update_hnsw_connections(&id)?;
298
299 Ok(())
300 }
301
302 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<RuVectorResult>> {
304 if query.len() != self.config.dimension {
305 return Err(RuVectorError::DimensionMismatch {
306 expected: self.config.dimension,
307 got: query.len(),
308 });
309 }
310
311 let mut results: Vec<RuVectorResult> = self
313 .vectors
314 .values()
315 .map(|v| {
316 let similarity = self.cosine_similarity(query, &v.embedding);
317 let distance = 1.0 - similarity;
318
319 let gnn_score = if self.config.gnn_enabled && !self.gnn_layers.is_empty() {
321 Some(self.compute_gnn_score(&v.embedding))
322 } else {
323 None
324 };
325
326 RuVectorResult {
327 id: v.id.clone(),
328 similarity,
329 distance,
330 metadata: v.metadata.clone(),
331 gnn_score,
332 }
333 })
334 .collect();
335
336 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
338 results.truncate(k);
339
340 Ok(results)
341 }
342
343 pub fn search_gnn(&self, query: &[f32], k: usize) -> Result<Vec<RuVectorResult>> {
345 if !self.config.gnn_enabled {
346 return self.search(query, k);
347 }
348
349 let mut results = self.search(query, k * 2)?; for result in &mut results {
353 if let Some(gnn_score) = result.gnn_score {
354 result.similarity = 0.7 * result.similarity + 0.3 * gnn_score;
356 }
357 }
358
359 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
360 results.truncate(k);
361
362 Ok(results)
363 }
364
365 pub fn graph_query(&self, query: &str) -> Result<Vec<GraphQueryResult>> {
367 if !self.config.graph_enabled {
368 return Err(RuVectorError::GraphError(
369 "Graph queries not enabled".to_string(),
370 ));
371 }
372
373 let results = self.execute_graph_query(query)?;
376 Ok(results)
377 }
378
379 pub fn add_relationship(
381 &mut self,
382 source: &str,
383 target: &str,
384 relationship: &str,
385 weight: f64,
386 properties: serde_json::Value,
387 ) -> Result<()> {
388 if !self.config.graph_enabled {
389 return Err(RuVectorError::GraphError(
390 "Graph queries not enabled".to_string(),
391 ));
392 }
393
394 let edge = GraphEdge {
395 source: source.to_string(),
396 target: target.to_string(),
397 relationship: relationship.to_string(),
398 weight,
399 properties,
400 };
401
402 self.graph_edges
403 .insert((source.to_string(), target.to_string()), edge);
404
405 if let Some(v) = self.vectors.get_mut(source) {
407 if !v.connections.contains(&target.to_string()) {
408 v.connections.push(target.to_string());
409 }
410 }
411
412 Ok(())
413 }
414
415 pub fn learn(&mut self, query: &[f32], relevant_ids: &[String]) -> Result<()> {
417 if !self.config.gnn_enabled || self.gnn_layers.is_empty() {
418 return Ok(());
419 }
420
421 let query_f64: Vec<f64> = query.iter().map(|&x| x as f64).collect();
422
423 for id in relevant_ids {
425 if let Some(entry) = self.vectors.get(id) {
426 let input: Vec<f64> = entry.embedding.iter().map(|&x| x as f64).collect();
427
428 let mut current = input.clone();
430 for layer in &mut self.gnn_layers {
431 let target = layer.forward(&query_f64);
432 layer.update(¤t, &target);
433 current = layer.forward(¤t);
434 }
435 }
436 }
437
438 Ok(())
439 }
440
441 pub fn get(&self, id: &str) -> Option<&VectorEntry> {
443 self.vectors.get(id)
444 }
445
446 pub fn remove(&mut self, id: &str) -> bool {
448 if self.vectors.remove(id).is_some() {
449 self.count -= 1;
450 true
451 } else {
452 false
453 }
454 }
455
456 pub fn len(&self) -> usize {
458 self.count
459 }
460
461 pub fn is_empty(&self) -> bool {
463 self.count == 0
464 }
465
466 pub fn gnn_stats(&self) -> GNNStats {
468 let total_updates: u64 = self.gnn_layers.iter().map(|l| l.update_count()).sum();
469 GNNStats {
470 layers: self.gnn_layers.len(),
471 total_updates,
472 enabled: self.config.gnn_enabled,
473 }
474 }
475
476 pub fn hnsw_graph(&self) -> &Vec<Vec<(String, f32)>> {
478 &self.hnsw_graph
479 }
480
481 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f64 {
484 if a.len() != b.len() {
485 return 0.0;
486 }
487
488 use simsimd::SpatialSimilarity;
490 match f32::cosine(a, b) {
491 Some(distance) => 1.0 - distance,
492 None => {
493 let mut dot = 0.0;
495 let mut norm_a = 0.0;
496 let mut norm_b = 0.0;
497 for (&x, &y) in a.iter().zip(b.iter()) {
498 dot += (x * y) as f64;
499 norm_a += (x * x) as f64;
500 norm_b += (y * y) as f64;
501 }
502 let denom = (norm_a * norm_b).sqrt();
503 if denom > 0.0 {
504 dot / denom
505 } else {
506 0.0
507 }
508 }
509 }
510 }
511
512 fn compute_gnn_score(&self, embedding: &[f32]) -> f64 {
513 if self.gnn_layers.is_empty() {
514 return 0.5;
515 }
516
517 let mut current: Vec<f64> = embedding.iter().map(|&x| x as f64).collect();
518 for layer in &self.gnn_layers {
519 current = layer.forward(¤t);
520 }
521
522 let score = current.iter().sum::<f64>() / current.len().max(1) as f64;
524 score.clamp(0.0, 1.0)
525 }
526
527 fn update_hnsw_connections(&mut self, _id: &str) -> Result<()> {
528 Ok(())
531 }
532
533 fn execute_graph_query(&self, query: &str) -> Result<Vec<GraphQueryResult>> {
534 let query_lower = query.to_lowercase();
536
537 if query_lower.contains("match") {
538 let results: Vec<GraphQueryResult> = self
540 .graph_edges
541 .values()
542 .filter(|e| {
543 query_lower.contains(&e.source.to_lowercase())
544 || query_lower.contains(&e.relationship.to_lowercase())
545 })
546 .map(|e| GraphQueryResult {
547 nodes: vec![e.source.clone(), e.target.clone()],
548 path: Some(vec![e.source.clone(), e.target.clone()]),
549 relationships: vec![e.relationship.clone()],
550 weight: e.weight,
551 })
552 .collect();
553
554 Ok(results)
555 } else {
556 Err(RuVectorError::QueryError(format!(
557 "Unsupported query: {}",
558 query
559 )))
560 }
561 }
562}
563
564impl Default for RuVectorIndex {
565 fn default() -> Self {
566 Self::new(RuVectorConfig::default())
567 }
568}
569
570#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct GNNStats {
573 pub layers: usize,
575 pub total_updates: u64,
577 pub enabled: bool,
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn test_ruvector_config() {
587 let config = RuVectorConfig::default();
588 assert_eq!(config.dimension, 1024);
589 assert!(config.gnn_enabled);
590 }
591
592 #[test]
593 fn test_ruvector_index_creation() {
594 let config = RuVectorConfig {
595 dimension: 64,
596 ..Default::default()
597 };
598 let index = RuVectorIndex::new(config);
599 assert_eq!(index.len(), 0);
600 }
601
602 #[test]
603 fn test_insert_and_search() {
604 let config = RuVectorConfig {
605 dimension: 8,
606 gnn_enabled: false,
607 ..Default::default()
608 };
609 let mut index = RuVectorIndex::new(config);
610
611 let entry = VectorEntry {
612 id: "test1".to_string(),
613 embedding: vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
614 metadata: serde_json::json!({"name": "test"}),
615 connections: vec![],
616 timestamp: 0,
617 };
618 index.insert(entry).unwrap();
619
620 let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
621 let results = index.search(&query, 1).unwrap();
622
623 assert_eq!(results.len(), 1);
624 assert_eq!(results[0].id, "test1");
625 assert!(results[0].similarity > 0.99);
626 }
627
628 #[test]
629 fn test_gnn_layer() {
630 let mut layer = GNNLayer::new(8, 4, 0.01);
631
632 let input = vec![1.0; 8];
633 let output = layer.forward(&input);
634 assert_eq!(output.len(), 4);
635
636 let target = vec![0.5; 4];
637 layer.update(&input, &target);
638 assert_eq!(layer.update_count(), 1);
639 }
640
641 #[test]
642 fn test_graph_relationship() {
643 let config = RuVectorConfig {
644 dimension: 4,
645 graph_enabled: true,
646 ..Default::default()
647 };
648 let mut index = RuVectorIndex::new(config);
649
650 for i in 0..3 {
652 let entry = VectorEntry {
653 id: format!("node{}", i),
654 embedding: vec![i as f32; 4],
655 metadata: serde_json::json!({}),
656 connections: vec![],
657 timestamp: 0,
658 };
659 index.insert(entry).unwrap();
660 }
661
662 index
664 .add_relationship("node0", "node1", "RELATES_TO", 1.0, serde_json::json!({}))
665 .unwrap();
666
667 let results = index.graph_query("MATCH (a)-[r:RELATES_TO]->(b)").unwrap();
669 assert!(!results.is_empty());
670 }
671
672 #[test]
673 fn test_gnn_learning() {
674 let config = RuVectorConfig {
675 dimension: 8,
676 gnn_enabled: true,
677 ..Default::default()
678 };
679 let mut index = RuVectorIndex::new(config);
680
681 let entry = VectorEntry {
682 id: "learn1".to_string(),
683 embedding: vec![0.5; 8],
684 metadata: serde_json::json!({}),
685 connections: vec![],
686 timestamp: 0,
687 };
688 index.insert(entry).unwrap();
689
690 let query = vec![0.5; 8];
692 index.learn(&query, &["learn1".to_string()]).unwrap();
693
694 let stats = index.gnn_stats();
695 assert!(stats.total_updates > 0);
696 }
697}