rs-adk 0.5.0

Agent runtime for Gemini Live — tools, streaming, agent transfer, middleware
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
//! Plugin system — lifecycle hooks with control-flow capabilities.
//!
//! Plugins are a superset of middleware: they can observe AND control agent
//! execution. A plugin can deny a tool call, short-circuit with a custom
//! response, or simply continue. The `PluginManager` runs plugins in order
//! and respects the first non-Continue result.

mod context_filter;
mod global_instruction;
mod logging;
mod reflect_retry;
mod security;

pub use context_filter::ContextFilterPlugin;
pub use global_instruction::GlobalInstructionPlugin;
pub use logging::LoggingPlugin;
pub use reflect_retry::ReflectRetryToolPlugin;
pub use security::{AllowAllPolicy, DenyListPolicy, PolicyEngine, PolicyOutcome, SecurityPlugin};

use std::sync::Arc;

use async_trait::async_trait;

use rs_genai::prelude::FunctionCall;

use crate::context::InvocationContext;
use crate::events::Event;

/// The result of a plugin hook — controls whether execution continues.
#[derive(Debug, Clone)]
pub enum PluginResult {
    /// Continue with normal execution.
    Continue,
    /// Short-circuit execution with a custom value (e.g., cached response).
    ShortCircuit(serde_json::Value),
    /// Deny the action with a reason string.
    Deny(String),
}

impl PluginResult {
    /// Returns true if this result is `Continue`.
    pub fn is_continue(&self) -> bool {
        matches!(self, Self::Continue)
    }

    /// Returns true if this result is `Deny`.
    pub fn is_deny(&self) -> bool {
        matches!(self, Self::Deny(_))
    }

    /// Returns true if this result is `ShortCircuit`.
    pub fn is_short_circuit(&self) -> bool {
        matches!(self, Self::ShortCircuit(_))
    }
}

