Skip to main content

hirn_core/
embed.rs

1//! F-39: Embedder trait — pluggable embedding providers for semantic vector search.
2//!
3//! The `Embedder` trait abstracts over local (ONNX) and remote (`OpenAI`, Cohere, …)
4//! embedding models so that users can swap providers without changing application code.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use crate::error::HirnResult;
11
12/// A single embedding result with its source model identifier.
13#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
14pub struct Embedding {
15    /// The embedding vector (f32 per dimension).
16    pub vector: Vec<f32>,
17    /// Identifier of the model that produced this embedding (e.g. `"text-embedding-3-small"`).
18    pub model_id: String,
19}
20
21/// A multivector (token-level) embedding for ColBERT-style late interaction.
22#[derive(Debug, Clone, PartialEq)]
23pub struct MultivectorEmbedding {
24    /// One vector per token (or sub-token).
25    pub vectors: Vec<Vec<f32>>,
26    /// Identifier of the model that produced these embeddings.
27    pub model_id: String,
28}
29
30/// Result of a reranking operation on a single document.
31#[derive(Debug, Clone, PartialEq)]
32pub struct RerankResult {
33    /// Index into the original `documents` slice.
34    pub index: usize,
35    /// Relevance score assigned by the reranker.
36    pub score: f32,
37}
38
39/// An entity extracted from unstructured text.
40#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct ExtractedEntity {
42    pub name: String,
43    pub entity_type: String,
44    pub confidence: f32,
45}
46
47/// A relation between two extracted entities.
48#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
49pub struct ExtractedRelation {
50    pub source: String,
51    pub target: String,
52    pub relation_type: String,
53    pub weight: f32,
54}
55
56// ── Embedder ─────────────────────────────────────────────────────────────
57
58/// Pluggable embedding provider (F-39).
59///
60/// Implementations live in the `hirn-provider` crate. The core crate only
61/// defines the contract so that `hirn-engine` can depend on `hirn-core`
62/// without pulling in heavy ML dependencies.
63#[async_trait]
64pub trait Embedder: Send + Sync {
65    /// Embed one or more texts, returning one [`Embedding`] per input.
66    ///
67    /// # Errors
68    /// Returns an error if the provider is unreachable or the input exceeds
69    /// the model's token limit.
70    async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
71
72    /// Number of dimensions the model produces.
73    fn dimensions(&self) -> usize;
74
75    /// Stable model identifier stored alongside each memory for re-embedding detection.
76    fn model_id(&self) -> &str;
77
78    /// Maximum number of input tokens the model accepts per text.
79    fn max_input_tokens(&self) -> usize;
80
81    /// Produce token-level (multivector) embeddings for ColBERT-style late interaction.
82    ///
83    /// Returns one [`MultivectorEmbedding`] per input text, where each embedding
84    /// contains one vector per token. Default implementation returns an error,
85    /// indicating the model does not support multivector embeddings.
86    async fn embed_multivec(&self, _texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
87        Err(crate::error::HirnError::InvalidInput(
88            "this embedder does not support multivector embeddings".into(),
89        ))
90    }
91
92    /// Whether this embedder supports multivector (ColBERT-style) embeddings.
93    fn supports_multivec(&self) -> bool {
94        false
95    }
96}
97
98#[async_trait]
99impl<T: Embedder + ?Sized> Embedder for Arc<T> {
100    async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
101        self.as_ref().embed(texts).await
102    }
103
104    fn dimensions(&self) -> usize {
105        self.as_ref().dimensions()
106    }
107
108    fn model_id(&self) -> &str {
109        self.as_ref().model_id()
110    }
111
112    fn max_input_tokens(&self) -> usize {
113        self.as_ref().max_input_tokens()
114    }
115
116    async fn embed_multivec(&self, texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
117        self.as_ref().embed_multivec(texts).await
118    }
119
120    fn supports_multivec(&self) -> bool {
121        self.as_ref().supports_multivec()
122    }
123}
124
125// ── TokenCounter ─────────────────────────────────────────────────────────
126
127/// Pluggable token counter (§11.5).
128///
129/// Used by the THINK budget planner to measure context length independently
130/// of the tokenizer backend.
131pub trait TokenCounter: Send + Sync {
132    /// Count the number of tokens in `text`.
133    fn count_tokens(&self, text: &str) -> usize;
134
135    /// Count tokens for multiple texts.
136    fn count_tokens_batch(&self, texts: &[&str]) -> Vec<usize> {
137        texts.iter().map(|t| self.count_tokens(t)).collect()
138    }
139}
140
141/// Character-estimate fallback: `ceil(len / 4)`. Always available, zero dependencies.
142#[derive(Debug, Clone, Copy)]
143pub struct CharEstimateCounter;
144
145impl TokenCounter for CharEstimateCounter {
146    fn count_tokens(&self, text: &str) -> usize {
147        text.len().div_ceil(4)
148    }
149}
150
151// ── Reranker ─────────────────────────────────────────────────────────────
152
153/// Two-stage reranker (§12.4.2): cross-encoder precision after bi-encoder recall.
154#[async_trait]
155pub trait Reranker: Send + Sync {
156    /// Rerank `documents` by relevance to `query`, returning the top-k results
157    /// sorted by descending score.
158    async fn rerank(
159        &self,
160        query: &str,
161        documents: &[&str],
162        top_k: usize,
163    ) -> HirnResult<Vec<RerankResult>>;
164}
165
166/// Identity reranker — returns all documents in original order.
167#[derive(Debug, Clone, Copy)]
168pub struct NoopReranker;
169
170#[async_trait]
171impl Reranker for NoopReranker {
172    async fn rerank(
173        &self,
174        _query: &str,
175        documents: &[&str],
176        top_k: usize,
177    ) -> HirnResult<Vec<RerankResult>> {
178        Ok(documents
179            .iter()
180            .enumerate()
181            .take(top_k)
182            .map(|(i, _)| RerankResult {
183                index: i,
184                score: 1.0 - (i as f32 / documents.len().max(1) as f32),
185            })
186            .collect())
187    }
188}
189
190// ── EntityExtractor ──────────────────────────────────────────────────────
191
192/// Entity and relation extraction from unstructured text (F-41).
193#[async_trait]
194pub trait EntityExtractor: Send + Sync {
195    /// Extract named entities from `text`, optionally filtering by `entity_types`.
196    async fn extract_entities(
197        &self,
198        text: &str,
199        entity_types: &[&str],
200    ) -> HirnResult<Vec<ExtractedEntity>>;
201
202    /// Extract relations between previously extracted entities.
203    async fn extract_relations(
204        &self,
205        text: &str,
206        entities: &[ExtractedEntity],
207    ) -> HirnResult<Vec<ExtractedRelation>>;
208}
209
210// ── LlmProvider ──────────────────────────────────────────────────────────
211
212/// Chat message for LLM generation.
213#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
214pub struct ChatMessage {
215    pub role: String,
216    pub content: String,
217}
218
219/// Desired response format from the LLM.
220#[derive(Debug, Clone, PartialEq, Eq, Default)]
221pub enum ResponseFormat {
222    /// Free-form text (default).
223    #[default]
224    Text,
225    /// Valid JSON object (no schema constraint).
226    JsonObject,
227    /// JSON conforming to the given JSON-Schema string.
228    JsonSchema(String),
229}
230
231/// Token usage reported by the provider.
232#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
233pub struct TokenUsage {
234    pub prompt_tokens: u32,
235    pub completion_tokens: u32,
236}
237
238impl TokenUsage {
239    /// Total tokens consumed.
240    pub const fn total(&self) -> u32 {
241        self.prompt_tokens + self.completion_tokens
242    }
243}
244
245/// Full response from the LLM provider.
246#[derive(Debug, Clone, PartialEq, Eq)]
247pub struct LlmResponse {
248    pub content: String,
249    pub usage: Option<TokenUsage>,
250}
251
252/// A single chunk from a streaming LLM response.
253#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct LlmChunk {
255    /// Incremental text delta.
256    pub delta: String,
257    /// Optional cumulative usage snapshot reported during streaming.
258    pub usage: Option<TokenUsage>,
259}
260
261/// Options for LLM generation requests.
262#[derive(Debug, Clone, PartialEq)]
263pub struct LlmOptions {
264    pub model_override: Option<String>,
265    pub temperature: f32,
266    pub max_tokens: u32,
267    pub response_format: ResponseFormat,
268}
269
270impl Default for LlmOptions {
271    fn default() -> Self {
272        Self {
273            model_override: None,
274            temperature: 0.0,
275            max_tokens: 1024,
276            response_format: ResponseFormat::Text,
277        }
278    }
279}
280
281/// Stream of LLM chunks.
282pub type LlmStream = std::pin::Pin<Box<dyn futures::Stream<Item = HirnResult<LlmChunk>> + Send>>;
283
284/// Pluggable LLM provider for structured extraction and generation (F-41, §12).
285#[async_trait]
286pub trait LlmProvider: Send + Sync {
287    /// Generate a text response.
288    async fn generate_text(
289        &self,
290        messages: &[ChatMessage],
291        options: &LlmOptions,
292    ) -> HirnResult<String>;
293
294    /// Generate a full response including usage metadata.
295    ///
296    /// The default implementation delegates to [`generate_text`](Self::generate_text).
297    async fn generate(
298        &self,
299        messages: &[ChatMessage],
300        options: &LlmOptions,
301    ) -> HirnResult<LlmResponse> {
302        let content = self.generate_text(messages, options).await?;
303        Ok(LlmResponse {
304            content,
305            usage: None,
306        })
307    }
308
309    /// Stream a response chunk-by-chunk.
310    ///
311    /// The default implementation collects the full response into a single chunk.
312    async fn generate_stream(
313        &self,
314        messages: &[ChatMessage],
315        options: &LlmOptions,
316    ) -> HirnResult<LlmStream> {
317        let text = self.generate_text(messages, options).await?;
318        let chunk = LlmChunk {
319            delta: text,
320            usage: None,
321        };
322        Ok(Box::pin(futures::stream::once(async { Ok(chunk) })))
323    }
324
325    /// Stable model identifier.
326    fn model_id(&self) -> &str;
327}
328
329// ── AsymmetricEmbedder ───────────────────────────────────────────────────
330
331/// Asymmetric embedding provider that separates source (ingest) and query (search)
332/// embedding spaces.
333///
334/// Some models (e.g. E5, GTR, asymmetric Cohere) produce different embeddings
335/// for documents vs. queries. This trait captures that distinction. The default
336/// implementation of [`embed_query`](Self::embed_query) delegates to
337/// [`embed_source`](Self::embed_source) for symmetric models.
338#[async_trait]
339pub trait AsymmetricEmbedder: Send + Sync {
340    /// Embed texts for storage (source / document embedding).
341    async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>>;
342
343    /// Embed texts for search (query embedding).
344    ///
345    /// Default: delegates to [`embed_source`](Self::embed_source).
346    async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
347        self.embed_source(texts).await
348    }
349
350    /// Model name used as registry key and cache key.
351    fn name(&self) -> &str;
352
353    /// Output embedding dimensionality.
354    fn dims(&self) -> usize;
355}
356
357/// Adapter that wraps any [`Embedder`] as an [`AsymmetricEmbedder`].
358///
359/// Both `embed_source` and `embed_query` delegate to the underlying
360/// [`Embedder::embed`], making this a symmetric adapter.
361pub struct EmbedderAdapter<E: Embedder> {
362    inner: E,
363}
364
365impl<E: Embedder> EmbedderAdapter<E> {
366    /// Wrap an existing embedder.
367    pub fn new(inner: E) -> Self {
368        Self { inner }
369    }
370}
371
372#[async_trait]
373impl<E: Embedder> AsymmetricEmbedder for EmbedderAdapter<E> {
374    async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
375        self.inner.embed(texts).await
376    }
377
378    fn name(&self) -> &str {
379        self.inner.model_id()
380    }
381
382    fn dims(&self) -> usize {
383        self.inner.dimensions()
384    }
385}
386
387// ── Matryoshka helper ────────────────────────────────────────────────────
388
389/// Truncate a Matryoshka-trained embedding to `target_dims` and re-normalize.
390/// Returns `None` if `embedding` is shorter than `target_dims`.
391#[must_use]
392pub fn truncate_matryoshka(embedding: &[f32], target_dims: usize) -> Option<Vec<f32>> {
393    if embedding.len() < target_dims {
394        return None;
395    }
396    let truncated = &embedding[..target_dims];
397    let norm = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
398    if norm > 0.0 {
399        Some(truncated.iter().map(|x| x / norm).collect())
400    } else {
401        Some(truncated.to_vec())
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn char_estimate_counter() {
411        let c = CharEstimateCounter;
412        assert_eq!(c.count_tokens(""), 0);
413        assert_eq!(c.count_tokens("hi"), 1);
414        assert_eq!(c.count_tokens("hello world"), 3); // 11/4 = 2.75 → ceil = 3
415    }
416
417    #[test]
418    fn char_estimate_batch() {
419        let c = CharEstimateCounter;
420        let counts = c.count_tokens_batch(&["a", "abcdefgh"]);
421        assert_eq!(counts, vec![1, 2]);
422    }
423
424    #[test]
425    fn noop_reranker_returns_descending() {
426        let r = NoopReranker;
427        let docs = ["alpha", "beta", "gamma"];
428        let results = tokio::runtime::Runtime::new()
429            .unwrap()
430            .block_on(r.rerank("q", &docs, 2))
431            .unwrap();
432        assert_eq!(results.len(), 2);
433        assert_eq!(results[0].index, 0);
434        assert_eq!(results[1].index, 1);
435        assert!(results[0].score >= results[1].score);
436    }
437
438    #[test]
439    fn matryoshka_truncate() {
440        let emb = vec![3.0, 4.0, 0.0, 0.0];
441        let t = truncate_matryoshka(&emb, 2).unwrap();
442        assert_eq!(t.len(), 2);
443        let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
444        assert!((norm - 1.0).abs() < 1e-5);
445    }
446
447    #[test]
448    fn matryoshka_too_short() {
449        assert!(truncate_matryoshka(&[1.0, 2.0], 5).is_none());
450    }
451
452    // ── AsymmetricEmbedder tests ─────────────────────────────────────────
453
454    /// Stub embedder for testing `EmbedderAdapter`.
455    struct StubEmbedder {
456        dim: usize,
457        id: &'static str,
458    }
459
460    #[async_trait]
461    impl Embedder for StubEmbedder {
462        async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
463            Ok(texts
464                .iter()
465                .map(|t| Embedding {
466                    vector: vec![t.len() as f32; self.dim],
467                    model_id: self.id.to_string(),
468                })
469                .collect())
470        }
471        fn dimensions(&self) -> usize {
472            self.dim
473        }
474        fn model_id(&self) -> &str {
475            self.id
476        }
477        fn max_input_tokens(&self) -> usize {
478            8192
479        }
480    }
481
482    #[tokio::test]
483    async fn embedder_adapter_delegates_embed_source() {
484        let adapter = EmbedderAdapter::new(StubEmbedder {
485            dim: 4,
486            id: "stub-v1",
487        });
488        let result = adapter.embed_source(&["hello", "world"]).await.unwrap();
489        assert_eq!(result.len(), 2);
490        assert_eq!(result[0].vector.len(), 4);
491        // "hello" has 5 chars → each dimension should be 5.0
492        assert_eq!(result[0].vector, vec![5.0; 4]);
493        assert_eq!(result[1].vector, vec![5.0; 4]);
494    }
495
496    #[tokio::test]
497    async fn embedder_adapter_name_and_dims() {
498        let adapter = EmbedderAdapter::new(StubEmbedder {
499            dim: 128,
500            id: "my-model",
501        });
502        assert_eq!(adapter.name(), "my-model");
503        assert_eq!(adapter.dims(), 128);
504    }
505
506    #[tokio::test]
507    async fn default_embed_query_delegates_to_embed_source() {
508        // The default `embed_query` should delegate to `embed_source`.
509        let adapter = EmbedderAdapter::new(StubEmbedder { dim: 3, id: "sym" });
510        let source = adapter.embed_source(&["test"]).await.unwrap();
511        let query = adapter.embed_query(&["test"]).await.unwrap();
512        assert_eq!(source, query);
513    }
514
515    /// Asymmetric embedder that returns different vectors for source vs query.
516    struct AsymStub;
517
518    #[async_trait]
519    impl AsymmetricEmbedder for AsymStub {
520        async fn embed_source(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
521            Ok(texts
522                .iter()
523                .map(|_| Embedding {
524                    vector: vec![1.0, 0.0, 0.0],
525                    model_id: "asym".to_string(),
526                })
527                .collect())
528        }
529        async fn embed_query(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
530            Ok(texts
531                .iter()
532                .map(|_| Embedding {
533                    vector: vec![0.0, 1.0, 0.0],
534                    model_id: "asym".to_string(),
535                })
536                .collect())
537        }
538        fn name(&self) -> &str {
539            "asym"
540        }
541        fn dims(&self) -> usize {
542            3
543        }
544    }
545
546    #[tokio::test]
547    async fn asymmetric_embedder_returns_different_vectors() {
548        let e = AsymStub;
549        let source = e.embed_source(&["hello"]).await.unwrap();
550        let query = e.embed_query(&["hello"]).await.unwrap();
551        assert_ne!(source[0].vector, query[0].vector);
552        assert_eq!(source[0].vector, vec![1.0, 0.0, 0.0]);
553        assert_eq!(query[0].vector, vec![0.0, 1.0, 0.0]);
554    }
555}