rig/pipeline/
agent_ops.rs

1use std::future::IntoFuture;
2
3use crate::{
4    completion::{self, CompletionModel},
5    extractor::{ExtractionError, Extractor},
6    message::Message,
7    vector_store,
8};
9
10use super::Op;
11
12pub struct Lookup<I, In, T> {
13    index: I,
14    n: usize,
15    _in: std::marker::PhantomData<In>,
16    _t: std::marker::PhantomData<T>,
17}
18
19impl<I, In, T> Lookup<I, In, T>
20where
21    I: vector_store::VectorStoreIndex,
22{
23    pub(crate) fn new(index: I, n: usize) -> Self {
24        Self {
25            index,
26            n,
27            _in: std::marker::PhantomData,
28            _t: std::marker::PhantomData,
29        }
30    }
31}
32
33impl<I, In, T> Op for Lookup<I, In, T>
34where
35    I: vector_store::VectorStoreIndex,
36    In: Into<String> + Send + Sync,
37    T: Send + Sync + for<'a> serde::Deserialize<'a>,
38{
39    type Input = In;
40    type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
41
42    async fn call(&self, input: Self::Input) -> Self::Output {
43        let query: String = input.into();
44
45        let docs = self
46            .index
47            .top_n::<T>(&query, self.n)
48            .await?
49            .into_iter()
50            .collect();
51
52        Ok(docs)
53    }
54}
55
56/// Create a new lookup operation.
57///
58/// The op will perform semantic search on the provided index and return the top `n`
59/// results closest results to the input.
60pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
61where
62    I: vector_store::VectorStoreIndex,
63    In: Into<String> + Send + Sync,
64    T: Send + Sync + for<'a> serde::Deserialize<'a>,
65{
66    Lookup::new(index, n)
67}
68
69pub struct Prompt<P, In> {
70    prompt: P,
71    _in: std::marker::PhantomData<In>,
72}
73
74impl<P, In> Prompt<P, In> {
75    pub(crate) fn new(prompt: P) -> Self {
76        Self {
77            prompt,
78            _in: std::marker::PhantomData,
79        }
80    }
81}
82
83impl<P, In> Op for Prompt<P, In>
84where
85    P: completion::Prompt + Send + Sync,
86    In: Into<String> + Send + Sync,
87{
88    type Input = In;
89    type Output = Result<String, completion::PromptError>;
90
91    fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
92        self.prompt.prompt(input.into()).into_future()
93    }
94}
95
96/// Create a new prompt operation.
97///
98/// The op will prompt the `model` with the input and return the response.
99pub fn prompt<P, In>(model: P) -> Prompt<P, In>
100where
101    P: completion::Prompt,
102    In: Into<String> + Send + Sync,
103{
104    Prompt::new(model)
105}
106
107pub struct Extract<M, Input, Output>
108where
109    M: CompletionModel,
110    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
111{
112    extractor: Extractor<M, Output>,
113    _in: std::marker::PhantomData<Input>,
114}
115
116impl<M, Input, Output> Extract<M, Input, Output>
117where
118    M: CompletionModel,
119    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
120{
121    pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
122        Self {
123            extractor,
124            _in: std::marker::PhantomData,
125        }
126    }
127}
128
129impl<M, Input, Output> Op for Extract<M, Input, Output>
130where
131    M: CompletionModel,
132    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
133    Input: Into<Message> + Send + Sync,
134{
135    type Input = Input;
136    type Output = Result<Output, ExtractionError>;
137
138    async fn call(&self, input: Self::Input) -> Self::Output {
139        self.extractor.extract(input).await
140    }
141}
142
143/// Create a new extract operation.
144///
145/// The op will extract the structured data from the input using the provided `extractor`.
146pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
147where
148    M: CompletionModel,
149    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
150    Input: Into<String> + Send + Sync,
151{
152    Extract::new(extractor)
153}
154
155#[cfg(test)]
156pub mod tests {
157    use super::*;
158    use crate::message;
159    use completion::{Prompt, PromptError};
160    use vector_store::{VectorStoreError, VectorStoreIndex};
161
162    pub struct MockModel;
163
164    impl Prompt for MockModel {
165        #[allow(refining_impl_trait)]
166        async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
167            let msg: message::Message = prompt.into();
168            let prompt = match msg {
169                message::Message::User { content } => match content.first() {
170                    message::UserContent::Text(message::Text { text }) => text,
171                    _ => unreachable!(),
172                },
173                _ => unreachable!(),
174            };
175            Ok(format!("Mock response: {}", prompt))
176        }
177    }
178
179    pub struct MockIndex;
180
181    impl VectorStoreIndex for MockIndex {
182        async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
183            &self,
184            _query: &str,
185            _n: usize,
186        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
187            let doc = serde_json::from_value(serde_json::json!({
188                "foo": "bar",
189            }))
190            .unwrap();
191
192            Ok(vec![(1.0, "doc1".to_string(), doc)])
193        }
194
195        async fn top_n_ids(
196            &self,
197            _query: &str,
198            _n: usize,
199        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
200            Ok(vec![(1.0, "doc1".to_string())])
201        }
202    }
203
204    #[derive(Debug, serde::Deserialize, PartialEq)]
205    pub struct Foo {
206        pub foo: String,
207    }
208
209    #[tokio::test]
210    async fn test_lookup() {
211        let index = MockIndex;
212        let lookup = lookup::<MockIndex, String, Foo>(index, 1);
213
214        let result = lookup.call("query".to_string()).await.unwrap();
215        assert_eq!(
216            result,
217            vec![(
218                1.0,
219                "doc1".to_string(),
220                Foo {
221                    foo: "bar".to_string()
222                }
223            )]
224        );
225    }
226
227    #[tokio::test]
228    async fn test_prompt() {
229        let model = MockModel;
230        let prompt = prompt::<MockModel, String>(model);
231
232        let result = prompt.call("hello".to_string()).await.unwrap();
233        assert_eq!(result, "Mock response: hello");
234    }
235}