ipfrs_semantic/utils.rs
1//! Utility functions and helpers for common semantic search workflows
2//!
3//! This module provides convenience functions that combine multiple features
4//! and simplify common patterns in semantic search applications.
5
6use crate::{
7 analyze_quality, diagnose_index, HybridIndex, Metadata, SemanticRouter, VectorIndex,
8 VectorQuality,
9};
10use ipfrs_core::{Cid, Result};
11use std::collections::HashMap;
12
13/// Statistics for a batch embedding operation
14#[derive(Debug, Clone)]
15pub struct BatchEmbeddingStats {
16 /// Total number of embeddings processed
17 pub total: usize,
18 /// Number of valid embeddings
19 pub valid: usize,
20 /// Number of invalid embeddings (failed quality check)
21 pub invalid: usize,
22 /// Average quality score
23 pub avg_quality: f32,
24 /// Minimum quality score
25 pub min_quality: f32,
26 /// Maximum quality score
27 pub max_quality: f32,
28}
29
30/// Result of a batch index operation with statistics
31#[derive(Debug)]
32pub struct BatchIndexResult {
33 /// Number of items successfully indexed
34 pub indexed: usize,
35 /// Number of items that failed indexing
36 pub failed: usize,
37 /// CIDs that failed (with error messages)
38 pub failures: Vec<(Cid, String)>,
39 /// Statistics about the batch
40 pub stats: BatchEmbeddingStats,
41}
42
43/// Validate and index embeddings in a single operation with quality checking
44///
45/// This helper combines quality analysis with indexing, automatically
46/// filtering out low-quality embeddings.
47///
48/// # Arguments
49/// * `router` - The semantic router to index into
50/// * `items` - Vector of (CID, embedding) pairs to index
51/// * `min_quality` - Minimum quality score (0.0-1.0) required for indexing
52///
53/// # Returns
54/// Statistics about the batch operation including success/failure counts
55///
56/// # Example
57/// ```
58/// use ipfrs_semantic::{SemanticRouter, utils::index_with_quality_check};
59/// use ipfrs_core::Cid;
60///
61/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
62/// let router = SemanticRouter::with_defaults()?;
63///
64/// let items = vec![
65/// ("bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi".parse::<Cid>()?, vec![0.1; 768]),
66/// ("bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi".parse::<Cid>()?, vec![0.2; 768]),
67/// ];
68///
69/// let result = index_with_quality_check(&router, &items, 0.5)?;
70/// println!("Indexed: {}, Failed: {}", result.indexed, result.failed);
71/// println!("Average quality: {:.2}", result.stats.avg_quality);
72/// # Ok(())
73/// # }
74/// ```
75pub fn index_with_quality_check(
76 router: &SemanticRouter,
77 items: &[(Cid, Vec<f32>)],
78 min_quality: f32,
79) -> Result<BatchIndexResult> {
80 let mut indexed = 0;
81 let mut failed = 0;
82 let mut failures = Vec::new();
83 let mut qualities = Vec::new();
84
85 for (cid, embedding) in items {
86 let quality = analyze_quality(embedding);
87 qualities.push(quality.quality_score);
88
89 if quality.quality_score >= min_quality && quality.is_valid {
90 match router.add(cid, embedding) {
91 Ok(_) => indexed += 1,
92 Err(e) => {
93 failed += 1;
94 failures.push((*cid, e.to_string()));
95 }
96 }
97 } else {
98 failed += 1;
99 failures.push((
100 *cid,
101 format!("Quality check failed: score={:.2}", quality.quality_score),
102 ));
103 }
104 }
105
106 let stats = if qualities.is_empty() {
107 BatchEmbeddingStats {
108 total: items.len(),
109 valid: 0,
110 invalid: items.len(),
111 avg_quality: 0.0,
112 min_quality: 0.0,
113 max_quality: 0.0,
114 }
115 } else {
116 let avg = qualities.iter().sum::<f32>() / qualities.len() as f32;
117 let min = qualities.iter().fold(f32::INFINITY, |a, &b| a.min(b));
118 let max = qualities.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
119
120 BatchEmbeddingStats {
121 total: items.len(),
122 valid: indexed,
123 invalid: failed,
124 avg_quality: avg,
125 min_quality: min,
126 max_quality: max,
127 }
128 };
129
130 Ok(BatchIndexResult {
131 indexed,
132 failed,
133 failures,
134 stats,
135 })
136}
137
138/// Validate a batch of embeddings and return quality reports
139///
140/// # Example
141/// ```
142/// use ipfrs_semantic::utils::validate_embeddings;
143///
144/// let embeddings = vec![
145/// vec![0.1, 0.2, 0.3],
146/// vec![0.4, 0.5, 0.6],
147/// vec![f32::NAN, 0.1, 0.2], // Invalid (contains NaN)
148/// ];
149///
150/// let reports = validate_embeddings(&embeddings);
151/// assert_eq!(reports.len(), 3);
152/// assert!(reports[0].is_valid);
153/// assert!(reports[1].is_valid);
154/// assert!(!reports[2].is_valid); // Contains NaN
155/// ```
156pub fn validate_embeddings(embeddings: &[Vec<f32>]) -> Vec<VectorQuality> {
157 embeddings.iter().map(|e| analyze_quality(e)).collect()
158}
159
160/// Create a hybrid index with metadata extracted from a CID mapping
161///
162/// This is a convenience function for creating a hybrid index when you have
163/// a mapping of CIDs to both embeddings and metadata.
164///
165/// # Example
166/// ```
167/// use ipfrs_semantic::{utils::create_hybrid_index_from_map, Metadata, MetadataValue};
168/// use ipfrs_core::Cid;
169/// use std::collections::HashMap;
170///
171/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
172/// let mut data = HashMap::new();
173///
174/// let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi".parse()?;
175/// let embedding = vec![0.5; 768];
176/// let mut metadata = Metadata::new();
177/// metadata.set("type", MetadataValue::String("document".to_string()));
178///
179/// data.insert(cid, (embedding, Some(metadata)));
180///
181/// let index = create_hybrid_index_from_map(768, data)?;
182/// # Ok(())
183/// # }
184/// ```
185pub fn create_hybrid_index_from_map(
186 dimension: usize,
187 data: HashMap<Cid, (Vec<f32>, Option<Metadata>)>,
188) -> Result<HybridIndex> {
189 let index = HybridIndex::new(crate::HybridConfig {
190 dimension,
191 ..Default::default()
192 })?;
193
194 for (cid, (embedding, metadata)) in data {
195 index.insert(&cid, &embedding, metadata)?;
196 }
197
198 Ok(index)
199}
200
201/// Health check result for a semantic router
202#[derive(Debug)]
203pub struct HealthCheckResult {
204 /// Is the index healthy?
205 pub is_healthy: bool,
206 /// Number of vectors in the index
207 pub vector_count: usize,
208 /// Estimated memory usage in bytes
209 pub memory_bytes: usize,
210 /// Issues detected (if any)
211 pub issues: Vec<String>,
212 /// Recommendations for optimization
213 pub recommendations: Vec<String>,
214}
215
216/// Perform a comprehensive health check on a vector index
217///
218/// # Example
219/// ```
220/// use ipfrs_semantic::{VectorIndex, utils::health_check};
221///
222/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
223/// let index = VectorIndex::with_defaults(768)?;
224/// let health = health_check(&index);
225///
226/// if health.is_healthy {
227/// println!("Index is healthy with {} vectors", health.vector_count);
228/// } else {
229/// println!("Issues found: {:?}", health.issues);
230/// }
231/// # Ok(())
232/// # }
233/// ```
234pub fn health_check(index: &VectorIndex) -> HealthCheckResult {
235 let report = diagnose_index(index);
236
237 HealthCheckResult {
238 is_healthy: matches!(report.status, crate::diagnostics::HealthStatus::Healthy),
239 vector_count: report.size,
240 memory_bytes: report.memory_usage,
241 issues: report
242 .issues
243 .iter()
244 .map(|i| i.description.clone())
245 .collect(),
246 recommendations: report.recommendations,
247 }
248}
249
250/// Normalize a vector to unit length (L2 norm = 1)
251///
252/// This is useful for cosine similarity searches, as normalized vectors
253/// allow using dot product instead of cosine distance.
254///
255/// # Example
256/// ```
257/// use ipfrs_semantic::utils::normalize_vector;
258///
259/// let mut vec = vec![3.0, 4.0];
260/// normalize_vector(&mut vec);
261///
262/// let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
263/// assert!((norm - 1.0).abs() < 1e-6);
264/// ```
265pub fn normalize_vector(vector: &mut [f32]) {
266 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
267 if norm > 0.0 {
268 for x in vector.iter_mut() {
269 *x /= norm;
270 }
271 }
272}
273
274/// Normalize a batch of vectors in place
275///
276/// # Example
277/// ```
278/// use ipfrs_semantic::utils::normalize_vectors;
279///
280/// let mut vectors = vec![
281/// vec![3.0, 4.0],
282/// vec![1.0, 0.0],
283/// ];
284///
285/// normalize_vectors(&mut vectors);
286///
287/// for vec in &vectors {
288/// let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
289/// assert!((norm - 1.0).abs() < 1e-6);
290/// }
291/// ```
292pub fn normalize_vectors(vectors: &mut [Vec<f32>]) {
293 for vector in vectors.iter_mut() {
294 normalize_vector(vector);
295 }
296}
297
298/// Calculate the average embedding from a set of embeddings
299///
300/// Useful for creating centroid embeddings or aggregate representations.
301///
302/// Returns None if the input is empty or embeddings have inconsistent dimensions.
303///
304/// # Example
305/// ```
306/// use ipfrs_semantic::utils::average_embedding;
307///
308/// let embeddings = vec![
309/// vec![1.0, 2.0, 3.0],
310/// vec![2.0, 3.0, 4.0],
311/// vec![3.0, 4.0, 5.0],
312/// ];
313///
314/// let avg = average_embedding(&embeddings).unwrap();
315/// assert_eq!(avg, vec![2.0, 3.0, 4.0]);
316/// ```
317pub fn average_embedding(embeddings: &[Vec<f32>]) -> Option<Vec<f32>> {
318 if embeddings.is_empty() {
319 return None;
320 }
321
322 let dim = embeddings[0].len();
323 if embeddings.iter().any(|e| e.len() != dim) {
324 return None;
325 }
326
327 let mut result = vec![0.0; dim];
328 for embedding in embeddings {
329 for (i, &val) in embedding.iter().enumerate() {
330 result[i] += val;
331 }
332 }
333
334 let count = embeddings.len() as f32;
335 for val in result.iter_mut() {
336 *val /= count;
337 }
338
339 Some(result)
340}
341
342/// Result of a batch deletion operation
343#[derive(Debug, Clone)]
344pub struct BatchDeletionResult {
345 /// Number of CIDs successfully deleted
346 pub deleted: usize,
347 /// Number of CIDs not found in the index
348 pub not_found: usize,
349 /// Number of CIDs that failed to delete
350 pub failed: usize,
351 /// CIDs that were not found
352 pub not_found_cids: Vec<Cid>,
353 /// CIDs that failed deletion (with error messages)
354 pub failures: Vec<(Cid, String)>,
355}
356
357/// Delete multiple CIDs from a vector index in batch
358///
359/// This function efficiently deletes multiple CIDs and provides detailed
360/// statistics about the operation.
361///
362/// # Arguments
363/// * `index` - The vector index to delete from
364/// * `cids` - List of CIDs to delete
365///
366/// # Returns
367/// Statistics about the deletion operation
368///
369/// # Example
370/// ```
371/// use ipfrs_semantic::{VectorIndex, utils::batch_delete};
372/// use ipfrs_core::Cid;
373///
374/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
375/// let mut index = VectorIndex::with_defaults(768)?;
376///
377/// // Add some vectors
378/// let cid1: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi".parse()?;
379/// let cid2: Cid = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi".parse()?;
380/// index.insert(&cid1, &vec![0.1; 768])?;
381/// index.insert(&cid2, &vec![0.2; 768])?;
382///
383/// // Delete them in batch
384/// let result = batch_delete(&mut index, &[cid1, cid2])?;
385/// assert_eq!(result.deleted, 2);
386/// assert_eq!(result.not_found, 0);
387/// # Ok(())
388/// # }
389/// ```
390pub fn batch_delete(index: &mut VectorIndex, cids: &[Cid]) -> Result<BatchDeletionResult> {
391 let mut deleted = 0;
392 let mut not_found = 0;
393 let mut failed = 0;
394 let mut not_found_cids = Vec::new();
395 let mut failures = Vec::new();
396
397 for cid in cids {
398 if !index.contains(cid) {
399 not_found += 1;
400 not_found_cids.push(*cid);
401 continue;
402 }
403
404 match index.delete(cid) {
405 Ok(_) => deleted += 1,
406 Err(e) => {
407 failed += 1;
408 failures.push((*cid, e.to_string()));
409 }
410 }
411 }
412
413 Ok(BatchDeletionResult {
414 deleted,
415 not_found,
416 failed,
417 not_found_cids,
418 failures,
419 })
420}
421
422/// Calculate cosine similarity between two embeddings
423///
424/// Returns a value between -1.0 and 1.0, where:
425/// - 1.0 means vectors point in the same direction (most similar)
426/// - 0.0 means vectors are orthogonal (no similarity)
427/// - -1.0 means vectors point in opposite directions (most dissimilar)
428///
429/// Returns None if embeddings have different dimensions or are zero vectors.
430///
431/// # Example
432/// ```
433/// use ipfrs_semantic::utils::cosine_similarity;
434///
435/// let embedding1 = vec![1.0, 2.0, 3.0];
436/// let embedding2 = vec![2.0, 4.0, 6.0]; // Parallel to embedding1
437///
438/// let similarity = cosine_similarity(&embedding1, &embedding2).unwrap();
439/// assert!((similarity - 1.0).abs() < 1e-6); // Should be 1.0 (perfectly similar)
440/// ```
441pub fn cosine_similarity(embedding1: &[f32], embedding2: &[f32]) -> Option<f32> {
442 if embedding1.len() != embedding2.len() {
443 return None;
444 }
445
446 let dot_product: f32 = embedding1
447 .iter()
448 .zip(embedding2.iter())
449 .map(|(a, b)| a * b)
450 .sum();
451
452 let norm1: f32 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
453 let norm2: f32 = embedding2.iter().map(|x| x * x).sum::<f32>().sqrt();
454
455 if norm1 == 0.0 || norm2 == 0.0 {
456 return None;
457 }
458
459 Some(dot_product / (norm1 * norm2))
460}
461
462/// Calculate pairwise cosine similarities between a query and multiple embeddings
463///
464/// This is useful for finding the most similar embeddings from a set without
465/// indexing them first.
466///
467/// # Arguments
468/// * `query` - Query embedding
469/// * `embeddings` - List of embeddings to compare against
470///
471/// # Returns
472/// Vector of (index, similarity) pairs, sorted by similarity (descending)
473///
474/// # Example
475/// ```
476/// use ipfrs_semantic::utils::pairwise_similarities;
477///
478/// let query = vec![1.0, 0.0, 0.0];
479/// let embeddings = vec![
480/// vec![1.0, 0.0, 0.0], // Same as query
481/// vec![0.0, 1.0, 0.0], // Orthogonal
482/// vec![0.7, 0.7, 0.0], // Partially similar
483/// ];
484///
485/// let similarities = pairwise_similarities(&query, &embeddings);
486/// assert_eq!(similarities.len(), 3);
487/// assert_eq!(similarities[0].0, 0); // First embedding is most similar
488/// assert!((similarities[0].1 - 1.0).abs() < 1e-6);
489/// ```
490pub fn pairwise_similarities(query: &[f32], embeddings: &[Vec<f32>]) -> Vec<(usize, f32)> {
491 let mut results: Vec<(usize, f32)> = embeddings
492 .iter()
493 .enumerate()
494 .filter_map(|(idx, emb)| cosine_similarity(query, emb).map(|sim| (idx, sim)))
495 .collect();
496
497 // Sort by similarity descending
498 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
499
500 results
501}
502
503/// Export index statistics to a JSON-serializable structure
504///
505/// This function extracts comprehensive statistics from an index for
506/// monitoring, debugging, or export purposes.
507///
508/// # Example
509/// ```
510/// use ipfrs_semantic::{VectorIndex, utils::export_index_stats};
511///
512/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
513/// let mut index = VectorIndex::with_defaults(768)?;
514/// let cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi".parse()?;
515/// index.insert(&cid, &vec![0.5; 768])?;
516///
517/// let stats = export_index_stats(&index);
518/// assert_eq!(stats.dimension, 768);
519/// assert_eq!(stats.vector_count, 1);
520/// # Ok(())
521/// # }
522/// ```
523#[derive(Debug, Clone, serde::Serialize)]
524pub struct IndexStats {
525 /// Embedding dimension
526 pub dimension: usize,
527 /// Number of vectors in index
528 pub vector_count: usize,
529 /// Distance metric used
530 pub metric: String,
531 /// Estimated memory usage in bytes
532 pub memory_bytes: usize,
533 /// Health status
534 pub health_status: String,
535 /// Issues detected
536 pub issues: Vec<String>,
537 /// Recommendations
538 pub recommendations: Vec<String>,
539}
540
541pub fn export_index_stats(index: &VectorIndex) -> IndexStats {
542 let health = health_check(index);
543 let metric = format!("{:?}", index.metric());
544
545 IndexStats {
546 dimension: index.dimension(),
547 vector_count: index.len(),
548 metric,
549 memory_bytes: health.memory_bytes,
550 health_status: if health.is_healthy {
551 "Healthy".to_string()
552 } else {
553 "Issues Detected".to_string()
554 },
555 issues: health.issues,
556 recommendations: health.recommendations,
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_normalize_vector() {
566 let mut vec = vec![3.0, 4.0];
567 normalize_vector(&mut vec);
568
569 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
570 assert!((norm - 1.0).abs() < 1e-6);
571 assert!((vec[0] - 0.6).abs() < 1e-6);
572 assert!((vec[1] - 0.8).abs() < 1e-6);
573 }
574
575 #[test]
576 fn test_normalize_zero_vector() {
577 let mut vec = vec![0.0, 0.0];
578 normalize_vector(&mut vec);
579 assert_eq!(vec, vec![0.0, 0.0]);
580 }
581
582 #[test]
583 fn test_normalize_vectors() {
584 let mut vectors = vec![vec![3.0, 4.0], vec![1.0, 0.0]];
585
586 normalize_vectors(&mut vectors);
587
588 for vec in &vectors {
589 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
590 assert!((norm - 1.0).abs() < 1e-6);
591 }
592 }
593
594 #[test]
595 fn test_average_embedding() {
596 let embeddings = vec![
597 vec![1.0, 2.0, 3.0],
598 vec![2.0, 3.0, 4.0],
599 vec![3.0, 4.0, 5.0],
600 ];
601
602 let avg = average_embedding(&embeddings).unwrap();
603 assert_eq!(avg, vec![2.0, 3.0, 4.0]);
604 }
605
606 #[test]
607 fn test_average_embedding_empty() {
608 let embeddings: Vec<Vec<f32>> = vec![];
609 assert!(average_embedding(&embeddings).is_none());
610 }
611
612 #[test]
613 fn test_average_embedding_inconsistent_dims() {
614 let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0, 5.0]];
615 assert!(average_embedding(&embeddings).is_none());
616 }
617
618 #[test]
619 fn test_validate_embeddings() {
620 let embeddings = vec![
621 vec![0.1, 0.2, 0.3],
622 vec![0.4, 0.5, 0.6],
623 vec![f32::NAN, 0.1, 0.2],
624 ];
625
626 let reports = validate_embeddings(&embeddings);
627 assert_eq!(reports.len(), 3);
628 assert!(reports[0].is_valid);
629 assert!(reports[1].is_valid);
630 assert!(!reports[2].is_valid); // NaN is invalid
631 }
632
633 #[test]
634 fn test_health_check() {
635 let index = VectorIndex::with_defaults(128).unwrap();
636 let health = health_check(&index);
637
638 // Empty index may not be healthy depending on implementation
639 // At minimum, it should report 0 vectors
640 assert_eq!(health.vector_count, 0);
641 }
642
643 #[test]
644 fn test_batch_delete() {
645 use multihash_codetable::{Code, MultihashDigest};
646
647 let mut index = VectorIndex::with_defaults(768).unwrap();
648
649 // Insert some test vectors
650 let mut cids = Vec::new();
651 for i in 0..5 {
652 let data = format!("test_vector_{}", i);
653 let hash = Code::Sha2_256.digest(data.as_bytes());
654 let cid = Cid::new_v1(0x55, hash);
655 index.insert(&cid, &vec![i as f32 * 0.1; 768]).unwrap();
656 cids.push(cid);
657 }
658
659 // Delete first 3 CIDs
660 let to_delete = &cids[0..3];
661 let result = batch_delete(&mut index, to_delete).unwrap();
662
663 assert_eq!(result.deleted, 3);
664 assert_eq!(result.not_found, 0);
665 assert_eq!(result.failed, 0);
666 assert_eq!(index.len(), 2); // 2 remaining
667 }
668
669 #[test]
670 fn test_batch_delete_not_found() {
671 use multihash_codetable::{Code, MultihashDigest};
672
673 let mut index = VectorIndex::with_defaults(768).unwrap();
674
675 // Create a CID that's not in the index
676 let data = "nonexistent";
677 let hash = Code::Sha2_256.digest(data.as_bytes());
678 let cid = Cid::new_v1(0x55, hash);
679
680 let result = batch_delete(&mut index, &[cid]).unwrap();
681
682 assert_eq!(result.deleted, 0);
683 assert_eq!(result.not_found, 1);
684 assert_eq!(result.not_found_cids.len(), 1);
685 }
686
687 #[test]
688 fn test_cosine_similarity() {
689 // Test identical vectors
690 let vec1 = vec![1.0, 2.0, 3.0];
691 let vec2 = vec![1.0, 2.0, 3.0];
692 let sim = cosine_similarity(&vec1, &vec2).unwrap();
693 assert!((sim - 1.0).abs() < 1e-6);
694
695 // Test orthogonal vectors
696 let vec3 = vec![1.0, 0.0, 0.0];
697 let vec4 = vec![0.0, 1.0, 0.0];
698 let sim2 = cosine_similarity(&vec3, &vec4).unwrap();
699 assert!(sim2.abs() < 1e-6); // Should be ~0
700
701 // Test parallel vectors (same direction, different magnitude)
702 let vec5 = vec![1.0, 2.0, 3.0];
703 let vec6 = vec![2.0, 4.0, 6.0];
704 let sim3 = cosine_similarity(&vec5, &vec6).unwrap();
705 assert!((sim3 - 1.0).abs() < 1e-6);
706 }
707
708 #[test]
709 fn test_cosine_similarity_different_dims() {
710 let vec1 = vec![1.0, 2.0];
711 let vec2 = vec![1.0, 2.0, 3.0];
712 assert!(cosine_similarity(&vec1, &vec2).is_none());
713 }
714
715 #[test]
716 fn test_cosine_similarity_zero_vector() {
717 let vec1 = vec![0.0, 0.0, 0.0];
718 let vec2 = vec![1.0, 2.0, 3.0];
719 assert!(cosine_similarity(&vec1, &vec2).is_none());
720 }
721
722 #[test]
723 fn test_pairwise_similarities() {
724 let query = vec![1.0, 0.0, 0.0];
725 let embeddings = vec![
726 vec![1.0, 0.0, 0.0], // Same as query
727 vec![0.0, 1.0, 0.0], // Orthogonal
728 vec![0.7, 0.7, 0.0], // Partially similar
729 ];
730
731 let similarities = pairwise_similarities(&query, &embeddings);
732
733 assert_eq!(similarities.len(), 3);
734 assert_eq!(similarities[0].0, 0); // First embedding is most similar
735 assert!((similarities[0].1 - 1.0).abs() < 1e-6);
736 assert!(similarities[1].1 > similarities[2].1); // vec[2] (orthogonal) should be least similar
737 }
738
739 #[test]
740 fn test_export_index_stats() {
741 use multihash_codetable::{Code, MultihashDigest};
742
743 let mut index = VectorIndex::with_defaults(768).unwrap();
744
745 // Add a vector
746 let data = "test_vector";
747 let hash = Code::Sha2_256.digest(data.as_bytes());
748 let cid = Cid::new_v1(0x55, hash);
749 index.insert(&cid, &vec![0.5; 768]).unwrap();
750
751 let stats = export_index_stats(&index);
752
753 assert_eq!(stats.dimension, 768);
754 assert_eq!(stats.vector_count, 1);
755 assert!(!stats.metric.is_empty());
756 }
757}