rs_adk/plugin/
reflect_retry.rs1use async_trait::async_trait;
7
8use rs_genai::prelude::FunctionCall;
9
10use super::{Plugin, PluginResult};
11use crate::context::InvocationContext;
12
13pub struct ReflectRetryToolPlugin {
19 max_retries: u32,
21}
22
23impl ReflectRetryToolPlugin {
24 pub fn new(max_retries: u32) -> Self {
26 Self { max_retries }
27 }
28
29 pub fn max_retries(&self) -> u32 {
31 self.max_retries
32 }
33}
34
35impl Default for ReflectRetryToolPlugin {
36 fn default() -> Self {
37 Self::new(2)
38 }
39}
40
41#[async_trait]
42impl Plugin for ReflectRetryToolPlugin {
43 fn name(&self) -> &str {
44 "reflect_retry_tool"
45 }
46
47 async fn on_tool_error(
48 &self,
49 call: &FunctionCall,
50 error: &str,
51 _ctx: &InvocationContext,
52 ) -> PluginResult {
53 let reflection = serde_json::json!({
56 "tool_error": {
57 "tool_name": call.name,
58 "error": error,
59 "action": "reflect_and_retry",
60 "max_retries": self.max_retries
61 }
62 });
63
64 PluginResult::ShortCircuit(reflection)
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[test]
73 fn default_retries() {
74 let plugin = ReflectRetryToolPlugin::default();
75 assert_eq!(plugin.max_retries(), 2);
76 }
77
78 #[test]
79 fn custom_retries() {
80 let plugin = ReflectRetryToolPlugin::new(5);
81 assert_eq!(plugin.max_retries(), 5);
82 }
83
84 #[test]
85 fn plugin_name() {
86 let plugin = ReflectRetryToolPlugin::default();
87 assert_eq!(plugin.name(), "reflect_retry_tool");
88 }
89
90 #[tokio::test]
91 async fn on_tool_error_returns_short_circuit() {
92 use tokio::sync::broadcast;
93
94 let plugin = ReflectRetryToolPlugin::new(3);
95
96 let (evt_tx, _) = broadcast::channel(16);
97 let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
98 std::sync::Arc::new(crate::test_helpers::MockWriter);
99 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
100 let ctx = InvocationContext::new(session);
101
102 let call = FunctionCall {
103 name: "search".into(),
104 args: serde_json::json!({"query": "test"}),
105 id: None,
106 };
107
108 let result = plugin
109 .on_tool_error(&call, "connection timeout", &ctx)
110 .await;
111 assert!(result.is_short_circuit());
112 }
113}