/// Plugin trait — lifecycle hooks with control-flow capabilities.
///
/// Unlike `Middleware` (which is observe-only), plugins can deny or
/// short-circuit execution. All hooks default to `PluginResult::Continue`.
#[async_trait]
pub trait Plugin: Send + Sync + 'static {
    /// Plugin name for logging/debugging.
    fn name(&self) -> &str;

    /// Called before an agent starts execution.
    async fn before_agent(&self, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called after an agent completes execution.
    async fn after_agent(&self, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called before a tool is executed. Return `Deny` to prevent execution.
    async fn before_tool(&self, _call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called after a tool completes. Can transform or deny the result.
    async fn after_tool(
        &self,
        _call: &FunctionCall,
        _result: &serde_json::Value,
        _ctx: &InvocationContext,
    ) -> PluginResult {
        PluginResult::Continue
    }

    /// Called when an event is emitted.
    async fn on_event(&self, _event: &Event, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called when a user message is received.
    async fn on_user_message(&self, _message: &str, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called before a run starts (before the agent loop).
    async fn before_run(&self, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called after a run completes (after the agent loop).
    async fn after_run(&self, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called before a model generation request.
    async fn before_model(
        &self,
        _request: &crate::llm::LlmRequest,
        _ctx: &InvocationContext,
    ) -> PluginResult {
        PluginResult::Continue
    }

    /// Called after a model generation response.
    async fn after_model(
        &self,
        _response: &crate::llm::LlmResponse,
        _ctx: &InvocationContext,
    ) -> PluginResult {
        PluginResult::Continue
    }

    /// Called when a model generation fails.
    async fn on_model_error(&self, _error: &str, _ctx: &InvocationContext) -> PluginResult {
        PluginResult::Continue
    }

    /// Called when a tool execution fails.
    async fn on_tool_error(
        &self,
        _call: &FunctionCall,
        _error: &str,
        _ctx: &InvocationContext,
    ) -> PluginResult {
        PluginResult::Continue
    }
}

/// Manages an ordered list of plugins, running them in sequence.
///
/// On each hook, plugins run in order. The first non-Continue result
/// short-circuits the remaining plugins.
#[derive(Clone, Default)]
pub struct PluginManager {
    plugins: Vec<Arc<dyn Plugin>>,
}

impl PluginManager {
    /// Create an empty plugin manager.
    pub fn new() -> Self {
        Self::default()
    }

    /// Add a plugin to the manager.
    pub fn add(&mut self, plugin: Arc<dyn Plugin>) {
        self.plugins.push(plugin);
    }

    /// Number of registered plugins.
    pub fn len(&self) -> usize {
        self.plugins.len()
    }

    /// Returns true if no plugins are registered.
    pub fn is_empty(&self) -> bool {
        self.plugins.is_empty()
    }

    /// Run before_agent hooks. Returns first non-Continue result, or Continue.
    pub async fn run_before_agent(&self, ctx: &InvocationContext) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.before_agent(ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run after_agent hooks. Returns first non-Continue result, or Continue.
    pub async fn run_after_agent(&self, ctx: &InvocationContext) -> PluginResult {
        for plugin in self.plugins.iter().rev() {
            let result = plugin.after_agent(ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run before_tool hooks. Returns first non-Continue result, or Continue.
    pub async fn run_before_tool(
        &self,
        call: &FunctionCall,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.before_tool(call, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run after_tool hooks. Returns first non-Continue result, or Continue.
    pub async fn run_after_tool(
        &self,
        call: &FunctionCall,
        value: &serde_json::Value,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in self.plugins.iter().rev() {
            let result = plugin.after_tool(call, value, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run on_event hooks. Returns first non-Continue result, or Continue.
    pub async fn run_on_event(&self, event: &Event, ctx: &InvocationContext) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.on_event(event, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run on_user_message hooks.
    pub async fn run_on_user_message(
        &self,
        message: &str,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.on_user_message(message, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run before_run hooks.
    pub async fn run_before_run(&self, ctx: &InvocationContext) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.before_run(ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run after_run hooks.
    pub async fn run_after_run(&self, ctx: &InvocationContext) -> PluginResult {
        for plugin in self.plugins.iter().rev() {
            let result = plugin.after_run(ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run before_model hooks.
    pub async fn run_before_model(
        &self,
        request: &crate::llm::LlmRequest,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.before_model(request, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run after_model hooks.
    pub async fn run_after_model(
        &self,
        response: &crate::llm::LlmResponse,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in self.plugins.iter().rev() {
            let result = plugin.after_model(response, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run on_model_error hooks.
    pub async fn run_on_model_error(&self, error: &str, ctx: &InvocationContext) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.on_model_error(error, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }

    /// Run on_tool_error hooks.
    pub async fn run_on_tool_error(
        &self,
        call: &FunctionCall,
        error: &str,
        ctx: &InvocationContext,
    ) -> PluginResult {
        for plugin in &self.plugins {
            let result = plugin.on_tool_error(call, error, ctx).await;
            if !result.is_continue() {
                return result;
            }
        }
        PluginResult::Continue
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn plugin_result_helpers() {
        assert!(PluginResult::Continue.is_continue());
        assert!(!PluginResult::Continue.is_deny());
        assert!(!PluginResult::Continue.is_short_circuit());

        assert!(PluginResult::Deny("nope".into()).is_deny());
        assert!(!PluginResult::Deny("nope".into()).is_continue());

        let val = serde_json::json!({"cached": true});
        assert!(PluginResult::ShortCircuit(val).is_short_circuit());
    }

    #[test]
    fn plugin_manager_empty() {
        let pm = PluginManager::new();
        assert!(pm.is_empty());
        assert_eq!(pm.len(), 0);
    }

    #[test]
    fn plugin_manager_add() {
        let mut pm = PluginManager::new();
        pm.add(Arc::new(LoggingPlugin::new()));
        assert_eq!(pm.len(), 1);
        assert!(!pm.is_empty());
    }

    #[test]
    fn plugin_is_object_safe() {
        fn _assert(_: &dyn Plugin) {}
    }

    struct DenyPlugin;

    #[async_trait]
    impl Plugin for DenyPlugin {
        fn name(&self) -> &str {
            "deny"
        }

        async fn before_tool(
            &self,
            _call: &FunctionCall,
            _ctx: &InvocationContext,
        ) -> PluginResult {
            PluginResult::Deny("blocked by policy".into())
        }
    }

    struct CountPlugin {
        count: std::sync::atomic::AtomicU32,
    }

    #[async_trait]
    impl Plugin for CountPlugin {
        fn name(&self) -> &str {
            "count"
        }

        async fn before_tool(
            &self,
            _call: &FunctionCall,
            _ctx: &InvocationContext,
        ) -> PluginResult {
            self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
            PluginResult::Continue
        }
    }

    // Test that new hooks default to Continue
    #[tokio::test]
    async fn new_hooks_default_to_continue() {
        use tokio::sync::broadcast;

        let mut pm = PluginManager::new();
        pm.add(Arc::new(LoggingPlugin::new()));

        let (evt_tx, _) = broadcast::channel(16);
        let writer: Arc<dyn rs_genai::session::SessionWriter> =
            Arc::new(crate::test_helpers::MockWriter);
        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
        let ctx = InvocationContext::new(session);

        assert!(pm.run_before_run(&ctx).await.is_continue());
        assert!(pm.run_after_run(&ctx).await.is_continue());
        assert!(pm.run_on_user_message("hello", &ctx).await.is_continue());

        let req = crate::llm::LlmRequest::from_text("test");
        assert!(pm.run_before_model(&req, &ctx).await.is_continue());

        assert!(pm.run_on_model_error("err", &ctx).await.is_continue());

        let call = FunctionCall {
            name: "t".into(),
            args: serde_json::json!({}),
            id: None,
        };
        assert!(pm.run_on_tool_error(&call, "err", &ctx).await.is_continue());
    }

    // Test custom plugin implementing before_model
    struct ModelBlockerPlugin;

    #[async_trait]
    impl Plugin for ModelBlockerPlugin {
        fn name(&self) -> &str {
            "model-blocker"
        }

        async fn before_model(
            &self,
            _request: &crate::llm::LlmRequest,
            _ctx: &InvocationContext,
        ) -> PluginResult {
            PluginResult::Deny("model calls blocked".into())
        }
    }

    #[tokio::test]
    async fn custom_before_model_plugin() {
        use tokio::sync::broadcast;

        let mut pm = PluginManager::new();
        pm.add(Arc::new(ModelBlockerPlugin));

        let (evt_tx, _) = broadcast::channel(16);
        let writer: Arc<dyn rs_genai::session::SessionWriter> =
            Arc::new(crate::test_helpers::MockWriter);
        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
        let ctx = InvocationContext::new(session);

        let req = crate::llm::LlmRequest::from_text("test");
        let result = pm.run_before_model(&req, &ctx).await;
        assert!(result.is_deny());
    }

    // Test that a deny plugin prevents later plugins from running
    #[tokio::test]
    async fn plugin_manager_deny_short_circuits() {
        use tokio::sync::broadcast;

        let count_plugin = Arc::new(CountPlugin {
            count: std::sync::atomic::AtomicU32::new(0),
        });

        let mut pm = PluginManager::new();
        pm.add(Arc::new(DenyPlugin));
        pm.add(count_plugin.clone());

        // Create a minimal InvocationContext for testing
        let (evt_tx, _) = broadcast::channel(16);
        let writer: Arc<dyn rs_genai::session::SessionWriter> =
            Arc::new(crate::test_helpers::MockWriter);
        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
        let ctx = InvocationContext::new(session);

        let call = FunctionCall {
            name: "dangerous_tool".into(),
            args: serde_json::json!({}),
            id: None,
        };

        let result = pm.run_before_tool(&call, &ctx).await;
        assert!(result.is_deny());

        // CountPlugin should NOT have been called
        assert_eq!(
            count_plugin.count.load(std::sync::atomic::Ordering::SeqCst),
            0
        );
    }
}