rig/pipeline/
agent_ops.rs

1use crate::{
2    completion::{self, CompletionModel},
3    extractor::{ExtractionError, Extractor},
4    vector_store,
5};
6
7use super::Op;
8
9pub struct Lookup<I, In, T> {
10    index: I,
11    n: usize,
12    _in: std::marker::PhantomData<In>,
13    _t: std::marker::PhantomData<T>,
14}
15
16impl<I, In, T> Lookup<I, In, T>
17where
18    I: vector_store::VectorStoreIndex,
19{
20    pub(crate) fn new(index: I, n: usize) -> Self {
21        Self {
22            index,
23            n,
24            _in: std::marker::PhantomData,
25            _t: std::marker::PhantomData,
26        }
27    }
28}
29
30impl<I, In, T> Op for Lookup<I, In, T>
31where
32    I: vector_store::VectorStoreIndex,
33    In: Into<String> + Send + Sync,
34    T: Send + Sync + for<'a> serde::Deserialize<'a>,
35{
36    type Input = In;
37    type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
38
39    async fn call(&self, input: Self::Input) -> Self::Output {
40        let query: String = input.into();
41
42        let docs = self
43            .index
44            .top_n::<T>(&query, self.n)
45            .await?
46            .into_iter()
47            .collect();
48
49        Ok(docs)
50    }
51}
52
53/// Create a new lookup operation.
54///
55/// The op will perform semantic search on the provided index and return the top `n`
56/// results closest results to the input.
57pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
58where
59    I: vector_store::VectorStoreIndex,
60    In: Into<String> + Send + Sync,
61    T: Send + Sync + for<'a> serde::Deserialize<'a>,
62{
63    Lookup::new(index, n)
64}
65
66pub struct Prompt<P, In> {
67    prompt: P,
68    _in: std::marker::PhantomData<In>,
69}
70
71impl<P, In> Prompt<P, In> {
72    pub(crate) fn new(prompt: P) -> Self {
73        Self {
74            prompt,
75            _in: std::marker::PhantomData,
76        }
77    }
78}
79
80impl<P, In> Op for Prompt<P, In>
81where
82    P: completion::Prompt,
83    In: Into<String> + Send + Sync,
84{
85    type Input = In;
86    type Output = Result<String, completion::PromptError>;
87
88    async fn call(&self, input: Self::Input) -> Self::Output {
89        self.prompt.prompt(input.into()).await
90    }
91}
92
93/// Create a new prompt operation.
94///
95/// The op will prompt the `model` with the input and return the response.
96pub fn prompt<P, In>(model: P) -> Prompt<P, In>
97where
98    P: completion::Prompt,
99    In: Into<String> + Send + Sync,
100{
101    Prompt::new(model)
102}
103
104pub struct Extract<M, Input, Output>
105where
106    M: CompletionModel,
107    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
108{
109    extractor: Extractor<M, Output>,
110    _in: std::marker::PhantomData<Input>,
111}
112
113impl<M, Input, Output> Extract<M, Input, Output>
114where
115    M: CompletionModel,
116    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
117{
118    pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
119        Self {
120            extractor,
121            _in: std::marker::PhantomData,
122        }
123    }
124}
125
126impl<M, Input, Output> Op for Extract<M, Input, Output>
127where
128    M: CompletionModel,
129    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
130    Input: Into<String> + Send + Sync,
131{
132    type Input = Input;
133    type Output = Result<Output, ExtractionError>;
134
135    async fn call(&self, input: Self::Input) -> Self::Output {
136        self.extractor.extract(&input.into()).await
137    }
138}
139
140/// Create a new extract operation.
141///
142/// The op will extract the structured data from the input using the provided `extractor`.
143pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
144where
145    M: CompletionModel,
146    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
147    Input: Into<String> + Send + Sync,
148{
149    Extract::new(extractor)
150}
151
152#[cfg(test)]
153pub mod tests {
154    use super::*;
155    use crate::message;
156    use completion::{Prompt, PromptError};
157    use vector_store::{VectorStoreError, VectorStoreIndex};
158
159    pub struct MockModel;
160
161    impl Prompt for MockModel {
162        async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
163            let msg: message::Message = prompt.into();
164            let prompt = match msg {
165                message::Message::User { content } => match content.first() {
166                    message::UserContent::Text(message::Text { text }) => text,
167                    _ => unreachable!(),
168                },
169                _ => unreachable!(),
170            };
171            Ok(format!("Mock response: {}", prompt))
172        }
173    }
174
175    pub struct MockIndex;
176
177    impl VectorStoreIndex for MockIndex {
178        async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
179            &self,
180            _query: &str,
181            _n: usize,
182        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
183            let doc = serde_json::from_value(serde_json::json!({
184                "foo": "bar",
185            }))
186            .unwrap();
187
188            Ok(vec![(1.0, "doc1".to_string(), doc)])
189        }
190
191        async fn top_n_ids(
192            &self,
193            _query: &str,
194            _n: usize,
195        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
196            Ok(vec![(1.0, "doc1".to_string())])
197        }
198    }
199
200    #[derive(Debug, serde::Deserialize, PartialEq)]
201    pub struct Foo {
202        pub foo: String,
203    }
204
205    #[tokio::test]
206    async fn test_lookup() {
207        let index = MockIndex;
208        let lookup = lookup::<MockIndex, String, Foo>(index, 1);
209
210        let result = lookup.call("query".to_string()).await.unwrap();
211        assert_eq!(
212            result,
213            vec![(
214                1.0,
215                "doc1".to_string(),
216                Foo {
217                    foo: "bar".to_string()
218                }
219            )]
220        );
221    }
222
223    #[tokio::test]
224    async fn test_prompt() {
225        let model = MockModel;
226        let prompt = prompt::<MockModel, String>(model);
227
228        let result = prompt.call("hello".to_string()).await.unwrap();
229        assert_eq!(result, "Mock response: hello");
230    }
231}