1use std::future::IntoFuture;
2
3use crate::{
4 completion::{self, CompletionModel},
5 extractor::{ExtractionError, Extractor},
6 message::Message,
7 vector_store::{self, request::VectorSearchRequest},
8 wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11use super::Op;
12
13pub struct Lookup<I, In, T> {
14 index: I,
15 n: usize,
16 _in: std::marker::PhantomData<In>,
17 _t: std::marker::PhantomData<T>,
18}
19
20impl<I, In, T> Lookup<I, In, T>
21where
22 I: vector_store::VectorStoreIndex,
23{
24 pub(crate) fn new(index: I, n: usize) -> Self {
25 Self {
26 index,
27 n,
28 _in: std::marker::PhantomData,
29 _t: std::marker::PhantomData,
30 }
31 }
32}
33
34impl<I, In, T> Op for Lookup<I, In, T>
35where
36 I: vector_store::VectorStoreIndex,
37 In: Into<String> + WasmCompatSend + WasmCompatSync,
38 T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
39{
40 type Input = In;
41 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
42
43 async fn call(&self, input: Self::Input) -> Self::Output {
44 let query: String = input.into();
45
46 let req = VectorSearchRequest::builder()
47 .query(query)
48 .samples(self.n as u64)
49 .build()?;
50
51 let docs = self.index.top_n::<T>(req).await?.into_iter().collect();
52
53 Ok(docs)
54 }
55}
56
57pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
62where
63 I: vector_store::VectorStoreIndex,
64 In: Into<String> + WasmCompatSend + WasmCompatSync,
65 T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
66{
67 Lookup::new(index, n)
68}
69
70pub struct Prompt<P, In> {
71 prompt: P,
72 _in: std::marker::PhantomData<In>,
73}
74
75impl<P, In> Prompt<P, In> {
76 pub(crate) fn new(prompt: P) -> Self {
77 Self {
78 prompt,
79 _in: std::marker::PhantomData,
80 }
81 }
82}
83
84impl<P, In> Op for Prompt<P, In>
85where
86 P: completion::Prompt + WasmCompatSend + WasmCompatSync,
87 In: Into<String> + WasmCompatSend + WasmCompatSync,
88{
89 type Input = In;
90 type Output = Result<String, completion::PromptError>;
91
92 fn call(
93 &self,
94 input: Self::Input,
95 ) -> impl std::future::Future<Output = Self::Output> + WasmCompatSend {
96 self.prompt.prompt(input.into()).into_future()
97 }
98}
99
100pub fn prompt<P, In>(model: P) -> Prompt<P, In>
104where
105 P: completion::Prompt,
106 In: Into<String> + WasmCompatSend + WasmCompatSync,
107{
108 Prompt::new(model)
109}
110
111pub struct Extract<M, Input, Output>
112where
113 M: CompletionModel,
114 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
115{
116 extractor: Extractor<M, Output>,
117 _in: std::marker::PhantomData<Input>,
118}
119
120impl<M, Input, Output> Extract<M, Input, Output>
121where
122 M: CompletionModel,
123 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
124{
125 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
126 Self {
127 extractor,
128 _in: std::marker::PhantomData,
129 }
130 }
131}
132
133impl<M, Input, Output> Op for Extract<M, Input, Output>
134where
135 M: CompletionModel,
136 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
137 Input: Into<Message> + WasmCompatSend + WasmCompatSync,
138{
139 type Input = Input;
140 type Output = Result<Output, ExtractionError>;
141
142 async fn call(&self, input: Self::Input) -> Self::Output {
143 self.extractor.extract(input).await
144 }
145}
146
147pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
151where
152 M: CompletionModel,
153 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
154 Input: Into<String> + WasmCompatSend + WasmCompatSync,
155{
156 Extract::new(extractor)
157}
158
159#[cfg(test)]
160pub mod tests {
161 use super::*;
162 use crate::message;
163 use completion::{Prompt, PromptError};
164 use vector_store::{VectorStoreError, VectorStoreIndex, request::Filter};
165
166 pub struct MockModel;
167
168 impl Prompt for MockModel {
169 #[allow(refining_impl_trait)]
170 async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
171 let msg: message::Message = prompt.into();
172 let prompt = match msg {
173 message::Message::User { content } => match content.first() {
174 message::UserContent::Text(message::Text { text }) => text,
175 _ => unreachable!(),
176 },
177 _ => unreachable!(),
178 };
179 Ok(format!("Mock response: {prompt}"))
180 }
181 }
182
183 pub struct MockIndex;
184
185 impl VectorStoreIndex for MockIndex {
186 type Filter = Filter<serde_json::Value>;
187
188 async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
189 &self,
190 _req: VectorSearchRequest,
191 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
192 let doc = serde_json::from_value(serde_json::json!({
193 "foo": "bar",
194 }))
195 .unwrap();
196
197 Ok(vec![(1.0, "doc1".to_string(), doc)])
198 }
199
200 async fn top_n_ids(
201 &self,
202 _req: VectorSearchRequest,
203 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
204 Ok(vec![(1.0, "doc1".to_string())])
205 }
206 }
207
208 #[derive(Debug, serde::Deserialize, PartialEq)]
209 pub struct Foo {
210 pub foo: String,
211 }
212
213 #[tokio::test]
214 async fn test_lookup() {
215 let index = MockIndex;
216 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
217
218 let result = lookup.call("query".to_string()).await.unwrap();
219 assert_eq!(
220 result,
221 vec![(
222 1.0,
223 "doc1".to_string(),
224 Foo {
225 foo: "bar".to_string()
226 }
227 )]
228 );
229 }
230
231 #[tokio::test]
232 async fn test_prompt() {
233 let model = MockModel;
234 let prompt = prompt::<MockModel, String>(model);
235
236 let result = prompt.call("hello".to_string()).await.unwrap();
237 assert_eq!(result, "Mock response: hello");
238 }
239}