1use std::sync::Arc;
42
43use adk_core::{
44 BeforeModelResult, CallbackContext, LlmRequest, LlmResponse, Result, Tool, async_trait,
45};
46use serde_json::Value;
47
48use crate::context::PluginContext;
49use crate::enhanced_plugin::EnhancedPlugin;
50use crate::hook_result::{
51 AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
52};
53use crate::plugin::Plugin;
54
55pub struct AdaptedPlugin {
65 inner: Plugin,
66 priority: i32,
67}
68
69impl AdaptedPlugin {
70 pub fn new(plugin: Plugin, priority: i32) -> Self {
91 Self { inner: plugin, priority }
92 }
93}
94
95#[async_trait]
96impl EnhancedPlugin for AdaptedPlugin {
97 fn name(&self) -> &str {
98 self.inner.name()
99 }
100
101 fn priority(&self) -> i32 {
102 self.priority
103 }
104
105 async fn before_tool_call(
106 &self,
107 _tool: Arc<dyn Tool>,
108 args: Value,
109 ctx: Arc<dyn CallbackContext>,
110 _plugin_ctx: &PluginContext,
111 ) -> Result<BeforeToolCallResult> {
112 if let Some(callback) = self.inner.before_tool() {
116 let _ = callback(ctx).await?;
119 }
120 Ok(BeforeToolCallResult::Continue(args))
121 }
122
123 async fn after_tool_call(
124 &self,
125 _tool: Arc<dyn Tool>,
126 _args: &Value,
127 result: Value,
128 ctx: Arc<dyn CallbackContext>,
129 _plugin_ctx: &PluginContext,
130 ) -> Result<AfterToolCallResult> {
131 if let Some(callback) = self.inner.after_tool() {
135 let _ = callback(ctx).await?;
136 }
137 Ok(AfterToolCallResult::Continue(result))
138 }
139
140 async fn before_model_call(
141 &self,
142 request: LlmRequest,
143 ctx: Arc<dyn CallbackContext>,
144 _plugin_ctx: &PluginContext,
145 ) -> Result<BeforeModelCallResult> {
146 if let Some(callback) = self.inner.before_model() {
150 let legacy_result = callback(ctx, request).await?;
151 match legacy_result {
152 BeforeModelResult::Continue(req) => Ok(BeforeModelCallResult::Continue(req)),
153 BeforeModelResult::Skip(response) => {
154 Ok(BeforeModelCallResult::ShortCircuit(response))
155 }
156 }
157 } else {
158 Ok(BeforeModelCallResult::Continue(request))
159 }
160 }
161
162 async fn after_model_call(
163 &self,
164 response: LlmResponse,
165 ctx: Arc<dyn CallbackContext>,
166 _plugin_ctx: &PluginContext,
167 ) -> Result<AfterModelCallResult> {
168 if let Some(callback) = self.inner.after_model() {
171 let result = callback(ctx, response.clone()).await?;
172 match result {
173 Some(modified_response) => Ok(AfterModelCallResult::Continue(modified_response)),
174 None => Ok(AfterModelCallResult::Continue(response)),
175 }
176 } else {
177 Ok(AfterModelCallResult::Continue(response))
178 }
179 }
180
181 async fn close(&self) {
182 self.inner.close().await;
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::{PluginConfig, plugin::Plugin};
190 use adk_core::{BeforeModelResult, Content, LlmRequest, LlmResponse, Part};
191 use std::sync::atomic::{AtomicBool, Ordering};
192
193 struct MockCallbackContext;
195
196 impl adk_core::ReadonlyContext for MockCallbackContext {
197 fn invocation_id(&self) -> &str {
198 "test-invocation"
199 }
200
201 fn agent_name(&self) -> &str {
202 "test-agent"
203 }
204
205 fn user_id(&self) -> &str {
206 "test-user"
207 }
208
209 fn app_name(&self) -> &str {
210 "test-app"
211 }
212
213 fn session_id(&self) -> &str {
214 "test-session"
215 }
216
217 fn branch(&self) -> &str {
218 "main"
219 }
220
221 fn user_content(&self) -> &Content {
222 static CONTENT: std::sync::OnceLock<Content> = std::sync::OnceLock::new();
223 CONTENT.get_or_init(|| Content::new("user"))
224 }
225 }
226
227 #[async_trait]
228 impl CallbackContext for MockCallbackContext {
229 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
230 None
231 }
232 }
233
234 struct MockTool;
236
237 #[async_trait]
238 impl Tool for MockTool {
239 fn name(&self) -> &str {
240 "mock-tool"
241 }
242
243 fn description(&self) -> &str {
244 "A mock tool for testing"
245 }
246
247 async fn execute(
248 &self,
249 _ctx: Arc<dyn adk_core::ToolContext>,
250 _args: Value,
251 ) -> adk_core::Result<Value> {
252 Ok(Value::Null)
253 }
254 }
255
256 #[tokio::test]
257 async fn test_name_delegates_to_inner() {
258 let plugin = Plugin::new(PluginConfig {
259 name: "my-legacy-plugin".to_string(),
260 ..Default::default()
261 });
262 let adapted = AdaptedPlugin::new(plugin, 100);
263 assert_eq!(adapted.name(), "my-legacy-plugin");
264 }
265
266 #[tokio::test]
267 async fn test_priority_uses_configured_value() {
268 let plugin = Plugin::new(PluginConfig { name: "test".to_string(), ..Default::default() });
269 let adapted = AdaptedPlugin::new(plugin, 42);
270 assert_eq!(adapted.priority(), 42);
271 }
272
273 #[tokio::test]
274 async fn test_before_tool_call_invokes_legacy_callback() {
275 let called = Arc::new(AtomicBool::new(false));
276 let called_clone = called.clone();
277
278 let plugin = Plugin::new(PluginConfig {
279 name: "test".to_string(),
280 before_tool: Some(Box::new(move |_ctx| {
281 let flag = called_clone.clone();
282 Box::pin(async move {
283 flag.store(true, Ordering::SeqCst);
284 Ok(None)
285 })
286 })),
287 ..Default::default()
288 });
289
290 let adapted = AdaptedPlugin::new(plugin, 100);
291 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
292 let plugin_ctx = PluginContext::new();
293 let tool: Arc<dyn Tool> = Arc::new(MockTool);
294 let args = serde_json::json!({"key": "value"});
295
296 let result = adapted.before_tool_call(tool, args.clone(), ctx, &plugin_ctx).await.unwrap();
297
298 assert!(called.load(Ordering::SeqCst));
299 match result {
300 BeforeToolCallResult::Continue(returned_args) => {
301 assert_eq!(returned_args, args);
302 }
303 _ => panic!("expected Continue"),
304 }
305 }
306
307 #[tokio::test]
308 async fn test_after_tool_call_invokes_legacy_callback() {
309 let called = Arc::new(AtomicBool::new(false));
310 let called_clone = called.clone();
311
312 let plugin = Plugin::new(PluginConfig {
313 name: "test".to_string(),
314 after_tool: Some(Box::new(move |_ctx| {
315 let flag = called_clone.clone();
316 Box::pin(async move {
317 flag.store(true, Ordering::SeqCst);
318 Ok(None)
319 })
320 })),
321 ..Default::default()
322 });
323
324 let adapted = AdaptedPlugin::new(plugin, 100);
325 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
326 let plugin_ctx = PluginContext::new();
327 let tool: Arc<dyn Tool> = Arc::new(MockTool);
328 let args = serde_json::json!({"input": "test"});
329 let result_val = serde_json::json!({"output": "done"});
330
331 let result = adapted
332 .after_tool_call(tool, &args, result_val.clone(), ctx, &plugin_ctx)
333 .await
334 .unwrap();
335
336 assert!(called.load(Ordering::SeqCst));
337 match result {
338 AfterToolCallResult::Continue(returned_result) => {
339 assert_eq!(returned_result, result_val);
340 }
341 }
342 }
343
344 #[tokio::test]
345 async fn test_before_model_call_maps_continue() {
346 let plugin = Plugin::new(PluginConfig {
347 name: "test".to_string(),
348 before_model: Some(Box::new(|_ctx, request| {
349 Box::pin(async move { Ok(BeforeModelResult::Continue(request)) })
350 })),
351 ..Default::default()
352 });
353
354 let adapted = AdaptedPlugin::new(plugin, 100);
355 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
356 let plugin_ctx = PluginContext::new();
357 let request = LlmRequest::new("test-model", vec![]);
358
359 let result = adapted.before_model_call(request, ctx, &plugin_ctx).await.unwrap();
360
361 match result {
362 BeforeModelCallResult::Continue(req) => {
363 assert_eq!(req.model, "test-model");
364 }
365 _ => panic!("expected Continue"),
366 }
367 }
368
369 #[tokio::test]
370 async fn test_before_model_call_maps_skip_to_short_circuit() {
371 let plugin = Plugin::new(PluginConfig {
372 name: "test".to_string(),
373 before_model: Some(Box::new(|_ctx, _request| {
374 Box::pin(async move {
375 let response = LlmResponse {
376 content: Some(Content::new("model").with_text("cached")),
377 ..Default::default()
378 };
379 Ok(BeforeModelResult::Skip(response))
380 })
381 })),
382 ..Default::default()
383 });
384
385 let adapted = AdaptedPlugin::new(plugin, 100);
386 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
387 let plugin_ctx = PluginContext::new();
388 let request = LlmRequest::new("model", vec![]);
389
390 let result = adapted.before_model_call(request, ctx, &plugin_ctx).await.unwrap();
391
392 match result {
393 BeforeModelCallResult::ShortCircuit(resp) => {
394 assert!(resp.content.is_some());
395 }
396 _ => panic!("expected ShortCircuit"),
397 }
398 }
399
400 #[tokio::test]
401 async fn test_after_model_call_maps_some_to_continue_modified() {
402 let plugin = Plugin::new(PluginConfig {
403 name: "test".to_string(),
404 after_model: Some(Box::new(|_ctx, _response| {
405 Box::pin(async move {
406 let modified = LlmResponse {
407 content: Some(Content::new("model").with_text("modified")),
408 ..Default::default()
409 };
410 Ok(Some(modified))
411 })
412 })),
413 ..Default::default()
414 });
415
416 let adapted = AdaptedPlugin::new(plugin, 100);
417 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
418 let plugin_ctx = PluginContext::new();
419 let response = LlmResponse::default();
420
421 let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
422
423 match result {
424 AfterModelCallResult::Continue(resp) => {
425 let content = resp.content.unwrap();
426 assert!(
427 content
428 .parts
429 .iter()
430 .any(|p| matches!(p, Part::Text { text } if text == "modified"))
431 );
432 }
433 }
434 }
435
436 #[tokio::test]
437 async fn test_after_model_call_maps_none_to_continue_unchanged() {
438 let plugin = Plugin::new(PluginConfig {
439 name: "test".to_string(),
440 after_model: Some(Box::new(|_ctx, _response| Box::pin(async move { Ok(None) }))),
441 ..Default::default()
442 });
443
444 let adapted = AdaptedPlugin::new(plugin, 100);
445 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
446 let plugin_ctx = PluginContext::new();
447 let response = LlmResponse {
448 content: Some(Content::new("model").with_text("original")),
449 ..Default::default()
450 };
451
452 let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
453
454 match result {
455 AfterModelCallResult::Continue(resp) => {
456 let content = resp.content.unwrap();
457 assert!(
458 content
459 .parts
460 .iter()
461 .any(|p| matches!(p, Part::Text { text } if text == "original"))
462 );
463 }
464 }
465 }
466
467 #[tokio::test]
468 async fn test_close_delegates_to_inner() {
469 let closed = Arc::new(AtomicBool::new(false));
470 let closed_clone = closed.clone();
471
472 let plugin = Plugin::new(PluginConfig {
473 name: "test".to_string(),
474 close_fn: Some(Box::new(move || {
475 let flag = closed_clone.clone();
476 Box::pin(async move {
477 flag.store(true, Ordering::SeqCst);
478 })
479 })),
480 ..Default::default()
481 });
482
483 let adapted = AdaptedPlugin::new(plugin, 100);
484 adapted.close().await;
485
486 assert!(closed.load(Ordering::SeqCst));
487 }
488
489 #[tokio::test]
490 async fn test_no_callbacks_returns_continue_unchanged() {
491 let plugin = Plugin::new(PluginConfig { name: "empty".to_string(), ..Default::default() });
492
493 let adapted = AdaptedPlugin::new(plugin, 100);
494 let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
495 let plugin_ctx = PluginContext::new();
496 let tool: Arc<dyn Tool> = Arc::new(MockTool);
497
498 let args = serde_json::json!({"x": 1});
500 let result = adapted
501 .before_tool_call(tool.clone(), args.clone(), ctx.clone(), &plugin_ctx)
502 .await
503 .unwrap();
504 match result {
505 BeforeToolCallResult::Continue(v) => assert_eq!(v, args),
506 _ => panic!("expected Continue"),
507 }
508
509 let res_val = serde_json::json!({"y": 2});
511 let result = adapted
512 .after_tool_call(tool.clone(), &args, res_val.clone(), ctx.clone(), &plugin_ctx)
513 .await
514 .unwrap();
515 match result {
516 AfterToolCallResult::Continue(v) => assert_eq!(v, res_val),
517 }
518
519 let request = LlmRequest::new("m", vec![]);
521 let result = adapted.before_model_call(request, ctx.clone(), &plugin_ctx).await.unwrap();
522 match result {
523 BeforeModelCallResult::Continue(req) => assert_eq!(req.model, "m"),
524 _ => panic!("expected Continue"),
525 }
526
527 let response = LlmResponse {
529 content: Some(Content::new("model").with_text("hi")),
530 ..Default::default()
531 };
532 let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
533 match result {
534 AfterModelCallResult::Continue(resp) => {
535 assert!(resp.content.is_some());
536 }
537 }
538 }
539}