Skip to main content

rig_core/test_utils/
pipeline.rs

1//! Pipeline helpers for deterministic tests.
2
3use crate::{
4    completion::{CompletionError, Prompt, PromptError},
5    message::{self, Message},
6    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndex, request::Filter},
7    wasm_compat::WasmCompatSend,
8};
9
10/// A prompt model that echoes user text with a stable prefix.
11pub struct MockPromptModel;
12
13impl Prompt for MockPromptModel {
14    #[allow(refining_impl_trait)]
15    async fn prompt(&self, prompt: impl Into<Message>) -> Result<String, PromptError> {
16        let msg = prompt.into();
17        let prompt = match msg {
18            Message::User { content } => match content.first() {
19                message::UserContent::Text(message::Text { text }) => text,
20                _ => {
21                    return Err(PromptError::CompletionError(CompletionError::RequestError(
22                        "mock prompt model only accepts text user messages".into(),
23                    )));
24                }
25            },
26            _ => {
27                return Err(PromptError::CompletionError(CompletionError::RequestError(
28                    "mock prompt model only accepts user messages".into(),
29                )));
30            }
31        };
32
33        Ok(format!("Mock response: {prompt}"))
34    }
35}
36
37/// A vector index that always returns one JSON document containing `{"foo":"bar"}`.
38pub struct MockVectorStoreIndex;
39
40impl VectorStoreIndex for MockVectorStoreIndex {
41    type Filter = Filter<serde_json::Value>;
42
43    async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
44        &self,
45        _req: VectorSearchRequest,
46    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
47        let doc = serde_json::from_value(serde_json::json!({
48            "foo": "bar",
49        }))?;
50
51        Ok(vec![(1.0, "doc1".to_string(), doc)])
52    }
53
54    async fn top_n_ids(
55        &self,
56        _req: VectorSearchRequest,
57    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
58        Ok(vec![(1.0, "doc1".to_string())])
59    }
60}
61
62/// Document fixture returned by [`MockVectorStoreIndex`] in pipeline tests.
63#[derive(Debug, serde::Deserialize, PartialEq)]
64pub struct Foo {
65    pub foo: String,
66}