1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::Duration;
6
7use tokio::sync::{Mutex, mpsc};
8use tokio_util::sync::CancellationToken;
9
10use std::sync::Arc;
11
12use crate::agent::{UiMessage, convert_controller_event_to_ui_message};
13use crate::client::error::LlmError;
14use crate::controller::error::ControllerError;
15use crate::controller::session::{LLMSession, LLMSessionConfig, LLMSessionManager};
16use crate::controller::tools::{
17 ToolBatchResult, ToolExecutor, ToolRegistry, ToolRequest, ToolResult,
18};
19use crate::controller::types::{
20 ControlCmd, ControllerEvent, ControllerInputPayload, FromLLMPayload, InputType, LLMRequestType,
21 LLMResponseType, ToLLMPayload, TurnId,
22};
23use crate::controller::usage::TokenUsageTracker;
24use crate::permissions::PermissionRegistry;
25
26pub const DEFAULT_CHANNEL_SIZE: usize = 500;
30
31const SEND_INPUT_TIMEOUT: Duration = Duration::from_secs(5);
33
34pub struct LLMController {
36 session_mgr: LLMSessionManager,
38
39 token_usage: TokenUsageTracker,
41
42 from_llm_rx: Mutex<mpsc::Receiver<FromLLMPayload>>,
44
45 from_llm_tx: mpsc::Sender<FromLLMPayload>,
47
48 input_rx: Mutex<mpsc::Receiver<ControllerInputPayload>>,
50
51 input_tx: mpsc::Sender<ControllerInputPayload>,
53
54 started: AtomicBool,
56
57 shutdown: AtomicBool,
59
60 cancel_token: CancellationToken,
62
63 ui_tx: Option<mpsc::Sender<UiMessage>>,
67
68 tool_registry: Arc<ToolRegistry>,
70
71 tool_executor: ToolExecutor,
73
74 tool_result_rx: Mutex<mpsc::Receiver<ToolResult>>,
76
77 batch_result_rx: Mutex<mpsc::Receiver<ToolBatchResult>>,
79
80 channel_size: usize,
82}
83
84impl LLMController {
85 pub fn new(
92 permission_registry: Arc<PermissionRegistry>,
93 ui_tx: Option<mpsc::Sender<UiMessage>>,
94 channel_size: Option<usize>,
95 ) -> Self {
96 let size = channel_size.unwrap_or(DEFAULT_CHANNEL_SIZE);
97
98 let (from_llm_tx, from_llm_rx) = mpsc::channel(size);
99 let (input_tx, input_rx) = mpsc::channel(size);
100
101 let (tool_result_tx, tool_result_rx) = mpsc::channel(size);
103 let (batch_result_tx, batch_result_rx) = mpsc::channel(size);
104
105 let tool_registry = Arc::new(ToolRegistry::new());
106 let tool_executor = ToolExecutor::new(
107 tool_registry.clone(),
108 permission_registry.clone(),
109 tool_result_tx,
110 batch_result_tx,
111 );
112
113 Self {
114 session_mgr: LLMSessionManager::new(),
115 token_usage: TokenUsageTracker::new(),
116 from_llm_rx: Mutex::new(from_llm_rx),
117 from_llm_tx,
118 input_rx: Mutex::new(input_rx),
119 input_tx,
120 started: AtomicBool::new(false),
121 shutdown: AtomicBool::new(false),
122 cancel_token: CancellationToken::new(),
123 ui_tx,
124 tool_registry,
125 tool_executor,
126 tool_result_rx: Mutex::new(tool_result_rx),
127 batch_result_rx: Mutex::new(batch_result_rx),
128 channel_size: size,
129 }
130 }
131
132 fn ui_has_capacity(&self) -> bool {
135 match &self.ui_tx {
136 Some(tx) => tx.capacity() > 0,
137 None => true, }
139 }
140
141 async fn send_to_ui(&self, event: ControllerEvent) {
144 if let Some(ref tx) = self.ui_tx {
145 let msg = convert_controller_event_to_ui_message(event);
146 if let Err(e) = tx.send(msg).await {
147 tracing::warn!("Failed to send event to UI: {}", e);
148 }
149 }
150 }
151
152 pub async fn start(&self) {
156 if self
158 .started
159 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
160 .is_err()
161 {
162 tracing::warn!("Controller already started");
163 return;
164 }
165
166 tracing::info!("Controller starting");
167
168 loop {
195 let mut from_llm_guard = self.from_llm_rx.lock().await;
196 let mut input_guard = self.input_rx.lock().await;
197 let mut batch_result_guard = self.batch_result_rx.lock().await;
198 let mut tool_result_guard = self.tool_result_rx.lock().await;
199
200 let ui_ready = self.ui_has_capacity();
202
203 tokio::select! {
204 _ = self.cancel_token.cancelled() => {
205 tracing::info!("Controller cancelled");
206 break;
207 }
208 msg = from_llm_guard.recv(), if ui_ready => {
210 drop(from_llm_guard);
211 drop(input_guard);
212 drop(batch_result_guard);
213 drop(tool_result_guard);
214 if let Some(payload) = msg {
215 self.handle_llm_response(payload).await;
216 } else {
217 tracing::info!("FromLLM channel closed");
218 break;
219 }
220 }
221 msg = input_guard.recv() => {
223 drop(from_llm_guard);
224 drop(input_guard);
225 drop(batch_result_guard);
226 drop(tool_result_guard);
227 if let Some(payload) = msg {
228 self.handle_input(payload).await;
229 } else {
230 tracing::info!("Input channel closed");
231 break;
232 }
233 }
234 batch_result = batch_result_guard.recv() => {
236 drop(from_llm_guard);
237 drop(input_guard);
238 drop(batch_result_guard);
239 drop(tool_result_guard);
240 if let Some(result) = batch_result {
241 self.handle_tool_batch_result(result).await;
242 }
243 }
244 tool_result = tool_result_guard.recv(), if ui_ready => {
246 drop(from_llm_guard);
247 drop(input_guard);
248 drop(batch_result_guard);
249 drop(tool_result_guard);
250 if let Some(result) = tool_result {
251 self.send_to_ui(ControllerEvent::ToolResult {
253 session_id: result.session_id,
254 tool_use_id: result.tool_use_id,
255 tool_name: result.tool_name,
256 display_name: result.display_name,
257 status: result.status,
258 content: result.content,
259 error: result.error,
260 turn_id: result.turn_id,
261 }).await;
262 }
263 }
264 }
265 }
266
267 tracing::info!("Controller stopped");
268 }
269
270 async fn handle_llm_response(&self, payload: FromLLMPayload) {
272 if payload.response_type == LLMResponseType::TokenUpdate
274 && let Some(session) = self.session_mgr.get_session_by_id(payload.session_id).await
275 {
276 self.token_usage
277 .increment(
278 payload.session_id,
279 session.model(),
280 payload.input_tokens,
281 payload.output_tokens,
282 )
283 .await;
284 }
285
286 let event = match payload.response_type {
287 LLMResponseType::StreamStart => Some(ControllerEvent::StreamStart {
288 session_id: payload.session_id,
289 message_id: payload.message_id,
290 model: payload.model,
291 turn_id: payload.turn_id,
292 }),
293 LLMResponseType::TextChunk => Some(ControllerEvent::TextChunk {
294 session_id: payload.session_id,
295 text: payload.text,
296 turn_id: payload.turn_id,
297 }),
298 LLMResponseType::ToolUseStart => {
299 if let Some(tool) = payload.tool_use {
300 Some(ControllerEvent::ToolUseStart {
301 session_id: payload.session_id,
302 tool_id: tool.id,
303 tool_name: tool.name,
304 turn_id: payload.turn_id,
305 })
306 } else {
307 None
308 }
309 }
310 LLMResponseType::ToolInputDelta => {
311 None
314 }
315 LLMResponseType::ToolUse => {
316 if let Some(ref tool) = payload.tool_use {
317 let input: HashMap<String, serde_json::Value> = tool
319 .input
320 .as_object()
321 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
322 .unwrap_or_default();
323
324 let (display_name, display_title) =
326 if let Some(t) = self.tool_registry().get(&tool.name).await {
327 let config = t.display_config();
328 (
329 Some(config.display_name),
330 Some((config.display_title)(&input)),
331 )
332 } else {
333 (None, None)
334 };
335
336 let request = ToolRequest {
337 tool_use_id: tool.id.clone(),
338 tool_name: tool.name.clone(),
339 input,
340 };
341
342 self.tool_executor
343 .execute(
344 payload.session_id,
345 payload.turn_id.clone(),
346 request,
347 self.cancel_token.clone(),
348 )
349 .await;
350
351 Some(ControllerEvent::ToolUse {
352 session_id: payload.session_id,
353 tool: payload.tool_use.unwrap(),
354 display_name,
355 display_title,
356 turn_id: payload.turn_id,
357 })
358 } else {
359 None
360 }
361 }
362 LLMResponseType::ToolBatch => {
363 if payload.tool_uses.is_empty() {
365 tracing::error!(
366 session_id = payload.session_id,
367 "Received tool batch response with empty tool_uses"
368 );
369 return;
370 }
371
372 tracing::debug!(
373 session_id = payload.session_id,
374 tool_count = payload.tool_uses.len(),
375 "LLM requested tool batch execution"
376 );
377
378 let mut requests = Vec::with_capacity(payload.tool_uses.len());
380 for tool_info in &payload.tool_uses {
381 let input: HashMap<String, serde_json::Value> = tool_info
382 .input
383 .as_object()
384 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
385 .unwrap_or_default();
386
387 requests.push(ToolRequest {
388 tool_use_id: tool_info.id.clone(),
389 tool_name: tool_info.name.clone(),
390 input: input.clone(),
391 });
392
393 let (display_name, display_title) =
395 if let Some(tool) = self.tool_registry().get(&tool_info.name).await {
396 let config = tool.display_config();
397 (
398 Some(config.display_name),
399 Some((config.display_title)(&input)),
400 )
401 } else {
402 (None, None)
403 };
404
405 self.send_to_ui(ControllerEvent::ToolUse {
407 session_id: payload.session_id,
408 tool: tool_info.clone(),
409 display_name,
410 display_title,
411 turn_id: payload.turn_id.clone(),
412 })
413 .await;
414 }
415
416 self.tool_executor
418 .execute_batch(
419 payload.session_id,
420 payload.turn_id.clone(),
421 requests,
422 self.cancel_token.clone(),
423 )
424 .await;
425
426 None
427 }
428 LLMResponseType::Complete => Some(ControllerEvent::Complete {
429 session_id: payload.session_id,
430 stop_reason: payload.stop_reason,
431 turn_id: payload.turn_id,
432 }),
433 LLMResponseType::Error => Some(ControllerEvent::Error {
434 session_id: payload.session_id,
435 error: payload.error.unwrap_or_else(|| "Unknown error".to_string()),
436 turn_id: payload.turn_id,
437 }),
438 LLMResponseType::TokenUpdate => {
439 let context_limit = if let Some(session) =
441 self.session_mgr.get_session_by_id(payload.session_id).await
442 {
443 session.context_limit()
444 } else {
445 0
446 };
447 Some(ControllerEvent::TokenUpdate {
448 session_id: payload.session_id,
449 input_tokens: payload.input_tokens,
450 output_tokens: payload.output_tokens,
451 context_limit,
452 })
453 }
454 };
455
456 if let Some(event) = event {
458 self.send_to_ui(event).await;
459 }
460 }
461
462 async fn handle_input(&self, payload: ControllerInputPayload) {
464 match payload.input_type {
465 InputType::Data => {
466 self.handle_data_input(payload).await;
467 }
468 InputType::Control => {
469 self.handle_control_input(payload).await;
470 }
471 }
472 }
473
474 async fn handle_data_input(&self, payload: ControllerInputPayload) {
476 let session_id = payload.session_id;
477
478 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
480 tracing::error!(session_id, "Session not found for data input");
481 self.emit_error(session_id, "Session not found".to_string(), payload.turn_id)
482 .await;
483 return;
484 };
485
486 let llm_payload = ToLLMPayload {
488 request_type: LLMRequestType::UserMessage,
489 content: payload.content,
490 tool_results: Vec::new(),
491 options: None,
492 turn_id: payload.turn_id,
493 compact_summaries: HashMap::new(),
494 };
495
496 let sent = session.send(llm_payload).await;
498 if !sent {
499 tracing::error!(session_id, "Failed to send message to session");
500 self.emit_error(
501 session_id,
502 "Failed to send message to session".to_string(),
503 None,
504 )
505 .await;
506 }
507 }
508
509 async fn handle_control_input(&self, payload: ControllerInputPayload) {
511 let session_id = payload.session_id;
512
513 let Some(cmd) = payload.control_cmd else {
514 tracing::warn!(session_id, "Control input without command");
515 return;
516 };
517
518 match cmd {
519 ControlCmd::Interrupt => {
520 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
522 session.interrupt().await;
523 tracing::info!(session_id, "Session interrupted");
524 } else {
525 tracing::warn!(session_id, "Cannot interrupt: session not found");
526 }
527 }
528 ControlCmd::Shutdown => {
529 tracing::info!("Shutdown command received");
531 self.shutdown().await;
532 }
533 ControlCmd::Clear => {
534 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
536 session.clear_conversation().await;
537 tracing::info!(session_id, "Session conversation cleared");
538 self.emit_command_complete(session_id, cmd, true, None)
539 .await;
540 } else {
541 tracing::warn!(session_id, "Cannot clear: session not found");
542 self.emit_command_complete(
543 session_id,
544 cmd,
545 false,
546 Some("Session not found".to_string()),
547 )
548 .await;
549 }
550 }
551 ControlCmd::Compact => {
552 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
554 let result = session.force_compact().await;
555
556 if let Some(error) = result.error {
557 tracing::warn!(session_id, error = %error, "Session compaction failed");
559 self.emit_command_complete(session_id, cmd, false, Some(error))
560 .await;
561 } else if !result.compacted {
562 tracing::info!(session_id, "Nothing to compact");
564 self.emit_command_complete(
565 session_id,
566 cmd,
567 true,
568 Some(
569 "Nothing to compact - not enough turns in conversation".to_string(),
570 ),
571 )
572 .await;
573 } else {
574 let message = format!(
576 "Conversation compacted\n Turns summarized: {}\n Turns kept: {}\n Messages: {} -> {}{}",
577 result.turns_compacted,
578 result.turns_kept,
579 result.messages_before,
580 result.messages_after,
581 if result.summary_length > 0 {
582 format!("\n Summary length: {} chars", result.summary_length)
583 } else {
584 String::new()
585 }
586 );
587 tracing::info!(
588 session_id,
589 turns_compacted = result.turns_compacted,
590 messages_before = result.messages_before,
591 messages_after = result.messages_after,
592 "Session compaction completed"
593 );
594 self.emit_command_complete(session_id, cmd, true, Some(message))
595 .await;
596 }
597 } else {
598 tracing::warn!(session_id, "Cannot compact: session not found");
599 self.emit_command_complete(
600 session_id,
601 cmd,
602 false,
603 Some("Session not found".to_string()),
604 )
605 .await;
606 }
607 }
608 }
609 }
610
611 async fn emit_error(&self, session_id: i64, error: String, turn_id: Option<TurnId>) {
613 self.send_to_ui(ControllerEvent::Error {
614 session_id,
615 error,
616 turn_id,
617 })
618 .await;
619 }
620
621 async fn emit_command_complete(
623 &self,
624 session_id: i64,
625 command: ControlCmd,
626 success: bool,
627 message: Option<String>,
628 ) {
629 self.send_to_ui(ControllerEvent::CommandComplete {
630 session_id,
631 command,
632 success,
633 message,
634 })
635 .await;
636 }
637
638 async fn handle_tool_batch_result(&self, batch_result: ToolBatchResult) {
640 use crate::controller::types::ToolResultInfo;
641
642 let session_id = batch_result.session_id;
643
644 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
645 tracing::error!(session_id, "Session not found for tool result");
646 return;
647 };
648
649 let mut compact_summaries = HashMap::new();
651 let tool_results: Vec<ToolResultInfo> = batch_result
652 .results
653 .iter()
654 .map(|result| {
655 let (content, is_error) = if let Some(ref error) = result.error {
656 (error.clone(), true)
657 } else {
658 (result.content.clone(), false)
659 };
660
661 if let Some(ref summary) = result.compact_summary {
663 compact_summaries.insert(result.tool_use_id.clone(), summary.clone());
664 tracing::debug!(
665 tool_use_id = %result.tool_use_id,
666 summary_len = summary.len(),
667 "Extracted compact summary for tool result"
668 );
669 }
670
671 ToolResultInfo {
672 tool_use_id: result.tool_use_id.clone(),
673 content,
674 is_error,
675 }
676 })
677 .collect();
678
679 tracing::info!(
680 session_id,
681 tool_count = tool_results.len(),
682 compact_summary_count = compact_summaries.len(),
683 "Sending tool results to session with compact summaries"
684 );
685
686 let llm_payload = ToLLMPayload {
688 request_type: LLMRequestType::ToolResult,
689 content: String::new(),
690 tool_results,
691 options: None,
692 turn_id: batch_result.turn_id,
693 compact_summaries,
694 };
695
696 let sent = session.send(llm_payload).await;
698 if !sent {
699 tracing::error!(session_id, "Failed to send tool result to session");
700 } else {
701 tracing::debug!(
702 session_id,
703 batch_id = batch_result.batch_id,
704 "Sent tool results to session"
705 );
706 }
707 }
708
709 pub async fn shutdown(&self) {
712 if self
713 .shutdown
714 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
715 .is_err()
716 {
717 return; }
719
720 tracing::info!("Controller shutting down");
721
722 self.session_mgr.shutdown().await;
724
725 self.cancel_token.cancel();
727
728 tracing::info!("Controller shutdown complete");
729 }
730
731 pub fn is_shutdown(&self) -> bool {
733 self.shutdown.load(Ordering::SeqCst)
734 }
735
736 pub fn is_started(&self) -> bool {
738 self.started.load(Ordering::SeqCst)
739 }
740
741 pub async fn create_session(&self, config: LLMSessionConfig) -> Result<i64, LlmError> {
754 let session_id = self
755 .session_mgr
756 .create_session(config, self.from_llm_tx.clone(), self.channel_size)
757 .await?;
758
759 tracing::info!(session_id, "Session created via controller");
760 Ok(session_id)
761 }
762
763 pub async fn get_session(&self, session_id: i64) -> Option<Arc<LLMSession>> {
768 self.session_mgr.get_session_by_id(session_id).await
769 }
770
771 pub async fn session_count(&self) -> usize {
773 self.session_mgr.session_count().await
774 }
775
776 pub async fn remove_session(&self, session_id: i64) -> bool {
787 self.session_mgr.remove_session(session_id).await
788 }
789
790 pub async fn send_input(&self, input: ControllerInputPayload) -> Result<(), ControllerError> {
800 if self.is_shutdown() {
801 return Err(ControllerError::Shutdown);
802 }
803
804 match tokio::time::timeout(SEND_INPUT_TIMEOUT, self.input_tx.send(input)).await {
805 Ok(Ok(())) => Ok(()),
806 Ok(Err(_)) => Err(ControllerError::ChannelClosed),
807 Err(_) => Err(ControllerError::SendTimeout(SEND_INPUT_TIMEOUT.as_secs())),
808 }
809 }
810
811 pub async fn get_session_token_usage(
815 &self,
816 session_id: i64,
817 ) -> Option<crate::controller::usage::TokenMeter> {
818 self.token_usage.get_session_usage(session_id).await
819 }
820
821 pub async fn get_model_token_usage(
823 &self,
824 model: &str,
825 ) -> Option<crate::controller::usage::TokenMeter> {
826 self.token_usage.get_model_usage(model).await
827 }
828
829 pub async fn get_total_token_usage(&self) -> crate::controller::usage::TokenMeter {
831 self.token_usage.get_total_usage().await
832 }
833
834 pub fn token_usage(&self) -> &TokenUsageTracker {
836 &self.token_usage
837 }
838
839 pub fn tool_registry(&self) -> &Arc<ToolRegistry> {
843 &self.tool_registry
844 }
845}