1use crate::{
2 completion::{self, CompletionModel},
3 extractor::{ExtractionError, Extractor},
4 vector_store,
5};
6
7use super::Op;
8
9pub struct Lookup<I, In, T> {
10 index: I,
11 n: usize,
12 _in: std::marker::PhantomData<In>,
13 _t: std::marker::PhantomData<T>,
14}
15
16impl<I, In, T> Lookup<I, In, T>
17where
18 I: vector_store::VectorStoreIndex,
19{
20 pub(crate) fn new(index: I, n: usize) -> Self {
21 Self {
22 index,
23 n,
24 _in: std::marker::PhantomData,
25 _t: std::marker::PhantomData,
26 }
27 }
28}
29
30impl<I, In, T> Op for Lookup<I, In, T>
31where
32 I: vector_store::VectorStoreIndex,
33 In: Into<String> + Send + Sync,
34 T: Send + Sync + for<'a> serde::Deserialize<'a>,
35{
36 type Input = In;
37 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
38
39 async fn call(&self, input: Self::Input) -> Self::Output {
40 let query: String = input.into();
41
42 let docs = self
43 .index
44 .top_n::<T>(&query, self.n)
45 .await?
46 .into_iter()
47 .collect();
48
49 Ok(docs)
50 }
51}
52
53pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
58where
59 I: vector_store::VectorStoreIndex,
60 In: Into<String> + Send + Sync,
61 T: Send + Sync + for<'a> serde::Deserialize<'a>,
62{
63 Lookup::new(index, n)
64}
65
66pub struct Prompt<P, In> {
67 prompt: P,
68 _in: std::marker::PhantomData<In>,
69}
70
71impl<P, In> Prompt<P, In> {
72 pub(crate) fn new(prompt: P) -> Self {
73 Self {
74 prompt,
75 _in: std::marker::PhantomData,
76 }
77 }
78}
79
80impl<P, In> Op for Prompt<P, In>
81where
82 P: completion::Prompt,
83 In: Into<String> + Send + Sync,
84{
85 type Input = In;
86 type Output = Result<String, completion::PromptError>;
87
88 async fn call(&self, input: Self::Input) -> Self::Output {
89 let prompt: String = input.into();
90 self.prompt.prompt(&prompt).await
91 }
92}
93
94pub fn prompt<P, In>(model: P) -> Prompt<P, In>
98where
99 P: completion::Prompt,
100 In: Into<String> + Send + Sync,
101{
102 Prompt::new(model)
103}
104
105pub struct Extract<M, Input, Output>
106where
107 M: CompletionModel,
108 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
109{
110 extractor: Extractor<M, Output>,
111 _in: std::marker::PhantomData<Input>,
112}
113
114impl<M, Input, Output> Extract<M, Input, Output>
115where
116 M: CompletionModel,
117 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
118{
119 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
120 Self {
121 extractor,
122 _in: std::marker::PhantomData,
123 }
124 }
125}
126
127impl<M, Input, Output> Op for Extract<M, Input, Output>
128where
129 M: CompletionModel,
130 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
131 Input: Into<String> + Send + Sync,
132{
133 type Input = Input;
134 type Output = Result<Output, ExtractionError>;
135
136 async fn call(&self, input: Self::Input) -> Self::Output {
137 self.extractor.extract(&input.into()).await
138 }
139}
140
141pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
145where
146 M: CompletionModel,
147 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
148 Input: Into<String> + Send + Sync,
149{
150 Extract::new(extractor)
151}
152
153#[cfg(test)]
154pub mod tests {
155 use super::*;
156 use completion::{Prompt, PromptError};
157 use vector_store::{VectorStoreError, VectorStoreIndex};
158
159 pub struct MockModel;
160
161 impl Prompt for MockModel {
162 async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
163 Ok(format!("Mock response: {}", prompt))
164 }
165 }
166
167 pub struct MockIndex;
168
169 impl VectorStoreIndex for MockIndex {
170 async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
171 &self,
172 _query: &str,
173 _n: usize,
174 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
175 let doc = serde_json::from_value(serde_json::json!({
176 "foo": "bar",
177 }))
178 .unwrap();
179
180 Ok(vec![(1.0, "doc1".to_string(), doc)])
181 }
182
183 async fn top_n_ids(
184 &self,
185 _query: &str,
186 _n: usize,
187 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
188 Ok(vec![(1.0, "doc1".to_string())])
189 }
190 }
191
192 #[derive(Debug, serde::Deserialize, PartialEq)]
193 pub struct Foo {
194 pub foo: String,
195 }
196
197 #[tokio::test]
198 async fn test_lookup() {
199 let index = MockIndex;
200 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
201
202 let result = lookup.call("query".to_string()).await.unwrap();
203 assert_eq!(
204 result,
205 vec![(
206 1.0,
207 "doc1".to_string(),
208 Foo {
209 foo: "bar".to_string()
210 }
211 )]
212 );
213 }
214
215 #[tokio::test]
216 async fn test_prompt() {
217 let model = MockModel;
218 let prompt = prompt::<MockModel, String>(model);
219
220 let result = prompt.call("hello".to_string()).await.unwrap();
221 assert_eq!(result, "Mock response: hello");
222 }
223}