1use std::sync::Arc;
20
21use adk_core::{AdkError, Tool, ToolContext};
22use async_trait::async_trait;
23use serde_json::{Value, json};
24use tracing::{error, info};
25
26use crate::pipeline::RagPipeline;
27
28pub struct RagTool {
34 pipeline: Arc<RagPipeline>,
35 default_collection: String,
36}
37
38impl RagTool {
39 pub fn new(pipeline: Arc<RagPipeline>, default_collection: impl Into<String>) -> Self {
44 Self { pipeline, default_collection: default_collection.into() }
45 }
46}
47
48#[async_trait]
49impl Tool for RagTool {
50 fn name(&self) -> &str {
51 "rag_search"
52 }
53
54 fn description(&self) -> &str {
55 "Search a knowledge base for relevant documents given a query"
56 }
57
58 fn parameters_schema(&self) -> Option<Value> {
59 Some(json!({
60 "type": "object",
61 "properties": {
62 "query": {
63 "type": "string",
64 "description": "The search query to find relevant documents"
65 },
66 "collection": {
67 "type": "string",
68 "description": "The name of the collection to search. Uses the default collection if omitted."
69 },
70 "top_k": {
71 "type": "integer",
72 "description": "Maximum number of results to return. Uses the pipeline default if omitted."
73 }
74 },
75 "required": ["query"]
76 }))
77 }
78
79 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
80 let query = args
81 .get("query")
82 .and_then(|v| v.as_str())
83 .ok_or_else(|| AdkError::tool("missing required 'query' parameter"))?;
84
85 let collection =
86 args.get("collection").and_then(|v| v.as_str()).unwrap_or(&self.default_collection);
87
88 let top_k_override = args.get("top_k").and_then(|v| v.as_u64()).map(|v| v as usize);
89
90 info!(query, collection, top_k_override, "rag_search tool called");
91
92 let results = if let Some(top_k) = top_k_override {
93 self.query_with_top_k(collection, query, top_k).await
95 } else {
96 self.pipeline.query(collection, query).await
97 };
98
99 let results = results.map_err(|e| {
100 error!(error = %e, "rag_search failed");
101 AdkError::tool(format!("RAG search failed: {e}"))
102 })?;
103
104 serde_json::to_value(&results).map_err(|e| {
105 error!(error = %e, "failed to serialize search results");
106 AdkError::tool(format!("failed to serialize results: {e}"))
107 })
108 }
109}
110
111impl RagTool {
112 async fn query_with_top_k(
114 &self,
115 collection: &str,
116 query: &str,
117 top_k: usize,
118 ) -> crate::error::Result<Vec<crate::document::SearchResult>> {
119 let query_embedding =
121 self.pipeline.embedding_provider().embed(query).await.map_err(|e| {
122 crate::error::RagError::PipelineError(format!("query embedding failed: {e}"))
123 })?;
124
125 let results = self
127 .pipeline
128 .vector_store()
129 .search(collection, &query_embedding, top_k)
130 .await
131 .map_err(|e| {
132 crate::error::RagError::PipelineError(format!(
133 "search failed in collection '{collection}': {e}"
134 ))
135 })?;
136
137 let threshold = self.pipeline.config().similarity_threshold;
139 let filtered = results.into_iter().filter(|r| r.score >= threshold).collect();
140
141 Ok(filtered)
142 }
143}