1use std::future::IntoFuture;
2
3use crate::{
4 completion::{self, CompletionModel},
5 extractor::{ExtractionError, Extractor},
6 message::Message,
7 vector_store::{self, request::VectorSearchRequest},
8 wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11use super::Op;
12
13pub struct Lookup<I, In, T> {
14 index: I,
15 n: usize,
16 _in: std::marker::PhantomData<In>,
17 _t: std::marker::PhantomData<T>,
18}
19
20impl<I, In, T> Lookup<I, In, T>
21where
22 I: vector_store::VectorStoreIndex,
23{
24 pub(crate) fn new(index: I, n: usize) -> Self {
25 Self {
26 index,
27 n,
28 _in: std::marker::PhantomData,
29 _t: std::marker::PhantomData,
30 }
31 }
32}
33
34impl<I, In, T> Op for Lookup<I, In, T>
35where
36 I: vector_store::VectorStoreIndex,
37 In: Into<String> + WasmCompatSend + WasmCompatSync,
38 T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
39{
40 type Input = In;
41 type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
42
43 async fn call(&self, input: Self::Input) -> Self::Output {
44 let query: String = input.into();
45
46 let req = VectorSearchRequest::builder()
47 .query(query)
48 .samples(self.n as u64)
49 .build()?;
50
51 let docs = self.index.top_n::<T>(req).await?.into_iter().collect();
52
53 Ok(docs)
54 }
55}
56
57pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
62where
63 I: vector_store::VectorStoreIndex,
64 In: Into<String> + WasmCompatSend + WasmCompatSync,
65 T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
66{
67 Lookup::new(index, n)
68}
69
70pub struct Prompt<P, In> {
71 prompt: P,
72 _in: std::marker::PhantomData<In>,
73}
74
75impl<P, In> Prompt<P, In> {
76 pub(crate) fn new(prompt: P) -> Self {
77 Self {
78 prompt,
79 _in: std::marker::PhantomData,
80 }
81 }
82}
83
84impl<P, In> Op for Prompt<P, In>
85where
86 P: completion::Prompt + WasmCompatSend + WasmCompatSync,
87 In: Into<String> + WasmCompatSend + WasmCompatSync,
88{
89 type Input = In;
90 type Output = Result<String, completion::PromptError>;
91
92 fn call(
93 &self,
94 input: Self::Input,
95 ) -> impl std::future::Future<Output = Self::Output> + WasmCompatSend {
96 self.prompt.prompt(input.into()).into_future()
97 }
98}
99
100pub fn prompt<P, In>(model: P) -> Prompt<P, In>
104where
105 P: completion::Prompt,
106 In: Into<String> + WasmCompatSend + WasmCompatSync,
107{
108 Prompt::new(model)
109}
110
111pub struct Extract<M, Input, Output>
112where
113 M: CompletionModel,
114 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
115{
116 extractor: Extractor<M, Output>,
117 _in: std::marker::PhantomData<Input>,
118}
119
120impl<M, Input, Output> Extract<M, Input, Output>
121where
122 M: CompletionModel,
123 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
124{
125 pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
126 Self {
127 extractor,
128 _in: std::marker::PhantomData,
129 }
130 }
131}
132
133impl<M, Input, Output> Op for Extract<M, Input, Output>
134where
135 M: CompletionModel,
136 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
137 Input: Into<Message> + WasmCompatSend + WasmCompatSync,
138{
139 type Input = Input;
140 type Output = Result<Output, ExtractionError>;
141
142 async fn call(&self, input: Self::Input) -> Self::Output {
143 self.extractor.extract(input).await
144 }
145}
146
147pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
151where
152 M: CompletionModel,
153 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
154 Input: Into<String> + WasmCompatSend + WasmCompatSync,
155{
156 Extract::new(extractor)
157}
158
159#[cfg(test)]
160pub mod tests {
161 use super::*;
162 use crate::message;
163 use completion::{Prompt, PromptError};
164 use vector_store::{VectorStoreError, VectorStoreIndex};
165
166 pub struct MockModel;
167
168 impl Prompt for MockModel {
169 #[allow(refining_impl_trait)]
170 async fn prompt(&self, prompt: impl Into<message::Message>) -> Result<String, PromptError> {
171 let msg: message::Message = prompt.into();
172 let prompt = match msg {
173 message::Message::User { content } => match content.first() {
174 message::UserContent::Text(message::Text { text }) => text,
175 _ => unreachable!(),
176 },
177 _ => unreachable!(),
178 };
179 Ok(format!("Mock response: {prompt}"))
180 }
181 }
182
183 pub struct MockIndex;
184
185 impl VectorStoreIndex for MockIndex {
186 async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
187 &self,
188 _req: VectorSearchRequest,
189 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
190 let doc = serde_json::from_value(serde_json::json!({
191 "foo": "bar",
192 }))
193 .unwrap();
194
195 Ok(vec![(1.0, "doc1".to_string(), doc)])
196 }
197
198 async fn top_n_ids(
199 &self,
200 _req: VectorSearchRequest,
201 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
202 Ok(vec![(1.0, "doc1".to_string())])
203 }
204 }
205
206 #[derive(Debug, serde::Deserialize, PartialEq)]
207 pub struct Foo {
208 pub foo: String,
209 }
210
211 #[tokio::test]
212 async fn test_lookup() {
213 let index = MockIndex;
214 let lookup = lookup::<MockIndex, String, Foo>(index, 1);
215
216 let result = lookup.call("query".to_string()).await.unwrap();
217 assert_eq!(
218 result,
219 vec![(
220 1.0,
221 "doc1".to_string(),
222 Foo {
223 foo: "bar".to_string()
224 }
225 )]
226 );
227 }
228
229 #[tokio::test]
230 async fn test_prompt() {
231 let model = MockModel;
232 let prompt = prompt::<MockModel, String>(model);
233
234 let result = prompt.call("hello".to_string()).await.unwrap();
235 assert_eq!(result, "Mock response: hello");
236 }
237}