1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::Duration;
6
7use tokio::sync::{mpsc, Mutex};
8use tokio_util::sync::CancellationToken;
9
10use std::sync::Arc;
11
12use crate::client::error::LlmError;
13use crate::controller::session::{LLMSession, LLMSessionConfig, LLMSessionManager};
14use crate::controller::tools::{
15 AskUserQuestionsResponse, PendingPermissionInfo, PendingQuestionInfo, PermissionError,
16 PermissionRegistry, PermissionResponse, ToolBatchResult, ToolExecutor, ToolRegistry,
17 ToolRequest, ToolResult, UserInteractionError, UserInteractionRegistry,
18};
19use crate::controller::error::ControllerError;
20use crate::controller::types::{
21 ControlCmd, ControllerEvent, ControllerInputPayload, FromLLMPayload, InputType,
22 LLMRequestType, LLMResponseType, ToLLMPayload, TurnId,
23};
24use crate::controller::usage::TokenUsageTracker;
25
26pub type EventFunc = Box<dyn Fn(ControllerEvent) + Send + Sync>;
28
29pub const DEFAULT_CHANNEL_SIZE: usize = 100;
31
32const SEND_INPUT_TIMEOUT: Duration = Duration::from_secs(5);
34
35pub struct LLMController {
37 session_mgr: LLMSessionManager,
39
40 token_usage: TokenUsageTracker,
42
43 from_llm_rx: Mutex<mpsc::Receiver<FromLLMPayload>>,
45
46 from_llm_tx: mpsc::Sender<FromLLMPayload>,
48
49 input_rx: Mutex<mpsc::Receiver<ControllerInputPayload>>,
51
52 input_tx: mpsc::Sender<ControllerInputPayload>,
54
55 started: AtomicBool,
57
58 shutdown: AtomicBool,
60
61 cancel_token: CancellationToken,
63
64 event_func: Option<EventFunc>,
66
67 tool_registry: Arc<ToolRegistry>,
69
70 tool_executor: ToolExecutor,
72
73 tool_result_rx: Mutex<mpsc::Receiver<ToolResult>>,
75
76 batch_result_rx: Mutex<mpsc::Receiver<ToolBatchResult>>,
78
79 user_interaction_registry: Arc<UserInteractionRegistry>,
81
82 user_interaction_rx: Mutex<mpsc::Receiver<ControllerEvent>>,
84
85 permission_registry: Arc<PermissionRegistry>,
87
88 permission_rx: Mutex<mpsc::Receiver<ControllerEvent>>,
90}
91
92impl LLMController {
93 pub fn new(event_func: Option<EventFunc>) -> Self {
98 let (from_llm_tx, from_llm_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
99 let (input_tx, input_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
100
101 let (tool_result_tx, tool_result_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
103 let (batch_result_tx, batch_result_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
104
105 let (user_interaction_tx, user_interaction_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
107 let user_interaction_registry = Arc::new(UserInteractionRegistry::new(user_interaction_tx));
108
109 let (permission_tx, permission_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
111 let permission_registry = Arc::new(PermissionRegistry::new(permission_tx));
112
113 let tool_registry = Arc::new(ToolRegistry::new());
114 let tool_executor = ToolExecutor::new(
115 tool_registry.clone(),
116 tool_result_tx,
117 batch_result_tx,
118 );
119
120 Self {
121 session_mgr: LLMSessionManager::new(),
122 token_usage: TokenUsageTracker::new(),
123 from_llm_rx: Mutex::new(from_llm_rx),
124 from_llm_tx,
125 input_rx: Mutex::new(input_rx),
126 input_tx,
127 started: AtomicBool::new(false),
128 shutdown: AtomicBool::new(false),
129 cancel_token: CancellationToken::new(),
130 event_func,
131 tool_registry,
132 tool_executor,
133 tool_result_rx: Mutex::new(tool_result_rx),
134 batch_result_rx: Mutex::new(batch_result_rx),
135 user_interaction_registry,
136 user_interaction_rx: Mutex::new(user_interaction_rx),
137 permission_registry,
138 permission_rx: Mutex::new(permission_rx),
139 }
140 }
141
142 pub async fn start(&self) {
146 if self
148 .started
149 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
150 .is_err()
151 {
152 tracing::warn!("Controller already started");
153 return;
154 }
155
156 tracing::info!("Controller starting");
157
158 loop {
183 let mut from_llm_guard = self.from_llm_rx.lock().await;
184 let mut input_guard = self.input_rx.lock().await;
185 let mut batch_result_guard = self.batch_result_rx.lock().await;
186 let mut tool_result_guard = self.tool_result_rx.lock().await;
187 let mut user_interaction_guard = self.user_interaction_rx.lock().await;
188 let mut permission_guard = self.permission_rx.lock().await;
189
190 tokio::select! {
191 _ = self.cancel_token.cancelled() => {
192 tracing::info!("Controller cancelled");
193 break;
194 }
195 msg = from_llm_guard.recv() => {
196 drop(from_llm_guard);
197 drop(input_guard);
198 drop(batch_result_guard);
199 drop(tool_result_guard);
200 drop(user_interaction_guard);
201 drop(permission_guard);
202 if let Some(payload) = msg {
203 self.handle_llm_response(payload).await;
204 } else {
205 tracing::info!("FromLLM channel closed");
206 break;
207 }
208 }
209 msg = input_guard.recv() => {
210 drop(from_llm_guard);
211 drop(input_guard);
212 drop(batch_result_guard);
213 drop(tool_result_guard);
214 drop(user_interaction_guard);
215 drop(permission_guard);
216 if let Some(payload) = msg {
217 self.handle_input(payload).await;
218 } else {
219 tracing::info!("Input channel closed");
220 break;
221 }
222 }
223 batch_result = batch_result_guard.recv() => {
224 drop(from_llm_guard);
225 drop(input_guard);
226 drop(batch_result_guard);
227 drop(tool_result_guard);
228 drop(user_interaction_guard);
229 drop(permission_guard);
230 if let Some(result) = batch_result {
231 self.handle_tool_batch_result(result).await;
232 }
233 }
234 tool_result = tool_result_guard.recv() => {
235 drop(from_llm_guard);
236 drop(input_guard);
237 drop(batch_result_guard);
238 drop(tool_result_guard);
239 drop(user_interaction_guard);
240 drop(permission_guard);
241 if let Some(result) = tool_result {
242 if let Some(ref func) = self.event_func {
244 func(ControllerEvent::ToolResult {
245 session_id: result.session_id,
246 tool_use_id: result.tool_use_id,
247 tool_name: result.tool_name,
248 display_name: result.display_name,
249 status: result.status,
250 content: result.content,
251 error: result.error,
252 turn_id: result.turn_id,
253 });
254 }
255 }
256 }
257 user_interaction_event = user_interaction_guard.recv() => {
258 drop(from_llm_guard);
259 drop(input_guard);
260 drop(batch_result_guard);
261 drop(tool_result_guard);
262 drop(user_interaction_guard);
263 drop(permission_guard);
264 if let Some(event) = user_interaction_event {
265 if let Some(ref func) = self.event_func {
267 func(event);
268 }
269 }
270 }
271 permission_event = permission_guard.recv() => {
272 drop(from_llm_guard);
273 drop(input_guard);
274 drop(batch_result_guard);
275 drop(tool_result_guard);
276 drop(user_interaction_guard);
277 drop(permission_guard);
278 if let Some(event) = permission_event {
279 if let Some(ref func) = self.event_func {
281 func(event);
282 }
283 }
284 }
285 }
286 }
287
288 tracing::info!("Controller stopped");
289 }
290
291 async fn handle_llm_response(&self, payload: FromLLMPayload) {
293 if payload.response_type == LLMResponseType::TokenUpdate {
295 if let Some(session) = self.session_mgr.get_session_by_id(payload.session_id).await {
296 self.token_usage
297 .increment(
298 payload.session_id,
299 session.model(),
300 payload.input_tokens,
301 payload.output_tokens,
302 )
303 .await;
304 }
305 }
306
307 let event = match payload.response_type {
308 LLMResponseType::StreamStart => Some(ControllerEvent::StreamStart {
309 session_id: payload.session_id,
310 message_id: payload.message_id,
311 model: payload.model,
312 turn_id: payload.turn_id,
313 }),
314 LLMResponseType::TextChunk => Some(ControllerEvent::TextChunk {
315 session_id: payload.session_id,
316 text: payload.text,
317 turn_id: payload.turn_id,
318 }),
319 LLMResponseType::ToolUseStart => {
320 if let Some(tool) = payload.tool_use {
321 Some(ControllerEvent::ToolUseStart {
322 session_id: payload.session_id,
323 tool_id: tool.id,
324 tool_name: tool.name,
325 turn_id: payload.turn_id,
326 })
327 } else {
328 None
329 }
330 }
331 LLMResponseType::ToolInputDelta => {
332 None
335 }
336 LLMResponseType::ToolUse => {
337 if let Some(ref tool) = payload.tool_use {
338 let input: HashMap<String, serde_json::Value> = tool
340 .input
341 .as_object()
342 .map(|obj| {
343 obj.iter()
344 .map(|(k, v)| (k.clone(), v.clone()))
345 .collect()
346 })
347 .unwrap_or_default();
348
349 let (display_name, display_title) =
351 if let Some(t) = self.tool_registry().get(&tool.name).await {
352 let config = t.display_config();
353 (Some(config.display_name), Some((config.display_title)(&input)))
354 } else {
355 (None, None)
356 };
357
358 let request = ToolRequest {
359 tool_use_id: tool.id.clone(),
360 tool_name: tool.name.clone(),
361 input,
362 };
363
364 self.tool_executor
365 .execute(
366 payload.session_id,
367 payload.turn_id.clone(),
368 request,
369 self.cancel_token.clone(),
370 )
371 .await;
372
373 Some(ControllerEvent::ToolUse {
374 session_id: payload.session_id,
375 tool: payload.tool_use.unwrap(),
376 display_name,
377 display_title,
378 turn_id: payload.turn_id,
379 })
380 } else {
381 None
382 }
383 }
384 LLMResponseType::ToolBatch => {
385 if payload.tool_uses.is_empty() {
387 tracing::error!(
388 session_id = payload.session_id,
389 "Received tool batch response with empty tool_uses"
390 );
391 return;
392 }
393
394 tracing::debug!(
395 session_id = payload.session_id,
396 tool_count = payload.tool_uses.len(),
397 "LLM requested tool batch execution"
398 );
399
400 let mut requests = Vec::with_capacity(payload.tool_uses.len());
402 for tool_info in &payload.tool_uses {
403 let input: HashMap<String, serde_json::Value> = tool_info
404 .input
405 .as_object()
406 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
407 .unwrap_or_default();
408
409 requests.push(ToolRequest {
410 tool_use_id: tool_info.id.clone(),
411 tool_name: tool_info.name.clone(),
412 input: input.clone(),
413 });
414
415 let (display_name, display_title) =
417 if let Some(tool) = self.tool_registry().get(&tool_info.name).await {
418 let config = tool.display_config();
419 (Some(config.display_name), Some((config.display_title)(&input)))
420 } else {
421 (None, None)
422 };
423
424 if let Some(ref func) = self.event_func {
426 func(ControllerEvent::ToolUse {
427 session_id: payload.session_id,
428 tool: tool_info.clone(),
429 display_name,
430 display_title,
431 turn_id: payload.turn_id.clone(),
432 });
433 }
434 }
435
436 self.tool_executor
438 .execute_batch(
439 payload.session_id,
440 payload.turn_id.clone(),
441 requests,
442 self.cancel_token.clone(),
443 )
444 .await;
445
446 None
447 }
448 LLMResponseType::Complete => Some(ControllerEvent::Complete {
449 session_id: payload.session_id,
450 stop_reason: payload.stop_reason,
451 turn_id: payload.turn_id,
452 }),
453 LLMResponseType::Error => Some(ControllerEvent::Error {
454 session_id: payload.session_id,
455 error: payload.error.unwrap_or_else(|| "Unknown error".to_string()),
456 turn_id: payload.turn_id,
457 }),
458 LLMResponseType::TokenUpdate => {
459 let context_limit = if let Some(session) =
461 self.session_mgr.get_session_by_id(payload.session_id).await
462 {
463 session.context_limit()
464 } else {
465 0
466 };
467 Some(ControllerEvent::TokenUpdate {
468 session_id: payload.session_id,
469 input_tokens: payload.input_tokens,
470 output_tokens: payload.output_tokens,
471 context_limit,
472 })
473 }
474 };
475
476 if let (Some(event), Some(func)) = (event, &self.event_func) {
478 func(event);
479 }
480 }
481
482 async fn handle_input(&self, payload: ControllerInputPayload) {
484 match payload.input_type {
485 InputType::Data => {
486 self.handle_data_input(payload).await;
487 }
488 InputType::Control => {
489 self.handle_control_input(payload).await;
490 }
491 }
492 }
493
494 async fn handle_data_input(&self, payload: ControllerInputPayload) {
496 let session_id = payload.session_id;
497
498 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
500 tracing::error!(session_id, "Session not found for data input");
501 self.emit_error(session_id, "Session not found".to_string(), payload.turn_id);
502 return;
503 };
504
505 let llm_payload = ToLLMPayload {
507 request_type: LLMRequestType::UserMessage,
508 content: payload.content,
509 tool_results: Vec::new(),
510 options: None,
511 turn_id: payload.turn_id,
512 compact_summaries: HashMap::new(),
513 };
514
515 let sent = session.send(llm_payload).await;
517 if !sent {
518 tracing::error!(session_id, "Failed to send message to session");
519 self.emit_error(
520 session_id,
521 "Failed to send message to session".to_string(),
522 None,
523 );
524 }
525 }
526
527 async fn handle_control_input(&self, payload: ControllerInputPayload) {
529 let session_id = payload.session_id;
530
531 let Some(cmd) = payload.control_cmd else {
532 tracing::warn!(session_id, "Control input without command");
533 return;
534 };
535
536 match cmd {
537 ControlCmd::Interrupt => {
538 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
540 session.interrupt().await;
541 tracing::info!(session_id, "Session interrupted");
542 } else {
543 tracing::warn!(session_id, "Cannot interrupt: session not found");
544 }
545 }
546 ControlCmd::Shutdown => {
547 tracing::info!("Shutdown command received");
549 self.shutdown().await;
550 }
551 ControlCmd::Clear => {
552 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
554 session.clear_conversation().await;
555 tracing::info!(session_id, "Session conversation cleared");
556 self.emit_command_complete(session_id, cmd, true, None);
557 } else {
558 tracing::warn!(session_id, "Cannot clear: session not found");
559 self.emit_command_complete(
560 session_id,
561 cmd,
562 false,
563 Some("Session not found".to_string()),
564 );
565 }
566 }
567 ControlCmd::Compact => {
568 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
570 let result = session.force_compact().await;
571
572 if let Some(error) = result.error {
573 tracing::warn!(session_id, error = %error, "Session compaction failed");
575 self.emit_command_complete(session_id, cmd, false, Some(error));
576 } else if !result.compacted {
577 tracing::info!(session_id, "Nothing to compact");
579 self.emit_command_complete(
580 session_id,
581 cmd,
582 true,
583 Some("Nothing to compact - not enough turns in conversation".to_string()),
584 );
585 } else {
586 let message = format!(
588 "Conversation compacted\n Turns summarized: {}\n Turns kept: {}\n Messages: {} -> {}{}",
589 result.turns_compacted,
590 result.turns_kept,
591 result.messages_before,
592 result.messages_after,
593 if result.summary_length > 0 {
594 format!("\n Summary length: {} chars", result.summary_length)
595 } else {
596 String::new()
597 }
598 );
599 tracing::info!(
600 session_id,
601 turns_compacted = result.turns_compacted,
602 messages_before = result.messages_before,
603 messages_after = result.messages_after,
604 "Session compaction completed"
605 );
606 self.emit_command_complete(session_id, cmd, true, Some(message));
607 }
608 } else {
609 tracing::warn!(session_id, "Cannot compact: session not found");
610 self.emit_command_complete(
611 session_id,
612 cmd,
613 false,
614 Some("Session not found".to_string()),
615 );
616 }
617 }
618 }
619 }
620
621 fn emit_error(&self, session_id: i64, error: String, turn_id: Option<TurnId>) {
623 if let Some(ref func) = self.event_func {
624 func(ControllerEvent::Error {
625 session_id,
626 error,
627 turn_id,
628 });
629 }
630 }
631
632 fn emit_command_complete(
634 &self,
635 session_id: i64,
636 command: ControlCmd,
637 success: bool,
638 message: Option<String>,
639 ) {
640 if let Some(ref func) = self.event_func {
641 func(ControllerEvent::CommandComplete {
642 session_id,
643 command,
644 success,
645 message,
646 });
647 }
648 }
649
650 async fn handle_tool_batch_result(&self, batch_result: ToolBatchResult) {
652 use crate::controller::types::ToolResultInfo;
653
654 let session_id = batch_result.session_id;
655
656 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
657 tracing::error!(session_id, "Session not found for tool result");
658 return;
659 };
660
661 let mut compact_summaries = HashMap::new();
663 let tool_results: Vec<ToolResultInfo> = batch_result
664 .results
665 .iter()
666 .map(|result| {
667 let (content, is_error) = if let Some(ref error) = result.error {
668 (error.clone(), true)
669 } else {
670 (result.content.clone(), false)
671 };
672
673 if let Some(ref summary) = result.compact_summary {
675 compact_summaries.insert(result.tool_use_id.clone(), summary.clone());
676 tracing::debug!(
677 tool_use_id = %result.tool_use_id,
678 summary_len = summary.len(),
679 "Extracted compact summary for tool result"
680 );
681 }
682
683 ToolResultInfo {
684 tool_use_id: result.tool_use_id.clone(),
685 content,
686 is_error,
687 }
688 })
689 .collect();
690
691 tracing::info!(
692 session_id,
693 tool_count = tool_results.len(),
694 compact_summary_count = compact_summaries.len(),
695 "Sending tool results to session with compact summaries"
696 );
697
698 let llm_payload = ToLLMPayload {
700 request_type: LLMRequestType::ToolResult,
701 content: String::new(),
702 tool_results,
703 options: None,
704 turn_id: batch_result.turn_id,
705 compact_summaries,
706 };
707
708 let sent = session.send(llm_payload).await;
710 if !sent {
711 tracing::error!(session_id, "Failed to send tool result to session");
712 } else {
713 tracing::debug!(
714 session_id,
715 batch_id = batch_result.batch_id,
716 "Sent tool results to session"
717 );
718 }
719 }
720
721 pub async fn shutdown(&self) {
724 if self
725 .shutdown
726 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
727 .is_err()
728 {
729 return; }
731
732 tracing::info!("Controller shutting down");
733
734 self.session_mgr.shutdown().await;
736
737 self.cancel_token.cancel();
739
740 tracing::info!("Controller shutdown complete");
741 }
742
743 pub fn is_shutdown(&self) -> bool {
745 self.shutdown.load(Ordering::SeqCst)
746 }
747
748 pub fn is_started(&self) -> bool {
750 self.started.load(Ordering::SeqCst)
751 }
752
753 pub async fn create_session(&self, config: LLMSessionConfig) -> Result<i64, LlmError> {
766 let session_id = self
767 .session_mgr
768 .create_session(config, self.from_llm_tx.clone())
769 .await?;
770
771 tracing::info!(session_id, "Session created via controller");
772 Ok(session_id)
773 }
774
775 pub async fn get_session(&self, session_id: i64) -> Option<Arc<LLMSession>> {
780 self.session_mgr.get_session_by_id(session_id).await
781 }
782
783 pub async fn session_count(&self) -> usize {
785 self.session_mgr.session_count().await
786 }
787
788 pub async fn send_input(&self, input: ControllerInputPayload) -> Result<(), ControllerError> {
798 if self.is_shutdown() {
799 return Err(ControllerError::Shutdown);
800 }
801
802 match tokio::time::timeout(SEND_INPUT_TIMEOUT, self.input_tx.send(input)).await {
803 Ok(Ok(())) => Ok(()),
804 Ok(Err(_)) => Err(ControllerError::ChannelClosed),
805 Err(_) => Err(ControllerError::SendTimeout(SEND_INPUT_TIMEOUT.as_secs())),
806 }
807 }
808
809 pub async fn get_session_token_usage(
813 &self,
814 session_id: i64,
815 ) -> Option<crate::controller::usage::TokenMeter> {
816 self.token_usage.get_session_usage(session_id).await
817 }
818
819 pub async fn get_model_token_usage(&self, model: &str) -> Option<crate::controller::usage::TokenMeter> {
821 self.token_usage.get_model_usage(model).await
822 }
823
824 pub async fn get_total_token_usage(&self) -> crate::controller::usage::TokenMeter {
826 self.token_usage.get_total_usage().await
827 }
828
829 pub fn token_usage(&self) -> &TokenUsageTracker {
831 &self.token_usage
832 }
833
834 pub fn tool_registry(&self) -> &Arc<ToolRegistry> {
838 &self.tool_registry
839 }
840
841 pub fn user_interaction_registry(&self) -> &Arc<UserInteractionRegistry> {
845 &self.user_interaction_registry
846 }
847
848 pub async fn respond_to_interaction(
856 &self,
857 tool_use_id: &str,
858 response: AskUserQuestionsResponse,
859 ) -> Result<(), UserInteractionError> {
860 self.user_interaction_registry
861 .respond(tool_use_id, response)
862 .await
863 }
864
865 pub async fn pending_interactions_for_session(
870 &self,
871 session_id: i64,
872 ) -> Vec<PendingQuestionInfo> {
873 self.user_interaction_registry
874 .pending_for_session(session_id)
875 .await
876 }
877
878 pub async fn has_pending_interactions(&self, session_id: i64) -> bool {
880 self.user_interaction_registry.has_pending(session_id).await
881 }
882
883 pub fn permission_registry(&self) -> &Arc<PermissionRegistry> {
887 &self.permission_registry
888 }
889
890 pub async fn respond_to_permission(
898 &self,
899 tool_use_id: &str,
900 response: PermissionResponse,
901 ) -> Result<(), PermissionError> {
902 self.permission_registry
903 .respond(tool_use_id, response)
904 .await
905 }
906
907 pub async fn pending_permissions_for_session(
912 &self,
913 session_id: i64,
914 ) -> Vec<PendingPermissionInfo> {
915 self.permission_registry
916 .pending_for_session(session_id)
917 .await
918 }
919
920 pub async fn has_pending_permissions(&self, session_id: i64) -> bool {
922 self.permission_registry.has_pending(session_id).await
923 }
924
925 pub async fn cancel_permission(&self, tool_use_id: &str) -> Result<(), PermissionError> {
930 self.permission_registry.cancel(tool_use_id).await
931 }
932}