use anyhow::{bail, Context, Result};
use libloading::{Library, Symbol};
use oxi_agent::{AgentEvent, AgentTool, AgentToolResult};
use oxi_ai::Message;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExtensionPermission {
FileRead,
FileWrite,
Bash,
Network,
}
impl fmt::Display for ExtensionPermission {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExtensionPermission::FileRead => write!(f, "file_read"),
ExtensionPermission::FileWrite => write!(f, "file_write"),
ExtensionPermission::Bash => write!(f, "bash"),
ExtensionPermission::Network => write!(f, "network"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionManifest {
pub name: String,
pub version: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub author: String,
#[serde(default)]
pub permissions: Vec<ExtensionPermission>,
#[serde(default)]
pub config_schema: Option<Value>,
}
impl ExtensionManifest {
pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
name: name.into(),
version: version.into(),
description: String::new(),
author: String::new(),
permissions: Vec::new(),
config_schema: None,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn with_author(mut self, author: impl Into<String>) -> Self {
self.author = author.into();
self
}
pub fn with_permission(mut self, perm: ExtensionPermission) -> Self {
if !self.permissions.contains(&perm) {
self.permissions.push(perm);
}
self
}
pub fn with_config_schema(mut self, schema: Value) -> Self {
self.config_schema = Some(schema);
self
}
pub fn has_permission(&self, perm: ExtensionPermission) -> bool {
self.permissions.contains(&perm)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ExtensionError {
#[error("Extension '{name}' not found")]
NotFound { name: String },
#[error("Failed to load extension '{name}': {reason}")]
LoadFailed { name: String, reason: String },
#[error("Extension '{name}' hook '{hook}' failed: {error}")]
HookFailed {
name: String,
hook: String,
error: String,
},
#[error("Extension '{name}' requires permission '{permission}'")]
PermissionDenied {
name: String,
permission: ExtensionPermission,
},
#[error("Extension '{name}' is disabled")]
Disabled { name: String },
#[error("Hot-reload of extension '{name}' failed: {reason}")]
HotReloadFailed { name: String, reason: String },
#[error("Invalid configuration for extension '{name}': {reason}")]
InvalidConfig { name: String, reason: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionErrorRecord {
pub extension_name: String,
pub event: String,
pub error: String,
#[serde(default)]
pub stack: Option<String>,
pub timestamp: i64,
}
impl ExtensionErrorRecord {
pub fn new(
extension_name: impl Into<String>,
event: impl Into<String>,
error: impl Into<String>,
) -> Self {
Self {
extension_name: extension_name.into(),
event: event.into(),
error: error.into(),
stack: None,
timestamp: chrono::Utc::now().timestamp_millis(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionSwitchReason {
New,
Resume,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionShutdownReason {
Quit,
Reload,
New,
Resume,
Fork,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelSelectSource {
Set,
Cycle,
Restore,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InputSource {
Interactive,
Rpc,
Extension,
}
#[derive(Debug, Clone)]
pub enum InputEventResult {
Continue,
Transform { text: String },
Handled,
}
#[derive(Debug, Clone)]
pub struct SessionBeforeSwitchEvent {
pub reason: SessionSwitchReason,
pub target_session_file: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionBeforeForkEvent {
pub entry_id: String,
pub position: String,
}
#[derive(Debug, Clone)]
pub struct SessionBeforeCompactEvent {
pub messages_count: usize,
pub tokens_before: usize,
pub target_tokens: usize,
pub custom_instructions: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionCompactEvent {
pub messages_count: usize,
pub tokens_after: usize,
pub from_extension: bool,
}
#[derive(Debug, Clone)]
pub struct SessionShutdownEvent {
pub reason: SessionShutdownReason,
pub target_session_file: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionBeforeTreeEvent {
pub target_id: String,
pub old_leaf_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SessionTreeEvent {
pub new_leaf_id: Option<String>,
pub old_leaf_id: Option<String>,
pub from_extension: bool,
}
#[derive(Debug, Clone)]
pub struct ContextEvent {
pub messages: Vec<Message>,
}
#[derive(Debug, Clone)]
pub struct BeforeProviderRequestEvent {
pub payload: Value,
}
#[derive(Debug, Clone)]
pub struct AfterProviderResponseEvent {
pub status: u16,
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct ModelSelectEvent {
pub model: String,
pub previous_model: Option<String>,
pub source: ModelSelectSource,
}
#[derive(Debug, Clone)]
pub struct ThinkingLevelSelectEvent {
pub level: String,
pub previous_level: String,
}
#[derive(Debug, Clone)]
pub struct BashEvent {
pub command: String,
pub exclude_from_context: bool,
pub cwd: PathBuf,
}
#[derive(Debug, Clone)]
pub struct InputEvent {
pub text: String,
pub source: InputSource,
}
pub struct ExtensionContext {
pub cwd: PathBuf,
settings: Arc<RwLock<crate::settings::Settings>>,
pub config: Value,
pub session_id: Option<String>,
idle: Arc<RwLock<bool>>,
tool_registrar: Arc<dyn Fn(Arc<dyn AgentTool>) + Send + Sync>,
message_sender: Arc<dyn Fn(&str) + Send + Sync>,
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
tool_getter: Arc<dyn Fn() -> Vec<Arc<dyn AgentTool>> + Send + Sync>,
tool_setter: Arc<dyn Fn(Vec<Arc<dyn AgentTool>>) + Send + Sync>,
model_setter: Arc<dyn Fn(&str) + Send + Sync>,
thinking_level_setter: Arc<dyn Fn(&str) + Send + Sync>,
system_prompt_appender: Arc<dyn Fn(&str) + Send + Sync>,
session_name_setter: Arc<dyn Fn(&str) + Send + Sync>,
session_entries_getter: Arc<dyn Fn() -> Vec<Value> + Send + Sync>,
session_fork: Arc<dyn Fn(&str) -> Result<String> + Send + Sync>,
}
impl fmt::Debug for ExtensionContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtensionContext")
.field("cwd", &self.cwd)
.field("session_id", &self.session_id)
.field("idle", &self.idle.read())
.finish()
}
}
impl ExtensionContext {
pub fn new(
cwd: PathBuf,
settings: Arc<RwLock<crate::settings::Settings>>,
config: Value,
session_id: Option<String>,
idle: Arc<RwLock<bool>>,
tool_registrar: Arc<dyn Fn(Arc<dyn AgentTool>) + Send + Sync>,
message_sender: Arc<dyn Fn(&str) + Send + Sync>,
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
) -> Self {
Self {
cwd,
settings,
config,
session_id,
idle,
tool_registrar,
message_sender,
errors,
tool_getter: Arc::new(|| vec![]),
tool_setter: Arc::new(|_| {}),
model_setter: Arc::new(|_| {}),
thinking_level_setter: Arc::new(|_| {}),
system_prompt_appender: Arc::new(|_| {}),
session_name_setter: Arc::new(|_| {}),
session_entries_getter: Arc::new(|| vec![]),
session_fork: Arc::new(|_| bail!("session fork not configured")),
}
}
pub fn settings(&self) -> crate::settings::Settings {
self.settings.read().clone()
}
pub fn is_idle(&self) -> bool {
*self.idle.read()
}
pub fn register_tool(&self, tool: Arc<dyn AgentTool>) {
(self.tool_registrar)(tool);
}
pub fn send_message(&self, text: &str) {
(self.message_sender)(text);
}
pub fn record_error(&self, extension_name: &str, event: &str, error: &str) {
let record = ExtensionErrorRecord::new(extension_name, event, error);
tracing::warn!(
extension = extension_name,
event = event,
error = error,
"Extension error recorded"
);
self.errors.write().push(record);
}
pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
self.errors.read().clone()
}
pub fn clear_errors(&self) {
self.errors.write().clear();
}
pub fn config_get(&self, path: &str) -> Option<Value> {
let mut current = &self.config;
for key in path.split('.') {
match current {
Value::Object(map) => current = map.get(key)?,
_ => return None,
}
}
Some(current.clone())
}
pub fn read_file(&self, relative_path: &Path) -> Result<String> {
let full_path = self.cwd.join(relative_path);
std::fs::read_to_string(&full_path)
.with_context(|| format!("Failed to read file: {}", full_path.display()))
}
pub fn get_tools(&self) -> Vec<Arc<dyn AgentTool>> {
(self.tool_getter)()
}
pub fn set_tools(&self, tools: Vec<Arc<dyn AgentTool>>) {
(self.tool_setter)(tools);
}
pub fn set_model(&self, model: &str) {
(self.model_setter)(model);
}
pub fn set_thinking_level(&self, level: &str) {
(self.thinking_level_setter)(level);
}
pub fn append_system_prompt(&self, text: &str) {
(self.system_prompt_appender)(text);
}
pub fn set_session_name(&self, name: &str) {
(self.session_name_setter)(name);
}
pub fn get_session_entries(&self) -> Vec<Value> {
(self.session_entries_getter)()
}
pub fn fork_session(&self, entry_id: &str) -> Result<String> {
(self.session_fork)(entry_id)
}
}
pub struct ExtensionContextBuilder {
cwd: PathBuf,
settings: Option<Arc<RwLock<crate::settings::Settings>>>,
config: Value,
session_id: Option<String>,
idle: Arc<RwLock<bool>>,
tool_registrar: Option<Arc<dyn Fn(Arc<dyn AgentTool>) + Send + Sync>>,
message_sender: Option<Arc<dyn Fn(&str) + Send + Sync>>,
errors: Option<Arc<RwLock<Vec<ExtensionErrorRecord>>>>,
tool_getter: Option<Arc<dyn Fn() -> Vec<Arc<dyn AgentTool>> + Send + Sync>>,
tool_setter: Option<Arc<dyn Fn(Vec<Arc<dyn AgentTool>>) + Send + Sync>>,
model_setter: Option<Arc<dyn Fn(&str) + Send + Sync>>,
thinking_level_setter: Option<Arc<dyn Fn(&str) + Send + Sync>>,
system_prompt_appender: Option<Arc<dyn Fn(&str) + Send + Sync>>,
session_name_setter: Option<Arc<dyn Fn(&str) + Send + Sync>>,
session_entries_getter: Option<Arc<dyn Fn() -> Vec<Value> + Send + Sync>>,
session_fork: Option<Arc<dyn Fn(&str) -> Result<String> + Send + Sync>>,
}
impl ExtensionContextBuilder {
pub fn new(cwd: PathBuf) -> Self {
Self {
cwd,
settings: None,
config: Value::Null,
session_id: None,
idle: Arc::new(RwLock::new(true)),
tool_registrar: None,
message_sender: None,
errors: None,
tool_getter: None,
tool_setter: None,
model_setter: None,
thinking_level_setter: None,
system_prompt_appender: None,
session_name_setter: None,
session_entries_getter: None,
session_fork: None,
}
}
pub fn settings(mut self, settings: Arc<RwLock<crate::settings::Settings>>) -> Self {
self.settings = Some(settings);
self
}
pub fn config(mut self, config: Value) -> Self {
self.config = config;
self
}
pub fn session_id(mut self, id: impl Into<String>) -> Self {
self.session_id = Some(id.into());
self
}
pub fn idle(mut self, idle: Arc<RwLock<bool>>) -> Self {
self.idle = idle;
self
}
pub fn tool_registrar(
mut self,
registrar: Arc<dyn Fn(Arc<dyn AgentTool>) + Send + Sync>,
) -> Self {
self.tool_registrar = Some(registrar);
self
}
pub fn message_sender(mut self, sender: Arc<dyn Fn(&str) + Send + Sync>) -> Self {
self.message_sender = Some(sender);
self
}
pub fn errors(mut self, errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>) -> Self {
self.errors = Some(errors);
self
}
pub fn tool_getter(
mut self,
getter: Arc<dyn Fn() -> Vec<Arc<dyn AgentTool>> + Send + Sync>,
) -> Self {
self.tool_getter = Some(getter);
self
}
pub fn tool_setter(
mut self,
setter: Arc<dyn Fn(Vec<Arc<dyn AgentTool>>) + Send + Sync>,
) -> Self {
self.tool_setter = Some(setter);
self
}
pub fn model_setter(mut self, setter: Arc<dyn Fn(&str) + Send + Sync>) -> Self {
self.model_setter = Some(setter);
self
}
pub fn thinking_level_setter(mut self, setter: Arc<dyn Fn(&str) + Send + Sync>) -> Self {
self.thinking_level_setter = Some(setter);
self
}
pub fn system_prompt_appender(mut self, appender: Arc<dyn Fn(&str) + Send + Sync>) -> Self {
self.system_prompt_appender = Some(appender);
self
}
pub fn session_name_setter(mut self, setter: Arc<dyn Fn(&str) + Send + Sync>) -> Self {
self.session_name_setter = Some(setter);
self
}
pub fn session_entries_getter(
mut self,
getter: Arc<dyn Fn() -> Vec<Value> + Send + Sync>,
) -> Self {
self.session_entries_getter = Some(getter);
self
}
pub fn session_fork(
mut self,
fork: Arc<dyn Fn(&str) -> Result<String> + Send + Sync>,
) -> Self {
self.session_fork = Some(fork);
self
}
pub fn build(self) -> ExtensionContext {
ExtensionContext {
cwd: self.cwd,
settings: self
.settings
.unwrap_or_else(|| Arc::new(RwLock::new(crate::settings::Settings::default()))),
config: self.config,
session_id: self.session_id,
idle: self.idle,
tool_registrar: self.tool_registrar.unwrap_or_else(|| {
Arc::new(|_tool| {
tracing::debug!("Tool registration attempted with no registrar");
})
}),
message_sender: self.message_sender.unwrap_or_else(|| {
Arc::new(|_msg| {
tracing::debug!("Message send attempted with no sender");
})
}),
errors: self.errors.unwrap_or_default(),
tool_getter: self.tool_getter.unwrap_or_else(|| Arc::new(Vec::new)),
tool_setter: self.tool_setter.unwrap_or_else(|| Arc::new(|_| {})),
model_setter: self.model_setter.unwrap_or_else(|| Arc::new(|_| {})),
thinking_level_setter: self.thinking_level_setter.unwrap_or_else(|| Arc::new(|_| {})),
system_prompt_appender: self
.system_prompt_appender
.unwrap_or_else(|| Arc::new(|_| {})),
session_name_setter: self.session_name_setter.unwrap_or_else(|| Arc::new(|_| {})),
session_entries_getter: self
.session_entries_getter
.unwrap_or_else(|| Arc::new(Vec::new)),
session_fork: self.session_fork.unwrap_or_else(|| {
Arc::new(|_| bail!("session fork not configured"))
}),
}
}
}
#[derive(Debug, Clone)]
pub struct Command {
pub name: String,
pub description: String,
pub usage: String,
}
impl Command {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
usage: impl Into<String>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
usage: usage.into(),
}
}
}
pub trait Extension: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn manifest(&self) -> ExtensionManifest {
ExtensionManifest::new(self.name(), "0.0.0").with_description(self.description())
}
fn register_tools(&self) -> Vec<Arc<dyn AgentTool>> {
vec![]
}
fn register_commands(&self) -> Vec<Command> {
vec![]
}
fn on_load(&self, _ctx: &ExtensionContext) {}
fn on_unload(&self) {}
fn on_message_sent(&self, _msg: &str) {}
fn on_message_received(&self, _msg: &str) {}
fn on_tool_call(&self, _tool: &str, _params: &Value) {}
fn on_tool_result(&self, _tool: &str, _result: &AgentToolResult) {}
fn on_session_start(&self, _session_id: &str) {}
fn on_session_end(&self, _session_id: &str) {}
fn on_settings_changed(&self, _settings: &crate::settings::Settings) {}
fn on_event(&self, _event: &AgentEvent) {}
fn on_before_tool_call(&self, _tool: &str, _args: &Value) -> Result<(), anyhow::Error> {
Ok(())
}
fn on_after_tool_call(&self, _tool: &str, _result: &AgentToolResult) -> Result<(), anyhow::Error> {
Ok(())
}
fn on_before_compaction(&self, _ctx: &crate::CompactionContext) -> Result<(), anyhow::Error> {
Ok(())
}
fn on_after_compaction(&self, _summary: &str) -> Result<(), anyhow::Error> {
Ok(())
}
fn on_error(&self, _error: &anyhow::Error) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_before_switch(&self, _event: &SessionBeforeSwitchEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_before_fork(&self, _event: &SessionBeforeForkEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_before_compact(&self, _event: &SessionBeforeCompactEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_compact(&self, _event: &SessionCompactEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_shutdown(&self, _event: &SessionShutdownEvent) {}
fn session_before_tree(&self, _event: &SessionBeforeTreeEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn session_tree(&self, _event: &SessionTreeEvent) {}
fn context(&self, _event: &mut ContextEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn before_provider_request(&self, _event: &mut BeforeProviderRequestEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn after_provider_response(&self, _event: &AfterProviderResponseEvent) -> Result<(), anyhow::Error> {
Ok(())
}
fn model_select(&self, _event: &ModelSelectEvent) {}
fn thinking_level_select(&self, _event: &ThinkingLevelSelectEvent) {}
fn bash(&self, _event: &BashEvent) {}
fn input(&self, _event: &InputEvent) -> InputEventResult {
InputEventResult::Continue
}
}
struct LoadedExtension {
extension: Arc<dyn Extension>,
enabled: bool,
source_path: Option<PathBuf>,
}
pub struct ExtensionRegistry {
entries: HashMap<String, LoadedExtension>,
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
#[allow(dead_code)]
libraries: Vec<Library>,
}
impl Default for ExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
errors: Arc::new(RwLock::new(Vec::new())),
libraries: Vec::new(),
}
}
pub fn register(&mut self, ext: Arc<dyn Extension>) {
let name = ext.name().to_string();
tracing::info!(name = %name, "extension registered");
self.entries.insert(
name,
LoadedExtension {
extension: ext,
enabled: true,
source_path: None,
},
);
}
pub fn register_with_library(
&mut self,
ext: Arc<dyn Extension>,
source_path: PathBuf,
library: Library,
) {
let name = ext.name().to_string();
tracing::info!(name = %name, path = %source_path.display(), "extension registered (dynamic)");
self.libraries.push(library);
self.entries.insert(
name,
LoadedExtension {
extension: ext,
enabled: true,
source_path: Some(source_path),
},
);
}
pub fn unregister(&mut self, name: &str) -> bool {
if let Some(entry) = self.entries.remove(name) {
self.call_hook_safe(name, "on_unload", || {
entry.extension.on_unload();
});
tracing::info!(name = %name, "extension unregistered");
true
} else {
false
}
}
pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
let ext = {
let entry = self
.entries
.get_mut(name)
.ok_or_else(|| ExtensionError::NotFound {
name: name.to_string(),
})?;
if !entry.enabled {
return Ok(());
}
entry.enabled = false;
Arc::clone(&entry.extension)
};
self.call_hook_safe(name, "on_unload", || {
ext.on_unload();
});
tracing::info!(name = %name, "extension disabled");
Ok(())
}
pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
let ext = {
let entry = self
.entries
.get_mut(name)
.ok_or_else(|| ExtensionError::NotFound {
name: name.to_string(),
})?;
if entry.enabled {
return Ok(());
}
entry.enabled = true;
Arc::clone(&entry.extension)
};
self.call_hook_safe(name, "on_load", || {
ext.on_load(ctx);
});
tracing::info!(name = %name, "extension enabled");
Ok(())
}
pub fn is_enabled(&self, name: &str) -> bool {
self.entries.get(name).map(|e| e.enabled).unwrap_or(false)
}
pub fn hot_reload(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
let source_path = {
let entry = self
.entries
.get(name)
.ok_or_else(|| ExtensionError::NotFound {
name: name.to_string(),
})?;
entry.source_path.clone()
};
let source_path = source_path.ok_or_else(|| ExtensionError::HotReloadFailed {
name: name.to_string(),
reason: "no source path recorded (in-memory extension)".to_string(),
})?;
self.unregister(name);
let new_ext =
load_extension(&source_path).map_err(|e| ExtensionError::HotReloadFailed {
name: name.to_string(),
reason: e.to_string(),
})?;
let library = unsafe {
Library::new(&source_path).map_err(|e| ExtensionError::HotReloadFailed {
name: name.to_string(),
reason: format!("Failed to re-open library: {}", e),
})?
};
self.call_hook_safe(name, "on_load", || {
new_ext.on_load(ctx);
});
self.register_with_library(new_ext, source_path, library);
tracing::info!(name = %name, "extension hot-reloaded");
Ok(())
}
pub fn all_tools(&self) -> Vec<Arc<dyn AgentTool>> {
self.entries
.values()
.filter(|e| e.enabled)
.flat_map(|e| e.extension.register_tools())
.collect()
}
pub fn all_commands(&self) -> Vec<Command> {
self.entries
.values()
.filter(|e| e.enabled)
.flat_map(|e| e.extension.register_commands())
.collect()
}
pub fn emit_load(&self, ctx: &ExtensionContext) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_load", || {
entry.extension.on_load(ctx);
});
}
}
pub fn emit_unload(&self) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_unload", || {
entry.extension.on_unload();
});
}
}
pub fn emit_message_sent(&self, msg: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_message_sent", || {
entry.extension.on_message_sent(msg);
});
}
}
pub fn emit_message_received(&self, msg: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_message_received", || {
entry.extension.on_message_received(msg);
});
}
}
pub fn emit_tool_call(&self, tool: &str, params: &Value) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_tool_call", || {
entry.extension.on_tool_call(tool, params);
});
}
}
pub fn emit_tool_result(&self, tool: &str, result: &AgentToolResult) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_tool_result", || {
entry.extension.on_tool_result(tool, result);
});
}
}
pub fn emit_session_start(&self, session_id: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_session_start", || {
entry.extension.on_session_start(session_id);
});
}
}
pub fn emit_session_end(&self, session_id: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_session_end", || {
entry.extension.on_session_end(session_id);
});
}
}
pub fn emit_settings_changed(&self, settings: &crate::settings::Settings) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_settings_changed", || {
entry.extension.on_settings_changed(settings);
});
}
}
pub fn emit_event(&self, event: &AgentEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_event", || {
entry.extension.on_event(event);
});
}
}
pub fn emit_before_tool_call(
&self,
tool: &str,
args: &Value,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.on_before_tool_call(tool, args) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, tool = tool, error = %e, "on_before_tool_call failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_after_tool_call(
&self,
tool: &str,
result: &AgentToolResult,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.on_after_tool_call(tool, result) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, tool = tool, error = %e, "on_after_tool_call failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_before_compaction(
&self,
ctx: &crate::CompactionContext,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.on_before_compaction(ctx) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "on_before_compaction failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_after_compaction(
&self,
summary: &str,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.on_after_compaction(summary) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "on_after_compaction failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_error(
&self,
error: &anyhow::Error,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.on_error(error) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "on_error hook failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_before_switch(
&self,
event: &SessionBeforeSwitchEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.session_before_switch(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "session_before_switch failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_before_fork(
&self,
event: &SessionBeforeForkEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.session_before_fork(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "session_before_fork failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_before_compact(
&self,
event: &SessionBeforeCompactEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.session_before_compact(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "session_before_compact failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_compact(
&self,
event: &SessionCompactEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.session_compact(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "session_compact failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_shutdown(&self, event: &SessionShutdownEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "session_shutdown", || {
entry.extension.session_shutdown(event);
});
}
}
pub fn emit_session_before_tree(
&self,
event: &SessionBeforeTreeEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.session_before_tree(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "session_before_tree failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_session_tree(&self, event: &SessionTreeEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "session_tree", || {
entry.extension.session_tree(event);
});
}
}
pub fn emit_context(
&self,
event: &mut ContextEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.context(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "context hook failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_before_provider_request(
&self,
event: &mut BeforeProviderRequestEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.before_provider_request(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "before_provider_request failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_after_provider_response(
&self,
event: &AfterProviderResponseEvent,
) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
match entry.extension.after_provider_response(event) {
Ok(()) => {}
Err(e) => {
tracing::warn!(extension = name, error = %e, "after_provider_response failed");
errors.push((name.to_string(), e));
}
}
}
errors
}
pub fn emit_model_select(&self, event: &ModelSelectEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "model_select", || {
entry.extension.model_select(event);
});
}
}
pub fn emit_thinking_level_select(&self, event: &ThinkingLevelSelectEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "thinking_level_select", || {
entry.extension.thinking_level_select(event);
});
}
}
pub fn emit_bash(&self, event: &BashEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "bash", || {
entry.extension.bash(event);
});
}
}
pub fn emit_input(&self, event: &InputEvent) -> InputEventResult {
let mut final_result = InputEventResult::Continue;
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "input", || {
let result = entry.extension.input(event);
if matches!(final_result, InputEventResult::Continue) {
final_result = result;
}
});
}
final_result
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
self.entries.get(name).map(|e| Arc::clone(&e.extension))
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.entries.keys().map(|s| s.as_str())
}
pub fn extensions(&self) -> impl Iterator<Item = &Arc<dyn Extension>> {
self.entries.values().map(|e| &e.extension)
}
pub fn manifest(&self, name: &str) -> Option<ExtensionManifest> {
self.entries.get(name).map(|e| e.extension.manifest())
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
self.errors.read().clone()
}
pub fn clear_errors(&self) {
self.errors.write().clear();
}
fn call_hook_safe<F>(&self, ext_name: &str, hook: &str, f: F)
where
F: FnOnce(),
{
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
if let Err(payload) = result {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
tracing::error!(
extension = ext_name,
hook = hook,
error = %msg,
"Extension hook panicked — graceful degradation"
);
self.errors.write().push(ExtensionErrorRecord::new(
ext_name,
hook,
&format!("panic: {}", msg),
));
}
}
}
impl fmt::Debug for ExtensionRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtensionRegistry")
.field("count", &self.entries.len())
.field("names", &self.entries.keys().cloned().collect::<Vec<_>>())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExtensionState {
Pending,
Active,
Disabled,
Failed,
Unloaded,
}
impl fmt::Display for ExtensionState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExtensionState::Pending => write!(f, "pending"),
ExtensionState::Active => write!(f, "active"),
ExtensionState::Disabled => write!(f, "disabled"),
ExtensionState::Failed => write!(f, "failed"),
ExtensionState::Unloaded => write!(f, "unloaded"),
}
}
}
#[derive(Debug)]
pub struct ToolCallEmitResult {
pub blocked: bool,
pub block_reason: Option<String>,
pub errors: Vec<(String, String)>,
}
impl Default for ToolCallEmitResult {
fn default() -> Self {
Self {
blocked: false,
block_reason: None,
errors: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct ToolResultEmitResult {
pub output: Option<String>,
pub success: Option<bool>,
pub errors: Vec<(String, String)>,
}
impl Default for ToolResultEmitResult {
fn default() -> Self {
Self {
output: None,
success: None,
errors: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct ContextEmitResult {
pub modified: bool,
pub messages: Vec<Message>,
pub errors: Vec<(String, String)>,
}
#[derive(Debug)]
pub struct ProviderRequestEmitResult {
pub modified: bool,
pub payload: Value,
pub errors: Vec<(String, String)>,
}
#[derive(Debug)]
pub struct SessionBeforeEmitResult {
pub cancelled: bool,
pub cancelled_by: Option<String>,
pub errors: Vec<(String, String)>,
}
impl Default for SessionBeforeEmitResult {
fn default() -> Self {
Self {
cancelled: false,
cancelled_by: None,
errors: Vec::new(),
}
}
}
pub type ExtensionErrorListener = dyn Fn(&ExtensionErrorRecord) + Send + Sync;
pub struct ExtensionRunner {
registry: ExtensionRegistry,
states: HashMap<String, ExtensionState>,
order: Vec<String>,
error_listeners: Vec<Arc<ExtensionErrorListener>>,
cwd: PathBuf,
load_errors: Vec<(PathBuf, String)>,
}
impl Default for ExtensionRunner {
fn default() -> Self {
Self::new(PathBuf::from("."))
}
}
impl ExtensionRunner {
pub fn new(cwd: PathBuf) -> Self {
Self {
registry: ExtensionRegistry::new(),
states: HashMap::new(),
order: Vec::new(),
error_listeners: Vec::new(),
cwd,
load_errors: Vec::new(),
}
}
pub fn on_error<F>(&mut self, listener: F) -> ExtensionErrorHandle
where
F: Fn(&ExtensionErrorRecord) + Send + Sync + 'static,
{
let arc: Arc<ExtensionErrorListener> = Arc::new(listener);
self.error_listeners.push(Arc::clone(&arc));
ExtensionErrorHandle { listener: Some(arc) }
}
fn broadcast_error(&self, record: &ExtensionErrorRecord) {
for listener in &self.error_listeners {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
listener(record);
}));
}
}
pub fn emit_error_record(&self, record: ExtensionErrorRecord) {
self.broadcast_error(&record);
self.registry.errors.write().push(record);
}
pub fn load_extension(
&mut self,
path: &Path,
ctx: &ExtensionContext,
) -> Result<(), ExtensionError> {
let path_display = path.display().to_string();
if !path.exists() {
return Err(ExtensionError::LoadFailed {
name: path_display,
reason: "File not found".to_string(),
});
}
let ext_os = path.extension().and_then(OsStr::to_str).unwrap_or("");
let valid = matches!(ext_os, "so" | "dylib" | "dll");
if !valid {
return Err(ExtensionError::LoadFailed {
name: path_display,
reason: format!("Unsupported extension file format: .{}", ext_os),
});
}
let library = unsafe {
match Library::new(path) {
Ok(lib) => lib,
Err(e) => {
let reason = format!("Failed to load library: {}", e);
self.load_errors.push((path.to_path_buf(), reason.clone()));
self.emit_error_record(ExtensionErrorRecord::new(
&path_display,
"load",
&reason,
));
return Err(ExtensionError::LoadFailed {
name: path_display,
reason,
});
}
}
};
let create: Symbol<CreateFn> = unsafe {
match library.get(ENTRY_SYMBOL) {
Ok(sym) => sym,
Err(e) => {
let reason = format!("Symbol not found: {}", e);
self.load_errors.push((path.to_path_buf(), reason.clone()));
self.emit_error_record(ExtensionErrorRecord::new(
&path_display,
"load",
&reason,
));
return Err(ExtensionError::LoadFailed {
name: path_display,
reason,
});
}
}
};
let raw_ptr = unsafe { create() };
if raw_ptr.is_null() {
let reason = "oxi_extension_create returned null".to_string();
self.load_errors.push((path.to_path_buf(), reason.clone()));
return Err(ExtensionError::LoadFailed {
name: path_display,
reason,
});
}
let boxed: Box<dyn Extension> = unsafe { Box::from_raw(raw_ptr) };
let ext: Arc<dyn Extension> = Arc::from(boxed);
let name = ext.name().to_string();
self.registry
.register_with_library(ext, path.to_path_buf(), library);
self.set_state(&name, ExtensionState::Active);
self.registry.call_hook_safe(&name, "on_load", || {
if let Some(e) = self.registry.get(&name) {
e.on_load(ctx);
}
});
tracing::info!(name = %name, path = %path_display, "extension loaded");
Ok(())
}
pub fn load_extensions_from_paths(
&mut self,
paths: &[PathBuf],
ctx: &ExtensionContext,
) -> Vec<anyhow::Error> {
let mut errors = Vec::new();
for path in paths {
if let Err(e) = self.load_extension(path, ctx) {
errors.push(anyhow::anyhow!("{}", e));
}
}
errors
}
pub fn unload_extension(&mut self, name: &str) -> bool {
let had = self.registry.unregister(name);
if had {
self.set_state(name, ExtensionState::Unloaded);
tracing::info!(name = %name, "extension unloaded");
}
had
}
pub fn reload_extension(
&mut self,
name: &str,
ctx: &ExtensionContext,
) -> Result<(), ExtensionError> {
self.registry.hot_reload(name, ctx)?;
self.set_state(name, ExtensionState::Active);
tracing::info!(name = %name, "extension reloaded");
Ok(())
}
fn set_state(&mut self, name: &str, state: ExtensionState) {
self.states.insert(name.to_string(), state);
if state == ExtensionState::Active && !self.order.contains(&name.to_string()) {
self.order.push(name.to_string());
}
if state == ExtensionState::Unloaded {
self.order.retain(|n| n != name);
}
}
pub fn state(&self, name: &str) -> ExtensionState {
self.states
.get(name)
.copied()
.unwrap_or(ExtensionState::Unloaded)
}
pub fn states(&self) -> &HashMap<String, ExtensionState> {
&self.states
}
pub fn extension_order(&self) -> &[String] {
&self.order
}
pub fn load_errors(&self) -> &[(PathBuf, String)] {
&self.load_errors
}
pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
self.registry.disable(name)?;
self.set_state(name, ExtensionState::Disabled);
Ok(())
}
pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
self.registry.enable(name, ctx)?;
self.set_state(name, ExtensionState::Active);
Ok(())
}
pub fn is_enabled(&self, name: &str) -> bool {
self.registry.is_enabled(name)
}
pub fn has_handlers(&self, _event_type: &str) -> bool {
self.has_enabled_extensions()
}
pub fn has_enabled_extensions(&self) -> bool {
self.registry.extensions().any(|_| true)
&& self.order.iter().any(|name| self.state(name) == ExtensionState::Active)
}
pub fn all_tools(&self) -> Vec<Arc<dyn AgentTool>> {
let mut tools = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
tools.extend(ext.register_tools());
}
}
tools
}
pub fn all_commands(&self) -> Vec<Command> {
let mut commands = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
commands.extend(ext.register_commands());
}
}
commands
}
pub fn wrap_tool(&self, tool: Arc<dyn AgentTool>) -> Arc<dyn AgentTool> {
Arc::new(WrappedTool {
inner: tool,
runner_state: Arc::new(RwLock::new(RunnerState {
errors: self.registry.errors.clone(),
error_listeners: self.error_listeners.clone(),
})),
})
}
pub fn wrap_tools(&self, tools: Vec<Arc<dyn AgentTool>>) -> Vec<Arc<dyn AgentTool>> {
tools.into_iter().map(|t| self.wrap_tool(t)).collect()
}
pub fn emit_tool_call(
&self,
tool_name: &str,
params: &Value,
) -> ToolCallEmitResult {
let mut result = ToolCallEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.on_before_tool_call(tool_name, params) {
Ok(()) => {}
Err(e) => {
let err_str = e.to_string();
tracing::warn!(
extension = name,
tool = tool_name,
error = %err_str,
"on_before_tool_call failed"
);
result.errors.push((name.clone(), err_str.clone()));
self.emit_error_record(ExtensionErrorRecord::new(
name,
"on_before_tool_call",
&err_str,
));
}
}
self.registry.call_hook_safe(name, "on_tool_call", || {
ext.on_tool_call(tool_name, params);
});
}
}
result
}
pub fn emit_tool_result_event(
&self,
tool_name: &str,
tool_result: &AgentToolResult,
) -> ToolResultEmitResult {
let mut result = ToolResultEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.on_after_tool_call(tool_name, tool_result) {
Ok(()) => {}
Err(e) => {
let err_str = e.to_string();
tracing::warn!(
extension = name,
tool = tool_name,
error = %err_str,
"on_after_tool_call failed"
);
result.errors.push((name.clone(), err_str));
}
}
self.registry.call_hook_safe(name, "on_tool_result", || {
ext.on_tool_result(tool_name, tool_result);
});
}
}
result
}
pub fn emit_input_event(&self, event: &mut InputEvent) -> InputEventResult {
let mut final_result = InputEventResult::Continue;
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
ext.input(event)
}));
match result {
Ok(InputEventResult::Handled) => {
return InputEventResult::Handled;
}
Ok(InputEventResult::Transform { text }) => {
event.text = text.clone();
final_result = InputEventResult::Transform { text };
}
Ok(InputEventResult::Continue) => {}
Err(payload) => {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
tracing::error!(
extension = name,
error = %msg,
"Extension input hook panicked"
);
self.emit_error_record(ExtensionErrorRecord::new(
name,
"input",
&format!("panic: {}", msg),
));
}
}
}
}
final_result
}
pub fn emit_context_event(
&self,
messages: Vec<Message>,
) -> ContextEmitResult {
let mut current_messages = messages;
let mut errors = Vec::new();
let mut modified = false;
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let prev_len = current_messages.len();
let mut event = ContextEvent {
messages: current_messages.clone(),
};
match ext.context(&mut event) {
Ok(()) => {
if event.messages.len() != prev_len {
current_messages = event.messages;
modified = true;
}
}
Err(e) => {
let err_str = e.to_string();
tracing::warn!(
extension = name,
error = %err_str,
"context hook failed"
);
errors.push((name.clone(), err_str));
}
}
}
}
ContextEmitResult {
modified,
messages: current_messages,
errors,
}
}
pub fn emit_before_provider_request_event(
&self,
payload: Value,
) -> ProviderRequestEmitResult {
let mut current_payload = payload;
let mut modified = false;
let mut errors = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let mut event = BeforeProviderRequestEvent {
payload: current_payload.clone(),
};
match ext.before_provider_request(&mut event) {
Ok(()) => {
if event.payload != current_payload {
current_payload = event.payload;
modified = true;
}
}
Err(e) => {
let err_str = e.to_string();
tracing::warn!(
extension = name,
error = %err_str,
"before_provider_request failed"
);
errors.push((name.clone(), err_str));
}
}
}
}
ProviderRequestEmitResult {
modified,
payload: current_payload,
errors,
}
}
pub fn emit_session_before_switch_event(
&self,
event: &SessionBeforeSwitchEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.session_before_switch(event) {
Ok(()) => {}
Err(e) => {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
}
result
}
pub fn emit_session_before_fork_event(
&self,
event: &SessionBeforeForkEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.session_before_fork(event) {
Ok(()) => {}
Err(e) => {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
}
result
}
pub fn emit_session_before_compact_event(
&self,
event: &SessionBeforeCompactEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.session_before_compact(event) {
Ok(()) => {}
Err(e) => {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
}
result
}
pub fn emit_session_before_tree_event(
&self,
event: &SessionBeforeTreeEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
match ext.session_before_tree(event) {
Ok(()) => {}
Err(e) => {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
}
result
}
pub fn emit_session_shutdown_event(&self, event: &SessionShutdownEvent) -> bool {
if !self.has_enabled_extensions() {
return false;
}
self.registry.emit_session_shutdown(event);
true
}
pub fn emit_event(&self, event: &AgentEvent) {
self.registry.emit_event(event);
}
pub fn registry(&self) -> &ExtensionRegistry {
&self.registry
}
pub fn registry_mut(&mut self) -> &mut ExtensionRegistry {
&mut self.registry
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
self.registry.get(name)
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.order.iter().map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.order.len()
}
pub fn is_empty(&self) -> bool {
self.order.is_empty()
}
pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
self.registry.errors()
}
pub fn clear_errors(&self) {
self.registry.clear_errors();
}
}
impl fmt::Debug for ExtensionRunner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtensionRunner")
.field("cwd", &self.cwd)
.field("extensions", &self.order)
.field("states", &self.states)
.finish()
}
}
pub struct ExtensionErrorHandle {
listener: Option<Arc<ExtensionErrorListener>>,
}
impl ExtensionErrorHandle {
pub fn unregister(&mut self) -> Option<Arc<ExtensionErrorListener>> {
self.listener.take()
}
}
impl Drop for ExtensionErrorHandle {
fn drop(&mut self) {
}
}
#[allow(dead_code)]
struct RunnerState {
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
error_listeners: Vec<Arc<ExtensionErrorListener>>,
}
struct WrappedTool {
inner: Arc<dyn AgentTool>,
#[allow(dead_code)]
runner_state: Arc<RwLock<RunnerState>>,
}
#[async_trait::async_trait]
impl AgentTool for WrappedTool {
fn name(&self) -> &str {
self.inner.name()
}
fn label(&self) -> &str {
self.inner.label()
}
fn description(&self) -> &str {
self.inner.description()
}
fn parameters_schema(&self) -> Value {
self.inner.parameters_schema()
}
async fn execute(
&self,
tool_call_id: &str,
params: Value,
signal: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<AgentToolResult, String> {
let result = self.inner.execute(tool_call_id, params, signal).await;
result
}
}
const SHARED_LIB_EXTENSIONS: &[&str] = if cfg!(target_os = "macos") {
&["dylib"]
} else if cfg!(target_os = "windows") {
&["dll"]
} else {
&["so"]
};
fn is_shared_library(name: &str) -> bool {
SHARED_LIB_EXTENSIONS
.iter()
.any(|ext| name.ends_with(&format!(".{}", ext)))
}
pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
if !dir.exists() {
return Vec::new();
}
let mut discovered = Vec::new();
let entries = match std::fs::read_dir(dir) {
Ok(entries) => entries,
Err(_) => return Vec::new(),
};
for entry in entries.flatten() {
let path = entry.path();
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
if path.is_dir() {
for ext in SHARED_LIB_EXTENSIONS {
let index_path = path.join(format!("index.{}", ext));
if index_path.exists() {
discovered.push(index_path);
break;
}
}
} else if is_shared_library(file_name) {
discovered.push(path);
}
}
discovered
}
pub fn discover_extensions(
cwd: &Path,
configured_paths: &[PathBuf],
) -> Vec<PathBuf> {
let mut all_paths = Vec::new();
let mut seen = std::collections::HashSet::new();
let add_paths = |paths: &mut Vec<PathBuf>, seen: &mut std::collections::HashSet<u64>, new: Vec<PathBuf>| {
for p in new {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
p.hash(&mut hasher);
let hash = hasher.finish();
if seen.insert(hash) {
paths.push(p);
}
}
};
let local_ext_dir = cwd.join(".oxi").join("extensions");
add_paths(
&mut all_paths,
&mut seen,
discover_extensions_in_dir(&local_ext_dir),
);
if let Some(home) = dirs::home_dir() {
let global_ext_dir = home.join(".oxi").join("extensions");
add_paths(
&mut all_paths,
&mut seen,
discover_extensions_in_dir(&global_ext_dir),
);
}
for p in configured_paths {
let resolved = if p.is_absolute() {
p.clone()
} else {
cwd.join(p)
};
if resolved.is_dir() {
add_paths(
&mut all_paths,
&mut seen,
discover_extensions_in_dir(&resolved),
);
} else if resolved.exists() {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
resolved.hash(&mut hasher);
let hash = hasher.finish();
if seen.insert(hash) {
all_paths.push(resolved);
}
}
}
all_paths
}
const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
type CreateFn = unsafe fn() -> *mut dyn Extension;
pub fn load_extension(path: &Path) -> Result<Arc<dyn Extension>> {
let extension = load_extension_inner(path)?;
Ok(extension)
}
fn load_extension_inner(path: &Path) -> Result<Arc<dyn Extension>> {
let ext = path.extension().and_then(OsStr::to_str).unwrap_or("");
let valid = matches!(ext, "so" | "dylib" | "dll");
if !valid {
bail!(
"Unsupported extension file format: .{}. Expected .so, .dylib, or .dll",
ext
);
}
if !path.exists() {
bail!("Extension file not found: {}", path.display());
}
let library = unsafe {
Library::new(path).with_context(|| format!("Failed to load library: {}", path.display()))?
};
let create: Symbol<CreateFn> = unsafe {
library.get(ENTRY_SYMBOL).with_context(|| {
format!(
"Symbol `oxi_extension_create` not found in {}",
path.display()
)
})?
};
let raw_ptr = unsafe { create() };
if raw_ptr.is_null() {
bail!("oxi_extension_create returned null in {}", path.display());
}
let boxed: Box<dyn Extension> = unsafe { Box::from_raw(raw_ptr) };
Ok(Arc::from(boxed))
}
pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
let mut loaded = Vec::with_capacity(paths.len());
let mut errors = Vec::new();
for &path in paths {
match load_extension(path) {
Ok(ext) => loaded.push(ext),
Err(e) => {
errors.push(e.context(format!("Failed to load extension: {}", path.display())))
}
}
}
(loaded, errors)
}
pub struct NoopExtension;
impl Extension for NoopExtension {
fn name(&self) -> &str {
"noop"
}
fn description(&self) -> &str {
"Built-in no-op extension"
}
}
#[cfg(test)]
pub struct RecordingExtension {
pub name: String,
pub calls: std::sync::Mutex<Vec<String>>,
}
#[cfg(test)]
impl RecordingExtension {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
calls: std::sync::Mutex::new(Vec::new()),
}
}
pub fn push(&self, call: &str) {
self.calls.lock().unwrap().push(call.to_string());
}
pub fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
}
#[cfg(test)]
impl Extension for RecordingExtension {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"recording test extension"
}
fn on_load(&self, _ctx: &ExtensionContext) {
self.push("on_load");
}
fn on_unload(&self) {
self.push("on_unload");
}
fn on_message_sent(&self, msg: &str) {
self.push(&format!("on_message_sent({})", msg));
}
fn on_message_received(&self, msg: &str) {
self.push(&format!("on_message_received({})", msg));
}
fn on_tool_call(&self, tool: &str, _params: &Value) {
self.push(&format!("on_tool_call({})", tool));
}
fn on_tool_result(&self, tool: &str, _result: &AgentToolResult) {
self.push(&format!("on_tool_result({})", tool));
}
fn on_session_start(&self, session_id: &str) {
self.push(&format!("on_session_start({})", session_id));
}
fn on_session_end(&self, session_id: &str) {
self.push(&format!("on_session_end({})", session_id));
}
fn on_settings_changed(&self, _settings: &crate::settings::Settings) {
self.push("on_settings_changed");
}
fn on_event(&self, _event: &AgentEvent) {
self.push("on_event");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::settings::Settings;
#[test]
fn test_manifest_builder() {
let manifest = ExtensionManifest::new("my-ext", "1.0.0")
.with_description("A test extension")
.with_author("test-author")
.with_permission(ExtensionPermission::FileRead)
.with_permission(ExtensionPermission::Bash)
.with_config_schema(serde_json::json!({
"type": "object",
"properties": {
"api_key": { "type": "string" }
}
}));
assert_eq!(manifest.name, "my-ext");
assert_eq!(manifest.version, "1.0.0");
assert_eq!(manifest.description, "A test extension");
assert_eq!(manifest.author, "test-author");
assert!(manifest.has_permission(ExtensionPermission::FileRead));
assert!(manifest.has_permission(ExtensionPermission::Bash));
assert!(!manifest.has_permission(ExtensionPermission::Network));
assert!(manifest.config_schema.is_some());
}
#[test]
fn test_manifest_serialization() {
let manifest =
ExtensionManifest::new("test", "0.1.0").with_permission(ExtensionPermission::Network);
let json = serde_json::to_string(&manifest).unwrap();
let parsed: ExtensionManifest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "test");
assert_eq!(parsed.version, "0.1.0");
assert!(parsed.has_permission(ExtensionPermission::Network));
}
#[test]
fn test_permission_display() {
assert_eq!(ExtensionPermission::FileRead.to_string(), "file_read");
assert_eq!(ExtensionPermission::FileWrite.to_string(), "file_write");
assert_eq!(ExtensionPermission::Bash.to_string(), "bash");
assert_eq!(ExtensionPermission::Network.to_string(), "network");
}
#[test]
fn test_extension_error_display() {
let err = ExtensionError::NotFound {
name: "test".to_string(),
};
assert!(err.to_string().contains("test"));
assert!(err.to_string().contains("not found"));
let err = ExtensionError::LoadFailed {
name: "bad".to_string(),
reason: "missing symbol".to_string(),
};
assert!(err.to_string().contains("bad"));
assert!(err.to_string().contains("missing symbol"));
let err = ExtensionError::HookFailed {
name: "ext".to_string(),
hook: "on_load".to_string(),
error: "boom".to_string(),
};
assert!(err.to_string().contains("on_load"));
let err = ExtensionError::PermissionDenied {
name: "ext".to_string(),
permission: ExtensionPermission::Network,
};
assert!(err.to_string().contains("network"));
let err = ExtensionError::Disabled {
name: "ext".to_string(),
};
assert!(err.to_string().contains("disabled"));
let err = ExtensionError::HotReloadFailed {
name: "ext".to_string(),
reason: "no path".to_string(),
};
assert!(err.to_string().contains("Hot-reload"));
}
#[test]
fn test_error_record() {
let record = ExtensionErrorRecord::new("my-ext", "on_load", "something broke");
assert_eq!(record.extension_name, "my-ext");
assert_eq!(record.event, "on_load");
assert_eq!(record.error, "something broke");
assert!(record.timestamp > 0);
}
#[test]
fn test_error_record_serialization() {
let record = ExtensionErrorRecord::new("ext", "hook", "err");
let json = serde_json::to_string(&record).unwrap();
let parsed: ExtensionErrorRecord = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.extension_name, "ext");
assert_eq!(parsed.event, "hook");
}
#[test]
fn test_context_builder_minimal() {
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
assert_eq!(ctx.cwd, PathBuf::from("/tmp"));
assert!(ctx.session_id.is_none());
assert!(ctx.is_idle());
}
#[test]
fn test_context_builder_full() {
let settings = Arc::new(RwLock::new(Settings::default()));
let errors = Arc::new(RwLock::new(Vec::new()));
let tools_registered = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
let messages_sent = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
let tools_ref = tools_registered.clone();
let msgs_ref = messages_sent.clone();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/home"))
.settings(settings)
.config(serde_json::json!({"key": "value"}))
.session_id("sess-123")
.errors(errors)
.tool_registrar(Arc::new(move |tool: Arc<dyn AgentTool>| {
tools_ref.lock().unwrap().push(tool.name().to_string());
}))
.message_sender(Arc::new(move |msg: &str| {
msgs_ref.lock().unwrap().push(msg.to_string());
}))
.build();
assert_eq!(ctx.cwd, PathBuf::from("/home"));
assert_eq!(ctx.session_id, Some("sess-123".to_string()));
assert!(ctx.is_idle());
assert_eq!(ctx.config_get("key"), Some(serde_json::json!("value")));
assert_eq!(ctx.config_get("missing"), None);
}
#[test]
fn test_context_config_nested() {
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp"))
.config(serde_json::json!({
"database": {
"host": "localhost",
"port": 5432
}
}))
.build();
assert_eq!(
ctx.config_get("database.host"),
Some(serde_json::json!("localhost"))
);
assert_eq!(
ctx.config_get("database.port"),
Some(serde_json::json!(5432))
);
assert_eq!(ctx.config_get("database.missing"), None);
}
#[test]
fn test_context_tool_registration() {
let registered = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
let reg_ref = registered.clone();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp"))
.tool_registrar(Arc::new(move |tool: Arc<dyn AgentTool>| {
reg_ref.lock().unwrap().push(tool.name().to_string());
}))
.build();
ctx.register_tool(Arc::new(oxi_agent::ReadTool::new()));
assert_eq!(registered.lock().unwrap()[0], "read");
}
#[test]
fn test_context_message_sending() {
let sent = Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
let sent_ref = sent.clone();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp"))
.message_sender(Arc::new(move |msg: &str| {
sent_ref.lock().unwrap().push(msg.to_string());
}))
.build();
ctx.send_message("hello");
ctx.send_message("world");
assert_eq!(*sent.lock().unwrap(), vec!["hello", "world"]);
}
#[test]
fn test_context_error_recording() {
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
assert!(ctx.errors().is_empty());
ctx.record_error("ext1", "on_load", "fail");
ctx.record_error("ext2", "on_tool_call", "oops");
let errs = ctx.errors();
assert_eq!(errs.len(), 2);
assert_eq!(errs[0].extension_name, "ext1");
assert_eq!(errs[1].extension_name, "ext2");
ctx.clear_errors();
assert!(ctx.errors().is_empty());
}
#[test]
fn test_context_settings() {
let settings = Arc::new(RwLock::new(Settings::default()));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp"))
.settings(settings.clone())
.build();
let s = ctx.settings();
assert_eq!(s.version, Settings::default().version);
}
#[test]
fn test_context_noop_callbacks() {
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
ctx.register_tool(Arc::new(oxi_agent::ReadTool::new()));
ctx.send_message("test");
}
#[test]
fn test_registry_register_and_collect() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
assert_eq!(reg.len(), 1);
assert!(!reg.is_empty());
assert!(reg.all_tools().is_empty());
assert!(reg.all_commands().is_empty());
}
#[test]
fn test_registry_names() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
let names: Vec<&str> = reg.names().collect();
assert_eq!(names, vec!["noop"]);
}
#[test]
fn test_registry_get() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
assert!(reg.get("noop").is_some());
assert!(reg.get("nonexistent").is_none());
}
#[test]
fn test_registry_manifest() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
let m = reg.manifest("noop").unwrap();
assert_eq!(m.name, "noop");
assert!(reg.manifest("missing").is_none());
}
#[test]
fn test_registry_unregister() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
assert_eq!(reg.len(), 1);
assert!(reg.unregister("noop"));
assert!(reg.is_empty());
assert!(!reg.unregister("noop")); }
#[test]
fn test_registry_enable_disable() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext);
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
assert!(reg.is_enabled("rec"));
reg.disable("rec").unwrap();
assert!(!reg.is_enabled("rec"));
assert!(reg.all_tools().is_empty());
reg.enable("rec", &ctx).unwrap();
assert!(reg.is_enabled("rec"));
}
#[test]
fn test_registry_disable_not_found() {
let mut reg = ExtensionRegistry::new();
let result = reg.disable("nonexistent");
assert!(result.is_err());
match result {
Err(ExtensionError::NotFound { name }) => assert_eq!(name, "nonexistent"),
_ => panic!("Expected NotFound error"),
}
}
#[test]
fn test_registry_enable_not_found() {
let mut reg = ExtensionRegistry::new();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
let result = reg.enable("nonexistent", &ctx);
assert!(result.is_err());
}
#[test]
fn test_registry_disable_already_disabled() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
reg.disable("noop").unwrap();
reg.disable("noop").unwrap();
assert!(!reg.is_enabled("noop"));
}
#[test]
fn test_registry_enable_already_enabled() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
reg.enable("noop", &ctx).unwrap();
assert!(reg.is_enabled("noop"));
}
#[test]
fn test_emit_load() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
reg.emit_load(&ctx);
assert_eq!(ext.calls(), vec!["on_load"]);
}
#[test]
fn test_emit_unload() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_unload();
assert_eq!(ext.calls(), vec!["on_unload"]);
}
#[test]
fn test_emit_message_sent() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_message_sent("hello");
assert_eq!(ext.calls(), vec!["on_message_sent(hello)"]);
}
#[test]
fn test_emit_message_received() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_message_received("world");
assert_eq!(ext.calls(), vec!["on_message_received(world)"]);
}
#[test]
fn test_emit_tool_call() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_tool_call("bash", &serde_json::json!({"command": "ls"}));
assert_eq!(ext.calls(), vec!["on_tool_call(bash)"]);
}
#[test]
fn test_emit_tool_result() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
let result = AgentToolResult::success("done");
reg.emit_tool_result("bash", &result);
assert_eq!(ext.calls(), vec!["on_tool_result(bash)"]);
}
#[test]
fn test_emit_session_start() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_session_start("sess-1");
assert_eq!(ext.calls(), vec!["on_session_start(sess-1)"]);
}
#[test]
fn test_emit_session_end() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_session_end("sess-1");
assert_eq!(ext.calls(), vec!["on_session_end(sess-1)"]);
}
#[test]
fn test_emit_settings_changed() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
let settings = Settings::default();
reg.emit_settings_changed(&settings);
assert_eq!(ext.calls(), vec!["on_settings_changed"]);
}
#[test]
fn test_emit_event() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.emit_event(&AgentEvent::Thinking);
assert_eq!(ext.calls(), vec!["on_event"]);
}
#[test]
fn test_disabled_extension_skips_broadcasts() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.disable("rec").unwrap();
{
let mut calls = ext.calls.lock().unwrap();
calls.clear();
}
reg.emit_message_sent("hello");
reg.emit_event(&AgentEvent::Thinking);
reg.emit_session_start("s1");
assert!(ext.calls().is_empty());
}
#[test]
fn test_graceful_degradation_on_panic() {
struct PanickingExtension;
impl Extension for PanickingExtension {
fn name(&self) -> &str {
"panicker"
}
fn description(&self) -> &str {
"Panics"
}
fn on_load(&self, _ctx: &ExtensionContext) {
panic!("intentional panic in on_load");
}
fn on_message_sent(&self, _msg: &str) {
panic!("intentional panic in on_message_sent");
}
}
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(PanickingExtension));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
reg.emit_load(&ctx);
reg.emit_message_sent("hello");
let errors = reg.errors();
assert_eq!(errors.len(), 2);
assert_eq!(errors[0].event, "on_load");
assert!(errors[0].error.contains("intentional panic"));
assert_eq!(errors[1].event, "on_message_sent");
}
#[test]
fn test_command_new() {
let cmd = Command::new("deploy", "Deploy the project", "/deploy <target>");
assert_eq!(cmd.name, "deploy");
assert_eq!(cmd.description, "Deploy the project");
assert_eq!(cmd.usage, "/deploy <target>");
}
#[test]
fn test_load_extension_missing_file() {
let result = load_extension(Path::new("/nonexistent/extension.so"));
assert!(result.is_err());
}
#[test]
fn test_load_extension_wrong_extension() {
let result = load_extension(Path::new("something.txt"));
assert!(result.is_err());
let msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("Expected error"),
};
assert!(msg.contains("Unsupported extension file format"));
}
#[test]
fn test_load_extensions_collects_errors() {
let paths: Vec<&Path> = vec![Path::new("/nonexistent1.so"), Path::new("/nonexistent2.so")];
let (loaded, errors) = load_extensions(&paths);
assert!(loaded.is_empty());
assert_eq!(errors.len(), 2);
}
#[test]
fn test_registry_debug() {
let reg = ExtensionRegistry::new();
let debug_str = format!("{:?}", reg);
assert!(debug_str.contains("count"));
}
#[test]
fn test_hot_reload_no_source_path() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
let result = reg.hot_reload("noop", &ctx);
assert!(result.is_err());
match result {
Err(ExtensionError::HotReloadFailed { name, reason }) => {
assert_eq!(name, "noop");
assert!(reason.contains("no source path"));
}
_ => panic!("Expected HotReloadFailed error"),
}
}
#[test]
fn test_hot_reload_not_found() {
let mut reg = ExtensionRegistry::new();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
let result = reg.hot_reload("nonexistent", &ctx);
assert!(result.is_err());
}
#[test]
fn test_broadcast_to_multiple_extensions() {
let mut reg = ExtensionRegistry::new();
let ext1 = Arc::new(RecordingExtension::new("ext1"));
let ext2 = Arc::new(RecordingExtension::new("ext2"));
reg.register(ext1.clone());
reg.register(ext2.clone());
reg.emit_message_sent("hello");
assert!(ext1.calls().contains(&"on_message_sent(hello)".to_string()));
assert!(ext2.calls().contains(&"on_message_sent(hello)".to_string()));
}
#[test]
fn test_unregister_calls_on_unload() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
reg.unregister("rec");
assert_eq!(ext.calls(), vec!["on_unload"]);
}
#[test]
fn test_registry_errors() {
let reg = ExtensionRegistry::new();
assert!(reg.errors().is_empty());
reg.clear_errors(); }
#[test]
fn test_emit_event_does_not_panic() {
let mut reg = ExtensionRegistry::new();
reg.register(Arc::new(NoopExtension));
reg.emit_event(&AgentEvent::Thinking);
}
#[test]
fn test_multiple_lifecycle_hooks() {
let mut reg = ExtensionRegistry::new();
let ext = Arc::new(RecordingExtension::new("rec"));
reg.register(ext.clone());
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
reg.emit_load(&ctx);
reg.emit_session_start("s1");
reg.emit_message_sent("hello");
reg.emit_tool_call("bash", &serde_json::json!({}));
let result = AgentToolResult::success("ok");
reg.emit_tool_result("bash", &result);
reg.emit_message_received("response");
reg.emit_session_end("s1");
reg.emit_unload();
let calls = ext.calls();
assert!(calls.contains(&"on_load".to_string()));
assert!(calls.contains(&"on_session_start(s1)".to_string()));
assert!(calls.contains(&"on_message_sent(hello)".to_string()));
assert!(calls.contains(&"on_tool_call(bash)".to_string()));
assert!(calls.contains(&"on_tool_result(bash)".to_string()));
assert!(calls.contains(&"on_message_received(response)".to_string()));
assert!(calls.contains(&"on_session_end(s1)".to_string()));
assert!(calls.contains(&"on_unload".to_string()));
}
#[test]
fn test_runner_new() {
let runner = ExtensionRunner::new(PathBuf::from("/tmp"));
assert!(runner.is_empty());
assert_eq!(runner.len(), 0);
assert!(runner.names().collect::<Vec<_>>().is_empty());
}
#[test]
fn test_runner_default() {
let runner = ExtensionRunner::default();
assert!(runner.is_empty());
}
#[test]
fn test_runner_register_in_memory() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("test-ext"));
runner.registry_mut().register(ext.clone());
runner.states.insert("test-ext".to_string(), ExtensionState::Active);
runner.order.push("test-ext".to_string());
assert_eq!(runner.len(), 1);
assert!(!runner.is_empty());
assert_eq!(runner.state("test-ext"), ExtensionState::Active);
}
#[test]
fn test_runner_state_tracking() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
assert_eq!(runner.state("ext1"), ExtensionState::Active);
assert_eq!(runner.state("nonexistent"), ExtensionState::Unloaded);
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
runner.disable("ext1").unwrap();
assert_eq!(runner.state("ext1"), ExtensionState::Disabled);
runner.enable("ext1", &ctx).unwrap();
assert_eq!(runner.state("ext1"), ExtensionState::Active);
}
#[test]
fn test_runner_enable_disable() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
assert!(runner.is_enabled("ext1"));
runner.disable("ext1").unwrap();
assert!(!runner.is_enabled("ext1"));
assert_eq!(runner.state("ext1"), ExtensionState::Disabled);
runner.disable("ext1").unwrap();
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
runner.enable("ext1", &ctx).unwrap();
assert!(runner.is_enabled("ext1"));
assert_eq!(runner.state("ext1"), ExtensionState::Active);
}
#[test]
fn test_runner_enable_disable_not_found() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
assert!(runner.disable("nonexistent").is_err());
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
assert!(runner.enable("nonexistent", &ctx).is_err());
}
#[test]
fn test_runner_unload() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
assert!(runner.unload_extension("ext1"));
assert_eq!(runner.state("ext1"), ExtensionState::Unloaded);
assert!(runner.is_empty());
assert!(!runner.unload_extension("ext1")); }
#[test]
fn test_runner_has_handlers() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
assert!(!runner.has_handlers("any_event"));
assert!(!runner.has_enabled_extensions());
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
assert!(runner.has_handlers("any_event"));
assert!(runner.has_enabled_extensions());
}
#[test]
fn test_runner_extension_order() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
for name in &["ext1", "ext2", "ext3"] {
let ext = Arc::new(RecordingExtension::new(name.to_string()));
runner.registry_mut().register(ext.clone());
runner.states.insert(name.to_string(), ExtensionState::Active);
runner.order.push(name.to_string());
}
assert_eq!(runner.extension_order(), &["ext1", "ext2", "ext3"]);
assert_eq!(runner.len(), 3);
}
#[test]
fn test_runner_error_listener() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let received = Arc::new(std::sync::Mutex::new(Vec::<ExtensionErrorRecord>::new()));
let received_clone = received.clone();
let _handle = runner.on_error(move |record| {
received_clone.lock().unwrap().push(record.clone());
});
runner.emit_error_record(ExtensionErrorRecord::new("test-ext", "test_event", "test error"));
let records = received.lock().unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].extension_name, "test-ext");
assert_eq!(records[0].event, "test_event");
}
#[test]
fn test_runner_emit_tool_call() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
let result = runner.emit_tool_call("bash", &serde_json::json!({"cmd": "ls"}));
assert!(!result.blocked);
assert!(result.errors.is_empty());
}
#[test]
fn test_runner_emit_tool_result() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
let tool_result = AgentToolResult::success("done");
let result = runner.emit_tool_result_event("bash", &tool_result);
assert!(result.output.is_none());
assert!(result.errors.is_empty());
}
#[test]
fn test_runner_emit_input_continue() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
let mut event = InputEvent {
text: "hello".to_string(),
source: InputSource::Interactive,
};
let result = runner.emit_input_event(&mut event);
assert!(matches!(result, InputEventResult::Continue));
}
#[test]
fn test_runner_emit_session_before_switch() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
let event = SessionBeforeSwitchEvent {
reason: SessionSwitchReason::New,
target_session_file: None,
};
let result = runner.emit_session_before_switch_event(&event);
assert!(!result.cancelled);
assert!(result.cancelled_by.is_none());
}
#[test]
fn test_runner_emit_session_shutdown() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(RecordingExtension::new("ext1"));
runner.registry_mut().register(ext.clone());
runner.states.insert("ext1".to_string(), ExtensionState::Active);
runner.order.push("ext1".to_string());
let event = SessionShutdownEvent {
reason: SessionShutdownReason::Quit,
target_session_file: None,
};
let handled = runner.emit_session_shutdown_event(&event);
assert!(handled);
}
#[test]
fn test_runner_emit_session_shutdown_no_extensions() {
let runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let event = SessionShutdownEvent {
reason: SessionShutdownReason::Quit,
target_session_file: None,
};
let handled = runner.emit_session_shutdown_event(&event);
assert!(!handled);
}
#[test]
fn test_runner_load_extension_missing_file() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
let result = runner.load_extension(Path::new("/nonexistent.so"), &ctx);
assert!(result.is_err());
assert!(!runner.load_errors().is_empty());
}
#[test]
fn test_runner_load_extension_wrong_format() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ctx = ExtensionContextBuilder::new(PathBuf::from("/tmp")).build();
let dir = tempfile::tempdir().unwrap();
let bad_file = dir.path().join("bad.txt");
std::fs::write(&bad_file, "not a library").unwrap();
let result = runner.load_extension(&bad_file, &ctx);
assert!(result.is_err());
}
#[test]
fn test_runner_all_tools_in_order() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
for name in &["ext1", "ext2"] {
let ext = Arc::new(NoopExtension);
runner.registry_mut().register(ext.clone());
runner.states.insert(name.to_string(), ExtensionState::Active);
runner.order.push(name.to_string());
}
let tools = runner.all_tools();
assert!(tools.is_empty()); }
#[test]
fn test_runner_delegation() {
let mut runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let ext = Arc::new(NoopExtension);
runner.registry_mut().register(ext);
runner.states.insert("noop".to_string(), ExtensionState::Active);
runner.order.push("noop".to_string());
assert!(runner.get("noop").is_some());
assert!(runner.get("missing").is_none());
assert_eq!(runner.names().collect::<Vec<_>>(), vec!["noop"]);
}
#[test]
fn test_runner_debug() {
let runner = ExtensionRunner::new(PathBuf::from("/tmp"));
let debug = format!("{:?}", runner);
assert!(debug.contains("ExtensionRunner"));
assert!(debug.contains("/tmp"));
}
#[test]
fn test_discover_extensions_empty_dir() {
let dir = tempfile::tempdir().unwrap();
let paths = discover_extensions_in_dir(dir.path());
assert!(paths.is_empty());
}
#[test]
fn test_discover_extensions_nonexistent_dir() {
let paths = discover_extensions_in_dir(Path::new("/nonexistent"));
assert!(paths.is_empty());
}
#[test]
fn test_discover_extensions_finds_shared_lib() {
let dir = tempfile::tempdir().unwrap();
let ext = if cfg!(target_os = "macos") {
"dylib"
} else if cfg!(target_os = "windows") {
"dll"
} else {
"so"
};
let lib_file = dir.path().join(format!("my_ext.{}", ext));
std::fs::write(&lib_file, b"fake lib").unwrap();
let paths = discover_extensions_in_dir(dir.path());
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], lib_file);
}
#[test]
fn test_discover_extensions_ignores_non_libs() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("readme.txt"), b"text").unwrap();
std::fs::write(dir.path().join("script.sh"), b"bash").unwrap();
let paths = discover_extensions_in_dir(dir.path());
assert!(paths.is_empty());
}
#[test]
fn test_discover_extensions_subdirectory_index() {
let dir = tempfile::tempdir().unwrap();
let subdir = dir.path().join("my_ext");
std::fs::create_dir(&subdir).unwrap();
let ext = if cfg!(target_os = "macos") {
"dylib"
} else if cfg!(target_os = "windows") {
"dll"
} else {
"so"
};
let index_lib = subdir.join(format!("index.{}", ext));
std::fs::write(&index_lib, b"fake lib").unwrap();
let paths = discover_extensions_in_dir(dir.path());
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], index_lib);
}
#[test]
fn test_discover_extensions_from_cwd() {
let cwd = tempfile::tempdir().unwrap();
let ext_dir = cwd.path().join(".oxi").join("extensions");
std::fs::create_dir_all(&ext_dir).unwrap();
let ext = if cfg!(target_os = "macos") {
"dylib"
} else if cfg!(target_os = "windows") {
"dll"
} else {
"so"
};
std::fs::write(ext_dir.join(format!("test.{}", ext)), b"fake").unwrap();
let paths = discover_extensions(cwd.path(), &[]);
assert_eq!(paths.len(), 1);
}
#[test]
fn test_extension_state_display() {
assert_eq!(ExtensionState::Pending.to_string(), "pending");
assert_eq!(ExtensionState::Active.to_string(), "active");
assert_eq!(ExtensionState::Disabled.to_string(), "disabled");
assert_eq!(ExtensionState::Failed.to_string(), "failed");
assert_eq!(ExtensionState::Unloaded.to_string(), "unloaded");
}
#[test]
fn test_extension_state_serialization() {
let state = ExtensionState::Active;
let json = serde_json::to_string(&state).unwrap();
assert_eq!(json, "\"active\"");
let parsed: ExtensionState = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, ExtensionState::Active);
}
#[test]
fn test_tool_call_emit_result_default() {
let result = ToolCallEmitResult::default();
assert!(!result.blocked);
assert!(result.block_reason.is_none());
assert!(result.errors.is_empty());
}
#[test]
fn test_tool_result_emit_result_default() {
let result = ToolResultEmitResult::default();
assert!(result.output.is_none());
assert!(result.success.is_none());
assert!(result.errors.is_empty());
}
#[test]
fn test_session_before_emit_result_default() {
let result = SessionBeforeEmitResult::default();
assert!(!result.cancelled);
assert!(result.cancelled_by.is_none());
assert!(result.errors.is_empty());
}
}