Skip to main content

adk_rag/
tool.rs

1//! Agentic retrieval tool for ADK agents.
2//!
3//! The [`RagTool`] wraps a [`RagPipeline`] as an
4//! [`adk_core::Tool`] so that agents can perform retrieval as a tool call.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use std::sync::Arc;
10//! use adk_rag::{RagPipeline, RagTool};
11//!
12//! let pipeline = Arc::new(build_pipeline()?);
13//! let tool = RagTool::new(pipeline, "my_docs");
14//!
15//! // The agent calls the tool with:
16//! // { "query": "How do I configure X?", "collection": "faq", "top_k": 5 }
17//! ```
18
19use std::sync::Arc;
20
21use adk_core::{AdkError, Tool, ToolContext};
22use async_trait::async_trait;
23use serde_json::{Value, json};
24use tracing::{error, info};
25
26use crate::pipeline::RagPipeline;
27
28/// A retrieval tool that wraps a [`RagPipeline`] for agentic use.
29///
30/// Implements [`adk_core::Tool`] so it can be attached to any ADK agent.
31/// The tool accepts a required `query` string and optional `collection`
32/// and `top_k` parameters.
33pub struct RagTool {
34    pipeline: Arc<RagPipeline>,
35    default_collection: String,
36}
37
38impl RagTool {
39    /// Create a new `RagTool` backed by the given pipeline.
40    ///
41    /// The `default_collection` is used when the agent does not specify
42    /// a collection in the tool call arguments.
43    pub fn new(pipeline: Arc<RagPipeline>, default_collection: impl Into<String>) -> Self {
44        Self { pipeline, default_collection: default_collection.into() }
45    }
46}
47
48#[async_trait]
49impl Tool for RagTool {
50    fn name(&self) -> &str {
51        "rag_search"
52    }
53
54    fn description(&self) -> &str {
55        "Search a knowledge base for relevant documents given a query"
56    }
57
58    fn parameters_schema(&self) -> Option<Value> {
59        Some(json!({
60            "type": "object",
61            "properties": {
62                "query": {
63                    "type": "string",
64                    "description": "The search query to find relevant documents"
65                },
66                "collection": {
67                    "type": "string",
68                    "description": "The name of the collection to search. Uses the default collection if omitted."
69                },
70                "top_k": {
71                    "type": "integer",
72                    "description": "Maximum number of results to return. Uses the pipeline default if omitted."
73                }
74            },
75            "required": ["query"]
76        }))
77    }
78
79    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
80        let query = args
81            .get("query")
82            .and_then(|v| v.as_str())
83            .ok_or_else(|| AdkError::tool("missing required 'query' parameter"))?;
84
85        let collection =
86            args.get("collection").and_then(|v| v.as_str()).unwrap_or(&self.default_collection);
87
88        let top_k_override = args.get("top_k").and_then(|v| v.as_u64()).map(|v| v as usize);
89
90        info!(query, collection, top_k_override, "rag_search tool called");
91
92        let results = if let Some(top_k) = top_k_override {
93            // Override top_k: embed query, search with custom top_k, rerank, filter
94            self.query_with_top_k(collection, query, top_k).await
95        } else {
96            self.pipeline.query(collection, query).await
97        };
98
99        let results = results.map_err(|e| {
100            error!(error = %e, "rag_search failed");
101            AdkError::tool(format!("RAG search failed: {e}"))
102        })?;
103
104        serde_json::to_value(&results).map_err(|e| {
105            error!(error = %e, "failed to serialize search results");
106            AdkError::tool(format!("failed to serialize results: {e}"))
107        })
108    }
109}
110
111impl RagTool {
112    /// Query with a custom `top_k`, bypassing the pipeline's configured value.
113    async fn query_with_top_k(
114        &self,
115        collection: &str,
116        query: &str,
117        top_k: usize,
118    ) -> crate::error::Result<Vec<crate::document::SearchResult>> {
119        // 1. Embed the query
120        let query_embedding =
121            self.pipeline.embedding_provider().embed(query).await.map_err(|e| {
122                crate::error::RagError::PipelineError(format!("query embedding failed: {e}"))
123            })?;
124
125        // 2. Search with the overridden top_k
126        let results = self
127            .pipeline
128            .vector_store()
129            .search(collection, &query_embedding, top_k)
130            .await
131            .map_err(|e| {
132                crate::error::RagError::PipelineError(format!(
133                    "search failed in collection '{collection}': {e}"
134                ))
135            })?;
136
137        // 3. Filter by similarity threshold
138        let threshold = self.pipeline.config().similarity_threshold;
139        let filtered = results.into_iter().filter(|r| r.score >= threshold).collect();
140
141        Ok(filtered)
142    }
143}