1use crate::{
2 completion::{self, CompletionModel},
3 extractor::{ExtractionError, Extractor},
4 vector_store,
5};
6
7use super::Op;
8
9pub struct Lookup<I, In, T> {
10 index: I,
11 n: usize,
12 _in: std::marker::PhantomData<In>,
13 _t: std::marker::PhantomData<T>,
14}
15
16impl<I, In, T> Lookup<I, In, T>
17where
18 I: vector_store::VectorStoreIndex,
19{
20 pub(crate) fn new(index: I, n: usize) -> Self {
21 Self {
22 index,
23 n,
24 _in: std::marker::PhantomData,
25 _t: std::marker::PhantomData,
26 }
27 }
28}
29
30impl<I, In, T> Op for Lookup<I, In, T>
31where
32 I: vector_store::VectorStoreIndex,
33 In: Into<String> + Send + Sync,
34 T: Send + Sync + for<'a> serde::Deserialize<'a>,
35{
36 type Input = In;
37 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
38
39 async fn call(&self, input: Self::Input) -> Self::Output {
40 let query: String = input.into();
41
42 let docs = self
43 .index
44 .top_n::<T>(&query, self.n)
45 .await?
46 .into_iter()
47 .collect();
48
49 Ok(docs)
50 }
51}
52
53pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
58where
59 I: vector_store::VectorStoreIndex,
60 In: Into<String> + Send + Sync,
61 T: Send + Sync + for<'a> serde::Deserialize<'a>,
62{
63 Lookup::new(index, n)
64}
65
66pub struct Prompt<P, In> {
67 prompt: P,
68 _in: std::marker::PhantomData<In>,
69}
70
71impl<P, In> Prompt<P, In> {
72 pub(crate) fn new(prompt: P) -> Self {
73 Self {
74 prompt,
75 _in: std::marker::PhantomData,
76 }
77 }
78}
79
80impl<P, In> Op for Prompt<P, In>
81where
82 P: completion::Prompt,
83 In: Into<String> + Send + Sync,
84{
85 type Input = In;
86 type Output = Result<String, completion::PromptError>;
87
88 async fn call(&self, input: Self::Input) -> Self::Output {
89 self.prompt.prompt(input.into()).await
90 }
91}
92
93pub fn prompt<P, In>(model: P) -> Prompt<P, In>
97where
98 P: completion::Prompt,
99 In: Into<String> + Send + Sync,
100{
101 Prompt::new(model)
102}
103
104pub struct Extract<M, Input, Output>
105where
106 M: CompletionModel,
107 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
108{
109 extractor: Extractor<M, Output>,
110 _in: std::marker::PhantomData<Input>,
111}
112
113impl<M, Input, Output> Extract<M, Input, Output>
114where
115 M: CompletionModel,
116 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
117{
118 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
119 Self {
120 extractor,
121 _in: std::marker::PhantomData,
122 }
123 }
124}
125
126impl<M, Input, Output> Op for Extract<M, Input, Output>
127where
128 M: CompletionModel,
129 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
130 Input: Into<String> + Send + Sync,
131{
132 type Input = Input;
133 type Output = Result<Output, ExtractionError>;
134
135 async fn call(&self, input: Self::Input) -> Self::Output {
136 self.extractor.extract(&input.into()).await
137 }
138}
139
140pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
144where
145 M: CompletionModel,
146 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
147 Input: Into<String> + Send + Sync,
148{
149 Extract::new(extractor)
150}
151
152#[cfg(test)]
153pub mod tests {
154 use super::*;
155 use crate::message;
156 use completion::{Prompt, PromptError};
157 use vector_store::{VectorStoreError, VectorStoreIndex};
158
159 pub struct MockModel;
160
161 impl Prompt for MockModel {
162 async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
163 let msg: message::Message = prompt.into();
164 let prompt = match msg {
165 message::Message::User { content } => match content.first() {
166 message::UserContent::Text(message::Text { text }) => text,
167 _ => unreachable!(),
168 },
169 _ => unreachable!(),
170 };
171 Ok(format!("Mock response: {}", prompt))
172 }
173 }
174
175 pub struct MockIndex;
176
177 impl VectorStoreIndex for MockIndex {
178 async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
179 &self,
180 _query: &str,
181 _n: usize,
182 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
183 let doc = serde_json::from_value(serde_json::json!({
184 "foo": "bar",
185 }))
186 .unwrap();
187
188 Ok(vec![(1.0, "doc1".to_string(), doc)])
189 }
190
191 async fn top_n_ids(
192 &self,
193 _query: &str,
194 _n: usize,
195 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
196 Ok(vec![(1.0, "doc1".to_string())])
197 }
198 }
199
200 #[derive(Debug, serde::Deserialize, PartialEq)]
201 pub struct Foo {
202 pub foo: String,
203 }
204
205 #[tokio::test]
206 async fn test_lookup() {
207 let index = MockIndex;
208 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
209
210 let result = lookup.call("query".to_string()).await.unwrap();
211 assert_eq!(
212 result,
213 vec![(
214 1.0,
215 "doc1".to_string(),
216 Foo {
217 foo: "bar".to_string()
218 }
219 )]
220 );
221 }
222
223 #[tokio::test]
224 async fn test_prompt() {
225 let model = MockModel;
226 let prompt = prompt::<MockModel, String>(model);
227
228 let result = prompt.call("hello".to_string()).await.unwrap();
229 assert_eq!(result, "Mock response: hello");
230 }
231}