opi_agent/extension.rs
1//! Extension system for agent customization.
2//!
3//! Provides the [`Extension`] trait for registering lifecycle hooks, custom
4//! tools, custom commands, custom agent messages, and scoped extension state.
5//!
6//! # Lifecycle Ordering
7//!
8//! Extension hooks are called in registration order after the base hooks:
9//!
10//! 1. Base [`AgentHooks::before_tool_call`] then extension
11//! [`Extension::on_before_tool_call`] for each extension (in registration
12//! order).
13//! 2. Base [`AgentHooks::after_tool_call`] then extension
14//! [`Extension::on_after_tool_call`] for each extension (in registration
15//! order).
16//! 3. Base [`AgentHooks::prepare_next_turn`] then extension
17//! [`Extension::prepare_next_turn`] for each extension (in registration
18//! order). Extra messages are appended in that order.
19//!
20//! If any hook in the chain returns a deny/block result, the chain stops and
21//! the denial propagates. Extensions cannot override a denial from the base
22//! hooks or from an earlier extension.
23//!
24//! # Hook Error/Blocking Semantics
25//!
26//! - `on_before_tool_call` returning [`ExtensionHookResult::Block`] prevents
27//! the tool from executing. The block reason is returned to the agent loop
28//! as a tool error.
29//! - `on_after_tool_call` is an observer callback; it cannot modify tool
30//! results. The base hook retains full control over result replacement via
31//! [`AfterToolCallResult::Replace`].
32//! - `on_command` errors propagate to the caller via
33//! [`ExtensionError::CommandError`]. If an extension returns an error, the
34//! dispatch stops and the error is returned.
35//!
36//! # State Serialization
37//!
38//! Each extension can serialize and restore its own state. Extension states
39//! are keyed by extension name in the serialized map produced by
40//! [`ExtensionRegistry::serialize_states`]. Extensions that don't need state
41//! persistence can use the default (no-op) implementations of
42//! [`Extension::serialize_state`] and [`Extension::restore_state`].
43//!
44//! # Custom Tools
45//!
46//! Extensions provide tools via the [`Extension::tools`] method. These tools
47//! are collected during [`ExtensionRegistry::collect_tools`] and added to the
48//! agent's tool set alongside built-in tools. Extension tools follow the same
49//! [`Tool`] trait contract and validation rules as built-in tools.
50//!
51//! # Custom Commands
52//!
53//! Extensions handle custom commands via [`Extension::on_command`]. Commands
54//! are dispatched via [`ExtensionRegistry::dispatch_command`] to extensions in
55//! registration order. The first extension that returns `Ok(Some(value))`
56//! claims the command.
57//!
58//! # Unstable
59//!
60//! This module is part of the **unstable 0.x extension API**. Breaking changes
61//! may occur between minor versions without a major version bump.
62
63use std::future::Future;
64use std::pin::Pin;
65use std::sync::Arc;
66
67use serde_json::Value;
68use tokio_util::sync::CancellationToken;
69
70use opi_ai::provider::{ModelInfo, Provider};
71
72use crate::event::AgentEvent;
73use crate::hooks::{
74 AfterToolCallContext, AfterToolCallResult, AgentHooks, BeforeToolCallContext,
75 BeforeToolCallResult, PrepareNextTurnContext, ShouldStopAfterTurnContext,
76};
77use crate::loop_types::{AgentError, AgentLoopTurnUpdate};
78use crate::message::AgentMessage;
79use crate::tool::{Tool, ToolResult};
80
81// ---------------------------------------------------------------------------
82// Error types
83// ---------------------------------------------------------------------------
84
85/// Errors from extension operations.
86#[derive(Debug, thiserror::Error)]
87pub enum ExtensionError {
88 /// An extension with the same name is already registered.
89 #[error("duplicate extension name: {0}")]
90 DuplicateName(String),
91 /// The registry has already been shared with hooks or an event sink.
92 #[error("cannot register extensions after registry has been shared")]
93 RegistryLocked,
94 /// Extension state serialization failed.
95 #[error("state serialization failed for extension '{name}': {reason}")]
96 StateSerialization { name: String, reason: String },
97 /// Extension state restoration failed.
98 #[error("state restoration failed for extension '{name}': {reason}")]
99 StateRestoration { name: String, reason: String },
100 /// An extension command returned an error.
101 #[error("extension command error: {0}")]
102 CommandError(String),
103 /// A generic extension error.
104 #[error("{0}")]
105 Other(String),
106}
107
108// ---------------------------------------------------------------------------
109// Hook result
110// ---------------------------------------------------------------------------
111
112/// Result of an extension lifecycle hook.
113#[non_exhaustive]
114#[derive(Debug, Clone)]
115pub enum ExtensionHookResult {
116 /// Continue processing. The next hook in the chain (or the default
117 /// behavior) is invoked.
118 Continue,
119 /// Block further processing. For `on_before_tool_call`, this denies the
120 /// tool execution. The block reason is returned to the agent loop as a
121 /// tool error.
122 Block { reason: String },
123}
124
125// ---------------------------------------------------------------------------
126// Extension command
127// ---------------------------------------------------------------------------
128
129/// A custom command dispatched to extensions.
130#[derive(Debug, Clone)]
131pub struct ExtensionCommand {
132 /// The command name (e.g. `"todo/add"`).
133 pub name: String,
134 /// Optional ID for response correlation.
135 pub id: Option<String>,
136 /// Command arguments.
137 pub args: Value,
138}
139
140impl ExtensionCommand {
141 /// Create a new extension command.
142 pub fn new(name: impl Into<String>, args: Value) -> Self {
143 Self {
144 name: name.into(),
145 id: None,
146 args,
147 }
148 }
149
150 /// Add an ID for response correlation.
151 pub fn with_id(mut self, id: impl Into<String>) -> Self {
152 self.id = Some(id.into());
153 self
154 }
155}
156
157// ---------------------------------------------------------------------------
158// Extension trait
159// ---------------------------------------------------------------------------
160
161/// Extension trait for registering lifecycle hooks, custom tools,
162/// custom commands, and scoped extension state.
163///
164/// See the [module-level documentation](self) for lifecycle ordering,
165/// error/blocking semantics, and state serialization contracts.
166///
167/// # Unstable
168///
169/// This trait is part of the **unstable 0.x extension API**. Breaking changes
170/// may occur between minor versions without a major version bump.
171pub trait Extension: Send + Sync {
172 /// Unique name for this extension. Must be non-empty and unique within
173 /// the registry.
174 fn name(&self) -> &str;
175
176 /// Tools provided by this extension.
177 ///
178 /// Called once during [`ExtensionRegistry::collect_tools`] to gather
179 /// extension tools for the agent's tool set.
180 fn tools(&self) -> Vec<Box<dyn Tool>> {
181 vec![]
182 }
183
184 /// Custom providers provided by this extension.
185 ///
186 /// Called during [`ExtensionRegistry::collect_providers`] to gather
187 /// providers for registration with the provider registry. Extensions
188 /// should return new provider instances on each call since `Box<dyn
189 /// Provider>` is not `Clone`.
190 ///
191 /// Provider breadth should arrive through registration rather than core
192 /// provider additions.
193 fn providers(&self) -> Vec<Box<dyn Provider>> {
194 vec![]
195 }
196
197 /// Additional models to register for existing providers.
198 ///
199 /// Called during [`ExtensionRegistry::collect_model_overrides`] to gather
200 /// model metadata that supplements or overrides the models declared by
201 /// built-in providers. Each entry is `(provider_id, ModelInfo)`.
202 fn model_overrides(&self) -> Vec<(String, ModelInfo)> {
203 vec![]
204 }
205
206 /// Called before a tool is executed (after the base hook, in registration
207 /// order).
208 ///
209 /// Return [`ExtensionHookResult::Block`] to deny the tool execution.
210 fn on_before_tool_call(
211 &self,
212 tool_name: &str,
213 args: &Value,
214 ) -> Pin<Box<dyn Future<Output = ExtensionHookResult> + Send>> {
215 let _ = (tool_name, args);
216 Box::pin(async { ExtensionHookResult::Continue })
217 }
218
219 /// Called after a tool has been executed (after the base hook, in
220 /// registration order).
221 ///
222 /// This is an observer callback; it cannot modify the tool result.
223 fn on_after_tool_call(
224 &self,
225 tool_name: &str,
226 result: &ToolResult,
227 ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
228 let _ = (tool_name, result);
229 Box::pin(async {})
230 }
231
232 /// Prepare context before the next turn begins.
233 ///
234 /// Extensions may return extra messages to inject into the agent's next
235 /// turn. Composite hooks append these messages after the base hook's
236 /// messages and preserve extension registration order.
237 fn prepare_next_turn(
238 &self,
239 _ctx: &PrepareNextTurnContext,
240 ) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
241 Box::pin(async { None })
242 }
243
244 /// Called for every agent event.
245 fn on_event(&self, _event: &AgentEvent) {}
246
247 /// Handle a custom command.
248 ///
249 /// Return `Ok(Some(value))` if the command was handled, `Ok(None)` if
250 /// the command is not recognized by this extension.
251 fn on_command(
252 &self,
253 _command: &ExtensionCommand,
254 ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, ExtensionError>> + Send>> {
255 Box::pin(async { Ok(None) })
256 }
257
258 /// Serialize extension state for session persistence.
259 ///
260 /// Return `Ok(Some(value))` with the serialized state, or `Ok(None)` if
261 /// the extension has no state to persist.
262 fn serialize_state(&self) -> Result<Option<Value>, ExtensionError> {
263 Ok(None)
264 }
265
266 /// Restore extension state from session persistence.
267 fn restore_state(&self, _state: Value) -> Result<(), ExtensionError> {
268 Ok(())
269 }
270}
271
272// ---------------------------------------------------------------------------
273// ExtensionRegistry
274// ---------------------------------------------------------------------------
275
276/// Registry that manages extensions and provides integration wrappers.
277///
278/// Extensions are registered before the agent loop starts. Once hooks or event
279/// sinks are wrapped via [`wrap_hooks`](Self::wrap_hooks) or
280/// [`wrap_event_sink`](Self::wrap_event_sink), the registry should not be
281/// modified further.
282pub struct ExtensionRegistry {
283 extensions: Arc<Vec<Box<dyn Extension>>>,
284}
285
286impl Clone for ExtensionRegistry {
287 fn clone(&self) -> Self {
288 Self {
289 extensions: self.extensions.clone(),
290 }
291 }
292}
293
294impl ExtensionRegistry {
295 /// Create an empty registry.
296 pub fn new() -> Self {
297 Self {
298 extensions: Arc::new(Vec::new()),
299 }
300 }
301
302 /// Register an extension.
303 ///
304 /// Returns an error if an extension with the same name already exists.
305 /// Returns [`ExtensionError::RegistryLocked`] if called after
306 /// `wrap_hooks()` or `wrap_event_sink()` has been called (i.e., the
307 /// extension list is shared).
308 pub fn register(&mut self, ext: Box<dyn Extension>) -> Result<(), ExtensionError> {
309 let name = ext.name().to_string();
310 if self.extensions.iter().any(|e| e.name() == name) {
311 return Err(ExtensionError::DuplicateName(name));
312 }
313 match Arc::get_mut(&mut self.extensions) {
314 Some(exts) => {
315 exts.push(ext);
316 }
317 None => {
318 return Err(ExtensionError::RegistryLocked);
319 }
320 }
321 Ok(())
322 }
323
324 /// Returns true if no extensions are registered.
325 pub fn is_empty(&self) -> bool {
326 self.extensions.is_empty()
327 }
328
329 /// Returns the number of registered extensions.
330 pub fn len(&self) -> usize {
331 self.extensions.len()
332 }
333
334 /// Return extension names in registration order.
335 pub fn names(&self) -> Vec<&str> {
336 self.extensions.iter().map(|e| e.name()).collect()
337 }
338
339 /// Look up an extension by name.
340 pub fn get(&self, name: &str) -> Option<&dyn Extension> {
341 self.extensions
342 .iter()
343 .find(|e| e.name() == name)
344 .map(|e| e.as_ref())
345 }
346
347 /// Collect all tools from all registered extensions.
348 pub fn collect_tools(&self) -> Vec<Box<dyn Tool>> {
349 self.extensions.iter().flat_map(|e| e.tools()).collect()
350 }
351
352 /// Collect all custom providers from all registered extensions.
353 ///
354 /// Each extension's [`providers`](Extension::providers) method is called
355 /// and the results are concatenated. Extensions should return fresh
356 /// provider instances since `Box<dyn Provider>` is not `Clone`.
357 pub fn collect_providers(&self) -> Vec<Box<dyn Provider>> {
358 self.extensions.iter().flat_map(|e| e.providers()).collect()
359 }
360
361 /// Collect all model overrides from all registered extensions.
362 ///
363 /// Each extension's [`model_overrides`](Extension::model_overrides) method
364 /// is called and the results are concatenated.
365 pub fn collect_model_overrides(&self) -> Vec<(String, ModelInfo)> {
366 self.extensions
367 .iter()
368 .flat_map(|e| e.model_overrides())
369 .collect()
370 }
371
372 /// Dispatch an event to all registered extensions.
373 pub fn dispatch_event(&self, event: &AgentEvent) {
374 for ext in self.extensions.iter() {
375 ext.on_event(event);
376 }
377 }
378
379 /// Dispatch a custom command to extensions in registration order.
380 ///
381 /// Returns the first `Some` response, or `None` if no extension handled
382 /// the command.
383 pub async fn dispatch_command(
384 &self,
385 command: &ExtensionCommand,
386 ) -> Result<Option<Value>, ExtensionError> {
387 for ext in self.extensions.iter() {
388 if let Some(value) = ext.on_command(command).await? {
389 return Ok(Some(value));
390 }
391 }
392 Ok(None)
393 }
394
395 /// Serialize all extension states into a JSON object keyed by extension
396 /// name.
397 pub fn serialize_states(&self) -> Result<Value, ExtensionError> {
398 let mut map = serde_json::Map::new();
399 for ext in self.extensions.iter() {
400 match ext.serialize_state() {
401 Ok(Some(state)) => {
402 map.insert(ext.name().to_string(), state);
403 }
404 Ok(None) => {}
405 Err(e) => return Err(e),
406 }
407 }
408 Ok(Value::Object(map))
409 }
410
411 /// Restore extension states from a JSON object keyed by extension name.
412 pub fn restore_states(&self, states: Value) -> Result<(), ExtensionError> {
413 let map = match states {
414 Value::Object(m) => m,
415 _ => return Ok(()),
416 };
417 for ext in self.extensions.iter() {
418 if let Some(state) = map.get(ext.name()) {
419 ext.restore_state(state.clone())?;
420 }
421 }
422 Ok(())
423 }
424
425 /// Create a composite [`AgentHooks`] that wraps the base hooks with
426 /// extension lifecycle callbacks.
427 ///
428 /// Extension hooks are called after the base hooks. If any extension
429 /// returns [`ExtensionHookResult::Block`], the chain stops and the block
430 /// propagates as a denial.
431 pub fn wrap_hooks(&self, base: Box<dyn AgentHooks>) -> Box<dyn AgentHooks> {
432 Box::new(CompositeHooks {
433 base: Arc::from(base),
434 extensions: self.extensions.clone(),
435 })
436 }
437
438 /// Wrap an event sink to dispatch events to all registered extensions
439 /// before forwarding to the base sink.
440 pub fn wrap_event_sink(
441 &self,
442 base_sink: crate::event::AgentEventSink,
443 ) -> crate::event::AgentEventSink {
444 let extensions = self.extensions.clone();
445 Box::new(move |event: AgentEvent| {
446 for ext in extensions.iter() {
447 ext.on_event(&event);
448 }
449 base_sink(event);
450 })
451 }
452}
453
454impl Default for ExtensionRegistry {
455 fn default() -> Self {
456 Self::new()
457 }
458}
459
460// ---------------------------------------------------------------------------
461// CompositeHooks
462// ---------------------------------------------------------------------------
463
464/// Internal type that chains extension hooks after base hooks.
465struct CompositeHooks {
466 base: Arc<dyn AgentHooks>,
467 extensions: Arc<Vec<Box<dyn Extension>>>,
468}
469
470impl AgentHooks for CompositeHooks {
471 fn convert_to_llm(
472 &self,
473 messages: &[AgentMessage],
474 ) -> Result<Vec<opi_ai::message::Message>, AgentError> {
475 self.base.convert_to_llm(messages)
476 }
477
478 fn transform_context(
479 &self,
480 messages: Vec<AgentMessage>,
481 signal: CancellationToken,
482 ) -> Pin<Box<dyn Future<Output = Result<Vec<AgentMessage>, AgentError>> + Send>> {
483 self.base.transform_context(messages, signal)
484 }
485
486 fn should_stop_after_turn(
487 &self,
488 ctx: ShouldStopAfterTurnContext,
489 ) -> Pin<Box<dyn Future<Output = bool> + Send>> {
490 self.base.should_stop_after_turn(ctx)
491 }
492
493 fn before_tool_call(
494 &self,
495 ctx: BeforeToolCallContext,
496 ) -> Pin<Box<dyn Future<Output = BeforeToolCallResult> + Send>> {
497 let base = self.base.clone();
498 let extensions = self.extensions.clone();
499 let tool_name = ctx.tool_name.clone();
500 let args = ctx.args.clone();
501 Box::pin(async move {
502 // Base hook decides first.
503 match base.before_tool_call(ctx).await {
504 BeforeToolCallResult::Allow => {}
505 BeforeToolCallResult::Deny { reason } => {
506 return BeforeToolCallResult::Deny { reason };
507 }
508 }
509
510 // Extension hooks in registration order.
511 for ext in extensions.iter() {
512 match ext.on_before_tool_call(&tool_name, &args).await {
513 ExtensionHookResult::Continue => {}
514 ExtensionHookResult::Block { reason } => {
515 return BeforeToolCallResult::Deny { reason };
516 }
517 }
518 }
519
520 BeforeToolCallResult::Allow
521 })
522 }
523
524 fn after_tool_call(
525 &self,
526 ctx: AfterToolCallContext,
527 ) -> Pin<Box<dyn Future<Output = AfterToolCallResult> + Send>> {
528 let base = self.base.clone();
529 let extensions = self.extensions.clone();
530 let tool_name = ctx.tool_name.clone();
531 let result_snapshot = ctx.result.clone();
532 Box::pin(async move {
533 // Base hook decides first (may keep or replace).
534 let base_result = base.after_tool_call(ctx).await;
535
536 // Determine the effective result for extension observation.
537 let effective: &ToolResult = match &base_result {
538 AfterToolCallResult::Keep => &result_snapshot,
539 AfterToolCallResult::Replace(r) => r,
540 };
541
542 // Notify extension observers (cannot modify result).
543 for ext in extensions.iter() {
544 ext.on_after_tool_call(&tool_name, effective).await;
545 }
546
547 base_result
548 })
549 }
550
551 fn prepare_next_turn(
552 &self,
553 ctx: PrepareNextTurnContext,
554 ) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
555 let base = self.base.clone();
556 let extensions = self.extensions.clone();
557 let extension_ctx = PrepareNextTurnContext {
558 messages: ctx.messages.clone(),
559 turn: ctx.turn,
560 };
561 Box::pin(async move {
562 let mut extra_messages = base
563 .prepare_next_turn(ctx)
564 .await
565 .map(|update| update.extra_messages)
566 .unwrap_or_default();
567
568 for ext in extensions.iter() {
569 if let Some(update) = ext.prepare_next_turn(&extension_ctx).await {
570 extra_messages.extend(update.extra_messages);
571 }
572 }
573
574 if extra_messages.is_empty() {
575 None
576 } else {
577 Some(AgentLoopTurnUpdate { extra_messages })
578 }
579 })
580 }
581}