rig/pipeline/
agent_ops.rs

1use std::future::IntoFuture;
2
3use crate::{
4    completion::{self, CompletionModel},
5    extractor::{ExtractionError, Extractor},
6    message::Message,
7    vector_store::{self, request::VectorSearchRequest},
8    wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11use super::Op;
12
13pub struct Lookup<I, In, T> {
14    index: I,
15    n: usize,
16    _in: std::marker::PhantomData<In>,
17    _t: std::marker::PhantomData<T>,
18}
19
20impl<I, In, T> Lookup<I, In, T>
21where
22    I: vector_store::VectorStoreIndex,
23{
24    pub(crate) fn new(index: I, n: usize) -> Self {
25        Self {
26            index,
27            n,
28            _in: std::marker::PhantomData,
29            _t: std::marker::PhantomData,
30        }
31    }
32}
33
34impl<I, In, T> Op for Lookup<I, In, T>
35where
36    I: vector_store::VectorStoreIndex,
37    In: Into<String> + WasmCompatSend + WasmCompatSync,
38    T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
39{
40    type Input = In;
41    type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
42
43    async fn call(&self, input: Self::Input) -> Self::Output {
44        let query: String = input.into();
45
46        let req = VectorSearchRequest::builder()
47            .query(query)
48            .samples(self.n as u64)
49            .build()?;
50
51        let docs = self.index.top_n::<T>(req).await?.into_iter().collect();
52
53        Ok(docs)
54    }
55}
56
57/// Create a new lookup operation.
58///
59/// The op will perform semantic search on the provided index and return the top `n`
60/// results closest results to the input.
61pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
62where
63    I: vector_store::VectorStoreIndex,
64    In: Into<String> + WasmCompatSend + WasmCompatSync,
65    T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
66{
67    Lookup::new(index, n)
68}
69
70pub struct Prompt<P, In> {
71    prompt: P,
72    _in: std::marker::PhantomData<In>,
73}
74
75impl<P, In> Prompt<P, In> {
76    pub(crate) fn new(prompt: P) -> Self {
77        Self {
78            prompt,
79            _in: std::marker::PhantomData,
80        }
81    }
82}
83
84impl<P, In> Op for Prompt<P, In>
85where
86    P: completion::Prompt + WasmCompatSend + WasmCompatSync,
87    In: Into<String> + WasmCompatSend + WasmCompatSync,
88{
89    type Input = In;
90    type Output = Result<String, completion::PromptError>;
91
92    fn call(
93        &self,
94        input: Self::Input,
95    ) -> impl std::future::Future<Output = Self::Output> + WasmCompatSend {
96        self.prompt.prompt(input.into()).into_future()
97    }
98}
99
100/// Create a new prompt operation.
101///
102/// The op will prompt the `model` with the input and return the response.
103pub fn prompt<P, In>(model: P) -> Prompt<P, In>
104where
105    P: completion::Prompt,
106    In: Into<String> + WasmCompatSend + WasmCompatSync,
107{
108    Prompt::new(model)
109}
110
111pub struct Extract<M, Input, Output>
112where
113    M: CompletionModel,
114    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
115{
116    extractor: Extractor<M, Output>,
117    _in: std::marker::PhantomData<Input>,
118}
119
120impl<M, Input, Output> Extract<M, Input, Output>
121where
122    M: CompletionModel,
123    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
124{
125    pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
126        Self {
127            extractor,
128            _in: std::marker::PhantomData,
129        }
130    }
131}
132
133impl<M, Input, Output> Op for Extract<M, Input, Output>
134where
135    M: CompletionModel,
136    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
137    Input: Into<Message> + WasmCompatSend + WasmCompatSync,
138{
139    type Input = Input;
140    type Output = Result<Output, ExtractionError>;
141
142    async fn call(&self, input: Self::Input) -> Self::Output {
143        self.extractor.extract(input).await
144    }
145}
146
147/// Create a new extract operation.
148///
149/// The op will extract the structured data from the input using the provided `extractor`.
150pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
151where
152    M: CompletionModel,
153    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
154    Input: Into<String> + WasmCompatSend + WasmCompatSync,
155{
156    Extract::new(extractor)
157}
158
159#[cfg(test)]
160pub mod tests {
161    use super::*;
162    use crate::message;
163    use completion::{Prompt, PromptError};
164    use vector_store::{VectorStoreError, VectorStoreIndex};
165
166    pub struct MockModel;
167
168    impl Prompt for MockModel {
169        #[allow(refining_impl_trait)]
170        async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
171            let msg: message::Message = prompt.into();
172            let prompt = match msg {
173                message::Message::User { content } => match content.first() {
174                    message::UserContent::Text(message::Text { text }) => text,
175                    _ => unreachable!(),
176                },
177                _ => unreachable!(),
178            };
179            Ok(format!("Mock response: {prompt}"))
180        }
181    }
182
183    pub struct MockIndex;
184
185    impl VectorStoreIndex for MockIndex {
186        async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
187            &self,
188            _req: VectorSearchRequest,
189        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
190            let doc = serde_json::from_value(serde_json::json!({
191                "foo": "bar",
192            }))
193            .unwrap();
194
195            Ok(vec![(1.0, "doc1".to_string(), doc)])
196        }
197
198        async fn top_n_ids(
199            &self,
200            _req: VectorSearchRequest,
201        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
202            Ok(vec![(1.0, "doc1".to_string())])
203        }
204    }
205
206    #[derive(Debug, serde::Deserialize, PartialEq)]
207    pub struct Foo {
208        pub foo: String,
209    }
210
211    #[tokio::test]
212    async fn test_lookup() {
213        let index = MockIndex;
214        let lookup = lookup::<MockIndex, String, Foo>(index, 1);
215
216        let result = lookup.call("query".to_string()).await.unwrap();
217        assert_eq!(
218            result,
219            vec![(
220                1.0,
221                "doc1".to_string(),
222                Foo {
223                    foo: "bar".to_string()
224                }
225            )]
226        );
227    }
228
229    #[tokio::test]
230    async fn test_prompt() {
231        let model = MockModel;
232        let prompt = prompt::<MockModel, String>(model);
233
234        let result = prompt.call("hello".to_string()).await.unwrap();
235        assert_eq!(result, "Mock response: hello");
236    }
237}