1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::Duration;
6
7use tokio::sync::{mpsc, Mutex};
8use tokio_util::sync::CancellationToken;
9
10use std::sync::Arc;
11
12use crate::client::error::LlmError;
13use crate::controller::session::{LLMSession, LLMSessionConfig, LLMSessionManager};
14use crate::controller::tools::{
15 ToolBatchResult, ToolExecutor, ToolRegistry, ToolRequest, ToolResult,
16};
17use crate::controller::error::ControllerError;
18use crate::controller::types::{
19 ControlCmd, ControllerEvent, ControllerInputPayload, FromLLMPayload, InputType,
20 LLMRequestType, LLMResponseType, ToLLMPayload, TurnId,
21};
22use crate::controller::usage::TokenUsageTracker;
23use crate::agent::{convert_controller_event_to_ui_message, UiMessage};
24
25pub const DEFAULT_CHANNEL_SIZE: usize = 500;
29
30const SEND_INPUT_TIMEOUT: Duration = Duration::from_secs(5);
32
33pub struct LLMController {
35 session_mgr: LLMSessionManager,
37
38 token_usage: TokenUsageTracker,
40
41 from_llm_rx: Mutex<mpsc::Receiver<FromLLMPayload>>,
43
44 from_llm_tx: mpsc::Sender<FromLLMPayload>,
46
47 input_rx: Mutex<mpsc::Receiver<ControllerInputPayload>>,
49
50 input_tx: mpsc::Sender<ControllerInputPayload>,
52
53 started: AtomicBool,
55
56 shutdown: AtomicBool,
58
59 cancel_token: CancellationToken,
61
62 ui_tx: Option<mpsc::Sender<UiMessage>>,
66
67 tool_registry: Arc<ToolRegistry>,
69
70 tool_executor: ToolExecutor,
72
73 tool_result_rx: Mutex<mpsc::Receiver<ToolResult>>,
75
76 batch_result_rx: Mutex<mpsc::Receiver<ToolBatchResult>>,
78
79 channel_size: usize,
81}
82
83impl LLMController {
84 pub fn new(ui_tx: Option<mpsc::Sender<UiMessage>>, channel_size: Option<usize>) -> Self {
90 let size = channel_size.unwrap_or(DEFAULT_CHANNEL_SIZE);
91
92 let (from_llm_tx, from_llm_rx) = mpsc::channel(size);
93 let (input_tx, input_rx) = mpsc::channel(size);
94
95 let (tool_result_tx, tool_result_rx) = mpsc::channel(size);
97 let (batch_result_tx, batch_result_rx) = mpsc::channel(size);
98
99 let tool_registry = Arc::new(ToolRegistry::new());
100 let tool_executor = ToolExecutor::new(
101 tool_registry.clone(),
102 tool_result_tx,
103 batch_result_tx,
104 );
105
106 Self {
107 session_mgr: LLMSessionManager::new(),
108 token_usage: TokenUsageTracker::new(),
109 from_llm_rx: Mutex::new(from_llm_rx),
110 from_llm_tx,
111 input_rx: Mutex::new(input_rx),
112 input_tx,
113 started: AtomicBool::new(false),
114 shutdown: AtomicBool::new(false),
115 cancel_token: CancellationToken::new(),
116 ui_tx,
117 tool_registry,
118 tool_executor,
119 tool_result_rx: Mutex::new(tool_result_rx),
120 batch_result_rx: Mutex::new(batch_result_rx),
121 channel_size: size,
122 }
123 }
124
125 fn ui_has_capacity(&self) -> bool {
128 match &self.ui_tx {
129 Some(tx) => tx.capacity() > 0,
130 None => true, }
132 }
133
134 async fn send_to_ui(&self, event: ControllerEvent) {
137 if let Some(ref tx) = self.ui_tx {
138 let msg = convert_controller_event_to_ui_message(event);
139 if let Err(e) = tx.send(msg).await {
140 tracing::warn!("Failed to send event to UI: {}", e);
141 }
142 }
143 }
144
145 pub async fn start(&self) {
149 if self
151 .started
152 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
153 .is_err()
154 {
155 tracing::warn!("Controller already started");
156 return;
157 }
158
159 tracing::info!("Controller starting");
160
161 loop {
188 let mut from_llm_guard = self.from_llm_rx.lock().await;
189 let mut input_guard = self.input_rx.lock().await;
190 let mut batch_result_guard = self.batch_result_rx.lock().await;
191 let mut tool_result_guard = self.tool_result_rx.lock().await;
192
193 let ui_ready = self.ui_has_capacity();
195
196 tokio::select! {
197 _ = self.cancel_token.cancelled() => {
198 tracing::info!("Controller cancelled");
199 break;
200 }
201 msg = from_llm_guard.recv(), if ui_ready => {
203 drop(from_llm_guard);
204 drop(input_guard);
205 drop(batch_result_guard);
206 drop(tool_result_guard);
207 if let Some(payload) = msg {
208 self.handle_llm_response(payload).await;
209 } else {
210 tracing::info!("FromLLM channel closed");
211 break;
212 }
213 }
214 msg = input_guard.recv() => {
216 drop(from_llm_guard);
217 drop(input_guard);
218 drop(batch_result_guard);
219 drop(tool_result_guard);
220 if let Some(payload) = msg {
221 self.handle_input(payload).await;
222 } else {
223 tracing::info!("Input channel closed");
224 break;
225 }
226 }
227 batch_result = batch_result_guard.recv() => {
229 drop(from_llm_guard);
230 drop(input_guard);
231 drop(batch_result_guard);
232 drop(tool_result_guard);
233 if let Some(result) = batch_result {
234 self.handle_tool_batch_result(result).await;
235 }
236 }
237 tool_result = tool_result_guard.recv(), if ui_ready => {
239 drop(from_llm_guard);
240 drop(input_guard);
241 drop(batch_result_guard);
242 drop(tool_result_guard);
243 if let Some(result) = tool_result {
244 self.send_to_ui(ControllerEvent::ToolResult {
246 session_id: result.session_id,
247 tool_use_id: result.tool_use_id,
248 tool_name: result.tool_name,
249 display_name: result.display_name,
250 status: result.status,
251 content: result.content,
252 error: result.error,
253 turn_id: result.turn_id,
254 }).await;
255 }
256 }
257 }
258 }
259
260 tracing::info!("Controller stopped");
261 }
262
263 async fn handle_llm_response(&self, payload: FromLLMPayload) {
265 if payload.response_type == LLMResponseType::TokenUpdate {
267 if let Some(session) = self.session_mgr.get_session_by_id(payload.session_id).await {
268 self.token_usage
269 .increment(
270 payload.session_id,
271 session.model(),
272 payload.input_tokens,
273 payload.output_tokens,
274 )
275 .await;
276 }
277 }
278
279 let event = match payload.response_type {
280 LLMResponseType::StreamStart => Some(ControllerEvent::StreamStart {
281 session_id: payload.session_id,
282 message_id: payload.message_id,
283 model: payload.model,
284 turn_id: payload.turn_id,
285 }),
286 LLMResponseType::TextChunk => Some(ControllerEvent::TextChunk {
287 session_id: payload.session_id,
288 text: payload.text,
289 turn_id: payload.turn_id,
290 }),
291 LLMResponseType::ToolUseStart => {
292 if let Some(tool) = payload.tool_use {
293 Some(ControllerEvent::ToolUseStart {
294 session_id: payload.session_id,
295 tool_id: tool.id,
296 tool_name: tool.name,
297 turn_id: payload.turn_id,
298 })
299 } else {
300 None
301 }
302 }
303 LLMResponseType::ToolInputDelta => {
304 None
307 }
308 LLMResponseType::ToolUse => {
309 if let Some(ref tool) = payload.tool_use {
310 let input: HashMap<String, serde_json::Value> = tool
312 .input
313 .as_object()
314 .map(|obj| {
315 obj.iter()
316 .map(|(k, v)| (k.clone(), v.clone()))
317 .collect()
318 })
319 .unwrap_or_default();
320
321 let (display_name, display_title) =
323 if let Some(t) = self.tool_registry().get(&tool.name).await {
324 let config = t.display_config();
325 (Some(config.display_name), Some((config.display_title)(&input)))
326 } else {
327 (None, None)
328 };
329
330 let request = ToolRequest {
331 tool_use_id: tool.id.clone(),
332 tool_name: tool.name.clone(),
333 input,
334 };
335
336 self.tool_executor
337 .execute(
338 payload.session_id,
339 payload.turn_id.clone(),
340 request,
341 self.cancel_token.clone(),
342 )
343 .await;
344
345 Some(ControllerEvent::ToolUse {
346 session_id: payload.session_id,
347 tool: payload.tool_use.unwrap(),
348 display_name,
349 display_title,
350 turn_id: payload.turn_id,
351 })
352 } else {
353 None
354 }
355 }
356 LLMResponseType::ToolBatch => {
357 if payload.tool_uses.is_empty() {
359 tracing::error!(
360 session_id = payload.session_id,
361 "Received tool batch response with empty tool_uses"
362 );
363 return;
364 }
365
366 tracing::debug!(
367 session_id = payload.session_id,
368 tool_count = payload.tool_uses.len(),
369 "LLM requested tool batch execution"
370 );
371
372 let mut requests = Vec::with_capacity(payload.tool_uses.len());
374 for tool_info in &payload.tool_uses {
375 let input: HashMap<String, serde_json::Value> = tool_info
376 .input
377 .as_object()
378 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
379 .unwrap_or_default();
380
381 requests.push(ToolRequest {
382 tool_use_id: tool_info.id.clone(),
383 tool_name: tool_info.name.clone(),
384 input: input.clone(),
385 });
386
387 let (display_name, display_title) =
389 if let Some(tool) = self.tool_registry().get(&tool_info.name).await {
390 let config = tool.display_config();
391 (Some(config.display_name), Some((config.display_title)(&input)))
392 } else {
393 (None, None)
394 };
395
396 self.send_to_ui(ControllerEvent::ToolUse {
398 session_id: payload.session_id,
399 tool: tool_info.clone(),
400 display_name,
401 display_title,
402 turn_id: payload.turn_id.clone(),
403 }).await;
404 }
405
406 self.tool_executor
408 .execute_batch(
409 payload.session_id,
410 payload.turn_id.clone(),
411 requests,
412 self.cancel_token.clone(),
413 )
414 .await;
415
416 None
417 }
418 LLMResponseType::Complete => Some(ControllerEvent::Complete {
419 session_id: payload.session_id,
420 stop_reason: payload.stop_reason,
421 turn_id: payload.turn_id,
422 }),
423 LLMResponseType::Error => Some(ControllerEvent::Error {
424 session_id: payload.session_id,
425 error: payload.error.unwrap_or_else(|| "Unknown error".to_string()),
426 turn_id: payload.turn_id,
427 }),
428 LLMResponseType::TokenUpdate => {
429 let context_limit = if let Some(session) =
431 self.session_mgr.get_session_by_id(payload.session_id).await
432 {
433 session.context_limit()
434 } else {
435 0
436 };
437 Some(ControllerEvent::TokenUpdate {
438 session_id: payload.session_id,
439 input_tokens: payload.input_tokens,
440 output_tokens: payload.output_tokens,
441 context_limit,
442 })
443 }
444 };
445
446 if let Some(event) = event {
448 self.send_to_ui(event).await;
449 }
450 }
451
452 async fn handle_input(&self, payload: ControllerInputPayload) {
454 match payload.input_type {
455 InputType::Data => {
456 self.handle_data_input(payload).await;
457 }
458 InputType::Control => {
459 self.handle_control_input(payload).await;
460 }
461 }
462 }
463
464 async fn handle_data_input(&self, payload: ControllerInputPayload) {
466 let session_id = payload.session_id;
467
468 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
470 tracing::error!(session_id, "Session not found for data input");
471 self.emit_error(session_id, "Session not found".to_string(), payload.turn_id).await;
472 return;
473 };
474
475 let llm_payload = ToLLMPayload {
477 request_type: LLMRequestType::UserMessage,
478 content: payload.content,
479 tool_results: Vec::new(),
480 options: None,
481 turn_id: payload.turn_id,
482 compact_summaries: HashMap::new(),
483 };
484
485 let sent = session.send(llm_payload).await;
487 if !sent {
488 tracing::error!(session_id, "Failed to send message to session");
489 self.emit_error(
490 session_id,
491 "Failed to send message to session".to_string(),
492 None,
493 ).await;
494 }
495 }
496
497 async fn handle_control_input(&self, payload: ControllerInputPayload) {
499 let session_id = payload.session_id;
500
501 let Some(cmd) = payload.control_cmd else {
502 tracing::warn!(session_id, "Control input without command");
503 return;
504 };
505
506 match cmd {
507 ControlCmd::Interrupt => {
508 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
510 session.interrupt().await;
511 tracing::info!(session_id, "Session interrupted");
512 } else {
513 tracing::warn!(session_id, "Cannot interrupt: session not found");
514 }
515 }
516 ControlCmd::Shutdown => {
517 tracing::info!("Shutdown command received");
519 self.shutdown().await;
520 }
521 ControlCmd::Clear => {
522 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
524 session.clear_conversation().await;
525 tracing::info!(session_id, "Session conversation cleared");
526 self.emit_command_complete(session_id, cmd, true, None).await;
527 } else {
528 tracing::warn!(session_id, "Cannot clear: session not found");
529 self.emit_command_complete(
530 session_id,
531 cmd,
532 false,
533 Some("Session not found".to_string()),
534 ).await;
535 }
536 }
537 ControlCmd::Compact => {
538 if let Some(session) = self.session_mgr.get_session_by_id(session_id).await {
540 let result = session.force_compact().await;
541
542 if let Some(error) = result.error {
543 tracing::warn!(session_id, error = %error, "Session compaction failed");
545 self.emit_command_complete(session_id, cmd, false, Some(error)).await;
546 } else if !result.compacted {
547 tracing::info!(session_id, "Nothing to compact");
549 self.emit_command_complete(
550 session_id,
551 cmd,
552 true,
553 Some("Nothing to compact - not enough turns in conversation".to_string()),
554 ).await;
555 } else {
556 let message = format!(
558 "Conversation compacted\n Turns summarized: {}\n Turns kept: {}\n Messages: {} -> {}{}",
559 result.turns_compacted,
560 result.turns_kept,
561 result.messages_before,
562 result.messages_after,
563 if result.summary_length > 0 {
564 format!("\n Summary length: {} chars", result.summary_length)
565 } else {
566 String::new()
567 }
568 );
569 tracing::info!(
570 session_id,
571 turns_compacted = result.turns_compacted,
572 messages_before = result.messages_before,
573 messages_after = result.messages_after,
574 "Session compaction completed"
575 );
576 self.emit_command_complete(session_id, cmd, true, Some(message)).await;
577 }
578 } else {
579 tracing::warn!(session_id, "Cannot compact: session not found");
580 self.emit_command_complete(
581 session_id,
582 cmd,
583 false,
584 Some("Session not found".to_string()),
585 ).await;
586 }
587 }
588 }
589 }
590
591 async fn emit_error(&self, session_id: i64, error: String, turn_id: Option<TurnId>) {
593 self.send_to_ui(ControllerEvent::Error {
594 session_id,
595 error,
596 turn_id,
597 }).await;
598 }
599
600 async fn emit_command_complete(
602 &self,
603 session_id: i64,
604 command: ControlCmd,
605 success: bool,
606 message: Option<String>,
607 ) {
608 self.send_to_ui(ControllerEvent::CommandComplete {
609 session_id,
610 command,
611 success,
612 message,
613 }).await;
614 }
615
616 async fn handle_tool_batch_result(&self, batch_result: ToolBatchResult) {
618 use crate::controller::types::ToolResultInfo;
619
620 let session_id = batch_result.session_id;
621
622 let Some(session) = self.session_mgr.get_session_by_id(session_id).await else {
623 tracing::error!(session_id, "Session not found for tool result");
624 return;
625 };
626
627 let mut compact_summaries = HashMap::new();
629 let tool_results: Vec<ToolResultInfo> = batch_result
630 .results
631 .iter()
632 .map(|result| {
633 let (content, is_error) = if let Some(ref error) = result.error {
634 (error.clone(), true)
635 } else {
636 (result.content.clone(), false)
637 };
638
639 if let Some(ref summary) = result.compact_summary {
641 compact_summaries.insert(result.tool_use_id.clone(), summary.clone());
642 tracing::debug!(
643 tool_use_id = %result.tool_use_id,
644 summary_len = summary.len(),
645 "Extracted compact summary for tool result"
646 );
647 }
648
649 ToolResultInfo {
650 tool_use_id: result.tool_use_id.clone(),
651 content,
652 is_error,
653 }
654 })
655 .collect();
656
657 tracing::info!(
658 session_id,
659 tool_count = tool_results.len(),
660 compact_summary_count = compact_summaries.len(),
661 "Sending tool results to session with compact summaries"
662 );
663
664 let llm_payload = ToLLMPayload {
666 request_type: LLMRequestType::ToolResult,
667 content: String::new(),
668 tool_results,
669 options: None,
670 turn_id: batch_result.turn_id,
671 compact_summaries,
672 };
673
674 let sent = session.send(llm_payload).await;
676 if !sent {
677 tracing::error!(session_id, "Failed to send tool result to session");
678 } else {
679 tracing::debug!(
680 session_id,
681 batch_id = batch_result.batch_id,
682 "Sent tool results to session"
683 );
684 }
685 }
686
687 pub async fn shutdown(&self) {
690 if self
691 .shutdown
692 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
693 .is_err()
694 {
695 return; }
697
698 tracing::info!("Controller shutting down");
699
700 self.session_mgr.shutdown().await;
702
703 self.cancel_token.cancel();
705
706 tracing::info!("Controller shutdown complete");
707 }
708
709 pub fn is_shutdown(&self) -> bool {
711 self.shutdown.load(Ordering::SeqCst)
712 }
713
714 pub fn is_started(&self) -> bool {
716 self.started.load(Ordering::SeqCst)
717 }
718
719 pub async fn create_session(&self, config: LLMSessionConfig) -> Result<i64, LlmError> {
732 let session_id = self
733 .session_mgr
734 .create_session(config, self.from_llm_tx.clone(), self.channel_size)
735 .await?;
736
737 tracing::info!(session_id, "Session created via controller");
738 Ok(session_id)
739 }
740
741 pub async fn get_session(&self, session_id: i64) -> Option<Arc<LLMSession>> {
746 self.session_mgr.get_session_by_id(session_id).await
747 }
748
749 pub async fn session_count(&self) -> usize {
751 self.session_mgr.session_count().await
752 }
753
754 pub async fn send_input(&self, input: ControllerInputPayload) -> Result<(), ControllerError> {
764 if self.is_shutdown() {
765 return Err(ControllerError::Shutdown);
766 }
767
768 match tokio::time::timeout(SEND_INPUT_TIMEOUT, self.input_tx.send(input)).await {
769 Ok(Ok(())) => Ok(()),
770 Ok(Err(_)) => Err(ControllerError::ChannelClosed),
771 Err(_) => Err(ControllerError::SendTimeout(SEND_INPUT_TIMEOUT.as_secs())),
772 }
773 }
774
775 pub async fn get_session_token_usage(
779 &self,
780 session_id: i64,
781 ) -> Option<crate::controller::usage::TokenMeter> {
782 self.token_usage.get_session_usage(session_id).await
783 }
784
785 pub async fn get_model_token_usage(&self, model: &str) -> Option<crate::controller::usage::TokenMeter> {
787 self.token_usage.get_model_usage(model).await
788 }
789
790 pub async fn get_total_token_usage(&self) -> crate::controller::usage::TokenMeter {
792 self.token_usage.get_total_usage().await
793 }
794
795 pub fn token_usage(&self) -> &TokenUsageTracker {
797 &self.token_usage
798 }
799
800 pub fn tool_registry(&self) -> &Arc<ToolRegistry> {
804 &self.tool_registry
805 }
806
807}