Skip to main content

rig_core/pipeline/
agent_ops.rs

1use std::future::IntoFuture;
2
3use crate::{
4    completion::{self, CompletionModel},
5    extractor::{ExtractionError, Extractor},
6    message::Message,
7    vector_store::{self, request::VectorSearchRequest},
8    wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11use super::Op;
12
13pub struct Lookup<I, In, T> {
14    index: I,
15    n: usize,
16    _in: std::marker::PhantomData<In>,
17    _t: std::marker::PhantomData<T>,
18}
19
20impl<I, In, T> Lookup<I, In, T>
21where
22    I: vector_store::VectorStoreIndex,
23{
24    pub(crate) fn new(index: I, n: usize) -> Self {
25        Self {
26            index,
27            n,
28            _in: std::marker::PhantomData,
29            _t: std::marker::PhantomData,
30        }
31    }
32}
33
34impl<I, In, T> Op for Lookup<I, In, T>
35where
36    I: vector_store::VectorStoreIndex,
37    In: Into<String> + WasmCompatSend + WasmCompatSync,
38    T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
39{
40    type Input = In;
41    type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;
42
43    async fn call(&self, input: Self::Input) -> Self::Output {
44        let query: String = input.into();
45
46        let req = VectorSearchRequest::builder()
47            .query(query)
48            .samples(self.n as u64)
49            .build();
50
51        let docs = self.index.top_n::<T>(req).await?.into_iter().collect();
52
53        Ok(docs)
54    }
55}
56
57/// Create a new lookup operation.
58///
59/// The op will perform semantic search on the provided index and return the top `n`
60/// results closest results to the input.
61pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
62where
63    I: vector_store::VectorStoreIndex,
64    In: Into<String> + WasmCompatSend + WasmCompatSync,
65    T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
66{
67    Lookup::new(index, n)
68}
69
70pub struct Prompt<P, In> {
71    prompt: P,
72    _in: std::marker::PhantomData<In>,
73}
74
75impl<P, In> Prompt<P, In> {
76    pub(crate) fn new(prompt: P) -> Self {
77        Self {
78            prompt,
79            _in: std::marker::PhantomData,
80        }
81    }
82}
83
84impl<P, In> Op for Prompt<P, In>
85where
86    P: completion::Prompt + WasmCompatSend + WasmCompatSync,
87    In: Into<String> + WasmCompatSend + WasmCompatSync,
88{
89    type Input = In;
90    type Output = Result<String, completion::PromptError>;
91
92    fn call(
93        &self,
94        input: Self::Input,
95    ) -> impl std::future::Future<Output = Self::Output> + WasmCompatSend {
96        self.prompt.prompt(input.into()).into_future()
97    }
98}
99
100/// Create a new prompt operation.
101///
102/// The op will prompt the `model` with the input and return the response.
103pub fn prompt<P, In>(model: P) -> Prompt<P, In>
104where
105    P: completion::Prompt,
106    In: Into<String> + WasmCompatSend + WasmCompatSync,
107{
108    Prompt::new(model)
109}
110
111pub struct Extract<M, Input, Output>
112where
113    M: CompletionModel,
114    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
115{
116    extractor: Extractor<M, Output>,
117    _in: std::marker::PhantomData<Input>,
118}
119
120impl<M, Input, Output> Extract<M, Input, Output>
121where
122    M: CompletionModel,
123    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
124{
125    pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
126        Self {
127            extractor,
128            _in: std::marker::PhantomData,
129        }
130    }
131}
132
133impl<M, Input, Output> Op for Extract<M, Input, Output>
134where
135    M: CompletionModel,
136    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
137    Input: Into<Message> + WasmCompatSend + WasmCompatSync,
138{
139    type Input = Input;
140    type Output = Result<Output, ExtractionError>;
141
142    async fn call(&self, input: Self::Input) -> Self::Output {
143        self.extractor.extract(input).await
144    }
145}
146
147/// Create a new extract operation.
148///
149/// The op will extract the structured data from the input using the provided `extractor`.
150pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
151where
152    M: CompletionModel,
153    Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync,
154    Input: Into<String> + WasmCompatSend + WasmCompatSync,
155{
156    Extract::new(extractor)
157}
158
159#[cfg(test)]
160pub mod tests {
161    use super::*;
162    use crate::test_utils::{Foo, MockPromptModel, MockVectorStoreIndex};
163
164    #[tokio::test]
165    async fn test_lookup() {
166        let index = MockVectorStoreIndex;
167        let lookup = lookup::<MockVectorStoreIndex, String, Foo>(index, 1);
168
169        let result = lookup.call("query".to_string()).await.unwrap();
170        assert_eq!(
171            result,
172            vec![(
173                1.0,
174                "doc1".to_string(),
175                Foo {
176                    foo: "bar".to_string()
177                }
178            )]
179        );
180    }
181
182    #[tokio::test]
183    async fn test_prompt() {
184        let model = MockPromptModel;
185        let prompt = prompt::<MockPromptModel, String>(model);
186
187        let result = prompt.call("hello".to_string()).await.unwrap();
188        assert_eq!(result, "Mock response: hello");
189    }
190}