1#![allow(unused)]
8
9use crate::extensions::types::{
10 AgentEvent, AgentToolResult, BashEvent, BeforeProviderRequestEvent, Command, ContextEmitResult,
11 ContextEvent, ExtensionError, ExtensionErrorListener, ExtensionErrorRecord, ExtensionManifest,
12 ExtensionState, InputEvent, InputEventResult, ModelSelectEvent, ProviderRequestEmitResult,
13 SessionBeforeCompactEvent, SessionBeforeEmitResult, SessionBeforeForkEvent,
14 SessionBeforeSwitchEvent, SessionBeforeTreeEvent, SessionCompactEvent, SessionShutdownEvent,
15 SessionTreeEvent, ThinkingLevelSelectEvent, ToolCallEmitResult, ToolResultEmitResult,
16};
17
18use crate::extensions::context::ExtensionContext;
19use crate::extensions::Extension;
20use crate::CompactionContext;
21use oxi_store::settings::Settings;
22
23use parking_lot::RwLock;
24use serde_json::Value;
25use std::collections::HashMap;
26use std::fmt;
27use std::path::PathBuf;
28use std::sync::Arc;
29
30type ToolType = dyn oxi_agent::AgentTool;
31type ToolArc = Arc<ToolType>;
32
33struct LoadedExtension {
38 extension: Arc<dyn Extension>,
39 enabled: bool,
40}
41
42pub struct ExtensionRegistry {
44 entries: HashMap<String, LoadedExtension>,
45 errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
46}
47
48impl Default for ExtensionRegistry {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl ExtensionRegistry {
55 pub fn new() -> Self {
57 Self {
58 entries: HashMap::new(),
59 errors: Arc::new(RwLock::new(Vec::new())),
60 }
61 }
62
63 pub fn register(&mut self, ext: Arc<dyn Extension>) {
65 let name = ext.name().to_string();
66 tracing::info!(name = %name, "extension registered");
67 self.entries.insert(
68 name,
69 LoadedExtension {
70 extension: ext,
71 enabled: true,
72 },
73 );
74 }
75
76 pub fn unregister(&mut self, name: &str) -> bool {
78 if let Some(entry) = self.entries.remove(name) {
79 self.call_hook_safe(name, "on_unload", || {
80 entry.extension.on_unload();
81 });
82 tracing::info!(name = %name, "extension unregistered");
83 true
84 } else {
85 false
86 }
87 }
88
89 pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
91 let ext = {
92 let entry = self
93 .entries
94 .get_mut(name)
95 .ok_or_else(|| ExtensionError::NotFound {
96 name: name.to_string(),
97 })?;
98 if !entry.enabled {
99 return Ok(());
100 }
101 entry.enabled = false;
102 Arc::clone(&entry.extension)
103 };
104 self.call_hook_safe(name, "on_unload", || {
105 ext.on_unload();
106 });
107 tracing::info!(name = %name, "extension disabled");
108 Ok(())
109 }
110
111 pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
113 let ext = {
114 let entry = self
115 .entries
116 .get_mut(name)
117 .ok_or_else(|| ExtensionError::NotFound {
118 name: name.to_string(),
119 })?;
120 if entry.enabled {
121 return Ok(());
122 }
123 entry.enabled = true;
124 Arc::clone(&entry.extension)
125 };
126 self.call_hook_safe(name, "on_load", || {
127 ext.on_load(ctx);
128 });
129 tracing::info!(name = %name, "extension enabled");
130 Ok(())
131 }
132
133 pub fn is_enabled(&self, name: &str) -> bool {
135 self.entries.get(name).map(|e| e.enabled).unwrap_or(false)
136 }
137
138 pub fn all_tools(&self) -> Vec<ToolArc> {
140 self.entries
141 .values()
142 .filter(|e| e.enabled)
143 .flat_map(|e| e.extension.register_tools())
144 .collect()
145 }
146
147 pub fn all_commands(&self) -> Vec<Command> {
149 self.entries
150 .values()
151 .filter(|e| e.enabled)
152 .flat_map(|e| e.extension.register_commands())
153 .collect()
154 }
155
156 pub fn emit_load(&self, ctx: &ExtensionContext) {
158 for entry in self.entries.values().filter(|e| e.enabled) {
159 let name = entry.extension.name();
160 self.call_hook_safe(name, "on_load", || {
161 entry.extension.on_load(ctx);
162 });
163 }
164 }
165
166 pub fn emit_unload(&self) {
168 for entry in self.entries.values().filter(|e| e.enabled) {
169 let name = entry.extension.name();
170 self.call_hook_safe(name, "on_unload", || {
171 entry.extension.on_unload();
172 });
173 }
174 }
175
176 pub fn emit_message_sent(&self, msg: &str) {
178 for entry in self.entries.values().filter(|e| e.enabled) {
179 let name = entry.extension.name();
180 self.call_hook_safe(name, "on_message_sent", || {
181 entry.extension.on_message_sent(msg);
182 });
183 }
184 }
185
186 pub fn emit_message_received(&self, msg: &str) {
188 for entry in self.entries.values().filter(|e| e.enabled) {
189 let name = entry.extension.name();
190 self.call_hook_safe(name, "on_message_received", || {
191 entry.extension.on_message_received(msg);
192 });
193 }
194 }
195
196 pub fn emit_tool_call(&self, tool: &str, params: &Value) {
198 for entry in self.entries.values().filter(|e| e.enabled) {
199 let name = entry.extension.name();
200 self.call_hook_safe(name, "on_tool_call", || {
201 entry.extension.on_tool_call(tool, params);
202 });
203 }
204 }
205
206 pub fn emit_tool_result(&self, tool: &str, result: &AgentToolResult) {
208 for entry in self.entries.values().filter(|e| e.enabled) {
209 let name = entry.extension.name();
210 self.call_hook_safe(name, "on_tool_result", || {
211 entry.extension.on_tool_result(tool, result);
212 });
213 }
214 }
215
216 pub fn emit_session_start(&self, session_id: &str) {
218 for entry in self.entries.values().filter(|e| e.enabled) {
219 let name = entry.extension.name();
220 self.call_hook_safe(name, "on_session_start", || {
221 entry.extension.on_session_start(session_id);
222 });
223 }
224 }
225
226 pub fn emit_session_end(&self, session_id: &str) {
228 for entry in self.entries.values().filter(|e| e.enabled) {
229 let name = entry.extension.name();
230 self.call_hook_safe(name, "on_session_end", || {
231 entry.extension.on_session_end(session_id);
232 });
233 }
234 }
235
236 pub fn emit_settings_changed(&self, settings: &Settings) {
238 for entry in self.entries.values().filter(|e| e.enabled) {
239 let name = entry.extension.name();
240 self.call_hook_safe(name, "on_settings_changed", || {
241 entry.extension.on_settings_changed(settings);
242 });
243 }
244 }
245
246 pub fn emit_event(&self, event: &AgentEvent) {
248 for entry in self.entries.values().filter(|e| e.enabled) {
249 let name = entry.extension.name();
250 self.call_hook_safe(name, "on_event", || {
251 entry.extension.on_event(event);
252 });
253 }
254 }
255
256 pub fn emit_error(&self, error: &anyhow::Error) -> Vec<(String, anyhow::Error)> {
258 let mut errors = Vec::new();
259 for entry in self.entries.values().filter(|e| e.enabled) {
260 let name = entry.extension.name();
261 if let Err(e) = entry.extension.on_error(error) {
262 tracing::warn!(extension = name, error = %e, "on_error hook failed");
263 errors.push((name.to_string(), e));
264 }
265 }
266 errors
267 }
268
269 pub fn emit_session_shutdown(&self, event: &SessionShutdownEvent) {
271 for entry in self.entries.values().filter(|e| e.enabled) {
272 let name = entry.extension.name();
273 self.call_hook_safe(name, "session_shutdown", || {
274 entry.extension.session_shutdown(event);
275 });
276 }
277 }
278
279 pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
281 self.entries.get(name).map(|e| Arc::clone(&e.extension))
282 }
283
284 pub fn names(&self) -> impl Iterator<Item = &str> {
286 self.entries.keys().map(|s| s.as_str())
287 }
288
289 pub fn extensions(&self) -> impl Iterator<Item = &Arc<dyn Extension>> {
291 self.entries.values().map(|e| &e.extension)
292 }
293
294 pub fn manifest(&self, name: &str) -> Option<ExtensionManifest> {
296 self.entries.get(name).map(|e| e.extension.manifest())
297 }
298
299 pub fn len(&self) -> usize {
301 self.entries.len()
302 }
303 pub fn is_empty(&self) -> bool {
305 self.entries.is_empty()
306 }
307 pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
309 self.errors.read().clone()
310 }
311 pub fn clear_errors(&self) {
313 self.errors.write().clear();
314 }
315
316 fn call_hook_safe<F>(&self, ext_name: &str, hook: &str, f: F)
317 where
318 F: FnOnce(),
319 {
320 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
321 if let Err(payload) = result {
322 let msg = if let Some(s) = payload.downcast_ref::<&str>() {
323 s.to_string()
324 } else if let Some(s) = payload.downcast_ref::<String>() {
325 s.clone()
326 } else {
327 "unknown panic".to_string()
328 };
329 tracing::error!(extension = ext_name, hook = hook, error = %msg, "Extension hook panicked");
330 self.errors.write().push(ExtensionErrorRecord::new(
331 ext_name,
332 hook,
333 format!("panic: {}", msg),
334 ));
335 }
336 }
337}
338
339impl fmt::Debug for ExtensionRegistry {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 f.debug_struct("ExtensionRegistry")
342 .field("count", &self.entries.len())
343 .finish()
344 }
345}
346
347pub struct ExtensionRunner {
353 registry: ExtensionRegistry,
354 states: HashMap<String, ExtensionState>,
355 order: Vec<String>,
356 error_listeners: Vec<Arc<ExtensionErrorListener>>,
357 cwd: PathBuf,
358}
359
360impl Default for ExtensionRunner {
361 fn default() -> Self {
362 Self::new(PathBuf::from("."))
363 }
364}
365
366impl ExtensionRunner {
367 pub fn new(cwd: PathBuf) -> Self {
369 Self {
370 registry: ExtensionRegistry::new(),
371 states: HashMap::new(),
372 order: Vec::new(),
373 error_listeners: Vec::new(),
374 cwd,
375 }
376 }
377
378 pub fn on_error<F>(&mut self, listener: F) -> ExtensionErrorHandle
380 where
381 F: Fn(&ExtensionErrorRecord) + Send + Sync + 'static,
382 {
383 let arc: Arc<ExtensionErrorListener> = Arc::new(listener);
384 self.error_listeners.push(Arc::clone(&arc));
385 ExtensionErrorHandle {
386 listener: Some(arc),
387 }
388 }
389
390 pub fn emit_error_record(&self, record: ExtensionErrorRecord) {
392 for listener in &self.error_listeners {
393 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
394 listener(&record);
395 }));
396 }
397 self.registry.errors.write().push(record);
398 }
399
400 pub fn register(&mut self, ext: Arc<dyn Extension>, ctx: &ExtensionContext) {
402 let name = ext.name().to_string();
403 self.registry.register(ext);
404 self.set_state(&name, ExtensionState::Active);
405 self.registry.call_hook_safe(&name, "on_load", || {
406 if let Some(e) = self.registry.get(&name) {
407 e.on_load(ctx);
408 }
409 });
410 }
411
412 pub fn unload_extension(&mut self, name: &str) -> bool {
414 let had = self.registry.unregister(name);
415 if had {
416 self.set_state(name, ExtensionState::Unloaded);
417 tracing::info!(name = %name, "extension unloaded");
418 }
419 had
420 }
421
422 fn set_state(&mut self, name: &str, state: ExtensionState) {
423 self.states.insert(name.to_string(), state);
424 if state == ExtensionState::Active && !self.order.contains(&name.to_string()) {
425 self.order.push(name.to_string());
426 }
427 if state == ExtensionState::Unloaded {
428 self.order.retain(|n| n != name);
429 }
430 }
431
432 pub fn state(&self, name: &str) -> ExtensionState {
434 self.states
435 .get(name)
436 .copied()
437 .unwrap_or(ExtensionState::Unloaded)
438 }
439 pub fn states(&self) -> &HashMap<String, ExtensionState> {
441 &self.states
442 }
443 pub fn extension_order(&self) -> &[String] {
445 &self.order
446 }
447
448 pub fn disable(&mut self, name: &str) -> Result<(), ExtensionError> {
450 self.registry.disable(name)?;
451 self.set_state(name, ExtensionState::Disabled);
452 Ok(())
453 }
454 pub fn enable(&mut self, name: &str, ctx: &ExtensionContext) -> Result<(), ExtensionError> {
456 self.registry.enable(name, ctx)?;
457 self.set_state(name, ExtensionState::Active);
458 Ok(())
459 }
460 pub fn is_enabled(&self, name: &str) -> bool {
462 self.registry.is_enabled(name)
463 }
464
465 pub fn has_handlers(&self, _event_type: &str) -> bool {
467 self.has_enabled_extensions()
468 }
469 pub fn has_enabled_extensions(&self) -> bool {
471 self.registry.extensions().any(|_| true)
472 && self
473 .order
474 .iter()
475 .any(|name| self.state(name) == ExtensionState::Active)
476 }
477
478 pub fn all_tools(&self) -> Vec<ToolArc> {
480 let mut tools = Vec::new();
481 for name in &self.order {
482 if self.state(name) != ExtensionState::Active {
483 continue;
484 }
485 if let Some(ext) = self.registry.get(name) {
486 tools.extend(ext.register_tools());
487 }
488 }
489 tools
490 }
491
492 pub fn all_commands(&self) -> Vec<Command> {
494 let mut commands = Vec::new();
495 for name in &self.order {
496 if self.state(name) != ExtensionState::Active {
497 continue;
498 }
499 if let Some(ext) = self.registry.get(name) {
500 commands.extend(ext.register_commands());
501 }
502 }
503 commands
504 }
505
506 pub fn wrap_tool(&self, tool: ToolArc) -> ToolArc {
508 Arc::new(WrappedTool {
509 inner: tool,
510 runner_state: Arc::new(RwLock::new(RunnerState {
511 errors: self.registry.errors.clone(),
512 error_listeners: self.error_listeners.clone(),
513 })),
514 })
515 }
516
517 pub fn wrap_tools(&self, tools: Vec<ToolArc>) -> Vec<ToolArc> {
519 tools.into_iter().map(|t| self.wrap_tool(t)).collect()
520 }
521
522 pub fn emit_tool_call(&self, tool_name: &str, params: &Value) -> ToolCallEmitResult {
524 let mut result = ToolCallEmitResult::default();
525 for name in &self.order {
526 if self.state(name) != ExtensionState::Active {
527 continue;
528 }
529 if let Some(ext) = self.registry.get(name) {
530 if let Err(e) = ext.on_before_tool_call(tool_name, params) {
531 let err_str = e.to_string();
532 result.errors.push((name.clone(), err_str.clone()));
533 self.emit_error_record(ExtensionErrorRecord::new(
534 name,
535 "on_before_tool_call",
536 &err_str,
537 ));
538 }
539 self.registry.call_hook_safe(name, "on_tool_call", || {
540 ext.on_tool_call(tool_name, params);
541 });
542 }
543 }
544 result
545 }
546
547 pub fn emit_tool_result_event(
549 &self,
550 tool_name: &str,
551 tool_result: &AgentToolResult,
552 ) -> ToolResultEmitResult {
553 let mut result = ToolResultEmitResult::default();
554 for name in &self.order {
555 if self.state(name) != ExtensionState::Active {
556 continue;
557 }
558 if let Some(ext) = self.registry.get(name) {
559 if let Err(e) = ext.on_after_tool_call(tool_name, tool_result) {
560 result.errors.push((name.clone(), e.to_string()));
561 }
562 self.registry.call_hook_safe(name, "on_tool_result", || {
563 ext.on_tool_result(tool_name, tool_result);
564 });
565 }
566 }
567 result
568 }
569
570 pub fn emit_input_event(&self, event: &mut InputEvent) -> InputEventResult {
572 let mut final_result = InputEventResult::Continue;
573 for name in &self.order {
574 if self.state(name) != ExtensionState::Active {
575 continue;
576 }
577 if let Some(ext) = self.registry.get(name) {
578 let result =
579 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| ext.input(event)));
580 match result {
581 Ok(InputEventResult::Handled) => return InputEventResult::Handled,
582 Ok(InputEventResult::Transform { text }) => {
583 event.text = text.clone();
584 final_result = InputEventResult::Transform { text };
585 }
586 Ok(InputEventResult::Continue) => {}
587 Err(payload) => {
588 let msg = if let Some(s) = payload.downcast_ref::<&str>() {
589 s.to_string()
590 } else if let Some(s) = payload.downcast_ref::<String>() {
591 s.clone()
592 } else {
593 "unknown panic".to_string()
594 };
595 self.emit_error_record(ExtensionErrorRecord::new(
596 name,
597 "input",
598 format!("panic: {}", msg),
599 ));
600 }
601 }
602 }
603 }
604 final_result
605 }
606
607 pub fn emit_context_event(&self, messages: Vec<oxi_ai::Message>) -> ContextEmitResult {
609 let mut current_messages = messages;
610 let mut errors = Vec::new();
611 let mut modified = false;
612 for name in &self.order {
613 if self.state(name) != ExtensionState::Active {
614 continue;
615 }
616 if let Some(ext) = self.registry.get(name) {
617 let prev_len = current_messages.len();
618 let mut event = ContextEvent {
619 messages: current_messages.clone(),
620 };
621 if let Err(e) = ext.context(&mut event) {
622 errors.push((name.clone(), e.to_string()));
623 } else if event.messages.len() != prev_len {
624 current_messages = event.messages;
625 modified = true;
626 }
627 }
628 }
629 ContextEmitResult {
630 modified,
631 messages: current_messages,
632 errors,
633 }
634 }
635
636 pub fn emit_before_provider_request_event(&self, payload: Value) -> ProviderRequestEmitResult {
638 let mut current_payload = payload;
639 let mut modified = false;
640 let mut errors = Vec::new();
641 for name in &self.order {
642 if self.state(name) != ExtensionState::Active {
643 continue;
644 }
645 if let Some(ext) = self.registry.get(name) {
646 let mut event = BeforeProviderRequestEvent {
647 payload: current_payload.clone(),
648 };
649 if let Err(e) = ext.before_provider_request(&mut event) {
650 errors.push((name.clone(), e.to_string()));
651 } else if event.payload != current_payload {
652 current_payload = event.payload;
653 modified = true;
654 }
655 }
656 }
657 ProviderRequestEmitResult {
658 modified,
659 payload: current_payload,
660 errors,
661 }
662 }
663
664 pub fn emit_session_before_switch_event(
666 &self,
667 event: &SessionBeforeSwitchEvent,
668 ) -> SessionBeforeEmitResult {
669 let mut result = SessionBeforeEmitResult::default();
670 for name in &self.order {
671 if self.state(name) != ExtensionState::Active {
672 continue;
673 }
674 if let Some(ext) = self.registry.get(name) {
675 if let Err(e) = ext.session_before_switch(event) {
676 result.cancelled = true;
677 result.cancelled_by = Some(name.clone());
678 result.errors.push((name.clone(), e.to_string()));
679 return result;
680 }
681 }
682 }
683 result
684 }
685
686 pub fn emit_session_before_fork_event(
688 &self,
689 event: &SessionBeforeForkEvent,
690 ) -> SessionBeforeEmitResult {
691 let mut result = SessionBeforeEmitResult::default();
692 for name in &self.order {
693 if self.state(name) != ExtensionState::Active {
694 continue;
695 }
696 if let Some(ext) = self.registry.get(name) {
697 if let Err(e) = ext.session_before_fork(event) {
698 result.cancelled = true;
699 result.cancelled_by = Some(name.clone());
700 result.errors.push((name.clone(), e.to_string()));
701 return result;
702 }
703 }
704 }
705 result
706 }
707
708 pub fn emit_session_before_compact_event(
710 &self,
711 event: &SessionBeforeCompactEvent,
712 ) -> SessionBeforeEmitResult {
713 let mut result = SessionBeforeEmitResult::default();
714 for name in &self.order {
715 if self.state(name) != ExtensionState::Active {
716 continue;
717 }
718 if let Some(ext) = self.registry.get(name) {
719 if let Err(e) = ext.session_before_compact(event) {
720 result.cancelled = true;
721 result.cancelled_by = Some(name.clone());
722 result.errors.push((name.clone(), e.to_string()));
723 return result;
724 }
725 }
726 }
727 result
728 }
729
730 pub fn emit_session_before_tree_event(
732 &self,
733 event: &SessionBeforeTreeEvent,
734 ) -> SessionBeforeEmitResult {
735 let mut result = SessionBeforeEmitResult::default();
736 for name in &self.order {
737 if self.state(name) != ExtensionState::Active {
738 continue;
739 }
740 if let Some(ext) = self.registry.get(name) {
741 if let Err(e) = ext.session_before_tree(event) {
742 result.cancelled = true;
743 result.cancelled_by = Some(name.clone());
744 result.errors.push((name.clone(), e.to_string()));
745 return result;
746 }
747 }
748 }
749 result
750 }
751
752 pub fn emit_session_shutdown_event(&self, event: &SessionShutdownEvent) -> bool {
754 if !self.has_enabled_extensions() {
755 return false;
756 }
757 self.registry.emit_session_shutdown(event);
758 true
759 }
760
761 pub fn emit_event(&self, event: &AgentEvent) {
763 self.registry.emit_event(event);
764 }
765
766 pub fn registry(&self) -> &ExtensionRegistry {
768 &self.registry
769 }
770 pub fn registry_mut(&mut self) -> &mut ExtensionRegistry {
772 &mut self.registry
773 }
774 pub fn get(&self, name: &str) -> Option<Arc<dyn Extension>> {
776 self.registry.get(name)
777 }
778 pub fn names(&self) -> impl Iterator<Item = &str> {
780 self.order.iter().map(|s| s.as_str())
781 }
782 pub fn len(&self) -> usize {
784 self.order.len()
785 }
786 pub fn is_empty(&self) -> bool {
788 self.order.is_empty()
789 }
790 pub fn errors(&self) -> Vec<ExtensionErrorRecord> {
792 self.registry.errors()
793 }
794 pub fn clear_errors(&self) {
796 self.registry.clear_errors();
797 }
798}
799
800impl fmt::Debug for ExtensionRunner {
801 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
802 f.debug_struct("ExtensionRunner")
803 .field("cwd", &self.cwd)
804 .field("extensions", &self.order)
805 .finish()
806 }
807}
808
809pub struct ExtensionErrorHandle {
811 listener: Option<Arc<ExtensionErrorListener>>,
812}
813impl ExtensionErrorHandle {
814 pub fn unregister(&mut self) -> Option<Arc<ExtensionErrorListener>> {
816 self.listener.take()
817 }
818}
819impl Drop for ExtensionErrorHandle {
820 fn drop(&mut self) {}
821}
822
823struct RunnerState {
824 errors: Arc<RwLock<Vec<ExtensionErrorRecord>>>,
825 error_listeners: Vec<Arc<ExtensionErrorListener>>,
826}
827struct WrappedTool {
828 inner: ToolArc,
829 runner_state: Arc<RwLock<RunnerState>>,
830}
831
832#[async_trait::async_trait]
833impl oxi_agent::AgentTool for WrappedTool {
834 fn name(&self) -> &str {
835 self.inner.name()
836 }
837 fn label(&self) -> &str {
838 self.inner.label()
839 }
840 fn description(&self) -> &str {
841 self.inner.description()
842 }
843 fn parameters_schema(&self) -> Value {
844 self.inner.parameters_schema()
845 }
846 async fn execute(
847 &self,
848 tool_call_id: &str,
849 params: Value,
850 signal: Option<tokio::sync::oneshot::Receiver<()>>,
851 ctx: &oxi_agent::ToolContext,
852 ) -> Result<AgentToolResult, String> {
853 self.inner.execute(tool_call_id, params, signal, ctx).await
854 }
855}