Skip to main content

memoir_core/embedding/
mod.rs

1//! Text-to-vector embedding primitive.
2//!
3//! Defines [`EmbeddingModel`], implemented by [`OnnxEmbedding`] and by callers
4//! who want to plug in their own embedder.
5
6mod error;
7pub mod onnx;
8
9pub use error::EmbeddingError;
10pub use onnx::OnnxEmbedding;
11
12/// Produces a fixed-dimension float vector from a text input.
13///
14/// Implementations must be deterministic: embedding the same input twice
15/// returns identical vectors. The returned vector's length must equal
16/// [`Self::dimensions`].
17pub trait EmbeddingModel: Send + Sync + 'static {
18    /// Embeds `text` into a [`Self::dimensions`]-length vector.
19    ///
20    /// # Errors
21    ///
22    /// Returns [`EmbeddingError::Embed`] when inference fails. Init-time
23    /// failures surface from the implementation's constructor.
24    fn embed(&self, text: &str) -> impl std::future::Future<Output = Result<Vec<f32>, EmbeddingError>> + Send;
25
26    /// Returns the dimension of vectors produced by [`Self::embed`].
27    fn dimensions(&self) -> usize;
28}
29
30impl<T: EmbeddingModel> EmbeddingModel for std::sync::Arc<T> {
31    fn embed(&self, text: &str) -> impl std::future::Future<Output = Result<Vec<f32>, EmbeddingError>> + Send {
32        (**self).embed(text)
33    }
34
35    fn dimensions(&self) -> usize {
36        (**self).dimensions()
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43
44    struct StubEmbedding {
45        dim: usize,
46    }
47
48    impl EmbeddingModel for StubEmbedding {
49        async fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
50            Ok(vec![0.1; self.dim])
51        }
52
53        fn dimensions(&self) -> usize {
54            self.dim
55        }
56    }
57
58    #[tokio::test(flavor = "current_thread")]
59    async fn should_implement_trait_with_in_test_stub() {
60        let model = StubEmbedding { dim: 4 };
61
62        let vector = model.embed("hello").await.unwrap();
63
64        assert_eq!(vector.len(), 4);
65        assert_eq!(model.dimensions(), 4);
66    }
67}