Skip to main content

rig_core/test_utils/
embeddings.rs

1//! Embedding helpers for deterministic tests.
2
3use crate::{
4    Embed,
5    client::Nothing,
6    embeddings::{
7        Embedding, EmbeddingError, EmbeddingModel,
8        embed::{EmbedError, TextEmbedder},
9    },
10    wasm_compat::WasmCompatSend,
11};
12
13/// A deterministic [`EmbeddingModel`] that returns a fixed vector for each input document.
14#[derive(Clone, Debug, Default)]
15pub struct MockEmbeddingModel;
16
17impl EmbeddingModel for MockEmbeddingModel {
18    const MAX_DOCUMENTS: usize = 5;
19
20    type Client = Nothing;
21
22    fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
23        Self
24    }
25
26    fn ndims(&self) -> usize {
27        10
28    }
29
30    async fn embed_texts(
31        &self,
32        documents: impl IntoIterator<Item = String> + WasmCompatSend,
33    ) -> Result<Vec<Embedding>, EmbeddingError> {
34        Ok(documents
35            .into_iter()
36            .map(|document| Embedding {
37                document,
38                vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
39            })
40            .collect())
41    }
42}
43
44/// A test document that contributes one text fragment to an embedding request.
45#[derive(Clone, Debug)]
46pub struct MockTextDocument {
47    /// Stable document identifier used by tests.
48    pub id: String,
49    /// Text to embed.
50    pub text: String,
51}
52
53impl MockTextDocument {
54    /// Create a single-text embedding fixture.
55    pub fn new(id: impl Into<String>, text: impl Into<String>) -> Self {
56        Self {
57            id: id.into(),
58            text: text.into(),
59        }
60    }
61}
62
63impl Embed for MockTextDocument {
64    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
65        embedder.embed(self.text.clone());
66        Ok(())
67    }
68}
69
70/// A test document that contributes multiple text fragments to an embedding request.
71#[derive(Clone, Debug)]
72pub struct MockMultiTextDocument {
73    /// Stable document identifier used by tests.
74    pub id: String,
75    /// Text fragments to embed.
76    pub texts: Vec<String>,
77}
78
79impl MockMultiTextDocument {
80    /// Create a multi-text embedding fixture.
81    pub fn new(id: impl Into<String>, texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
82        Self {
83            id: id.into(),
84            texts: texts.into_iter().map(Into::into).collect(),
85        }
86    }
87}
88
89impl Embed for MockMultiTextDocument {
90    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
91        for text in &self.texts {
92            embedder.embed(text.clone());
93        }
94        Ok(())
95    }
96}