rig_core/test_utils/
pipeline.rs1use crate::{
4 completion::{CompletionError, Prompt, PromptError},
5 message::{self, Message},
6 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndex, request::Filter},
7 wasm_compat::WasmCompatSend,
8};
9
10pub struct MockPromptModel;
12
13impl Prompt for MockPromptModel {
14 #[allow(refining_impl_trait)]
15 async fn prompt(&self, prompt: impl Into<Message>) -> Result<String, PromptError> {
16 let msg = prompt.into();
17 let prompt = match msg {
18 Message::User { content } => match content.first() {
19 message::UserContent::Text(message::Text { text }) => text,
20 _ => {
21 return Err(PromptError::CompletionError(CompletionError::RequestError(
22 "mock prompt model only accepts text user messages".into(),
23 )));
24 }
25 },
26 _ => {
27 return Err(PromptError::CompletionError(CompletionError::RequestError(
28 "mock prompt model only accepts user messages".into(),
29 )));
30 }
31 };
32
33 Ok(format!("Mock response: {prompt}"))
34 }
35}
36
37pub struct MockVectorStoreIndex;
39
40impl VectorStoreIndex for MockVectorStoreIndex {
41 type Filter = Filter<serde_json::Value>;
42
43 async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
44 &self,
45 _req: VectorSearchRequest,
46 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
47 let doc = serde_json::from_value(serde_json::json!({
48 "foo": "bar",
49 }))?;
50
51 Ok(vec![(1.0, "doc1".to_string(), doc)])
52 }
53
54 async fn top_n_ids(
55 &self,
56 _req: VectorSearchRequest,
57 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
58 Ok(vec![(1.0, "doc1".to_string())])
59 }
60}
61
62#[derive(Debug, serde::Deserialize, PartialEq)]
64pub struct Foo {
65 pub foo: String,
66}