1use std::collections::HashMap;
7use std::future::Future;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::documents::Document;
14use crate::error::Result;
15use crate::retrievers::BaseRetriever;
16
17use super::base::{ArgsSchema, ResponseFormat};
18use super::structured::StructuredTool;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RetrieverInput {
23 pub query: String,
25}
26
27impl RetrieverInput {
28 pub fn new(query: impl Into<String>) -> Self {
30 Self {
31 query: query.into(),
32 }
33 }
34}
35
36fn retriever_args_schema() -> ArgsSchema {
38 ArgsSchema::JsonSchema(serde_json::json!({
39 "type": "object",
40 "title": "RetrieverInput",
41 "description": "Input to the retriever",
42 "properties": {
43 "query": {
44 "type": "string",
45 "description": "query to look up in retriever"
46 }
47 },
48 "required": ["query"]
49 }))
50}
51
52pub fn create_retriever_tool<R>(
66 retriever: Arc<R>,
67 name: impl Into<String>,
68 description: impl Into<String>,
69) -> StructuredTool
70where
71 R: BaseRetriever + Send + Sync + 'static,
72{
73 create_retriever_tool_with_options(
74 retriever,
75 name,
76 description,
77 None,
78 "\n\n",
79 ResponseFormat::Content,
80 )
81}
82
83pub fn create_retriever_tool_with_options<R>(
98 retriever: Arc<R>,
99 name: impl Into<String>,
100 description: impl Into<String>,
101 _document_prompt: Option<String>,
102 document_separator: &str,
103 response_format: ResponseFormat,
104) -> StructuredTool
105where
106 R: BaseRetriever + Send + Sync + 'static,
107{
108 let name = name.into();
109 let description = description.into();
110 let separator = document_separator.to_string();
111
112 let retriever_clone = retriever.clone();
113 let separator_clone = separator.clone();
114 let response_format_clone = response_format;
115
116 let func = {
118 let _retriever = retriever_clone.clone();
119 let separator = separator_clone.clone();
120 move |args: HashMap<String, Value>| -> Result<Value> {
121 let _query = args
122 .get("query")
123 .and_then(|v| v.as_str())
124 .unwrap_or("")
125 .to_string();
126
127 let docs: Vec<Document> = Vec::new();
130 let content = format_documents(&docs, &separator);
131
132 match response_format_clone {
133 ResponseFormat::Content => Ok(Value::String(content)),
134 ResponseFormat::ContentAndArtifact => {
135 let docs_json: Vec<Value> = docs
136 .iter()
137 .map(|d| {
138 serde_json::json!({
139 "page_content": d.page_content,
140 "metadata": d.metadata
141 })
142 })
143 .collect();
144 Ok(serde_json::json!([content, docs_json]))
145 }
146 }
147 }
148 };
149
150 StructuredTool::from_function(func, name.clone(), description, retriever_args_schema())
151 .with_response_format(response_format)
152}
153
154fn format_documents(docs: &[Document], separator: &str) -> String {
156 docs.iter()
157 .map(|doc| doc.page_content.clone())
158 .collect::<Vec<_>>()
159 .join(separator)
160}
161
162pub fn create_async_retriever_tool<R, F, Fut>(
166 retriever: Arc<R>,
167 retrieve_fn: F,
168 name: impl Into<String>,
169 description: impl Into<String>,
170) -> StructuredTool
171where
172 R: Send + Sync + 'static,
173 F: Fn(Arc<R>, String) -> Fut + Send + Sync + 'static,
174 Fut: Future<Output = Result<Vec<Document>>> + Send + 'static,
175{
176 let name = name.into();
177 let description = description.into();
178
179 let _retriever_clone = retriever.clone();
180 let _retrieve_fn = Arc::new(retrieve_fn);
181
182 let func = move |args: HashMap<String, Value>| -> Result<Value> {
185 let query = args
186 .get("query")
187 .and_then(|v| v.as_str())
188 .unwrap_or("")
189 .to_string();
190
191 Ok(Value::String(format!(
193 "Retrieval for query '{}' (use async invoke for actual results)",
194 query
195 )))
196 };
197
198 StructuredTool::from_function(func, name, description, retriever_args_schema())
199}
200
201pub struct RetrieverToolBuilder<R>
203where
204 R: BaseRetriever + Send + Sync + 'static,
205{
206 retriever: Arc<R>,
207 name: Option<String>,
208 description: Option<String>,
209 document_prompt: Option<String>,
210 document_separator: String,
211 response_format: ResponseFormat,
212}
213
214impl<R> RetrieverToolBuilder<R>
215where
216 R: BaseRetriever + Send + Sync + 'static,
217{
218 pub fn new(retriever: Arc<R>) -> Self {
220 Self {
221 retriever,
222 name: None,
223 description: None,
224 document_prompt: None,
225 document_separator: "\n\n".to_string(),
226 response_format: ResponseFormat::Content,
227 }
228 }
229
230 pub fn name(mut self, name: impl Into<String>) -> Self {
232 self.name = Some(name.into());
233 self
234 }
235
236 pub fn description(mut self, description: impl Into<String>) -> Self {
238 self.description = Some(description.into());
239 self
240 }
241
242 pub fn document_prompt(mut self, prompt: impl Into<String>) -> Self {
244 self.document_prompt = Some(prompt.into());
245 self
246 }
247
248 pub fn document_separator(mut self, separator: impl Into<String>) -> Self {
250 self.document_separator = separator.into();
251 self
252 }
253
254 pub fn response_format(mut self, format: ResponseFormat) -> Self {
256 self.response_format = format;
257 self
258 }
259
260 pub fn build(self) -> Result<StructuredTool> {
262 let name = self.name.ok_or_else(|| {
263 crate::error::Error::InvalidConfig("Retriever tool name is required".to_string())
264 })?;
265
266 let description = self.description.ok_or_else(|| {
267 crate::error::Error::InvalidConfig("Retriever tool description is required".to_string())
268 })?;
269
270 Ok(create_retriever_tool_with_options(
271 self.retriever,
272 name,
273 description,
274 self.document_prompt,
275 &self.document_separator,
276 self.response_format,
277 ))
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_retriever_input() {
287 let input = RetrieverInput::new("test query");
288 assert_eq!(input.query, "test query");
289 }
290
291 #[test]
292 fn test_retriever_args_schema() {
293 let schema = retriever_args_schema();
294 let json = schema.to_json_schema();
295
296 assert_eq!(json["type"], "object");
297 assert!(json["properties"]["query"].is_object());
298 }
299
300 #[test]
301 fn test_format_documents() {
302 let docs = vec![
303 Document::new("First document"),
304 Document::new("Second document"),
305 ];
306
307 let formatted = format_documents(&docs, "\n\n");
308 assert_eq!(formatted, "First document\n\nSecond document");
309 }
310
311 #[test]
312 fn test_format_documents_custom_separator() {
313 let docs = vec![
314 Document::new("Doc 1"),
315 Document::new("Doc 2"),
316 Document::new("Doc 3"),
317 ];
318
319 let formatted = format_documents(&docs, " | ");
320 assert_eq!(formatted, "Doc 1 | Doc 2 | Doc 3");
321 }
322}