Skip to main content

rs_adk/plugin/
reflect_retry.rs

1//! Reflect-retry plugin — retries failed tool calls with reflection.
2//!
3//! Mirrors ADK-Python's `reflect_retry_tool_plugin`. When a tool call
4//! fails, injects the error as context and asks the model to retry.
5
6use async_trait::async_trait;
7
8use rs_genai::prelude::FunctionCall;
9
10use super::{Plugin, PluginResult};
11use crate::context::InvocationContext;
12
13/// Plugin that handles tool failures by reflecting on errors.
14///
15/// When a tool call fails, this plugin injects the error message
16/// as context so the model can learn from the failure and try
17/// a different approach.
18pub struct ReflectRetryToolPlugin {
19    /// Maximum number of retries per tool call.
20    max_retries: u32,
21}
22
23impl ReflectRetryToolPlugin {
24    /// Create a new reflect-retry plugin with the given max retries.
25    pub fn new(max_retries: u32) -> Self {
26        Self { max_retries }
27    }
28
29    /// Returns the maximum retry count.
30    pub fn max_retries(&self) -> u32 {
31        self.max_retries
32    }
33}
34
35impl Default for ReflectRetryToolPlugin {
36    fn default() -> Self {
37        Self::new(2)
38    }
39}
40
41#[async_trait]
42impl Plugin for ReflectRetryToolPlugin {
43    fn name(&self) -> &str {
44        "reflect_retry_tool"
45    }
46
47    async fn on_tool_error(
48        &self,
49        call: &FunctionCall,
50        error: &str,
51        _ctx: &InvocationContext,
52    ) -> PluginResult {
53        // Signal that the error should be reflected back to the model
54        // for retry. The runtime handles the actual retry loop.
55        let reflection = serde_json::json!({
56            "tool_error": {
57                "tool_name": call.name,
58                "error": error,
59                "action": "reflect_and_retry",
60                "max_retries": self.max_retries
61            }
62        });
63
64        PluginResult::ShortCircuit(reflection)
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn default_retries() {
74        let plugin = ReflectRetryToolPlugin::default();
75        assert_eq!(plugin.max_retries(), 2);
76    }
77
78    #[test]
79    fn custom_retries() {
80        let plugin = ReflectRetryToolPlugin::new(5);
81        assert_eq!(plugin.max_retries(), 5);
82    }
83
84    #[test]
85    fn plugin_name() {
86        let plugin = ReflectRetryToolPlugin::default();
87        assert_eq!(plugin.name(), "reflect_retry_tool");
88    }
89
90    #[tokio::test]
91    async fn on_tool_error_returns_short_circuit() {
92        use tokio::sync::broadcast;
93
94        let plugin = ReflectRetryToolPlugin::new(3);
95
96        let (evt_tx, _) = broadcast::channel(16);
97        let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
98            std::sync::Arc::new(crate::test_helpers::MockWriter);
99        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
100        let ctx = InvocationContext::new(session);
101
102        let call = FunctionCall {
103            name: "search".into(),
104            args: serde_json::json!({"query": "test"}),
105            id: None,
106        };
107
108        let result = plugin
109            .on_tool_error(&call, "connection timeout", &ctx)
110            .await;
111        assert!(result.is_short_circuit());
112    }
113}