1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::Duration;
6
7use tokio::sync::{mpsc, Mutex};
8use tokio_util::sync::CancellationToken;
9
10use std::sync::Arc;
11
12use crate::client::error::LlmError;
13use crate::controller::session::{LLMSession, LLMSessionConfig, LLMSessionManager};
14use crate::controller::tools::{
15 ToolBatchResult, ToolExecutor, ToolRegistry, ToolRequest, ToolResult,
16};
17use crate::controller::error::ControllerError;
18use crate::controller::types::{
19 ControlCmd, ControllerEvent, ControllerInputPayload, FromLLMPayload, InputType,
20 LLMRequestType, LLMResponseType, ToLLMPayload, TurnId,
21};
22use crate::controller::usage::TokenUsageTracker;
23use crate::permissions::PermissionRegistry;
24use crate::agent::{convert_controller_event_to_ui_message, UiMessage};
25
26pub const DEFAULT_CHANNEL_SIZE: usize = 500;
30
31const SEND_INPUT_TIMEOUT: Duration = Duration::from_secs(5);
33
34pub struct LLMController {
36 session_mgr: LLMSessionManager,
38
39 token_usage: TokenUsageTracker,
41
42 from_llm_rx: Mutex<mpsc::Receiver<FromLLMPayload>>,
44
45 from_llm_tx: mpsc::Sender<FromLLMPayload>,
47
48 input_rx: Mutex<mpsc::Receiver<ControllerInputPayload>>,
50
51 input_tx: mpsc::Sender<ControllerInputPayload>,
53
54 started: AtomicBool,
56
57 shutdown: AtomicBool,
59
60 cancel_token: CancellationToken,
62
63 ui_tx: Option<mpsc::Sender<UiMessage>>,
67
68 tool_registry: Arc<ToolRegistry>,
70
71 tool_executor: ToolExecutor,
73
74 tool_result_rx: Mutex<mpsc::Receiver<ToolResult>>,
76
77 batch_result_rx: Mutex<mpsc::Receiver<ToolBatchResult>>,
79
80 channel_size: usize,
82}
83
84impl LLMController {
85 pub fn new(
92 permission_registry: Arc<PermissionRegistry>,
93 ui_tx: Option<mpsc::Sender<UiMessage>>,
94 channel_size: Option<usize>,
95 ) -> Self {
96 let size = channel_size.unwrap_or(DEFAULT_CHANNEL_SIZE);
97
98 let (from_llm_tx, from_llm_rx) = mpsc::channel(size);
99 let (input_tx, input_rx) = mpsc::channel(size);
100
101 let (tool_result_tx, tool_result_rx) = mpsc::channel(size);
103 let (batch_result_tx, batch_result_rx) = mpsc::channel(size);
104
105 let tool_registry = Arc::new(ToolRegistry::new());
106 let tool_executor = ToolExecutor::new(
107 tool_registry.clone(),
108 permission_registry.clone(),
109 tool_result_tx,
110 batch_result_tx,
111 );
112
113 Self {
114 session_mgr: LLMSessionManager::new(),
115 token_usage: TokenUsageTracker::new(),
116 from_llm_rx: Mutex::new(from_llm_rx),
117 from_llm_tx,
118 input_rx: Mutex::new(input_rx),
119 input_tx,
120 started: AtomicBool::new(false),
121 shutdown: AtomicBool::new(false),
122 cancel_token: CancellationToken::new(),
123 ui_tx,
124 tool_registry,
125 tool_executor,
126 tool_result_rx: Mutex::new(tool_result_rx),
127 batch_result_rx: Mutex::new(batch_result_rx),
128 channel_size: size,
129 }
130 }
131
132 fn ui_has_capacity(&self) -> bool {
135 match &self.ui_tx {
136 Some(tx) => tx.capacity() > 0,
137 None => true, }
139 }
140
141 async fn send_to_ui(&self, event: ControllerEvent) {
144 if let Some(ref tx) = self.ui_tx {
145 let msg = convert_controller_event_to_ui_message(event);
146 if let Err(e) = tx.send(msg).await {
147 tracing::warn!("Failed to send event to UI: {}", e);
148 }
149 }
150 }
151
152 pub async fn start(&self) {
156 if self
158 .started
159 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
160 .is_err()
161 {
162 tracing::warn!("Controller already started");
163 return;
164 }
165
166 tracing::info!("Controller starting");
167
168 loop {
195 let mut from_llm_guard = self.from_llm_rx.lock().await;
196 let mut input_guard = self.input_rx.lock().await;
197 let mut batch_result_guard = self.batch_result_rx.lock().await;
198 let mut tool_result_guard = self.tool_result_rx.lock().await;
199
200 let ui_ready = self.ui_has_capacity();
202
203 tokio::select! {
204 _ = self.cancel_token.cancelled() => {
205 tracing::info!("Controller cancelled");
206 break;
207 }
208 msg = from_llm_guard.recv(), if ui_ready => {
210 drop(from_llm_guard);
211 drop(input_guard);
212 drop(batch_result_guard);
213 drop(tool_result_guard);
214 if let Some(payload) = msg {
215 self.handle_llm_response(payload).await;
216 } else {
217 tracing::info!("FromLLM channel closed");
218 break;
219 }
220 }
221 msg = input_guard.recv() => {
223 drop(from_llm_guard);
224 drop(input_guard);
225 drop(batch_result_guard);
226 drop(tool_result_guard);
227 if let Some(payload) = msg {
228 self.handle_input(payload).await;
229 } else {
230 tracing::info!("Input channel closed");
231 break;
232 }
233 }
234 batch_result = batch_result_guard.recv() => {
236 drop(from_llm_guard);
237 drop(input_guard);
238 drop(batch_result_guard);
239 drop(tool_result_guard);
240 if let Some(result) = batch_result {
241 self.handle_tool_batch_result(result).await;
242 }
243 }
244 tool_result = tool_result_guard.recv(), if ui_ready => {
246 drop(from_llm_guard);
247 drop(input_guard);
248 drop(batch_result_guard);
249 drop(tool_result_guard);
250 if let Some(result) = tool_result {
251 self.send_to_ui(ControllerEvent::ToolResult {
253 session_id: result.session_id,
254 tool_use_id: result.tool_use_id,
255 tool_name: result.tool_name,
256 display_name: result.display_name,
257 status: result.status,
258 content: result.content,
259 error: result.error,
260 turn_id: result.turn_id,
261 }).await;
262 }
263 }
264 }
265 }
266
267 tracing::info!("Controller stopped");
268 }
269
270 async fn handle_llm_response(&self, payload: FromLLMPayload) {
272 if payload.response_type == LLMResponseType::TokenUpdate {
274 if let Some(session) = self.session_mgr.get_session_by_id(payload.session_id).await {
275 self.token_usage
276 .increment(
277 payload.session_id,
278 session.model(),
279 payload.input_tokens,
280 payload.output_tokens,
281 )
282 .await;
283 }
284 }
285
286 let event = match payload.response_type {
287 LLMResponseType::StreamStart => Some(ControllerEvent::StreamStart {
288 session_id: payload.session_id,
289 message_id: payload.message_id,
290 model: payload.model,
291 turn_id: payload.turn_id,
292 }),
293 LLMResponseType::TextChunk => Some(ControllerEvent::TextChunk {
294 session_id: payload.session_id,
295 text: payload.text,
296 turn_id: payload.turn_id,
297 }),
298 LLMResponseType::ToolUseStart => {
299 if let Some(tool) = payload.tool_use {
300 Some(ControllerEvent::ToolUseStart {
301 session_id: payload.session_id,
302 tool_id: tool.id,
303 tool_name: tool.name,
304 turn_id: payload.turn_id,
305 })
306 } else {
307 None
308 }
309 }
310 LLMResponseType::ToolInputDelta => {
311 None
314 }
315 LLMResponseType::ToolUse => {
316 if let Some(ref tool) = payload.tool_use {
317 let input: HashMap<String, serde_json::Value> = tool
319 .input
320 .as_object()
321 .map(|obj| {
322 obj.iter()
323 .map(|(k, v)| (k.clone(), v.clone()))
324 .collect()
325 })
326 .unwrap_or_default();
327
328 let (display_name, display_title) =
330 if let Some(t) = self.tool_registry().get(&tool.name).await {
331 let config = t.display_config();
332 (Some(config.display_name), Some((config.display_title)(&input)))
333 } else {
334 (None, None)
335 };
336
337 let request = ToolRequest {
338 tool_use_id: tool.id.clone(),
339 tool_name: tool.name.clone(),
340 input,
341 };
342
343 self.tool_executor
344 .execute(
345 payload.session_id,
346 payload.turn_id.clone(),
347 request,
348 self.cancel_token.clone(),
349 )
350 .await;
351
352 Some(ControllerEvent::ToolUse {
353 session_id: payload.session_id,
354 tool: payload.tool_use.unwrap(),
355 display_name,
356 display_title,
357 turn_id: payload.turn_id,
358 })
359 } else {
360 None
361 }
362 }
363 LLMResponseType::ToolBatch => {
364 if payload.tool_uses.is_empty() {
366 tracing::error!(
367 session_id = payload.session_id,
368 "Received tool batch response with empty tool_uses"
369 );
370 return;
371 }
372
373 tracing::debug!(
374 session_id = payload.session_id,
375 tool_count = payload.tool_uses.len(),
376 "LLM requested tool batch execution"
377 );
378
379 let mut requests = Vec::with_capacity(payload.tool_uses.len());
381 for tool_info in &payload.tool_uses {
382 let input: HashMap<String, serde_json::Value> = tool_info
383 .input
384 .as_object()
385 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
386 .unwrap_or_default();
387
388 requests.push(ToolRequest {
389 tool_use_id: tool_info.id.clone(),
390 tool_name: tool_info.name.clone(),
391 input: input.clone(),
392 });
393
394 let (display_name, display_title) =
396 if let Some(tool) = self.tool_registry().get(&tool_info.name).await {
397 let config = tool.display_config();
398 (Some(config.display_name), Some((config.display_title)(&input)))
399 } else {
400 (None, None)
401 };
402
403 self.send_to_ui(ControllerEvent::ToolUse {
405 session_id: payload.session_id,
406 tool: tool_info.clone(),
407 display_name,
408 display_title,
409 turn_id: payload.turn_id.clone(),
410 }).await;
411 }
412
413 self.tool_executor
415 .execute_batch(
416 payload.session_id,
417 payload.turn_id.clone(),
418 requests,
419 self.cancel_token.clone(),
420 )
421 .await;
422
423 None
424 }
425 LLMResponseType::Complete => Some(ControllerEvent::Complete {
426 session_id: payload.session_id,
427 stop_reason: payload.stop_reason,
428 turn_id: payload.turn_id,
429 }),
430 LLMResponseType::Error => Some(ControllerEvent::Error {
431 session_id: payload.session_id,
432 error: payload.error.unwrap_or_else(|| "Unknown error".to_string()),
433 turn_id: payload.turn_id,
434 }),
435 LLMResponseType::TokenUpdate => {
436 let context_limit = if let Some(session) =
438 self.session_mgr.get_session_by_id(payload.session_id).await
439 {
440 session.context_limit()
441 } else {
442 0
443 };
444 Some(ControllerEvent::TokenUpdate {
445 session_id: payload.session_id,
446 input_tokens: payload.input_tokens,
447 output_tokens: payload.output_tokens,
448 context_limit,
449 })
450 }
451 };
452
453 if let Some(event) = event {
455 self.send_to_ui(event).await;
456 }
457 }
458
459 async fn handle_input(&self, payload: ControllerInputPayload) {
461 match payload.input_type {
462 InputType::Data => {
463 self.handle_data_input(payload).await;
464 }
465 InputType::Control => {
466 self.handle_control_input(payload).await;
467 }
468 }
469 }
470
471 async fn handle_data_input(&self, payload: ControllerInputPayload) {
473 let session_id = payload.session_id;
474
475 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
477 tracing::error!(session_id, "Session not found for data input");
478 self.emit_error(session_id, "Session not found".to_string(), payload.turn_id).await;
479 return;
480 };
481
482 let llm_payload = ToLLMPayload {
484 request_type: LLMRequestType::UserMessage,
485 content: payload.content,
486 tool_results: Vec::new(),
487 options: None,
488 turn_id: payload.turn_id,
489 compact_summaries: HashMap::new(),
490 };
491
492 let sent = session.send(llm_payload).await;
494 if !sent {
495 tracing::error!(session_id, "Failed to send message to session");
496 self.emit_error(
497 session_id,
498 "Failed to send message to session".to_string(),
499 None,
500 ).await;
501 }
502 }
503
504 async fn handle_control_input(&self, payload: ControllerInputPayload) {
506 let session_id = payload.session_id;
507
508 let Some(cmd) = payload.control_cmd else {
509 tracing::warn!(session_id, "Control input without command");
510 return;
511 };
512
513 match cmd {
514 ControlCmd::Interrupt => {
515 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
517 session.interrupt().await;
518 tracing::info!(session_id, "Session interrupted");
519 } else {
520 tracing::warn!(session_id, "Cannot interrupt: session not found");
521 }
522 }
523 ControlCmd::Shutdown => {
524 tracing::info!("Shutdown command received");
526 self.shutdown().await;
527 }
528 ControlCmd::Clear => {
529 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
531 session.clear_conversation().await;
532 tracing::info!(session_id, "Session conversation cleared");
533 self.emit_command_complete(session_id, cmd, true, None).await;
534 } else {
535 tracing::warn!(session_id, "Cannot clear: session not found");
536 self.emit_command_complete(
537 session_id,
538 cmd,
539 false,
540 Some("Session not found".to_string()),
541 ).await;
542 }
543 }
544 ControlCmd::Compact => {
545 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
547 let result = session.force_compact().await;
548
549 if let Some(error) = result.error {
550 tracing::warn!(session_id, error = %error, "Session compaction failed");
552 self.emit_command_complete(session_id, cmd, false, Some(error)).await;
553 } else if !result.compacted {
554 tracing::info!(session_id, "Nothing to compact");
556 self.emit_command_complete(
557 session_id,
558 cmd,
559 true,
560 Some("Nothing to compact - not enough turns in conversation".to_string()),
561 ).await;
562 } else {
563 let message = format!(
565 "Conversation compacted\n Turns summarized: {}\n Turns kept: {}\n Messages: {} -> {}{}",
566 result.turns_compacted,
567 result.turns_kept,
568 result.messages_before,
569 result.messages_after,
570 if result.summary_length > 0 {
571 format!("\n Summary length: {} chars", result.summary_length)
572 } else {
573 String::new()
574 }
575 );
576 tracing::info!(
577 session_id,
578 turns_compacted = result.turns_compacted,
579 messages_before = result.messages_before,
580 messages_after = result.messages_after,
581 "Session compaction completed"
582 );
583 self.emit_command_complete(session_id, cmd, true, Some(message)).await;
584 }
585 } else {
586 tracing::warn!(session_id, "Cannot compact: session not found");
587 self.emit_command_complete(
588 session_id,
589 cmd,
590 false,
591 Some("Session not found".to_string()),
592 ).await;
593 }
594 }
595 }
596 }
597
598 async fn emit_error(&self, session_id: i64, error: String, turn_id: Option<TurnId>) {
600 self.send_to_ui(ControllerEvent::Error {
601 session_id,
602 error,
603 turn_id,
604 }).await;
605 }
606
607 async fn emit_command_complete(
609 &self,
610 session_id: i64,
611 command: ControlCmd,
612 success: bool,
613 message: Option<String>,
614 ) {
615 self.send_to_ui(ControllerEvent::CommandComplete {
616 session_id,
617 command,
618 success,
619 message,
620 }).await;
621 }
622
623 async fn handle_tool_batch_result(&self, batch_result: ToolBatchResult) {
625 use crate::controller::types::ToolResultInfo;
626
627 let session_id = batch_result.session_id;
628
629 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
630 tracing::error!(session_id, "Session not found for tool result");
631 return;
632 };
633
634 let mut compact_summaries = HashMap::new();
636 let tool_results: Vec<ToolResultInfo> = batch_result
637 .results
638 .iter()
639 .map(|result| {
640 let (content, is_error) = if let Some(ref error) = result.error {
641 (error.clone(), true)
642 } else {
643 (result.content.clone(), false)
644 };
645
646 if let Some(ref summary) = result.compact_summary {
648 compact_summaries.insert(result.tool_use_id.clone(), summary.clone());
649 tracing::debug!(
650 tool_use_id = %result.tool_use_id,
651 summary_len = summary.len(),
652 "Extracted compact summary for tool result"
653 );
654 }
655
656 ToolResultInfo {
657 tool_use_id: result.tool_use_id.clone(),
658 content,
659 is_error,
660 }
661 })
662 .collect();
663
664 tracing::info!(
665 session_id,
666 tool_count = tool_results.len(),
667 compact_summary_count = compact_summaries.len(),
668 "Sending tool results to session with compact summaries"
669 );
670
671 let llm_payload = ToLLMPayload {
673 request_type: LLMRequestType::ToolResult,
674 content: String::new(),
675 tool_results,
676 options: None,
677 turn_id: batch_result.turn_id,
678 compact_summaries,
679 };
680
681 let sent = session.send(llm_payload).await;
683 if !sent {
684 tracing::error!(session_id, "Failed to send tool result to session");
685 } else {
686 tracing::debug!(
687 session_id,
688 batch_id = batch_result.batch_id,
689 "Sent tool results to session"
690 );
691 }
692 }
693
694 pub async fn shutdown(&self) {
697 if self
698 .shutdown
699 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
700 .is_err()
701 {
702 return; }
704
705 tracing::info!("Controller shutting down");
706
707 self.session_mgr.shutdown().await;
709
710 self.cancel_token.cancel();
712
713 tracing::info!("Controller shutdown complete");
714 }
715
716 pub fn is_shutdown(&self) -> bool {
718 self.shutdown.load(Ordering::SeqCst)
719 }
720
721 pub fn is_started(&self) -> bool {
723 self.started.load(Ordering::SeqCst)
724 }
725
726 pub async fn create_session(&self, config: LLMSessionConfig) -> Result<i64, LlmError> {
739 let session_id = self
740 .session_mgr
741 .create_session(config, self.from_llm_tx.clone(), self.channel_size)
742 .await?;
743
744 tracing::info!(session_id, "Session created via controller");
745 Ok(session_id)
746 }
747
748 pub async fn get_session(&self, session_id: i64) -> Option<Arc<LLMSession>> {
753 self.session_mgr.get_session_by_id(session_id).await
754 }
755
756 pub async fn session_count(&self) -> usize {
758 self.session_mgr.session_count().await
759 }
760
761 pub async fn remove_session(&self, session_id: i64) -> bool {
772 self.session_mgr.remove_session(session_id).await
773 }
774
775 pub async fn send_input(&self, input: ControllerInputPayload) -> Result<(), ControllerError> {
785 if self.is_shutdown() {
786 return Err(ControllerError::Shutdown);
787 }
788
789 match tokio::time::timeout(SEND_INPUT_TIMEOUT, self.input_tx.send(input)).await {
790 Ok(Ok(())) => Ok(()),
791 Ok(Err(_)) => Err(ControllerError::ChannelClosed),
792 Err(_) => Err(ControllerError::SendTimeout(SEND_INPUT_TIMEOUT.as_secs())),
793 }
794 }
795
796 pub async fn get_session_token_usage(
800 &self,
801 session_id: i64,
802 ) -> Option<crate::controller::usage::TokenMeter> {
803 self.token_usage.get_session_usage(session_id).await
804 }
805
806 pub async fn get_model_token_usage(&self, model: &str) -> Option<crate::controller::usage::TokenMeter> {
808 self.token_usage.get_model_usage(model).await
809 }
810
811 pub async fn get_total_token_usage(&self) -> crate::controller::usage::TokenMeter {
813 self.token_usage.get_total_usage().await
814 }
815
816 pub fn token_usage(&self) -> &TokenUsageTracker {
818 &self.token_usage
819 }
820
821 pub fn tool_registry(&self) -> &Arc<ToolRegistry> {
825 &self.tool_registry
826 }
827
828}