Skip to main content

opi_agent/
extension.rs

1//! Extension system for agent customization.
2//!
3//! Provides the [`Extension`] trait for registering lifecycle hooks, custom
4//! tools, custom commands, custom agent messages, and scoped extension state.
5//!
6//! # Lifecycle Ordering
7//!
8//! Extension hooks are called in registration order after the base hooks:
9//!
10//! 1. Base [`AgentHooks::before_tool_call`] then extension
11//!    [`Extension::on_before_tool_call`] for each extension (in registration
12//!    order).
13//! 2. Base [`AgentHooks::after_tool_call`] then extension
14//!    [`Extension::on_after_tool_call`] for each extension (in registration
15//!    order).
16//! 3. Base [`AgentHooks::prepare_next_turn`] then extension
17//!    [`Extension::prepare_next_turn`] for each extension (in registration
18//!    order). Extra messages are appended in that order.
19//!
20//! If any hook in the chain returns a deny/block result, the chain stops and
21//! the denial propagates. Extensions cannot override a denial from the base
22//! hooks or from an earlier extension.
23//!
24//! # Hook Error/Blocking Semantics
25//!
26//! - `on_before_tool_call` returning [`ExtensionHookResult::Block`] prevents
27//!   the tool from executing. The block reason is returned to the agent loop
28//!   as a tool error.
29//! - `on_after_tool_call` is an observer callback; it cannot modify tool
30//!   results. The base hook retains full control over result replacement via
31//!   [`AfterToolCallResult::Replace`].
32//! - `on_command` errors propagate to the caller via
33//!   [`ExtensionError::CommandError`]. If an extension returns an error, the
34//!   dispatch stops and the error is returned.
35//!
36//! # State Serialization
37//!
38//! Each extension can serialize and restore its own state. Extension states
39//! are keyed by extension name in the serialized map produced by
40//! [`ExtensionRegistry::serialize_states`]. Extensions that don't need state
41//! persistence can use the default (no-op) implementations of
42//! [`Extension::serialize_state`] and [`Extension::restore_state`].
43//!
44//! # Custom Tools
45//!
46//! Extensions provide tools via the [`Extension::tools`] method. These tools
47//! are collected during [`ExtensionRegistry::collect_tools`] and added to the
48//! agent's tool set alongside built-in tools. Extension tools follow the same
49//! [`Tool`] trait contract and validation rules as built-in tools.
50//!
51//! # Custom Commands
52//!
53//! Extensions handle custom commands via [`Extension::on_command`]. Commands
54//! are dispatched via [`ExtensionRegistry::dispatch_command`] to extensions in
55//! registration order. The first extension that returns `Ok(Some(value))`
56//! claims the command.
57//!
58//! # Unstable
59//!
60//! This module is part of the **unstable 0.x extension API**. Breaking changes
61//! may occur between minor versions without a major version bump.
62
63use std::future::Future;
64use std::pin::Pin;
65use std::sync::Arc;
66
67use serde_json::Value;
68use tokio_util::sync::CancellationToken;
69
70use opi_ai::provider::{ModelInfo, Provider};
71
72use crate::event::AgentEvent;
73use crate::hooks::{
74    AfterToolCallContext, AfterToolCallResult, AgentHooks, BeforeToolCallContext,
75    BeforeToolCallResult, PrepareNextTurnContext, ShouldStopAfterTurnContext,
76};
77use crate::loop_types::{AgentError, AgentLoopTurnUpdate};
78use crate::message::AgentMessage;
79use crate::tool::{Tool, ToolResult};
80
81// ---------------------------------------------------------------------------
82// Error types
83// ---------------------------------------------------------------------------
84
85/// Errors from extension operations.
86#[derive(Debug, thiserror::Error)]
87pub enum ExtensionError {
88    /// An extension with the same name is already registered.
89    #[error("duplicate extension name: {0}")]
90    DuplicateName(String),
91    /// The registry has already been shared with hooks or an event sink.
92    #[error("cannot register extensions after registry has been shared")]
93    RegistryLocked,
94    /// Extension state serialization failed.
95    #[error("state serialization failed for extension '{name}': {reason}")]
96    StateSerialization { name: String, reason: String },
97    /// Extension state restoration failed.
98    #[error("state restoration failed for extension '{name}': {reason}")]
99    StateRestoration { name: String, reason: String },
100    /// An extension command returned an error.
101    #[error("extension command error: {0}")]
102    CommandError(String),
103    /// A generic extension error.
104    #[error("{0}")]
105    Other(String),
106}
107
108// ---------------------------------------------------------------------------
109// Hook result
110// ---------------------------------------------------------------------------
111
112/// Result of an extension lifecycle hook.
113#[non_exhaustive]
114#[derive(Debug, Clone)]
115pub enum ExtensionHookResult {
116    /// Continue processing. The next hook in the chain (or the default
117    /// behavior) is invoked.
118    Continue,
119    /// Block further processing. For `on_before_tool_call`, this denies the
120    /// tool execution. The block reason is returned to the agent loop as a
121    /// tool error.
122    Block { reason: String },
123}
124
125// ---------------------------------------------------------------------------
126// Extension command
127// ---------------------------------------------------------------------------
128
129/// A custom command dispatched to extensions.
130#[derive(Debug, Clone)]
131pub struct ExtensionCommand {
132    /// The command name (e.g. `"todo/add"`).
133    pub name: String,
134    /// Optional ID for response correlation.
135    pub id: Option<String>,
136    /// Command arguments.
137    pub args: Value,
138}
139
140impl ExtensionCommand {
141    /// Create a new extension command.
142    pub fn new(name: impl Into<String>, args: Value) -> Self {
143        Self {
144            name: name.into(),
145            id: None,
146            args,
147        }
148    }
149
150    /// Add an ID for response correlation.
151    pub fn with_id(mut self, id: impl Into<String>) -> Self {
152        self.id = Some(id.into());
153        self
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Extension trait
159// ---------------------------------------------------------------------------
160
161/// Extension trait for registering lifecycle hooks, custom tools,
162/// custom commands, and scoped extension state.
163///
164/// See the [module-level documentation](self) for lifecycle ordering,
165/// error/blocking semantics, and state serialization contracts.
166///
167/// # Unstable
168///
169/// This trait is part of the **unstable 0.x extension API**. Breaking changes
170/// may occur between minor versions without a major version bump.
171pub trait Extension: Send + Sync {
172    /// Unique name for this extension. Must be non-empty and unique within
173    /// the registry.
174    fn name(&self) -> &str;
175
176    /// Tools provided by this extension.
177    ///
178    /// Called once during [`ExtensionRegistry::collect_tools`] to gather
179    /// extension tools for the agent's tool set.
180    fn tools(&self) -> Vec<Box<dyn Tool>> {
181        vec![]
182    }
183
184    /// Custom providers provided by this extension.
185    ///
186    /// Called during [`ExtensionRegistry::collect_providers`] to gather
187    /// providers for registration with the provider registry. Extensions
188    /// should return new provider instances on each call since `Box<dyn
189    /// Provider>` is not `Clone`.
190    ///
191    /// Provider breadth should arrive through registration rather than core
192    /// provider additions.
193    fn providers(&self) -> Vec<Box<dyn Provider>> {
194        vec![]
195    }
196
197    /// Additional models to register for existing providers.
198    ///
199    /// Called during [`ExtensionRegistry::collect_model_overrides`] to gather
200    /// model metadata that supplements or overrides the models declared by
201    /// built-in providers. Each entry is `(provider_id, ModelInfo)`.
202    fn model_overrides(&self) -> Vec<(String, ModelInfo)> {
203        vec![]
204    }
205
206    /// Called before a tool is executed (after the base hook, in registration
207    /// order).
208    ///
209    /// Return [`ExtensionHookResult::Block`] to deny the tool execution.
210    fn on_before_tool_call(
211        &self,
212        tool_name: &str,
213        args: &Value,
214    ) -> Pin<Box<dyn Future<Output = ExtensionHookResult> + Send>> {
215        let _ = (tool_name, args);
216        Box::pin(async { ExtensionHookResult::Continue })
217    }
218
219    /// Called after a tool has been executed (after the base hook, in
220    /// registration order).
221    ///
222    /// This is an observer callback; it cannot modify the tool result.
223    fn on_after_tool_call(
224        &self,
225        tool_name: &str,
226        result: &ToolResult,
227    ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
228        let _ = (tool_name, result);
229        Box::pin(async {})
230    }
231
232    /// Prepare context before the next turn begins.
233    ///
234    /// Extensions may return extra messages to inject into the agent's next
235    /// turn. Composite hooks append these messages after the base hook's
236    /// messages and preserve extension registration order.
237    fn prepare_next_turn(
238        &self,
239        _ctx: &PrepareNextTurnContext,
240    ) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
241        Box::pin(async { None })
242    }
243
244    /// Called for every agent event.
245    fn on_event(&self, _event: &AgentEvent) {}
246
247    /// Handle a custom command.
248    ///
249    /// Return `Ok(Some(value))` if the command was handled, `Ok(None)` if
250    /// the command is not recognized by this extension.
251    fn on_command(
252        &self,
253        _command: &ExtensionCommand,
254    ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, ExtensionError>> + Send>> {
255        Box::pin(async { Ok(None) })
256    }
257
258    /// Serialize extension state for session persistence.
259    ///
260    /// Return `Ok(Some(value))` with the serialized state, or `Ok(None)` if
261    /// the extension has no state to persist.
262    fn serialize_state(&self) -> Result<Option<Value>, ExtensionError> {
263        Ok(None)
264    }
265
266    /// Restore extension state from session persistence.
267    fn restore_state(&self, _state: Value) -> Result<(), ExtensionError> {
268        Ok(())
269    }
270}
271
272// ---------------------------------------------------------------------------
273// ExtensionRegistry
274// ---------------------------------------------------------------------------
275
276/// Registry that manages extensions and provides integration wrappers.
277///
278/// Extensions are registered before the agent loop starts. Once hooks or event
279/// sinks are wrapped via [`wrap_hooks`](Self::wrap_hooks) or
280/// [`wrap_event_sink`](Self::wrap_event_sink), the registry should not be
281/// modified further.
282pub struct ExtensionRegistry {
283    extensions: Arc<Vec<Box<dyn Extension>>>,
284}
285
286impl Clone for ExtensionRegistry {
287    fn clone(&self) -> Self {
288        Self {
289            extensions: self.extensions.clone(),
290        }
291    }
292}
293
294impl ExtensionRegistry {
295    /// Create an empty registry.
296    pub fn new() -> Self {
297        Self {
298            extensions: Arc::new(Vec::new()),
299        }
300    }
301
302    /// Register an extension.
303    ///
304    /// Returns an error if an extension with the same name already exists.
305    /// Returns [`ExtensionError::RegistryLocked`] if called after
306    /// `wrap_hooks()` or `wrap_event_sink()` has been called (i.e., the
307    /// extension list is shared).
308    pub fn register(&mut self, ext: Box<dyn Extension>) -> Result<(), ExtensionError> {
309        let name = ext.name().to_string();
310        if self.extensions.iter().any(|e| e.name() == name) {
311            return Err(ExtensionError::DuplicateName(name));
312        }
313        match Arc::get_mut(&mut self.extensions) {
314            Some(exts) => {
315                exts.push(ext);
316            }
317            None => {
318                return Err(ExtensionError::RegistryLocked);
319            }
320        }
321        Ok(())
322    }
323
324    /// Returns true if no extensions are registered.
325    pub fn is_empty(&self) -> bool {
326        self.extensions.is_empty()
327    }
328
329    /// Returns the number of registered extensions.
330    pub fn len(&self) -> usize {
331        self.extensions.len()
332    }
333
334    /// Return extension names in registration order.
335    pub fn names(&self) -> Vec<&str> {
336        self.extensions.iter().map(|e| e.name()).collect()
337    }
338
339    /// Look up an extension by name.
340    pub fn get(&self, name: &str) -> Option<&dyn Extension> {
341        self.extensions
342            .iter()
343            .find(|e| e.name() == name)
344            .map(|e| e.as_ref())
345    }
346
347    /// Collect all tools from all registered extensions.
348    pub fn collect_tools(&self) -> Vec<Box<dyn Tool>> {
349        self.extensions.iter().flat_map(|e| e.tools()).collect()
350    }
351
352    /// Collect all custom providers from all registered extensions.
353    ///
354    /// Each extension's [`providers`](Extension::providers) method is called
355    /// and the results are concatenated. Extensions should return fresh
356    /// provider instances since `Box<dyn Provider>` is not `Clone`.
357    pub fn collect_providers(&self) -> Vec<Box<dyn Provider>> {
358        self.extensions.iter().flat_map(|e| e.providers()).collect()
359    }
360
361    /// Collect all model overrides from all registered extensions.
362    ///
363    /// Each extension's [`model_overrides`](Extension::model_overrides) method
364    /// is called and the results are concatenated.
365    pub fn collect_model_overrides(&self) -> Vec<(String, ModelInfo)> {
366        self.extensions
367            .iter()
368            .flat_map(|e| e.model_overrides())
369            .collect()
370    }
371
372    /// Dispatch an event to all registered extensions.
373    pub fn dispatch_event(&self, event: &AgentEvent) {
374        for ext in self.extensions.iter() {
375            ext.on_event(event);
376        }
377    }
378
379    /// Dispatch a custom command to extensions in registration order.
380    ///
381    /// Returns the first `Some` response, or `None` if no extension handled
382    /// the command.
383    pub async fn dispatch_command(
384        &self,
385        command: &ExtensionCommand,
386    ) -> Result<Option<Value>, ExtensionError> {
387        for ext in self.extensions.iter() {
388            if let Some(value) = ext.on_command(command).await? {
389                return Ok(Some(value));
390            }
391        }
392        Ok(None)
393    }
394
395    /// Serialize all extension states into a JSON object keyed by extension
396    /// name.
397    pub fn serialize_states(&self) -> Result<Value, ExtensionError> {
398        let mut map = serde_json::Map::new();
399        for ext in self.extensions.iter() {
400            match ext.serialize_state() {
401                Ok(Some(state)) => {
402                    map.insert(ext.name().to_string(), state);
403                }
404                Ok(None) => {}
405                Err(e) => return Err(e),
406            }
407        }
408        Ok(Value::Object(map))
409    }
410
411    /// Restore extension states from a JSON object keyed by extension name.
412    pub fn restore_states(&self, states: Value) -> Result<(), ExtensionError> {
413        let map = match states {
414            Value::Object(m) => m,
415            _ => return Ok(()),
416        };
417        for ext in self.extensions.iter() {
418            if let Some(state) = map.get(ext.name()) {
419                ext.restore_state(state.clone())?;
420            }
421        }
422        Ok(())
423    }
424
425    /// Create a composite [`AgentHooks`] that wraps the base hooks with
426    /// extension lifecycle callbacks.
427    ///
428    /// Extension hooks are called after the base hooks. If any extension
429    /// returns [`ExtensionHookResult::Block`], the chain stops and the block
430    /// propagates as a denial.
431    pub fn wrap_hooks(&self, base: Box<dyn AgentHooks>) -> Box<dyn AgentHooks> {
432        Box::new(CompositeHooks {
433            base: Arc::from(base),
434            extensions: self.extensions.clone(),
435        })
436    }
437
438    /// Wrap an event sink to dispatch events to all registered extensions
439    /// before forwarding to the base sink.
440    pub fn wrap_event_sink(
441        &self,
442        base_sink: crate::event::AgentEventSink,
443    ) -> crate::event::AgentEventSink {
444        let extensions = self.extensions.clone();
445        Box::new(move |event: AgentEvent| {
446            for ext in extensions.iter() {
447                ext.on_event(&event);
448            }
449            base_sink(event);
450        })
451    }
452}
453
454impl Default for ExtensionRegistry {
455    fn default() -> Self {
456        Self::new()
457    }
458}
459
460// ---------------------------------------------------------------------------
461// CompositeHooks
462// ---------------------------------------------------------------------------
463
464/// Internal type that chains extension hooks after base hooks.
465struct CompositeHooks {
466    base: Arc<dyn AgentHooks>,
467    extensions: Arc<Vec<Box<dyn Extension>>>,
468}
469
470impl AgentHooks for CompositeHooks {
471    fn convert_to_llm(
472        &self,
473        messages: &[AgentMessage],
474    ) -> Result<Vec<opi_ai::message::Message>, AgentError> {
475        self.base.convert_to_llm(messages)
476    }
477
478    fn transform_context(
479        &self,
480        messages: Vec<AgentMessage>,
481        signal: CancellationToken,
482    ) -> Pin<Box<dyn Future<Output = Result<Vec<AgentMessage>, AgentError>> + Send>> {
483        self.base.transform_context(messages, signal)
484    }
485
486    fn should_stop_after_turn(
487        &self,
488        ctx: ShouldStopAfterTurnContext,
489    ) -> Pin<Box<dyn Future<Output = bool> + Send>> {
490        self.base.should_stop_after_turn(ctx)
491    }
492
493    fn before_tool_call(
494        &self,
495        ctx: BeforeToolCallContext,
496    ) -> Pin<Box<dyn Future<Output = BeforeToolCallResult> + Send>> {
497        let base = self.base.clone();
498        let extensions = self.extensions.clone();
499        let tool_name = ctx.tool_name.clone();
500        let args = ctx.args.clone();
501        Box::pin(async move {
502            // Base hook decides first.
503            match base.before_tool_call(ctx).await {
504                BeforeToolCallResult::Allow => {}
505                BeforeToolCallResult::Deny { reason } => {
506                    return BeforeToolCallResult::Deny { reason };
507                }
508            }
509
510            // Extension hooks in registration order.
511            for ext in extensions.iter() {
512                match ext.on_before_tool_call(&tool_name, &args).await {
513                    ExtensionHookResult::Continue => {}
514                    ExtensionHookResult::Block { reason } => {
515                        return BeforeToolCallResult::Deny { reason };
516                    }
517                }
518            }
519
520            BeforeToolCallResult::Allow
521        })
522    }
523
524    fn after_tool_call(
525        &self,
526        ctx: AfterToolCallContext,
527    ) -> Pin<Box<dyn Future<Output = AfterToolCallResult> + Send>> {
528        let base = self.base.clone();
529        let extensions = self.extensions.clone();
530        let tool_name = ctx.tool_name.clone();
531        let result_snapshot = ctx.result.clone();
532        Box::pin(async move {
533            // Base hook decides first (may keep or replace).
534            let base_result = base.after_tool_call(ctx).await;
535
536            // Determine the effective result for extension observation.
537            let effective: &ToolResult = match &base_result {
538                AfterToolCallResult::Keep => &result_snapshot,
539                AfterToolCallResult::Replace(r) => r,
540            };
541
542            // Notify extension observers (cannot modify result).
543            for ext in extensions.iter() {
544                ext.on_after_tool_call(&tool_name, effective).await;
545            }
546
547            base_result
548        })
549    }
550
551    fn prepare_next_turn(
552        &self,
553        ctx: PrepareNextTurnContext,
554    ) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
555        let base = self.base.clone();
556        let extensions = self.extensions.clone();
557        let extension_ctx = PrepareNextTurnContext {
558            messages: ctx.messages.clone(),
559            turn: ctx.turn,
560        };
561        Box::pin(async move {
562            let mut extra_messages = base
563                .prepare_next_turn(ctx)
564                .await
565                .map(|update| update.extra_messages)
566                .unwrap_or_default();
567
568            for ext in extensions.iter() {
569                if let Some(update) = ext.prepare_next_turn(&extension_ctx).await {
570                    extra_messages.extend(update.extra_messages);
571                }
572            }
573
574            if extra_messages.is_empty() {
575                None
576            } else {
577                Some(AgentLoopTurnUpdate { extra_messages })
578            }
579        })
580    }
581}