Skip to main content

sphereql_embed/
text_embedder.rs

1//! Pluggable text-to-embedding hook.
2//!
3//! Consumers that want to query the pipeline from natural-language text
4//! (e.g. the GraphQL crate's `drillDown(queryText: ...)`, a REPL, a Python
5//! shell helper) need a way to turn a `&str` into an [`Embedding`] without
6//! this crate taking a dependency on any specific embedder. `TextEmbedder`
7//! is that hook.
8//!
9//! The trait is deliberately minimal: one fallible method, `Send + Sync`,
10//! no lifetime parameters. Users implement it on their own newtype — a
11//! thin wrapper around `sentence-transformers`, an OpenAI client, a local
12//! ONNX model, or a deterministic hash — and hand an `Arc<dyn TextEmbedder>`
13//! to whatever schema / pipeline / REPL consumes it.
14//!
15//! Convenience types:
16//!
17//! - [`NoEmbedder`] — the default; returns an error on `embed()`. Wired
18//!   into GraphQL schemas that haven't configured a real embedder so text
19//!   query resolvers fail with a clear message rather than panicking.
20//! - [`FnEmbedder`] — a zero-cost wrapper that lifts a closure into the
21//!   trait, for quick wiring in tests and examples.
22//!
23//! # Example
24//!
25//! ```ignore
26//! use std::sync::Arc;
27//! use sphereql_embed::text_embedder::{TextEmbedder, FnEmbedder, EmbedderError};
28//! use sphereql_embed::types::Embedding;
29//!
30//! let embedder: Arc<dyn TextEmbedder> = Arc::new(FnEmbedder::new(|text: &str| {
31//!     let len = text.len() as f64;
32//!     Ok(Embedding::new(vec![len, len.sqrt(), len.ln().max(0.0)]))
33//! }));
34//!
35//! let vec = embedder.embed("hello").unwrap();
36//! assert_eq!(vec.dimension(), 3);
37//! ```
38
39use thiserror::Error;
40
41use crate::types::Embedding;
42
43/// Errors surfaced from a [`TextEmbedder`] implementation.
44///
45/// Kept intentionally small: the single stringly-typed variant lets any
46/// backend (HTTP, ONNX runtime, pure-Rust model) funnel its native error
47/// into the trait without forcing this crate to enumerate every failure
48/// mode. Consumers that need typed errors can match on `.to_string()` or
49/// wrap this in their own error enum.
50#[derive(Debug, Error)]
51pub enum EmbedderError {
52    /// Embedder implementation failed while embedding the input.
53    #[error("embedder failed: {0}")]
54    Embedding(String),
55
56    /// Input was rejected before reaching the model (empty string,
57    /// too long, invalid UTF-8 boundary, etc.).
58    #[error("invalid input: {0}")]
59    InvalidInput(String),
60}
61
62impl EmbedderError {
63    /// Construct an [`EmbedderError::Embedding`] from any displayable error.
64    pub fn embedding<E: std::fmt::Display>(err: E) -> Self {
65        EmbedderError::Embedding(err.to_string())
66    }
67
68    /// Construct an [`EmbedderError::InvalidInput`] from any displayable error.
69    pub fn invalid_input<E: std::fmt::Display>(err: E) -> Self {
70        EmbedderError::InvalidInput(err.to_string())
71    }
72}
73
74/// Turns free-form text into an [`Embedding`] suitable for projection
75/// through the sphereQL pipeline.
76///
77/// Implement on a newtype that owns the embedder's state (HTTP client,
78/// tokenizer, model weights). The trait is `Send + Sync` so a single
79/// embedder can be shared across async request handlers via
80/// `Arc<dyn TextEmbedder>`.
81pub trait TextEmbedder: Send + Sync {
82    /// Embed a text query.
83    ///
84    /// Returns an [`Embedding`] whose dimensionality must match whatever
85    /// pipeline the result will be fed into. Implementations should
86    /// surface upstream failures as [`EmbedderError::Embedding`] rather
87    /// than panicking.
88    fn embed(&self, text: &str) -> Result<Embedding, EmbedderError>;
89}
90
91impl<T: TextEmbedder + ?Sized> TextEmbedder for std::sync::Arc<T> {
92    fn embed(&self, text: &str) -> Result<Embedding, EmbedderError> {
93        (**self).embed(text)
94    }
95}
96
97impl<T: TextEmbedder + ?Sized> TextEmbedder for Box<T> {
98    fn embed(&self, text: &str) -> Result<Embedding, EmbedderError> {
99        (**self).embed(text)
100    }
101}
102
103/// Default embedder that always fails with a descriptive error.
104///
105/// Wired into schemas that haven't been given a real embedder, so a text
106/// query hits an actionable "no TextEmbedder configured" error rather
107/// than panicking or silently returning empty results.
108#[derive(Debug, Default, Clone, Copy)]
109pub struct NoEmbedder;
110
111impl TextEmbedder for NoEmbedder {
112    fn embed(&self, _text: &str) -> Result<Embedding, EmbedderError> {
113        Err(EmbedderError::Embedding(
114            "no TextEmbedder configured — supply one to build_category_schema \
115             (or equivalent) before issuing text queries"
116                .into(),
117        ))
118    }
119}
120
121/// Zero-cost wrapper that lifts a closure into [`TextEmbedder`].
122///
123/// Useful for tests and examples where the full trait-impl-on-newtype
124/// ceremony is overkill:
125///
126/// ```ignore
127/// use sphereql_embed::text_embedder::FnEmbedder;
128/// use sphereql_embed::types::Embedding;
129///
130/// let embedder = FnEmbedder::new(|text: &str| {
131///     Ok(Embedding::new(vec![text.len() as f64; 128]))
132/// });
133/// ```
134pub struct FnEmbedder<F> {
135    inner: F,
136}
137
138impl<F> FnEmbedder<F>
139where
140    F: Fn(&str) -> Result<Embedding, EmbedderError> + Send + Sync,
141{
142    /// Wrap a closure. The closure must be `Send + Sync` so the resulting
143    /// embedder can be shared across threads.
144    pub fn new(f: F) -> Self {
145        Self { inner: f }
146    }
147}
148
149impl<F> TextEmbedder for FnEmbedder<F>
150where
151    F: Fn(&str) -> Result<Embedding, EmbedderError> + Send + Sync,
152{
153    fn embed(&self, text: &str) -> Result<Embedding, EmbedderError> {
154        (self.inner)(text)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn no_embedder_errors_descriptively() {
164        let err = NoEmbedder.embed("hello").unwrap_err();
165        let msg = err.to_string();
166        assert!(msg.contains("no TextEmbedder configured"), "got: {msg}");
167    }
168
169    #[test]
170    fn fn_embedder_round_trips() {
171        let embedder =
172            FnEmbedder::new(|text: &str| Ok(Embedding::new(vec![text.len() as f64, 0.0, 0.0])));
173        let v = embedder.embed("hello world").unwrap();
174        assert_eq!(v.dimension(), 3);
175        assert_eq!(v.values[0], 11.0);
176    }
177
178    #[test]
179    fn arc_forwards() {
180        let arc: std::sync::Arc<dyn TextEmbedder> =
181            std::sync::Arc::new(FnEmbedder::new(|_| Ok(Embedding::new(vec![1.0]))));
182        assert_eq!(arc.embed("x").unwrap().values, vec![1.0]);
183    }
184
185    #[test]
186    fn error_constructors_format() {
187        let e = EmbedderError::embedding("upstream blew up");
188        assert_eq!(e.to_string(), "embedder failed: upstream blew up");
189        let e = EmbedderError::invalid_input("empty");
190        assert_eq!(e.to_string(), "invalid input: empty");
191    }
192}