1use std::io;
7use std::sync::Arc;
8
9use tokio::runtime::Runtime;
10use tokio::sync::mpsc;
11use tokio_util::sync::CancellationToken;
12
13use crate::controller::{
14 ControllerEvent, ControllerInputPayload, Executable, LLMController, LLMSessionConfig, LLMTool,
15 ListSkillsTool, PermissionRegistry, ToolRegistry, UserInteractionRegistry,
16};
17use crate::skills::{SkillDiscovery, SkillDiscoveryError, SkillRegistry, SkillReloadResult};
18
19use super::config::{AgentConfig, LLMRegistry, load_config};
20use super::error::AgentError;
21use super::logger::Logger;
22use super::messages::UiMessage;
23use super::messages::channels::DEFAULT_CHANNEL_SIZE;
24use super::router::InputRouter;
25
26pub type ToControllerTx = mpsc::Sender<ControllerInputPayload>;
28pub type ToControllerRx = mpsc::Receiver<ControllerInputPayload>;
30pub type FromControllerTx = mpsc::Sender<UiMessage>;
32pub type FromControllerRx = mpsc::Receiver<UiMessage>;
34
35pub struct AgentAir {
78 #[allow(dead_code)]
81 logger: Logger,
82
83 name: String,
85
86 version: String,
88
89 runtime: Runtime,
91
92 controller: Arc<LLMController>,
94
95 llm_registry: Option<LLMRegistry>,
97
98 to_controller_tx: ToControllerTx,
100
101 to_controller_rx: Option<ToControllerRx>,
103
104 from_controller_tx: FromControllerTx,
106
107 from_controller_rx: Option<FromControllerRx>,
109
110 cancel_token: CancellationToken,
112
113 user_interaction_registry: Arc<UserInteractionRegistry>,
115
116 permission_registry: Arc<PermissionRegistry>,
118
119 tool_definitions: Vec<LLMTool>,
121
122 error_no_session: Option<String>,
124
125 skill_registry: Arc<SkillRegistry>,
127
128 skill_discovery: SkillDiscovery,
130}
131
132impl AgentAir {
133 pub fn new<C: AgentConfig>(config: &C) -> io::Result<Self> {
143 let logger = Logger::new(config.log_prefix())?;
144 tracing::info!("{} agent initialized", config.name());
145
146 let llm_registry = load_config(config);
148 if llm_registry.is_empty() {
149 tracing::warn!(
150 "No LLM providers configured. Set ANTHROPIC_API_KEY or create ~/{}",
151 config.config_path()
152 );
153 } else {
154 tracing::info!(
155 "Loaded {} LLM provider(s): {:?}",
156 llm_registry.providers().len(),
157 llm_registry.providers()
158 );
159 }
160
161 let runtime = Runtime::new()
163 .map_err(|e| io::Error::other(format!("Failed to create runtime: {}", e)))?;
164
165 let channel_size = config.channel_buffer_size().unwrap_or(DEFAULT_CHANNEL_SIZE);
167 tracing::debug!("Using channel buffer size: {}", channel_size);
168
169 let (to_controller_tx, to_controller_rx) =
171 mpsc::channel::<ControllerInputPayload>(channel_size);
172 let (from_controller_tx, from_controller_rx) = mpsc::channel::<UiMessage>(channel_size);
173
174 let (interaction_event_tx, mut interaction_event_rx) =
176 mpsc::channel::<ControllerEvent>(channel_size);
177
178 let user_interaction_registry =
180 Arc::new(UserInteractionRegistry::new(interaction_event_tx));
181
182 let ui_tx_for_interactions = from_controller_tx.clone();
185 runtime.spawn(async move {
186 while let Some(event) = interaction_event_rx.recv().await {
187 let msg = convert_controller_event_to_ui_message(event);
188 if let Err(e) = ui_tx_for_interactions.send(msg).await {
189 tracing::warn!("Failed to send user interaction event to UI: {}", e);
190 }
191 }
192 });
193
194 let (permission_event_tx, mut permission_event_rx) =
196 mpsc::channel::<ControllerEvent>(channel_size);
197
198 let permission_registry = Arc::new(PermissionRegistry::new(permission_event_tx));
200
201 let ui_tx_for_permissions = from_controller_tx.clone();
204 runtime.spawn(async move {
205 while let Some(event) = permission_event_rx.recv().await {
206 let msg = convert_controller_event_to_ui_message(event);
207 if let Err(e) = ui_tx_for_permissions.send(msg).await {
208 tracing::warn!("Failed to send permission event to UI: {}", e);
209 }
210 }
211 });
212
213 let controller = Arc::new(LLMController::new(
218 permission_registry.clone(),
219 Some(from_controller_tx.clone()),
220 Some(channel_size),
221 ));
222 let cancel_token = CancellationToken::new();
223
224 Ok(Self {
225 logger,
226 name: config.name().to_string(),
227 version: "0.1.0".to_string(),
228 runtime,
229 controller,
230 llm_registry: Some(llm_registry),
231 to_controller_tx,
232 to_controller_rx: Some(to_controller_rx),
233 from_controller_tx,
234 from_controller_rx: Some(from_controller_rx),
235 cancel_token,
236 user_interaction_registry,
237 permission_registry,
238 tool_definitions: Vec::new(),
239 error_no_session: None,
240 skill_registry: Arc::new(SkillRegistry::new()),
241 skill_discovery: SkillDiscovery::new(),
242 })
243 }
244
245 pub fn with_config(
266 name: impl Into<String>,
267 config_path: impl Into<String>,
268 system_prompt: impl Into<String>,
269 ) -> io::Result<Self> {
270 let config = super::config::SimpleConfig::new(name, config_path, system_prompt);
271 Self::new(&config)
272 }
273
274 pub fn set_error_no_session(&mut self, message: impl Into<String>) -> &mut Self {
284 self.error_no_session = Some(message.into());
285 self
286 }
287
288 pub fn error_no_session(&self) -> Option<&str> {
290 self.error_no_session.as_deref()
291 }
292
293 pub fn set_version(&mut self, version: impl Into<String>) {
295 self.version = version.into();
296 }
297
298 pub fn version(&self) -> &str {
300 &self.version
301 }
302
303 pub fn load_environment_context(&mut self) -> &mut Self {
321 if let Some(registry) = self.llm_registry.take() {
322 self.llm_registry = Some(registry.with_environment_context());
323 tracing::info!("Environment context loaded into system prompt");
324 }
325 self
326 }
327
328 pub fn register_tools<F>(&mut self, f: F) -> Result<(), AgentError>
341 where
342 F: FnOnce(
343 &Arc<ToolRegistry>,
344 &Arc<UserInteractionRegistry>,
345 &Arc<PermissionRegistry>,
346 ) -> Result<Vec<LLMTool>, String>,
347 {
348 let tool_defs = f(
349 self.controller.tool_registry(),
350 &self.user_interaction_registry,
351 &self.permission_registry,
352 )
353 .map_err(AgentError::ToolRegistration)?;
354 self.tool_definitions = tool_defs;
355 Ok(())
356 }
357
358 pub fn register_tools_async<F, Fut>(&mut self, f: F) -> Result<(), AgentError>
371 where
372 F: FnOnce(Arc<ToolRegistry>, Arc<UserInteractionRegistry>, Arc<PermissionRegistry>) -> Fut,
373 Fut: std::future::Future<Output = Result<Vec<LLMTool>, String>>,
374 {
375 let tool_defs = self
376 .runtime
377 .block_on(f(
378 self.controller.tool_registry().clone(),
379 self.user_interaction_registry.clone(),
380 self.permission_registry.clone(),
381 ))
382 .map_err(AgentError::ToolRegistration)?;
383 self.tool_definitions = tool_defs;
384 Ok(())
385 }
386
387 pub fn start_background_tasks(&mut self) {
392 tracing::info!("{} starting background tasks", self.name);
393
394 let controller = self.controller.clone();
396 self.runtime.spawn(async move {
397 controller.start().await;
398 });
399 tracing::info!("Controller started");
400
401 if let Some(to_controller_rx) = self.to_controller_rx.take() {
403 let router = InputRouter::new(
404 self.controller.clone(),
405 to_controller_rx,
406 self.cancel_token.clone(),
407 );
408 self.runtime.spawn(async move {
409 router.run().await;
410 });
411 tracing::info!("InputRouter started");
412 }
413 }
414
415 async fn create_session_internal(
417 controller: &Arc<LLMController>,
418 mut config: LLMSessionConfig,
419 tools: &[LLMTool],
420 skill_registry: &Arc<SkillRegistry>,
421 ) -> Result<i64, crate::client::error::LlmError> {
422 let skills_xml = skill_registry.to_prompt_xml();
424 if !skills_xml.is_empty() {
425 config.system_prompt = Some(match config.system_prompt {
426 Some(prompt) => format!("{}\n\n{}", prompt, skills_xml),
427 None => skills_xml,
428 });
429 }
430
431 let id = controller.create_session(config).await?;
432
433 if !tools.is_empty()
435 && let Some(session) = controller.get_session(id).await
436 {
437 session.set_tools(tools.to_vec()).await;
438 }
439
440 Ok(id)
441 }
442
443 pub fn create_initial_session(&mut self) -> Result<(i64, String, i32), AgentError> {
447 let registry = self
448 .llm_registry
449 .as_ref()
450 .ok_or_else(|| AgentError::NoConfiguration("No LLM registry available".to_string()))?;
451
452 let config = registry.get_default().ok_or_else(|| {
453 AgentError::NoConfiguration("No default LLM provider configured".to_string())
454 })?;
455
456 let model = config.model.clone();
457 let context_limit = config.context_limit;
458
459 let controller = self.controller.clone();
460 let tool_definitions = self.tool_definitions.clone();
461 let skill_registry = self.skill_registry.clone();
462
463 let session_id = self.runtime.block_on(Self::create_session_internal(
464 &controller,
465 config.clone(),
466 &tool_definitions,
467 &skill_registry,
468 ))?;
469
470 tracing::info!(
471 session_id = session_id,
472 model = %model,
473 "Created initial session"
474 );
475
476 Ok((session_id, model, context_limit))
477 }
478
479 pub fn create_session(&self, config: LLMSessionConfig) -> Result<i64, AgentError> {
483 let controller = self.controller.clone();
484 let tool_definitions = self.tool_definitions.clone();
485 let skill_registry = self.skill_registry.clone();
486
487 self.runtime
488 .block_on(Self::create_session_internal(
489 &controller,
490 config,
491 &tool_definitions,
492 &skill_registry,
493 ))
494 .map_err(AgentError::from)
495 }
496
497 pub fn shutdown(&self) {
499 tracing::info!("{} shutting down", self.name);
500 self.cancel_token.cancel();
501
502 let controller = self.controller.clone();
503 self.runtime.block_on(async move {
504 controller.shutdown().await;
505 });
506
507 tracing::info!("{} shutdown complete", self.name);
508 }
509
510 pub fn run_with_frontend<E, I, P>(
551 &mut self,
552 event_sink: E,
553 mut input_source: I,
554 permission_policy: P,
555 ) -> io::Result<()>
556 where
557 E: super::interface::EventSink,
558 I: super::interface::InputSource,
559 P: super::interface::PermissionPolicy,
560 {
561 use super::interface::PolicyDecision;
562 use crate::permissions::{BatchPermissionResponse, PermissionPanelResponse};
563 use std::sync::Arc;
564
565 tracing::info!("{} starting with custom frontend", self.name);
566
567 let sink = Arc::new(event_sink);
569 let policy = Arc::new(permission_policy);
570
571 let controller = self.controller.clone();
574 self.runtime.spawn(async move {
575 controller.start().await;
576 });
577 tracing::info!("Controller started");
578
579 if let Some(mut from_controller_rx) = self.from_controller_rx.take() {
582 let sink_clone = sink.clone();
583 let policy_clone = policy.clone();
584 let permission_registry = self.permission_registry.clone();
585 let user_interaction_registry = self.user_interaction_registry.clone();
586
587 self.runtime.spawn(async move {
588 while let Some(event) = from_controller_rx.recv().await {
589 match &event {
591 UiMessage::PermissionRequired {
592 tool_use_id,
593 request,
594 ..
595 } => {
596 match policy_clone.decide(request) {
597 PolicyDecision::AskUser => {
598 }
600 decision => {
601 let response = match decision {
602 PolicyDecision::Allow => PermissionPanelResponse {
603 granted: true,
604 grant: None,
605 message: None,
606 },
607 PolicyDecision::AllowWithGrant(grant) => {
608 PermissionPanelResponse {
609 granted: true,
610 grant: Some(grant),
611 message: None,
612 }
613 }
614 PolicyDecision::Deny { reason } => {
615 PermissionPanelResponse {
616 granted: false,
617 grant: None,
618 message: reason,
619 }
620 }
621 PolicyDecision::AskUser => unreachable!(),
622 };
623 if let Err(e) = permission_registry
624 .respond_to_request(tool_use_id, response)
625 .await
626 {
627 tracing::warn!(
628 "Failed to respond to permission request: {}",
629 e
630 );
631 }
632 continue; }
634 }
635 }
636 UiMessage::BatchPermissionRequired { batch, .. } => {
637 let mut all_handled = true;
639 let mut approved_grants = Vec::new();
640 let mut denied_ids = Vec::new();
641
642 for request in &batch.requests {
643 match policy_clone.decide(request) {
644 PolicyDecision::Allow => {
645 }
647 PolicyDecision::AllowWithGrant(grant) => {
648 approved_grants.push(grant);
649 }
650 PolicyDecision::Deny { .. } => {
651 denied_ids.push(request.id.clone());
652 }
653 PolicyDecision::AskUser => {
654 all_handled = false;
655 break;
656 }
657 }
658 }
659
660 if all_handled {
661 let response = if denied_ids.is_empty() {
663 BatchPermissionResponse::all_granted(
664 &batch.batch_id,
665 approved_grants,
666 )
667 } else {
668 BatchPermissionResponse::all_denied(&batch.batch_id, denied_ids)
669 };
670 if let Err(e) = permission_registry
671 .respond_to_batch(&batch.batch_id, response)
672 .await
673 {
674 tracing::warn!(
675 "Failed to respond to batch permission request: {}",
676 e
677 );
678 }
679 continue; }
681 }
683 UiMessage::UserInteractionRequired { tool_use_id, .. } => {
684 if !policy_clone.supports_interaction() {
685 if let Err(e) = user_interaction_registry.cancel(tool_use_id).await
687 {
688 tracing::warn!("Failed to cancel user interaction: {}", e);
689 }
690 tracing::debug!("Auto-cancelled user interaction in headless mode");
691 continue; }
693 }
695 _ => {}
696 }
697
698 if let Err(e) = sink_clone.send(event) {
700 tracing::warn!("Failed to send event to sink: {}", e);
701 }
702 }
703 });
704 }
705
706 match self.create_initial_session() {
708 Ok((session_id, model, _)) => {
709 tracing::info!(session_id, model = %model, "Created initial session");
710 }
711 Err(e) => {
712 tracing::warn!(error = %e, "No initial session created");
713 }
714 }
715
716 let to_controller_tx = self.to_controller_tx.clone();
718 self.runtime.block_on(async {
719 while let Some(input) = input_source.recv().await {
720 if let Err(e) = to_controller_tx.send(input).await {
721 tracing::error!(error = %e, "Failed to send input to controller");
722 break;
723 }
724 }
725 });
726
727 self.shutdown();
729 tracing::info!("{} stopped", self.name);
730
731 Ok(())
732 }
733
734 pub fn to_controller_tx(&self) -> ToControllerTx {
738 self.to_controller_tx.clone()
739 }
740
741 pub fn take_from_controller_rx(&mut self) -> Option<FromControllerRx> {
743 self.from_controller_rx.take()
744 }
745
746 pub fn controller(&self) -> &Arc<LLMController> {
748 &self.controller
749 }
750
751 pub fn runtime(&self) -> &Runtime {
753 &self.runtime
754 }
755
756 pub fn runtime_handle(&self) -> tokio::runtime::Handle {
758 self.runtime.handle().clone()
759 }
760
761 pub fn user_interaction_registry(&self) -> &Arc<UserInteractionRegistry> {
763 &self.user_interaction_registry
764 }
765
766 pub fn permission_registry(&self) -> &Arc<PermissionRegistry> {
768 &self.permission_registry
769 }
770
771 pub async fn remove_session(&self, session_id: i64) -> bool {
785 let removed = self.controller.remove_session(session_id).await;
787
788 self.permission_registry.cancel_session(session_id).await;
790
791 self.user_interaction_registry
793 .cancel_session(session_id)
794 .await;
795
796 self.controller
798 .tool_registry()
799 .cleanup_session(session_id)
800 .await;
801
802 if removed {
803 tracing::info!(session_id, "Session removed with full cleanup");
804 }
805
806 removed
807 }
808
809 pub fn llm_registry(&self) -> Option<&LLMRegistry> {
811 self.llm_registry.as_ref()
812 }
813
814 pub fn take_llm_registry(&mut self) -> Option<LLMRegistry> {
816 self.llm_registry.take()
817 }
818
819 pub fn cancel_token(&self) -> CancellationToken {
821 self.cancel_token.clone()
822 }
823
824 pub fn name(&self) -> &str {
826 &self.name
827 }
828
829 pub fn from_controller_tx(&self) -> FromControllerTx {
833 self.from_controller_tx.clone()
834 }
835
836 pub fn tool_definitions(&self) -> &[LLMTool] {
838 &self.tool_definitions
839 }
840
841 pub fn skill_registry(&self) -> &Arc<SkillRegistry> {
845 &self.skill_registry
846 }
847
848 pub fn register_list_skills_tool(&mut self) -> Result<LLMTool, AgentError> {
856 let tool = ListSkillsTool::new(self.skill_registry.clone());
857 let llm_tool = tool.to_llm_tool();
858
859 self.runtime
860 .block_on(async {
861 self.controller
862 .tool_registry()
863 .register(Arc::new(tool))
864 .await
865 })
866 .map_err(|e| AgentError::ToolRegistration(e.to_string()))?;
867
868 self.tool_definitions.push(llm_tool.clone());
869 tracing::info!("Registered list_skills tool");
870
871 Ok(llm_tool)
872 }
873
874 pub fn add_skill_path(&mut self, path: std::path::PathBuf) -> &mut Self {
879 self.skill_discovery.add_path(path);
880 self
881 }
882
883 pub fn load_skills(&mut self) -> (usize, Vec<SkillDiscoveryError>) {
890 let results = self.skill_discovery.discover();
891 self.register_discovered_skills(results)
892 }
893
894 pub fn load_skills_from(
903 &self,
904 paths: Vec<std::path::PathBuf>,
905 ) -> (usize, Vec<SkillDiscoveryError>) {
906 let mut discovery = SkillDiscovery::empty();
907 for path in paths {
908 discovery.add_path(path);
909 }
910
911 let results = discovery.discover();
912 self.register_discovered_skills(results)
913 }
914
915 fn register_discovered_skills(
919 &self,
920 results: Vec<Result<crate::skills::Skill, SkillDiscoveryError>>,
921 ) -> (usize, Vec<SkillDiscoveryError>) {
922 let mut errors = Vec::new();
923 let mut count = 0;
924
925 for result in results {
926 match result {
927 Ok(skill) => {
928 let skill_name = skill.metadata.name.clone();
929 let skill_path = skill.path.clone();
930 let replaced = self.skill_registry.register(skill);
931
932 if let Some(old_skill) = replaced {
933 tracing::warn!(
934 skill_name = %skill_name,
935 new_path = %skill_path.display(),
936 old_path = %old_skill.path.display(),
937 "Duplicate skill name detected - replaced existing skill"
938 );
939 }
940
941 tracing::info!(
942 skill_name = %skill_name,
943 skill_path = %skill_path.display(),
944 "Loaded skill"
945 );
946 count += 1;
947 }
948 Err(e) => {
949 tracing::warn!(
950 path = %e.path.display(),
951 error = %e.message,
952 "Failed to load skill"
953 );
954 errors.push(e);
955 }
956 }
957 }
958
959 tracing::info!("Loaded {} skill(s)", count);
960 (count, errors)
961 }
962
963 pub fn reload_skills(&mut self) -> SkillReloadResult {
972 let current_names: std::collections::HashSet<String> =
973 self.skill_registry.names().into_iter().collect();
974
975 let results = self.skill_discovery.discover();
976 let mut discovered_names = std::collections::HashSet::new();
977 let mut result = SkillReloadResult::default();
978
979 for discovery_result in results {
981 match discovery_result {
982 Ok(skill) => {
983 let name = skill.metadata.name.clone();
984 discovered_names.insert(name.clone());
985
986 if !current_names.contains(&name) {
987 tracing::info!(skill_name = %name, "Added new skill");
988 result.added.push(name);
989 }
990 self.skill_registry.register(skill);
991 }
992 Err(e) => {
993 tracing::warn!(
994 path = %e.path.display(),
995 error = %e.message,
996 "Failed to load skill during reload"
997 );
998 result.errors.push(e);
999 }
1000 }
1001 }
1002
1003 for name in ¤t_names {
1005 if !discovered_names.contains(name) {
1006 tracing::info!(skill_name = %name, "Removed skill");
1007 self.skill_registry.unregister(name);
1008 result.removed.push(name.clone());
1009 }
1010 }
1011
1012 tracing::info!(
1013 added = result.added.len(),
1014 removed = result.removed.len(),
1015 errors = result.errors.len(),
1016 "Skills reloaded"
1017 );
1018
1019 result
1020 }
1021
1022 pub fn skills_prompt_xml(&self) -> String {
1027 self.skill_registry.to_prompt_xml()
1028 }
1029
1030 pub async fn refresh_session_skills(&self, session_id: i64) -> Result<(), AgentError> {
1038 let skills_xml = self.skills_prompt_xml();
1039 if skills_xml.is_empty() {
1040 return Ok(());
1041 }
1042
1043 let session = self
1044 .controller
1045 .get_session(session_id)
1046 .await
1047 .ok_or(AgentError::SessionNotFound(session_id))?;
1048
1049 let current_prompt = session.system_prompt().await.unwrap_or_default();
1050
1051 let new_prompt = if current_prompt.contains("<available_skills>") {
1053 replace_skills_section(¤t_prompt, &skills_xml)
1055 } else if current_prompt.is_empty() {
1056 skills_xml
1058 } else {
1059 format!("{}\n\n{}", current_prompt, skills_xml)
1061 };
1062
1063 session.set_system_prompt(new_prompt).await;
1064 tracing::debug!(session_id, "Refreshed session skills");
1065 Ok(())
1066 }
1067}
1068
1069fn replace_skills_section(prompt: &str, new_skills_xml: &str) -> String {
1071 if let Some(start) = prompt.find("<available_skills>")
1072 && let Some(end) = prompt.find("</available_skills>")
1073 {
1074 let end = end + "</available_skills>".len();
1075 let mut result = String::with_capacity(prompt.len());
1076 result.push_str(&prompt[..start]);
1077 result.push_str(new_skills_xml);
1078 result.push_str(&prompt[end..]);
1079 return result;
1080 }
1081 format!("{}\n\n{}", prompt, new_skills_xml)
1083}
1084
1085pub fn convert_controller_event_to_ui_message(event: ControllerEvent) -> UiMessage {
1102 match event {
1103 ControllerEvent::StreamStart { session_id, .. } => {
1104 UiMessage::System {
1106 session_id,
1107 message: String::new(),
1108 }
1109 }
1110 ControllerEvent::TextChunk {
1111 session_id,
1112 text,
1113 turn_id,
1114 } => UiMessage::TextChunk {
1115 session_id,
1116 turn_id,
1117 text,
1118 input_tokens: 0,
1119 output_tokens: 0,
1120 },
1121 ControllerEvent::ToolUseStart {
1122 session_id,
1123 tool_name,
1124 turn_id,
1125 ..
1126 } => UiMessage::Display {
1127 session_id,
1128 turn_id,
1129 message: format!("Executing tool: {}", tool_name),
1130 },
1131 ControllerEvent::ToolUse {
1132 session_id,
1133 tool,
1134 display_name,
1135 display_title,
1136 turn_id,
1137 } => UiMessage::ToolExecuting {
1138 session_id,
1139 turn_id,
1140 tool_use_id: tool.id.clone(),
1141 display_name: display_name.unwrap_or_else(|| tool.name.clone()),
1142 display_title: display_title.unwrap_or_default(),
1143 },
1144 ControllerEvent::Complete {
1145 session_id,
1146 turn_id,
1147 stop_reason,
1148 } => UiMessage::Complete {
1149 session_id,
1150 turn_id,
1151 input_tokens: 0,
1152 output_tokens: 0,
1153 stop_reason,
1154 },
1155 ControllerEvent::Error {
1156 session_id,
1157 error,
1158 turn_id,
1159 } => UiMessage::Error {
1160 session_id,
1161 turn_id,
1162 error,
1163 },
1164 ControllerEvent::TokenUpdate {
1165 session_id,
1166 input_tokens,
1167 output_tokens,
1168 context_limit,
1169 } => UiMessage::TokenUpdate {
1170 session_id,
1171 turn_id: None,
1172 input_tokens,
1173 output_tokens,
1174 context_limit,
1175 },
1176 ControllerEvent::ToolResult {
1177 session_id,
1178 tool_use_id,
1179 status,
1180 error,
1181 turn_id,
1182 ..
1183 } => UiMessage::ToolCompleted {
1184 session_id,
1185 turn_id,
1186 tool_use_id,
1187 status,
1188 error,
1189 },
1190 ControllerEvent::CommandComplete {
1191 session_id,
1192 command,
1193 success,
1194 message,
1195 } => UiMessage::CommandComplete {
1196 session_id,
1197 command,
1198 success,
1199 message,
1200 },
1201 ControllerEvent::UserInteractionRequired {
1202 session_id,
1203 tool_use_id,
1204 request,
1205 turn_id,
1206 } => UiMessage::UserInteractionRequired {
1207 session_id,
1208 tool_use_id,
1209 request,
1210 turn_id,
1211 },
1212 ControllerEvent::PermissionRequired {
1213 session_id,
1214 tool_use_id,
1215 request,
1216 turn_id,
1217 } => UiMessage::PermissionRequired {
1218 session_id,
1219 tool_use_id,
1220 request,
1221 turn_id,
1222 },
1223 ControllerEvent::BatchPermissionRequired {
1224 session_id,
1225 batch,
1226 turn_id,
1227 } => UiMessage::BatchPermissionRequired {
1228 session_id,
1229 batch,
1230 turn_id,
1231 },
1232 }
1233}
1234
1235#[cfg(test)]
1236mod tests {
1237 use super::*;
1238 use crate::controller::TurnId;
1239
1240 #[test]
1241 fn test_convert_text_chunk_event() {
1242 let event = ControllerEvent::TextChunk {
1243 session_id: 1,
1244 text: "Hello".to_string(),
1245 turn_id: Some(TurnId::new_user_turn(1)),
1246 };
1247
1248 let msg = convert_controller_event_to_ui_message(event);
1249
1250 match msg {
1251 UiMessage::TextChunk {
1252 session_id, text, ..
1253 } => {
1254 assert_eq!(session_id, 1);
1255 assert_eq!(text, "Hello");
1256 }
1257 _ => panic!("Expected TextChunk message"),
1258 }
1259 }
1260
1261 #[test]
1262 fn test_convert_error_event() {
1263 let event = ControllerEvent::Error {
1264 session_id: 1,
1265 error: "Test error".to_string(),
1266 turn_id: None,
1267 };
1268
1269 let msg = convert_controller_event_to_ui_message(event);
1270
1271 match msg {
1272 UiMessage::Error {
1273 session_id, error, ..
1274 } => {
1275 assert_eq!(session_id, 1);
1276 assert_eq!(error, "Test error");
1277 }
1278 _ => panic!("Expected Error message"),
1279 }
1280 }
1281
1282 #[test]
1283 fn test_replace_skills_section_replaces_existing() {
1284 let prompt = "System prompt.\n\n<available_skills>\n <skill>old</skill>\n</available_skills>\n\nMore text.";
1285 let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
1286
1287 let result = replace_skills_section(prompt, new_xml);
1288
1289 assert!(result.contains("<skill>new</skill>"));
1290 assert!(!result.contains("<skill>old</skill>"));
1291 assert!(result.contains("System prompt."));
1292 assert!(result.contains("More text."));
1293 }
1294
1295 #[test]
1296 fn test_replace_skills_section_no_existing() {
1297 let prompt = "System prompt without skills.";
1298 let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
1299
1300 let result = replace_skills_section(prompt, new_xml);
1301
1302 assert!(result.contains("System prompt without skills."));
1304 assert!(result.contains("<skill>new</skill>"));
1305 }
1306
1307 #[test]
1308 fn test_replace_skills_section_malformed_no_closing_tag() {
1309 let prompt =
1310 "System prompt.\n\n<available_skills>\n <skill>old</skill>\n\nNo closing tag.";
1311 let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
1312
1313 let result = replace_skills_section(prompt, new_xml);
1314
1315 assert!(result.contains("<skill>old</skill>"));
1317 assert!(result.contains("<skill>new</skill>"));
1318 }
1319
1320 #[test]
1321 fn test_replace_skills_section_at_end() {
1322 let prompt =
1323 "System prompt.\n\n<available_skills>\n <skill>old</skill>\n</available_skills>";
1324 let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
1325
1326 let result = replace_skills_section(prompt, new_xml);
1327
1328 assert!(result.contains("<skill>new</skill>"));
1329 assert!(!result.contains("<skill>old</skill>"));
1330 assert!(result.starts_with("System prompt."));
1331 }
1332}