1use std::future::IntoFuture;
2
3use crate::{
4 completion::{self, CompletionModel},
5 extractor::{ExtractionError, Extractor},
6 message::Message,
7 vector_store,
8};
9
10use super::Op;
11
12pub struct Lookup<I, In, T> {
13 index: I,
14 n: usize,
15 _in: std::marker::PhantomData<In>,
16 _t: std::marker::PhantomData<T>,
17}
18
19impl<I, In, T> Lookup<I, In, T>
20where
21 I: vector_store::VectorStoreIndex,
22{
23 pub(crate) fn new(index: I, n: usize) -> Self {
24 Self {
25 index,
26 n,
27 _in: std::marker::PhantomData,
28 _t: std::marker::PhantomData,
29 }
30 }
31}
32
33impl<I, In, T> Op for Lookup<I, In, T>
34where
35 I: vector_store::VectorStoreIndex,
36 In: Into<String> + Send + Sync,
37 T: Send + Sync + for<'a> serde::Deserialize<'a>,
38{
39 type Input = In;
40 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
41
42 async fn call(&self, input: Self::Input) -> Self::Output {
43 let query: String = input.into();
44
45 let docs = self
46 .index
47 .top_n::<T>(&query, self.n)
48 .await?
49 .into_iter()
50 .collect();
51
52 Ok(docs)
53 }
54}
55
56pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
61where
62 I: vector_store::VectorStoreIndex,
63 In: Into<String> + Send + Sync,
64 T: Send + Sync + for<'a> serde::Deserialize<'a>,
65{
66 Lookup::new(index, n)
67}
68
69pub struct Prompt<P, In> {
70 prompt: P,
71 _in: std::marker::PhantomData<In>,
72}
73
74impl<P, In> Prompt<P, In> {
75 pub(crate) fn new(prompt: P) -> Self {
76 Self {
77 prompt,
78 _in: std::marker::PhantomData,
79 }
80 }
81}
82
83impl<P, In> Op for Prompt<P, In>
84where
85 P: completion::Prompt + Send + Sync,
86 In: Into<String> + Send + Sync,
87{
88 type Input = In;
89 type Output = Result<String, completion::PromptError>;
90
91 fn call(&self, input: Self::Input) -> impl std::future::Future<Output = Self::Output> + Send {
92 self.prompt.prompt(input.into()).into_future()
93 }
94}
95
96pub fn prompt<P, In>(model: P) -> Prompt<P, In>
100where
101 P: completion::Prompt,
102 In: Into<String> + Send + Sync,
103{
104 Prompt::new(model)
105}
106
107pub struct Extract<M, Input, Output>
108where
109 M: CompletionModel,
110 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
111{
112 extractor: Extractor<M, Output>,
113 _in: std::marker::PhantomData<Input>,
114}
115
116impl<M, Input, Output> Extract<M, Input, Output>
117where
118 M: CompletionModel,
119 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
120{
121 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
122 Self {
123 extractor,
124 _in: std::marker::PhantomData,
125 }
126 }
127}
128
129impl<M, Input, Output> Op for Extract<M, Input, Output>
130where
131 M: CompletionModel,
132 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
133 Input: Into<Message> + Send + Sync,
134{
135 type Input = Input;
136 type Output = Result<Output, ExtractionError>;
137
138 async fn call(&self, input: Self::Input) -> Self::Output {
139 self.extractor.extract(input).await
140 }
141}
142
143pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
147where
148 M: CompletionModel,
149 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
150 Input: Into<String> + Send + Sync,
151{
152 Extract::new(extractor)
153}
154
155#[cfg(test)]
156pub mod tests {
157 use super::*;
158 use crate::message;
159 use completion::{Prompt, PromptError};
160 use vector_store::{VectorStoreError, VectorStoreIndex};
161
162 pub struct MockModel;
163
164 impl Prompt for MockModel {
165 #[allow(refining_impl_trait)]
166 async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
167 let msg: message::Message = prompt.into();
168 let prompt = match msg {
169 message::Message::User { content } => match content.first() {
170 message::UserContent::Text(message::Text { text }) => text,
171 _ => unreachable!(),
172 },
173 _ => unreachable!(),
174 };
175 Ok(format!("Mock response: {prompt}"))
176 }
177 }
178
179 pub struct MockIndex;
180
181 impl VectorStoreIndex for MockIndex {
182 async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
183 &self,
184 _query: &str,
185 _n: usize,
186 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
187 let doc = serde_json::from_value(serde_json::json!({
188 "foo": "bar",
189 }))
190 .unwrap();
191
192 Ok(vec![(1.0, "doc1".to_string(), doc)])
193 }
194
195 async fn top_n_ids(
196 &self,
197 _query: &str,
198 _n: usize,
199 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
200 Ok(vec![(1.0, "doc1".to_string())])
201 }
202 }
203
204 #[derive(Debug, serde::Deserialize, PartialEq)]
205 pub struct Foo {
206 pub foo: String,
207 }
208
209 #[tokio::test]
210 async fn test_lookup() {
211 let index = MockIndex;
212 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
213
214 let result = lookup.call("query".to_string()).await.unwrap();
215 assert_eq!(
216 result,
217 vec![(
218 1.0,
219 "doc1".to_string(),
220 Foo {
221 foo: "bar".to_string()
222 }
223 )]
224 );
225 }
226
227 #[tokio::test]
228 async fn test_prompt() {
229 let model = MockModel;
230 let prompt = prompt::<MockModel, String>(model);
231
232 let result = prompt.call("hello".to_string()).await.unwrap();
233 assert_eq!(result, "Mock response: hello");
234 }
235}