harness_core/embed.rs
1//! Optional embeddings trait. **Strictly opt-in** — nothing in `Model`,
2//! `AgentLoop`, `Hook`, `Guide`, `Sensor`, or `Memory` references this. Code
3//! that wants semantic search / vector recall holds an `Arc<dyn Embedder>`
4//! explicitly; everything else compiles without ever touching this module.
5//!
6//! Implementations live in `harness-models` (e.g. `GeminiEmbed`,
7//! `OpenAiEmbed`). Local/embedded backends (BGE via fastembed-rs etc.) can
8//! be added later without changing this trait.
9//!
10//! Output convention: each input string maps 1:1 to a `Vec<f32>` of length
11//! `dim()`. Vectors are returned **unnormalised**; callers that want cosine
12//! similarity should L2-normalise both sides themselves (one pass over the
13//! vector), or use a helper.
14
15use async_trait::async_trait;
16use std::fmt;
17
18/// Failures from an `Embedder::embed` call. Kept separate from `ModelError`
19/// because the surfaces differ (no thinking, no tools, no streaming) and we
20/// don't want adapters reaching across modules.
21#[derive(Debug)]
22#[non_exhaustive]
23pub enum EmbedError {
24 /// Network / DNS / TLS / timeout — anything reqwest can throw.
25 Transport(String),
26 /// Provider returned a non-2xx response or malformed body.
27 Provider(String),
28 /// Caller passed something unembeddable (empty input list, oversize batch
29 /// for the provider). Surfaced rather than truncating silently.
30 BadInput(String),
31}
32
33impl fmt::Display for EmbedError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 EmbedError::Transport(s) => write!(f, "embed transport: {s}"),
37 EmbedError::Provider(s) => write!(f, "embed provider: {s}"),
38 EmbedError::BadInput(s) => write!(f, "embed bad input: {s}"),
39 }
40 }
41}
42
43impl std::error::Error for EmbedError {}
44
45/// Producer of fixed-dimension float vectors for input text. Batched.
46///
47/// Adapters MUST:
48/// - Return exactly `inputs.len()` vectors, in the same order.
49/// - Each vector MUST be exactly `dim()` long.
50/// - Treat empty `inputs` as `Ok(Vec::new())` (no provider call).
51#[async_trait]
52pub trait Embedder: Send + Sync + 'static {
53 /// Embed a batch of strings. Empty input → empty output, no provider call.
54 async fn embed(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError>;
55
56 /// Output dimensionality. Constant per adapter instance.
57 fn dim(&self) -> usize;
58
59 /// Human-readable identifier, e.g. `"gemini:text-embedding-004"`. Used
60 /// in logs and to tag stored vectors so the schema can detect a dim
61 /// change after a model swap.
62 fn handle(&self) -> &str;
63}
64
65/// Convenience: one-shot single-string embed. Default impl wraps `embed`.
66#[async_trait]
67pub trait EmbedderExt: Embedder {
68 async fn embed_one(&self, input: &str) -> Result<Vec<f32>, EmbedError> {
69 let mut out = self.embed(&[input]).await?;
70 out.pop()
71 .ok_or_else(|| EmbedError::Provider("empty result for single input".into()))
72 }
73}
74
75impl<T: Embedder + ?Sized> EmbedderExt for T {}
76
77/// Mutate `v` in place to unit length (L2). No-op on zero vector. Callers
78/// that want cosine similarity should normalise both query and corpus
79/// once, then use dot product.
80pub fn l2_normalize(v: &mut [f32]) {
81 let mut s = 0.0f32;
82 for &x in v.iter() {
83 s += x * x;
84 }
85 if s <= 0.0 {
86 return;
87 }
88 let inv = 1.0 / s.sqrt();
89 for x in v.iter_mut() {
90 *x *= inv;
91 }
92}
93
94/// Plain dot product. With both vectors L2-normalised this equals cosine
95/// similarity. Bounded by ±1 in that case; outside that for raw vectors.
96pub fn dot(a: &[f32], b: &[f32]) -> f32 {
97 let n = a.len().min(b.len());
98 let mut s = 0.0f32;
99 for i in 0..n {
100 s += a[i] * b[i];
101 }
102 s
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn normalize_unit_length() {
111 let mut v = vec![3.0f32, 4.0];
112 l2_normalize(&mut v);
113 let len = (v[0] * v[0] + v[1] * v[1]).sqrt();
114 assert!((len - 1.0).abs() < 1e-6);
115 }
116
117 #[test]
118 fn normalize_zero_noop() {
119 let mut v = vec![0.0f32, 0.0];
120 l2_normalize(&mut v);
121 assert_eq!(v, vec![0.0, 0.0]);
122 }
123
124 #[test]
125 fn dot_matches_naive() {
126 let a = [1.0f32, 2.0, 3.0];
127 let b = [4.0f32, 5.0, 6.0];
128 assert!((dot(&a, &b) - (1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0)).abs() < 1e-6);
129 }
130
131 #[test]
132 fn cosine_via_normalized_dot() {
133 let mut a = vec![3.0f32, 4.0];
134 let mut b = vec![4.0f32, 3.0];
135 l2_normalize(&mut a);
136 l2_normalize(&mut b);
137 let cos = dot(&a, &b);
138 // (3*4+4*3)/((5)(5)) = 24/25 = 0.96
139 assert!((cos - 0.96).abs() < 1e-4);
140 }
141}