#![allow(unused)]
use crate::extensions::types::{
AgentEvent, AgentToolResult, BashEvent, BeforeProviderRequestEvent, Command, ContextEmitResult,
ContextEvent, ExtensionError, ExtensionErrorListener, ExtensionErrorRecord, ExtensionManifest,
ExtensionState, InputEvent, InputEventResult, ModelSelectEvent, ProviderRequestEmitResult,
SessionBeforeCompactEvent, SessionBeforeEmitResult, SessionBeforeForkEvent,
SessionBeforeSwitchEvent, SessionBeforeTreeEvent, SessionCompactEvent, SessionShutdownEvent,
SessionTreeEvent, ThinkingLevelSelectEvent, ToolCallEmitResult, ToolResultEmitResult,
};
use crate::extensions::context::ExtensionContext;
use crate::extensions::Extension;
use crate::CompactionContext;
use oxi_store::settings::Settings;
use parking_lot::RwLock;
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use std::path::PathBuf;
use std::sync::Arc;
type ToolType = dyn oxi_agent::AgentTool;
type ToolArc = Arc<ToolType>;
struct LoadedExtension {
extension: Arc<dyn Extension>,
enabled: bool,
}
pub struct ExtensionRegistry {
entries: HashMap<String, LoadedExtension>,
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
}
impl Default for ExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
errors: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn register(&mut self, ext: Arc<dyn Extension>) {
let name = ext.name().to_string();
tracing::info!(name = %name, "extension registered");
self.entries.insert(
name,
LoadedExtension {
extension: ext,
enabled: true,
},
);
}
pub fn unregister(&mut self, name: &str) -> bool {
if let Some(entry) = self.entries.remove(name) {
self.call_hook_safe(name, "on_unload", || {
entry.extension.on_unload();
});
tracing::info!(name = %name, "extension unregistered");
true
} else {
false
}
}
pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
let ext = {
let entry = self
.entries
.get_mut(name)
.ok_or_else(|| ExtensionError::NotFound {
name: name.to_string(),
})?;
if !entry.enabled {
return Ok(());
}
entry.enabled = false;
Arc::clone(&entry.extension)
};
self.call_hook_safe(name, "on_unload", || {
ext.on_unload();
});
tracing::info!(name = %name, "extension disabled");
Ok(())
}
pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
let ext = {
let entry = self
.entries
.get_mut(name)
.ok_or_else(|| ExtensionError::NotFound {
name: name.to_string(),
})?;
if entry.enabled {
return Ok(());
}
entry.enabled = true;
Arc::clone(&entry.extension)
};
self.call_hook_safe(name, "on_load", || {
ext.on_load(ctx);
});
tracing::info!(name = %name, "extension enabled");
Ok(())
}
pub fn is_enabled(&self, name: &str) -> bool {
self.entries.get(name).map(|e| e.enabled).unwrap_or(false)
}
pub fn all_tools(&self) -> Vec<ToolArc> {
self.entries
.values()
.filter(|e| e.enabled)
.flat_map(|e| e.extension.register_tools())
.collect()
}
pub fn all_commands(&self) -> Vec<Command> {
self.entries
.values()
.filter(|e| e.enabled)
.flat_map(|e| e.extension.register_commands())
.collect()
}
pub fn emit_load(&self, ctx: &ExtensionContext) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_load", || {
entry.extension.on_load(ctx);
});
}
}
pub fn emit_unload(&self) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_unload", || {
entry.extension.on_unload();
});
}
}
pub fn emit_message_sent(&self, msg: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_message_sent", || {
entry.extension.on_message_sent(msg);
});
}
}
pub fn emit_message_received(&self, msg: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_message_received", || {
entry.extension.on_message_received(msg);
});
}
}
pub fn emit_tool_call(&self, tool: &str, params: &Value) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_tool_call", || {
entry.extension.on_tool_call(tool, params);
});
}
}
pub fn emit_tool_result(&self, tool: &str, result: &AgentToolResult) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_tool_result", || {
entry.extension.on_tool_result(tool, result);
});
}
}
pub fn emit_session_start(&self, session_id: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_session_start", || {
entry.extension.on_session_start(session_id);
});
}
}
pub fn emit_session_end(&self, session_id: &str) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_session_end", || {
entry.extension.on_session_end(session_id);
});
}
}
pub fn emit_settings_changed(&self, settings: &Settings) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_settings_changed", || {
entry.extension.on_settings_changed(settings);
});
}
}
pub fn emit_event(&self, event: &AgentEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "on_event", || {
entry.extension.on_event(event);
});
}
}
pub fn emit_error(&self, error: &anyhow::Error) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
if let Err(e) = entry.extension.on_error(error) {
tracing::warn!(extension = name, error = %e, "on_error hook failed");
errors.push((name.to_string(), e));
}
}
errors
}
pub fn emit_session_shutdown(&self, event: &SessionShutdownEvent) {
for entry in self.entries.values().filter(|e| e.enabled) {
let name = entry.extension.name();
self.call_hook_safe(name, "session_shutdown", || {
entry.extension.session_shutdown(event);
});
}
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
self.entries.get(name).map(|e| Arc::clone(&e.extension))
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.entries.keys().map(|s| s.as_str())
}
pub fn extensions(&self) -> impl Iterator<Item = &Arc<dyn Extension>> {
self.entries.values().map(|e| &e.extension)
}
pub fn manifest(&self, name: &str) -> Option<ExtensionManifest> {
self.entries.get(name).map(|e| e.extension.manifest())
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
self.errors.read().clone()
}
pub fn clear_errors(&self) {
self.errors.write().clear();
}
fn call_hook_safe<F>(&self, ext_name: &str, hook: &str, f: F)
where
F: FnOnce(),
{
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
if let Err(payload) = result {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
tracing::error!(extension = ext_name, hook = hook, error = %msg, "Extension hook panicked");
self.errors.write().push(ExtensionErrorRecord::new(
ext_name,
hook,
format!("panic: {}", msg),
));
}
}
}
impl fmt::Debug for ExtensionRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtensionRegistry")
.field("count", &self.entries.len())
.finish()
}
}
pub struct ExtensionRunner {
registry: ExtensionRegistry,
states: HashMap<String, ExtensionState>,
order: Vec<String>,
error_listeners: Vec<Arc<ExtensionErrorListener>>,
cwd: PathBuf,
}
impl Default for ExtensionRunner {
fn default() -> Self {
Self::new(PathBuf::from("."))
}
}
impl ExtensionRunner {
pub fn new(cwd: PathBuf) -> Self {
Self {
registry: ExtensionRegistry::new(),
states: HashMap::new(),
order: Vec::new(),
error_listeners: Vec::new(),
cwd,
}
}
pub fn on_error<F>(&mut self, listener: F) -> ExtensionErrorHandle
where
F: Fn(&ExtensionErrorRecord) + Send + Sync + 'static,
{
let arc: Arc<ExtensionErrorListener> = Arc::new(listener);
self.error_listeners.push(Arc::clone(&arc));
ExtensionErrorHandle {
listener: Some(arc),
}
}
pub fn emit_error_record(&self, record: ExtensionErrorRecord) {
for listener in &self.error_listeners {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
listener(&record);
}));
}
self.registry.errors.write().push(record);
}
pub fn register(&mut self, ext: Arc<dyn Extension>, ctx: &ExtensionContext) {
let name = ext.name().to_string();
self.registry.register(ext);
self.set_state(&name, ExtensionState::Active);
self.registry.call_hook_safe(&name, "on_load", || {
if let Some(e) = self.registry.get(&name) {
e.on_load(ctx);
}
});
}
pub fn unload_extension(&mut self, name: &str) -> bool {
let had = self.registry.unregister(name);
if had {
self.set_state(name, ExtensionState::Unloaded);
tracing::info!(name = %name, "extension unloaded");
}
had
}
fn set_state(&mut self, name: &str, state: ExtensionState) {
self.states.insert(name.to_string(), state);
if state == ExtensionState::Active && !self.order.contains(&name.to_string()) {
self.order.push(name.to_string());
}
if state == ExtensionState::Unloaded {
self.order.retain(|n| n != name);
}
}
pub fn state(&self, name: &str) -> ExtensionState {
self.states
.get(name)
.copied()
.unwrap_or(ExtensionState::Unloaded)
}
pub fn states(&self) -> &HashMap<String, ExtensionState> {
&self.states
}
pub fn extension_order(&self) -> &[String] {
&self.order
}
pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
self.registry.disable(name)?;
self.set_state(name, ExtensionState::Disabled);
Ok(())
}
pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
self.registry.enable(name, ctx)?;
self.set_state(name, ExtensionState::Active);
Ok(())
}
pub fn is_enabled(&self, name: &str) -> bool {
self.registry.is_enabled(name)
}
pub fn has_handlers(&self, _event_type: &str) -> bool {
self.has_enabled_extensions()
}
pub fn has_enabled_extensions(&self) -> bool {
self.registry.extensions().any(|_| true)
&& self
.order
.iter()
.any(|name| self.state(name) == ExtensionState::Active)
}
pub fn all_tools(&self) -> Vec<ToolArc> {
let mut tools = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
tools.extend(ext.register_tools());
}
}
tools
}
pub fn all_commands(&self) -> Vec<Command> {
let mut commands = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
commands.extend(ext.register_commands());
}
}
commands
}
pub fn wrap_tool(&self, tool: ToolArc) -> ToolArc {
Arc::new(WrappedTool {
inner: tool,
runner_state: Arc::new(RwLock::new(RunnerState {
errors: self.registry.errors.clone(),
error_listeners: self.error_listeners.clone(),
})),
})
}
pub fn wrap_tools(&self, tools: Vec<ToolArc>) -> Vec<ToolArc> {
tools.into_iter().map(|t| self.wrap_tool(t)).collect()
}
pub fn emit_tool_call(&self, tool_name: &str, params: &Value) -> ToolCallEmitResult {
let mut result = ToolCallEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.on_before_tool_call(tool_name, params) {
let err_str = e.to_string();
result.errors.push((name.clone(), err_str.clone()));
self.emit_error_record(ExtensionErrorRecord::new(
name,
"on_before_tool_call",
&err_str,
));
}
self.registry.call_hook_safe(name, "on_tool_call", || {
ext.on_tool_call(tool_name, params);
});
}
}
result
}
pub fn emit_tool_result_event(
&self,
tool_name: &str,
tool_result: &AgentToolResult,
) -> ToolResultEmitResult {
let mut result = ToolResultEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.on_after_tool_call(tool_name, tool_result) {
result.errors.push((name.clone(), e.to_string()));
}
self.registry.call_hook_safe(name, "on_tool_result", || {
ext.on_tool_result(tool_name, tool_result);
});
}
}
result
}
pub fn emit_input_event(&self, event: &mut InputEvent) -> InputEventResult {
let mut final_result = InputEventResult::Continue;
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| ext.input(event)));
match result {
Ok(InputEventResult::Handled) => return InputEventResult::Handled,
Ok(InputEventResult::Transform { text }) => {
event.text = text.clone();
final_result = InputEventResult::Transform { text };
}
Ok(InputEventResult::Continue) => {}
Err(payload) => {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
self.emit_error_record(ExtensionErrorRecord::new(
name,
"input",
format!("panic: {}", msg),
));
}
}
}
}
final_result
}
pub fn emit_context_event(&self, messages: Vec<oxi_ai::Message>) -> ContextEmitResult {
let mut current_messages = messages;
let mut errors = Vec::new();
let mut modified = false;
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let prev_len = current_messages.len();
let mut event = ContextEvent {
messages: current_messages.clone(),
};
if let Err(e) = ext.context(&mut event) {
errors.push((name.clone(), e.to_string()));
} else if event.messages.len() != prev_len {
current_messages = event.messages;
modified = true;
}
}
}
ContextEmitResult {
modified,
messages: current_messages,
errors,
}
}
pub fn emit_before_provider_request_event(&self, payload: Value) -> ProviderRequestEmitResult {
let mut current_payload = payload;
let mut modified = false;
let mut errors = Vec::new();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
let mut event = BeforeProviderRequestEvent {
payload: current_payload.clone(),
};
if let Err(e) = ext.before_provider_request(&mut event) {
errors.push((name.clone(), e.to_string()));
} else if event.payload != current_payload {
current_payload = event.payload;
modified = true;
}
}
}
ProviderRequestEmitResult {
modified,
payload: current_payload,
errors,
}
}
pub fn emit_session_before_switch_event(
&self,
event: &SessionBeforeSwitchEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.session_before_switch(event) {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
result
}
pub fn emit_session_before_fork_event(
&self,
event: &SessionBeforeForkEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.session_before_fork(event) {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
result
}
pub fn emit_session_before_compact_event(
&self,
event: &SessionBeforeCompactEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.session_before_compact(event) {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
result
}
pub fn emit_session_before_tree_event(
&self,
event: &SessionBeforeTreeEvent,
) -> SessionBeforeEmitResult {
let mut result = SessionBeforeEmitResult::default();
for name in &self.order {
if self.state(name) != ExtensionState::Active {
continue;
}
if let Some(ext) = self.registry.get(name) {
if let Err(e) = ext.session_before_tree(event) {
result.cancelled = true;
result.cancelled_by = Some(name.clone());
result.errors.push((name.clone(), e.to_string()));
return result;
}
}
}
result
}
pub fn emit_session_shutdown_event(&self, event: &SessionShutdownEvent) -> bool {
if !self.has_enabled_extensions() {
return false;
}
self.registry.emit_session_shutdown(event);
true
}
pub fn emit_event(&self, event: &AgentEvent) {
self.registry.emit_event(event);
}
pub fn registry(&self) -> &ExtensionRegistry {
&self.registry
}
pub fn registry_mut(&mut self) -> &mut ExtensionRegistry {
&mut self.registry
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
self.registry.get(name)
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.order.iter().map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.order.len()
}
pub fn is_empty(&self) -> bool {
self.order.is_empty()
}
pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
self.registry.errors()
}
pub fn clear_errors(&self) {
self.registry.clear_errors();
}
}
impl fmt::Debug for ExtensionRunner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtensionRunner")
.field("cwd", &self.cwd)
.field("extensions", &self.order)
.finish()
}
}
pub struct ExtensionErrorHandle {
listener: Option<Arc<ExtensionErrorListener>>,
}
impl ExtensionErrorHandle {
pub fn unregister(&mut self) -> Option<Arc<ExtensionErrorListener>> {
self.listener.take()
}
}
impl Drop for ExtensionErrorHandle {
fn drop(&mut self) {}
}
struct RunnerState {
errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
error_listeners: Vec<Arc<ExtensionErrorListener>>,
}
struct WrappedTool {
inner: ToolArc,
runner_state: Arc<RwLock<RunnerState>>,
}
#[async_trait::async_trait]
impl oxi_agent::AgentTool for WrappedTool {
fn name(&self) -> &str {
self.inner.name()
}
fn label(&self) -> &str {
self.inner.label()
}
fn description(&self) -> &str {
self.inner.description()
}
fn parameters_schema(&self) -> Value {
self.inner.parameters_schema()
}
async fn execute(
&self,
tool_call_id: &str,
params: Value,
signal: Option<tokio::sync::oneshot::Receiver<()>>,
ctx: &oxi_agent::ToolContext,
) -> Result<AgentToolResult, String> {
self.inner.execute(tool_call_id, params, signal, ctx).await
}
}