1use std::future::IntoFuture;
2
3use crate::{
4 completion::{self, CompletionModel},
5 extractor::{ExtractionError, Extractor},
6 message::Message,
7 vector_store::{self, request::VectorSearchRequest},
8};
9
10use super::Op;
11
12pub struct Lookup<I, In, T> {
13 index: I,
14 n: usize,
15 _in: std::marker::PhantomData<In>,
16 _t: std::marker::PhantomData<T>,
17}
18
19impl<I, In, T> Lookup<I, In, T>
20where
21 I: vector_store::VectorStoreIndex,
22{
23 pub(crate) fn new(index: I, n: usize) -> Self {
24 Self {
25 index,
26 n,
27 _in: std::marker::PhantomData,
28 _t: std::marker::PhantomData,
29 }
30 }
31}
32
33impl<I, In, T> Op for Lookup<I, In, T>
34where
35 I: vector_store::VectorStoreIndex,
36 In: Into<String> + Send + Sync,
37 T: Send + Sync + for<'a> serde::Deserialize<'a>,
38{
39 type Input = In;
40 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
41
42 async fn call(&self, input: Self::Input) -> Self::Output {
43 let query: String = input.into();
44
45 let req = VectorSearchRequest::builder()
46 .query(query)
47 .samples(self.n as u64)
48 .build()?;
49
50 let docs = self.index.top_n::<T>(req).await?.into_iter().collect();
51
52 Ok(docs)
53 }
54}
55
56pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
61where
62 I: vector_store::VectorStoreIndex,
63 In: Into<String> + Send + Sync,
64 T: Send + Sync + for<'a> serde::Deserialize<'a>,
65{
66 Lookup::new(index, n)
67}
68
69pub struct Prompt<P, In> {
70 prompt: P,
71 _in: std::marker::PhantomData<In>,
72}
73
74impl<P, In> Prompt<P, In> {
75 pub(crate) fn new(prompt: P) -> Self {
76 Self {
77 prompt,
78 _in: std::marker::PhantomData,
79 }
80 }
81}
82
83impl<P, In> Op for Prompt<P, In>
84where
85 P: completion::Prompt + Send + Sync,
86 In: Into<String> + Send + Sync,
87{
88 type Input = In;
89 type Output = Result<String, completion::PromptError>;
90
91 fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
92 self.prompt.prompt(input.into()).into_future()
93 }
94}
95
96pub fn prompt<P, In>(model: P) -> Prompt<P, In>
100where
101 P: completion::Prompt,
102 In: Into<String> + Send + Sync,
103{
104 Prompt::new(model)
105}
106
107pub struct Extract<M, Input, Output>
108where
109 M: CompletionModel,
110 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
111{
112 extractor: Extractor<M, Output>,
113 _in: std::marker::PhantomData<Input>,
114}
115
116impl<M, Input, Output> Extract<M, Input, Output>
117where
118 M: CompletionModel,
119 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
120{
121 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
122 Self {
123 extractor,
124 _in: std::marker::PhantomData,
125 }
126 }
127}
128
129impl<M, Input, Output> Op for Extract<M, Input, Output>
130where
131 M: CompletionModel,
132 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
133 Input: Into<Message> + Send + Sync,
134{
135 type Input = Input;
136 type Output = Result<Output, ExtractionError>;
137
138 async fn call(&self, input: Self::Input) -> Self::Output {
139 self.extractor.extract(input).await
140 }
141}
142
143pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
147where
148 M: CompletionModel,
149 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
150 Input: Into<String> + Send + Sync,
151{
152 Extract::new(extractor)
153}
154
155#[cfg(test)]
156pub mod tests {
157 use super::*;
158 use crate::message;
159 use completion::{Prompt, PromptError};
160 use vector_store::{VectorStoreError, VectorStoreIndex};
161
162 pub struct MockModel;
163
164 impl Prompt for MockModel {
165 #[allow(refining_impl_trait)]
166 async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
167 let msg: message::Message = prompt.into();
168 let prompt = match msg {
169 message::Message::User { content } => match content.first() {
170 message::UserContent::Text(message::Text { text }) => text,
171 _ => unreachable!(),
172 },
173 _ => unreachable!(),
174 };
175 Ok(format!("Mock response: {prompt}"))
176 }
177 }
178
179 pub struct MockIndex;
180
181 impl VectorStoreIndex for MockIndex {
182 async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
183 &self,
184 _req: VectorSearchRequest,
185 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
186 let doc = serde_json::from_value(serde_json::json!({
187 "foo": "bar",
188 }))
189 .unwrap();
190
191 Ok(vec![(1.0, "doc1".to_string(), doc)])
192 }
193
194 async fn top_n_ids(
195 &self,
196 _req: VectorSearchRequest,
197 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
198 Ok(vec![(1.0, "doc1".to_string())])
199 }
200 }
201
202 #[derive(Debug, serde::Deserialize, PartialEq)]
203 pub struct Foo {
204 pub foo: String,
205 }
206
207 #[tokio::test]
208 async fn test_lookup() {
209 let index = MockIndex;
210 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
211
212 let result = lookup.call("query".to_string()).await.unwrap();
213 assert_eq!(
214 result,
215 vec![(
216 1.0,
217 "doc1".to_string(),
218 Foo {
219 foo: "bar".to_string()
220 }
221 )]
222 );
223 }
224
225 #[tokio::test]
226 async fn test_prompt() {
227 let model = MockModel;
228 let prompt = prompt::<MockModel, String>(model);
229
230 let result = prompt.call("hello".to_string()).await.unwrap();
231 assert_eq!(result, "Mock response: hello");
232 }
233}