use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use opi_ai::provider::{ModelInfo, Provider};
use crate::event::AgentEvent;
use crate::hooks::{
AfterToolCallContext, AfterToolCallResult, AgentHooks, BeforeToolCallContext,
BeforeToolCallResult, PrepareNextTurnContext, ShouldStopAfterTurnContext,
};
use crate::loop_types::{AgentError, AgentLoopTurnUpdate};
use crate::message::AgentMessage;
use crate::tool::{Tool, ToolResult};
#[derive(Debug, thiserror::Error)]
pub enum ExtensionError {
#[error("duplicate extension name: {0}")]
DuplicateName(String),
#[error("cannot register extensions after registry has been shared")]
RegistryLocked,
#[error("state serialization failed for extension '{name}': {reason}")]
StateSerialization { name: String, reason: String },
#[error("state restoration failed for extension '{name}': {reason}")]
StateRestoration { name: String, reason: String },
#[error("extension command error: {0}")]
CommandError(String),
#[error("{0}")]
Other(String),
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ExtensionHookResult {
Continue,
Block { reason: String },
}
#[derive(Debug, Clone)]
pub struct ExtensionCommand {
pub name: String,
pub id: Option<String>,
pub args: Value,
}
impl ExtensionCommand {
pub fn new(name: impl Into<String>, args: Value) -> Self {
Self {
name: name.into(),
id: None,
args,
}
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
}
pub trait Extension: Send + Sync {
fn name(&self) -> &str;
fn tools(&self) -> Vec<Box<dyn Tool>> {
vec![]
}
fn providers(&self) -> Vec<Box<dyn Provider>> {
vec![]
}
fn model_overrides(&self) -> Vec<(String, ModelInfo)> {
vec![]
}
fn on_before_tool_call(
&self,
tool_name: &str,
args: &Value,
) -> Pin<Box<dyn Future<Output = ExtensionHookResult> + Send>> {
let _ = (tool_name, args);
Box::pin(async { ExtensionHookResult::Continue })
}
fn on_after_tool_call(
&self,
tool_name: &str,
result: &ToolResult,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
let _ = (tool_name, result);
Box::pin(async {})
}
fn prepare_next_turn(
&self,
_ctx: &PrepareNextTurnContext,
) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
Box::pin(async { None })
}
fn on_event(&self, _event: &AgentEvent) {}
fn on_command(
&self,
_command: &ExtensionCommand,
) -> Pin<Box<dyn Future<Output = Result<Option<Value>, ExtensionError>> + Send>> {
Box::pin(async { Ok(None) })
}
fn serialize_state(&self) -> Result<Option<Value>, ExtensionError> {
Ok(None)
}
fn restore_state(&self, _state: Value) -> Result<(), ExtensionError> {
Ok(())
}
}
pub struct ExtensionRegistry {
extensions: Arc<Vec<Box<dyn Extension>>>,
}
impl Clone for ExtensionRegistry {
fn clone(&self) -> Self {
Self {
extensions: self.extensions.clone(),
}
}
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self {
extensions: Arc::new(Vec::new()),
}
}
pub fn register(&mut self, ext: Box<dyn Extension>) -> Result<(), ExtensionError> {
let name = ext.name().to_string();
if self.extensions.iter().any(|e| e.name() == name) {
return Err(ExtensionError::DuplicateName(name));
}
match Arc::get_mut(&mut self.extensions) {
Some(exts) => {
exts.push(ext);
}
None => {
return Err(ExtensionError::RegistryLocked);
}
}
Ok(())
}
pub fn is_empty(&self) -> bool {
self.extensions.is_empty()
}
pub fn len(&self) -> usize {
self.extensions.len()
}
pub fn names(&self) -> Vec<&str> {
self.extensions.iter().map(|e| e.name()).collect()
}
pub fn get(&self, name: &str) -> Option<&dyn Extension> {
self.extensions
.iter()
.find(|e| e.name() == name)
.map(|e| e.as_ref())
}
pub fn collect_tools(&self) -> Vec<Box<dyn Tool>> {
self.extensions.iter().flat_map(|e| e.tools()).collect()
}
pub fn collect_providers(&self) -> Vec<Box<dyn Provider>> {
self.extensions.iter().flat_map(|e| e.providers()).collect()
}
pub fn collect_model_overrides(&self) -> Vec<(String, ModelInfo)> {
self.extensions
.iter()
.flat_map(|e| e.model_overrides())
.collect()
}
pub fn dispatch_event(&self, event: &AgentEvent) {
for ext in self.extensions.iter() {
ext.on_event(event);
}
}
pub async fn dispatch_command(
&self,
command: &ExtensionCommand,
) -> Result<Option<Value>, ExtensionError> {
for ext in self.extensions.iter() {
if let Some(value) = ext.on_command(command).await? {
return Ok(Some(value));
}
}
Ok(None)
}
pub fn serialize_states(&self) -> Result<Value, ExtensionError> {
let mut map = serde_json::Map::new();
for ext in self.extensions.iter() {
match ext.serialize_state() {
Ok(Some(state)) => {
map.insert(ext.name().to_string(), state);
}
Ok(None) => {}
Err(e) => return Err(e),
}
}
Ok(Value::Object(map))
}
pub fn restore_states(&self, states: Value) -> Result<(), ExtensionError> {
let map = match states {
Value::Object(m) => m,
_ => return Ok(()),
};
for ext in self.extensions.iter() {
if let Some(state) = map.get(ext.name()) {
ext.restore_state(state.clone())?;
}
}
Ok(())
}
pub fn wrap_hooks(&self, base: Box<dyn AgentHooks>) -> Box<dyn AgentHooks> {
Box::new(CompositeHooks {
base: Arc::from(base),
extensions: self.extensions.clone(),
})
}
pub fn wrap_event_sink(
&self,
base_sink: crate::event::AgentEventSink,
) -> crate::event::AgentEventSink {
let extensions = self.extensions.clone();
Box::new(move |event: AgentEvent| {
for ext in extensions.iter() {
ext.on_event(&event);
}
base_sink(event);
})
}
}
impl Default for ExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
struct CompositeHooks {
base: Arc<dyn AgentHooks>,
extensions: Arc<Vec<Box<dyn Extension>>>,
}
impl AgentHooks for CompositeHooks {
fn convert_to_llm(
&self,
messages: &[AgentMessage],
) -> Result<Vec<opi_ai::message::Message>, AgentError> {
self.base.convert_to_llm(messages)
}
fn transform_context(
&self,
messages: Vec<AgentMessage>,
signal: CancellationToken,
) -> Pin<Box<dyn Future<Output = Result<Vec<AgentMessage>, AgentError>> + Send>> {
self.base.transform_context(messages, signal)
}
fn should_stop_after_turn(
&self,
ctx: ShouldStopAfterTurnContext,
) -> Pin<Box<dyn Future<Output = bool> + Send>> {
self.base.should_stop_after_turn(ctx)
}
fn before_tool_call(
&self,
ctx: BeforeToolCallContext,
) -> Pin<Box<dyn Future<Output = BeforeToolCallResult> + Send>> {
let base = self.base.clone();
let extensions = self.extensions.clone();
let tool_name = ctx.tool_name.clone();
let args = ctx.args.clone();
Box::pin(async move {
match base.before_tool_call(ctx).await {
BeforeToolCallResult::Allow => {}
BeforeToolCallResult::Deny { reason } => {
return BeforeToolCallResult::Deny { reason };
}
}
for ext in extensions.iter() {
match ext.on_before_tool_call(&tool_name, &args).await {
ExtensionHookResult::Continue => {}
ExtensionHookResult::Block { reason } => {
return BeforeToolCallResult::Deny { reason };
}
}
}
BeforeToolCallResult::Allow
})
}
fn after_tool_call(
&self,
ctx: AfterToolCallContext,
) -> Pin<Box<dyn Future<Output = AfterToolCallResult> + Send>> {
let base = self.base.clone();
let extensions = self.extensions.clone();
let tool_name = ctx.tool_name.clone();
let result_snapshot = ctx.result.clone();
Box::pin(async move {
let base_result = base.after_tool_call(ctx).await;
let effective: &ToolResult = match &base_result {
AfterToolCallResult::Keep => &result_snapshot,
AfterToolCallResult::Replace(r) => r,
};
for ext in extensions.iter() {
ext.on_after_tool_call(&tool_name, effective).await;
}
base_result
})
}
fn prepare_next_turn(
&self,
ctx: PrepareNextTurnContext,
) -> Pin<Box<dyn Future<Output = Option<AgentLoopTurnUpdate>> + Send>> {
let base = self.base.clone();
let extensions = self.extensions.clone();
let extension_ctx = PrepareNextTurnContext {
messages: ctx.messages.clone(),
turn: ctx.turn,
};
Box::pin(async move {
let mut extra_messages = base
.prepare_next_turn(ctx)
.await
.map(|update| update.extra_messages)
.unwrap_or_default();
for ext in extensions.iter() {
if let Some(update) = ext.prepare_next_turn(&extension_ctx).await {
extra_messages.extend(update.extra_messages);
}
}
if extra_messages.is_empty() {
None
} else {
Some(AgentLoopTurnUpdate { extra_messages })
}
})
}
}