agent_chain_core/tools/
retriever.rs

1//! Retriever tool.
2//!
3//! This module provides utilities for creating tools from retrievers,
4//! mirroring `langchain_core.tools.retriever`.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::documents::Document;
14use crate::error::Result;
15use crate::retrievers::BaseRetriever;
16
17use super::base::{ArgsSchema, ResponseFormat};
18use super::structured::StructuredTool;
19
20/// Input schema for retriever tools.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RetrieverInput {
23    /// The query to look up in the retriever.
24    pub query: String,
25}
26
27impl RetrieverInput {
28    /// Create a new RetrieverInput.
29    pub fn new(query: impl Into<String>) -> Self {
30        Self {
31            query: query.into(),
32        }
33    }
34}
35
36/// Get the default args schema for retriever tools.
37fn retriever_args_schema() -> ArgsSchema {
38    ArgsSchema::JsonSchema(serde_json::json!({
39        "type": "object",
40        "title": "RetrieverInput",
41        "description": "Input to the retriever",
42        "properties": {
43            "query": {
44                "type": "string",
45                "description": "query to look up in retriever"
46            }
47        },
48        "required": ["query"]
49    }))
50}
51
52/// Create a tool to do retrieval of documents.
53///
54/// # Arguments
55///
56/// * `retriever` - The retriever to use for the retrieval.
57/// * `name` - The name for the tool. This will be passed to the language model,
58///   so should be unique and somewhat descriptive.
59/// * `description` - The description for the tool. This will be passed to the
60///   language model, so should be descriptive.
61///
62/// # Returns
63///
64/// A StructuredTool configured for document retrieval.
65pub fn create_retriever_tool<R>(
66    retriever: Arc<R>,
67    name: impl Into<String>,
68    description: impl Into<String>,
69) -> StructuredTool
70where
71    R: BaseRetriever + Send + Sync + 'static,
72{
73    create_retriever_tool_with_options(
74        retriever,
75        name,
76        description,
77        None,
78        "\n\n",
79        ResponseFormat::Content,
80    )
81}
82
83/// Create a retriever tool with additional options.
84///
85/// # Arguments
86///
87/// * `retriever` - The retriever to use.
88/// * `name` - The tool name.
89/// * `description` - The tool description.
90/// * `document_prompt` - Optional template for formatting documents.
91/// * `document_separator` - Separator between documents (default: "\n\n").
92/// * `response_format` - The tool response format.
93///
94/// # Returns
95///
96/// A StructuredTool configured for document retrieval.
97pub fn create_retriever_tool_with_options<R>(
98    retriever: Arc<R>,
99    name: impl Into<String>,
100    description: impl Into<String>,
101    _document_prompt: Option<String>,
102    document_separator: &str,
103    response_format: ResponseFormat,
104) -> StructuredTool
105where
106    R: BaseRetriever + Send + Sync + 'static,
107{
108    let name = name.into();
109    let description = description.into();
110    let separator = document_separator.to_string();
111
112    let retriever_clone = retriever.clone();
113    let separator_clone = separator.clone();
114    let response_format_clone = response_format;
115
116    // Sync function
117    let func = {
118        let _retriever = retriever_clone.clone();
119        let separator = separator_clone.clone();
120        move |args: HashMap<String, Value>| -> Result<Value> {
121            let _query = args
122                .get("query")
123                .and_then(|v| v.as_str())
124                .unwrap_or("")
125                .to_string();
126
127            // Note: In a real implementation, we'd call the retriever synchronously
128            // For now, we return a placeholder since retrievers are typically async
129            let docs: Vec<Document> = Vec::new();
130            let content = format_documents(&docs, &separator);
131
132            match response_format_clone {
133                ResponseFormat::Content => Ok(Value::String(content)),
134                ResponseFormat::ContentAndArtifact => {
135                    let docs_json: Vec<Value> = docs
136                        .iter()
137                        .map(|d| {
138                            serde_json::json!({
139                                "page_content": d.page_content,
140                                "metadata": d.metadata
141                            })
142                        })
143                        .collect();
144                    Ok(serde_json::json!([content, docs_json]))
145                }
146            }
147        }
148    };
149
150    StructuredTool::from_function(func, name.clone(), description, retriever_args_schema())
151        .with_response_format(response_format)
152}
153
154/// Format documents into a single string.
155fn format_documents(docs: &[Document], separator: &str) -> String {
156    docs.iter()
157        .map(|doc| doc.page_content.clone())
158        .collect::<Vec<_>>()
159        .join(separator)
160}
161
162/// Create a retriever tool with async support.
163///
164/// This version properly supports async retrieval.
165pub fn create_async_retriever_tool<R, F, Fut>(
166    retriever: Arc<R>,
167    retrieve_fn: F,
168    name: impl Into<String>,
169    description: impl Into<String>,
170) -> StructuredTool
171where
172    R: Send + Sync + 'static,
173    F: Fn(Arc<R>, String) -> Fut + Send + Sync + 'static,
174    Fut: Future<Output = Result<Vec<Document>>> + Send + 'static,
175{
176    let name = name.into();
177    let description = description.into();
178
179    let _retriever_clone = retriever.clone();
180    let _retrieve_fn = Arc::new(retrieve_fn);
181
182    // Async version - wrapped as a sync function that returns a placeholder
183    // In practice, you'd use the async invoke
184    let func = move |args: HashMap<String, Value>| -> Result<Value> {
185        let query = args
186            .get("query")
187            .and_then(|v| v.as_str())
188            .unwrap_or("")
189            .to_string();
190
191        // Return the query as placeholder - actual retrieval happens async
192        Ok(Value::String(format!(
193            "Retrieval for query '{}' (use async invoke for actual results)",
194            query
195        )))
196    };
197
198    StructuredTool::from_function(func, name, description, retriever_args_schema())
199}
200
201/// Builder for creating retriever tools with full configuration.
202pub struct RetrieverToolBuilder<R>
203where
204    R: BaseRetriever + Send + Sync + 'static,
205{
206    retriever: Arc<R>,
207    name: Option<String>,
208    description: Option<String>,
209    document_prompt: Option<String>,
210    document_separator: String,
211    response_format: ResponseFormat,
212}
213
214impl<R> RetrieverToolBuilder<R>
215where
216    R: BaseRetriever + Send + Sync + 'static,
217{
218    /// Create a new RetrieverToolBuilder.
219    pub fn new(retriever: Arc<R>) -> Self {
220        Self {
221            retriever,
222            name: None,
223            description: None,
224            document_prompt: None,
225            document_separator: "\n\n".to_string(),
226            response_format: ResponseFormat::Content,
227        }
228    }
229
230    /// Set the tool name.
231    pub fn name(mut self, name: impl Into<String>) -> Self {
232        self.name = Some(name.into());
233        self
234    }
235
236    /// Set the tool description.
237    pub fn description(mut self, description: impl Into<String>) -> Self {
238        self.description = Some(description.into());
239        self
240    }
241
242    /// Set the document prompt template.
243    pub fn document_prompt(mut self, prompt: impl Into<String>) -> Self {
244        self.document_prompt = Some(prompt.into());
245        self
246    }
247
248    /// Set the document separator.
249    pub fn document_separator(mut self, separator: impl Into<String>) -> Self {
250        self.document_separator = separator.into();
251        self
252    }
253
254    /// Set the response format.
255    pub fn response_format(mut self, format: ResponseFormat) -> Self {
256        self.response_format = format;
257        self
258    }
259
260    /// Build the retriever tool.
261    pub fn build(self) -> Result<StructuredTool> {
262        let name = self.name.ok_or_else(|| {
263            crate::error::Error::InvalidConfig("Retriever tool name is required".to_string())
264        })?;
265
266        let description = self.description.ok_or_else(|| {
267            crate::error::Error::InvalidConfig("Retriever tool description is required".to_string())
268        })?;
269
270        Ok(create_retriever_tool_with_options(
271            self.retriever,
272            name,
273            description,
274            self.document_prompt,
275            &self.document_separator,
276            self.response_format,
277        ))
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_retriever_input() {
287        let input = RetrieverInput::new("test query");
288        assert_eq!(input.query, "test query");
289    }
290
291    #[test]
292    fn test_retriever_args_schema() {
293        let schema = retriever_args_schema();
294        let json = schema.to_json_schema();
295
296        assert_eq!(json["type"], "object");
297        assert!(json["properties"]["query"].is_object());
298    }
299
300    #[test]
301    fn test_format_documents() {
302        let docs = vec![
303            Document::new("First document"),
304            Document::new("Second document"),
305        ];
306
307        let formatted = format_documents(&docs, "\n\n");
308        assert_eq!(formatted, "First document\n\nSecond document");
309    }
310
311    #[test]
312    fn test_format_documents_custom_separator() {
313        let docs = vec![
314            Document::new("Doc 1"),
315            Document::new("Doc 2"),
316            Document::new("Doc 3"),
317        ];
318
319        let formatted = format_documents(&docs, " | ");
320        assert_eq!(formatted, "Doc 1 | Doc 2 | Doc 3");
321    }
322}