Skip to main content

coding_agent_search/search/
embedder.rs

1//! Embedder trait and types for semantic search.
2//!
3//! This module re-exports the canonical [`Embedder`] trait from frankensearch's
4//! [`SyncEmbed`](frankensearch::SyncEmbed) trait. All embedding implementations
5//! must satisfy `Embedder`, which provides a synchronous embedding interface
6//! suitable for cass's sync call sites.
7//!
8//! The [`SyncEmbedderAdapter`](frankensearch::SyncEmbedderAdapter) can wrap any
9//! `Embedder` implementor into frankensearch's async `Embedder` trait when needed
10//! for the frankensearch search pipeline.
11//!
12//! # Implementations
13//!
14//! - **Hash embedder**: FNV-1a feature hashing (always available, ~256 dimensions)
15//! - **ML embedder**: FastEmbed with the MiniLM model (requires model download, 384 dimensions)
16
17use std::fmt;
18
19pub use frankensearch::SearchError as EmbedderError;
20pub use frankensearch::SearchResult as EmbedderResult;
21pub use frankensearch::SyncEmbed as Embedder;
22
23/// Metadata about an embedder for display and logging.
24#[derive(Debug, Clone)]
25pub struct EmbedderInfo {
26    /// The embedder's unique identifier.
27    pub id: String,
28    /// The output dimension.
29    pub dimension: usize,
30    /// Whether it's a semantic (ML) embedder.
31    pub is_semantic: bool,
32}
33
34impl EmbedderInfo {
35    /// Create info from an embedder instance.
36    pub fn from_embedder(embedder: &dyn Embedder) -> Self {
37        Self {
38            id: embedder.id().to_string(),
39            dimension: embedder.dimension(),
40            is_semantic: embedder.is_semantic(),
41        }
42    }
43}
44
45impl fmt::Display for EmbedderInfo {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        let kind = if self.is_semantic {
48            "semantic"
49        } else {
50            "lexical"
51        };
52        write!(f, "{} ({}, {} dims)", self.id, kind, self.dimension)
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::search::fastembed_embedder::FastEmbedder;
60    use crate::search::hash_embedder::HashEmbedder;
61    use std::path::PathBuf;
62
63    fn fastembed_fixture_dir() -> PathBuf {
64        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
65            .join("tests/fixtures/models/xenova-paraphrase-minilm-l3-v2-int8")
66    }
67
68    fn load_fastembed_fixture() -> FastEmbedder {
69        FastEmbedder::load_from_dir(&fastembed_fixture_dir())
70            .expect("fastembed fixture should load")
71    }
72
73    #[test]
74    fn test_embedder_trait_basic() {
75        let embedder = HashEmbedder::new(256);
76        let embedding = embedder.embed_sync("hello world").unwrap();
77        assert_eq!(embedding.len(), 256);
78        assert_eq!(embedder.id(), "fnv1a-256");
79        assert!(!embedder.is_semantic());
80    }
81
82    #[test]
83    fn test_embedder_trait_semantic() {
84        let embedder = load_fastembed_fixture();
85        assert_eq!(embedder.dimension(), 384);
86        assert_eq!(embedder.id(), FastEmbedder::embedder_id_static());
87        assert!(embedder.is_semantic());
88    }
89
90    #[test]
91    fn test_embedder_batch() {
92        let embedder = load_fastembed_fixture();
93        let texts = &["hello", "world", "test"];
94        let embeddings = embedder.embed_batch_sync(texts).unwrap();
95
96        assert_eq!(embeddings.len(), 3);
97        for embedding in &embeddings {
98            assert_eq!(embedding.len(), 384);
99        }
100    }
101
102    #[test]
103    fn test_embedder_empty_input_error() {
104        let embedder = load_fastembed_fixture();
105        let result = embedder.embed_sync("");
106        assert!(result.is_err());
107    }
108
109    #[test]
110    fn test_embedder_info() {
111        let embedder = load_fastembed_fixture();
112        let info = EmbedderInfo::from_embedder(&embedder);
113        assert_eq!(info.id, FastEmbedder::embedder_id_static());
114        assert_eq!(info.dimension, 384);
115        assert!(info.is_semantic);
116
117        let display = format!("{info}");
118        for expected in [FastEmbedder::embedder_id_static(), "semantic", "384"] {
119            assert!(
120                display.contains(expected),
121                "display {display:?} should contain {expected:?}"
122            );
123        }
124    }
125
126    #[test]
127    fn test_embedder_error_display() {
128        let err = EmbedderError::EmbedderUnavailable {
129            model: "test".to_string(),
130            reason: "model not downloaded".to_string(),
131        };
132        assert!(err.to_string().contains("model not downloaded"));
133
134        let err = EmbedderError::EmbeddingFailed {
135            model: "test".to_string(),
136            source: Box::new(std::io::Error::other("inference error")),
137        };
138        assert!(err.to_string().contains("inference error"));
139    }
140}