bep/pipeline/
agent_ops.rs

1use crate::{
2    completion::{self, CompletionModel},
3    extractor::{ExtractionError, Extractor},
4    vector_store,
5};
6
7use super::Op;
8
9pub struct Lookup<I, In, T> {
10    index: I,
11    n: usize,
12    _in: std::marker::PhantomData<In>,
13    _t: std::marker::PhantomData<T>,
14}
15
16impl<I, In, T> Lookup<I, In, T>
17where
18    I: vector_store::VectorStoreIndex,
19{
20    pub(crate) fn new(index: I, n: usize) -> Self {
21        Self {
22            index,
23            n,
24            _in: std::marker::PhantomData,
25            _t: std::marker::PhantomData,
26        }
27    }
28}
29
30impl<I, In, T> Op for Lookup<I, In, T>
31where
32    I: vector_store::VectorStoreIndex,
33    In: Into<String> + Send + Sync,
34    T: Send + Sync + for<'a> serde::Deserialize<'a>,
35{
36    type Input = In;
37    type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
38
39    async fn call(&self, input: Self::Input) -> Self::Output {
40        let query: String = input.into();
41
42        let docs = self
43            .index
44            .top_n::<T>(&query, self.n)
45            .await?
46            .into_iter()
47            .collect();
48
49        Ok(docs)
50    }
51}
52
53/// Create a new lookup operation.
54///
55/// The op will perform semantic search on the provided index and return the top `n`
56/// results closest results to the input.
57pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
58where
59    I: vector_store::VectorStoreIndex,
60    In: Into<String> + Send + Sync,
61    T: Send + Sync + for<'a> serde::Deserialize<'a>,
62{
63    Lookup::new(index, n)
64}
65
66pub struct Prompt<P, In> {
67    prompt: P,
68    _in: std::marker::PhantomData<In>,
69}
70
71impl<P, In> Prompt<P, In> {
72    pub(crate) fn new(prompt: P) -> Self {
73        Self {
74            prompt,
75            _in: std::marker::PhantomData,
76        }
77    }
78}
79
80impl<P, In> Op for Prompt<P, In>
81where
82    P: completion::Prompt,
83    In: Into<String> + Send + Sync,
84{
85    type Input = In;
86    type Output = Result<String, completion::PromptError>;
87
88    async fn call(&self, input: Self::Input) -> Self::Output {
89        let prompt: String = input.into();
90        self.prompt.prompt(&prompt).await
91    }
92}
93
94/// Create a new prompt operation.
95///
96/// The op will prompt the `model` with the input and return the response.
97pub fn prompt<P, In>(model: P) -> Prompt<P, In>
98where
99    P: completion::Prompt,
100    In: Into<String> + Send + Sync,
101{
102    Prompt::new(model)
103}
104
105pub struct Extract<M, Input, Output>
106where
107    M: CompletionModel,
108    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
109{
110    extractor: Extractor<M, Output>,
111    _in: std::marker::PhantomData<Input>,
112}
113
114impl<M, Input, Output> Extract<M, Input, Output>
115where
116    M: CompletionModel,
117    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
118{
119    pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
120        Self {
121            extractor,
122            _in: std::marker::PhantomData,
123        }
124    }
125}
126
127impl<M, Input, Output> Op for Extract<M, Input, Output>
128where
129    M: CompletionModel,
130    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
131    Input: Into<String> + Send + Sync,
132{
133    type Input = Input;
134    type Output = Result<Output, ExtractionError>;
135
136    async fn call(&self, input: Self::Input) -> Self::Output {
137        self.extractor.extract(&input.into()).await
138    }
139}
140
141/// Create a new extract operation.
142///
143/// The op will extract the structured data from the input using the provided `extractor`.
144pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
145where
146    M: CompletionModel,
147    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
148    Input: Into<String> + Send + Sync,
149{
150    Extract::new(extractor)
151}
152
153#[cfg(test)]
154pub mod tests {
155    use super::*;
156    use completion::{Prompt, PromptError};
157    use vector_store::{VectorStoreError, VectorStoreIndex};
158
159    pub struct MockModel;
160
161    impl Prompt for MockModel {
162        async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
163            Ok(format!("Mock response: {}", prompt))
164        }
165    }
166
167    pub struct MockIndex;
168
169    impl VectorStoreIndex for MockIndex {
170        async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
171            &self,
172            _query: &str,
173            _n: usize,
174        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
175            let doc = serde_json::from_value(serde_json::json!({
176                "foo": "bar",
177            }))
178            .unwrap();
179
180            Ok(vec![(1.0, "doc1".to_string(), doc)])
181        }
182
183        async fn top_n_ids(
184            &self,
185            _query: &str,
186            _n: usize,
187        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
188            Ok(vec![(1.0, "doc1".to_string())])
189        }
190    }
191
192    #[derive(Debug, serde::Deserialize, PartialEq)]
193    pub struct Foo {
194        pub foo: String,
195    }
196
197    #[tokio::test]
198    async fn test_lookup() {
199        let index = MockIndex;
200        let lookup = lookup::<MockIndex, String, Foo>(index, 1);
201
202        let result = lookup.call("query".to_string()).await.unwrap();
203        assert_eq!(
204            result,
205            vec![(
206                1.0,
207                "doc1".to_string(),
208                Foo {
209                    foo: "bar".to_string()
210                }
211            )]
212        );
213    }
214
215    #[tokio::test]
216    async fn test_prompt() {
217        let model = MockModel;
218        let prompt = prompt::<MockModel, String>(model);
219
220        let result = prompt.call("hello".to_string()).await.unwrap();
221        assert_eq!(result, "Mock response: hello");
222    }
223}