Skip to main content

agent_core/controller/session/
session.rs

1// This implements a single session with an LLM
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::sync::{mpsc, Mutex, RwLock};
9use tokio_util::sync::CancellationToken;
10
11use crate::client::error::LlmError;
12use crate::client::models::Tool as LLMTool;
13use crate::client::providers::anthropic::AnthropicProvider;
14use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
15use crate::client::providers::cohere::CohereProvider;
16use crate::client::providers::gemini::GeminiProvider;
17use crate::client::providers::openai::OpenAIProvider;
18use crate::client::LLMClient;
19
20use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
21use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
22
23/// Creates an LLMClient from the session configuration.
24fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
25    match config.provider {
26        LLMProvider::Anthropic => {
27            let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
28            LLMClient::new(Box::new(provider))
29        }
30        LLMProvider::OpenAI => {
31            // Check for Azure configuration first
32            let provider = if let (Some(resource), Some(deployment)) =
33                (&config.azure_resource, &config.azure_deployment)
34            {
35                let api_version = config
36                    .azure_api_version
37                    .clone()
38                    .unwrap_or_else(|| "2024-10-21".to_string());
39                OpenAIProvider::azure(
40                    config.api_key.clone(),
41                    resource.clone(),
42                    deployment.clone(),
43                    api_version,
44                )
45            } else if let Some(base_url) = &config.base_url {
46                OpenAIProvider::with_base_url(
47                    config.api_key.clone(),
48                    config.model.clone(),
49                    base_url.clone(),
50                )
51            } else {
52                OpenAIProvider::new(config.api_key.clone(), config.model.clone())
53            };
54            LLMClient::new(Box::new(provider))
55        }
56        LLMProvider::Google => {
57            let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
58            LLMClient::new(Box::new(provider))
59        }
60        LLMProvider::Cohere => {
61            let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
62            LLMClient::new(Box::new(provider))
63        }
64        LLMProvider::Bedrock => {
65            // Bedrock requires all four credential/region fields
66            let region = config.bedrock_region.clone().ok_or_else(|| {
67                LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_region")
68            })?;
69            let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
70                LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_access_key_id")
71            })?;
72            let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
73                LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_secret_access_key")
74            })?;
75
76            let credentials = match &config.bedrock_session_token {
77                Some(token) => {
78                    BedrockCredentials::with_session_token(access_key_id, secret_access_key, token.clone())
79                }
80                None => BedrockCredentials::new(access_key_id, secret_access_key),
81            };
82
83            let provider = BedrockProvider::new(credentials, region, config.model.clone());
84            LLMClient::new(Box::new(provider))
85        }
86    }
87}
88use crate::controller::types::{
89    AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
90};
91
92/// Token usage statistics for the session
93#[derive(Debug, Clone, Default)]
94pub struct TokenUsage {
95    /// Total input tokens across all requests
96    pub total_input_tokens: i64,
97    /// Total output tokens across all requests
98    pub total_output_tokens: i64,
99    /// Number of completed LLM requests
100    pub request_count: i64,
101    /// Input tokens from the most recent request
102    pub last_input_tokens: i64,
103    /// Output tokens from the most recent request
104    pub last_output_tokens: i64,
105}
106
107/// Current status of an LLM session
108#[derive(Debug, Clone)]
109pub struct SessionStatus {
110    /// Session identifier
111    pub session_id: i64,
112    /// Model name
113    pub model: String,
114    /// When session was created
115    pub created_at: Instant,
116    /// Number of messages in conversation
117    pub conversation_len: usize,
118    /// Current input tokens (context size)
119    pub context_used: i64,
120    /// Model's context window limit
121    pub context_limit: i32,
122    /// Context utilization percentage (0-100)
123    pub utilization: f64,
124    /// Cumulative input tokens
125    pub total_input: i64,
126    /// Cumulative output tokens
127    pub total_output: i64,
128    /// Number of API calls
129    pub request_count: i64,
130}
131
132/// Result of a manual compaction operation.
133/// Used by `force_compact()` to report what happened during compaction.
134#[derive(Debug, Clone, Default)]
135pub struct CompactResult {
136    /// Whether compaction was actually performed.
137    /// False if there weren't enough turns to compact.
138    pub compacted: bool,
139    /// Number of messages before compaction.
140    pub messages_before: usize,
141    /// Number of messages after compaction.
142    pub messages_after: usize,
143    /// Number of turns that were summarized.
144    pub turns_compacted: usize,
145    /// Number of recent turns that were preserved.
146    pub turns_kept: usize,
147    /// Character length of the generated summary (for LLM compaction).
148    pub summary_length: usize,
149    /// Error message if compaction failed.
150    pub error: Option<String>,
151}
152
153/// Global counter for generating unique session IDs
154static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
155
156/// A session that manages communication with an LLM
157pub struct LLMSession {
158    // Session identification
159    id: AtomicI64,
160
161    // LLM client
162    client: LLMClient,
163
164    // Channels for communication
165    to_llm_tx: mpsc::Sender<ToLLMPayload>,
166    to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
167    from_llm: mpsc::Sender<FromLLMPayload>,
168
169    // Session configuration
170    config: LLMSessionConfig,
171
172    // Runtime overrides for LLM options
173    system_prompt: RwLock<Option<String>>,
174    max_tokens: AtomicI64,
175    created_at: Instant,
176
177    // Conversation state
178    conversation: RwLock<Arc<Vec<Message>>>,
179
180    // Shutdown management
181    shutdown: AtomicBool,
182    cancel_token: CancellationToken,
183
184    // Per-request cancellation
185    current_cancel: Mutex<Option<CancellationToken>>,
186
187    // Current turn ID for the active request (used for filtering on interrupt)
188    current_turn_id: RwLock<Option<TurnId>>,
189
190    // Token tracking for current request
191    current_input_tokens: AtomicI64,
192    current_output_tokens: AtomicI64,
193
194    // Cumulative token tracking
195    request_count: AtomicI64,
196
197    // Tool definitions for LLM API calls
198    tool_definitions: RwLock<Vec<LLMTool>>,
199
200    // Compaction support
201    compactor: Option<Box<dyn Compactor>>,
202    llm_compactor: Option<LLMCompactor>,
203    context_limit: AtomicI32,
204    compact_summaries: RwLock<HashMap<String, String>>,
205}
206
207impl LLMSession {
208    /// Creates a new LLM session
209    ///
210    /// # Arguments
211    /// * `config` - Session configuration
212    /// * `from_llm` - Sender for outgoing responses
213    /// * `cancel_token` - Token for session cancellation
214    /// * `channel_size` - Buffer size for the session's input channel
215    ///
216    /// # Errors
217    /// Returns an error if the LLM client fails to initialize (e.g., TLS setup failure)
218    pub fn new(
219        config: LLMSessionConfig,
220        from_llm: mpsc::Sender<FromLLMPayload>,
221        cancel_token: CancellationToken,
222        channel_size: usize,
223    ) -> Result<Self, LlmError> {
224        let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
225        let (to_llm_tx, to_llm_rx) = mpsc::channel(channel_size);
226        let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
227        let system_prompt = config.system_prompt.clone();
228
229        // Create the LLMClient client based on the provider
230        let client = create_llm_client(&config)?;
231
232        // Create compactor if configured
233        let mut compactor: Option<Box<dyn Compactor>> = None;
234        let mut llm_compactor: Option<LLMCompactor> = None;
235
236        if let Some(ref compactor_type) = config.compaction {
237            match compactor_type {
238                CompactorType::Threshold(c) => {
239                    match ThresholdCompactor::new(c.threshold, c.keep_recent_turns, c.tool_compaction) {
240                        Ok(tc) => {
241                            tracing::info!(
242                                threshold = c.threshold,
243                                keep_recent_turns = c.keep_recent_turns,
244                                tool_compaction = %c.tool_compaction,
245                                "Threshold compaction enabled for session"
246                            );
247                            compactor = Some(Box::new(tc) as Box<dyn Compactor>);
248                        }
249                        Err(e) => {
250                            tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
251                        }
252                    }
253                }
254                CompactorType::LLM(c) => {
255                    // Create a separate LLMClient client for LLM compaction
256                    let llm_client = create_llm_client(&config)?;
257
258                    match LLMCompactor::new(llm_client, c.clone()) {
259                        Ok(lc) => {
260                            tracing::info!(
261                                threshold = c.threshold,
262                                keep_recent_turns = c.keep_recent_turns,
263                                "LLM compaction enabled for session"
264                            );
265                            llm_compactor = Some(lc);
266                        }
267                        Err(e) => {
268                            tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
269                        }
270                    }
271                }
272            }
273        }
274
275        let context_limit = config.context_limit;
276
277        Ok(Self {
278            id: AtomicI64::new(session_id),
279            client,
280            to_llm_tx,
281            to_llm_rx: Mutex::new(to_llm_rx),
282            from_llm,
283            config,
284            system_prompt: RwLock::new(system_prompt),
285            max_tokens: AtomicI64::new(max_tokens),
286            created_at: Instant::now(),
287            conversation: RwLock::new(Arc::new(Vec::new())),
288            shutdown: AtomicBool::new(false),
289            cancel_token,
290            current_cancel: Mutex::new(None),
291            current_turn_id: RwLock::new(None),
292            current_input_tokens: AtomicI64::new(0),
293            current_output_tokens: AtomicI64::new(0),
294            request_count: AtomicI64::new(0),
295            tool_definitions: RwLock::new(Vec::new()),
296            compactor,
297            llm_compactor,
298            context_limit: AtomicI32::new(context_limit),
299            compact_summaries: RwLock::new(HashMap::new()),
300        })
301    }
302
303    /// Returns the session ID
304    pub fn id(&self) -> i64 {
305        self.id.load(Ordering::SeqCst)
306    }
307
308    /// Returns when the session was created
309    pub fn created_at(&self) -> Instant {
310        self.created_at
311    }
312
313    /// Returns the model for this session
314    pub fn model(&self) -> &str {
315        &self.config.model
316    }
317
318    // ---- Max Tokens ----
319
320    /// Sets the default maximum tokens for the session
321    pub fn set_max_tokens(&self, max_tokens: i64) {
322        self.max_tokens.store(max_tokens, Ordering::SeqCst);
323    }
324
325    /// Returns the current max tokens setting
326    pub fn max_tokens(&self) -> i64 {
327        self.max_tokens.load(Ordering::SeqCst)
328    }
329
330    /// Returns the context limit for this session's model
331    pub fn context_limit(&self) -> i32 {
332        self.context_limit.load(Ordering::SeqCst)
333    }
334
335    // ---- System Prompt ----
336
337    /// Sets the default system prompt for the session
338    pub async fn set_system_prompt(&self, prompt: String) {
339        let mut guard = self.system_prompt.write().await;
340        *guard = Some(prompt);
341    }
342
343    /// Clears the default system prompt
344    pub async fn clear_system_prompt(&self) {
345        let mut guard = self.system_prompt.write().await;
346        *guard = None;
347    }
348
349    /// Returns the current system prompt
350    pub async fn system_prompt(&self) -> Option<String> {
351        self.system_prompt.read().await.clone()
352    }
353
354    // ---- Tools ----
355
356    /// Sets the tool definitions for this session.
357    /// Tools will be included in all subsequent LLM API calls.
358    pub async fn set_tools(&self, tools: Vec<LLMTool>) {
359        let mut guard = self.tool_definitions.write().await;
360        *guard = tools;
361    }
362
363    /// Clears all tool definitions for this session.
364    pub async fn clear_tools(&self) {
365        let mut guard = self.tool_definitions.write().await;
366        guard.clear();
367    }
368
369    /// Returns a copy of the current tool definitions.
370    pub async fn tools(&self) -> Vec<LLMTool> {
371        self.tool_definitions.read().await.clone()
372    }
373
374    // ---- Compaction ----
375
376    /// Stores compact summaries for later use during compaction.
377    /// Called when tool results are received.
378    async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
379        if summaries.is_empty() {
380            tracing::warn!(
381                session_id = self.id(),
382                "No compact summaries provided with tool results"
383            );
384            return;
385        }
386        let mut guard = self.compact_summaries.write().await;
387        for (tool_use_id, summary) in summaries {
388            tracing::info!(
389                session_id = self.id(),
390                tool_use_id = %tool_use_id,
391                summary_len = summary.len(),
392                summary_preview = %summary.chars().take(50).collect::<String>(),
393                "Storing compact summary"
394            );
395            guard.insert(tool_use_id.clone(), summary.clone());
396        }
397        tracing::info!(
398            session_id = self.id(),
399            new_summaries = summaries.len(),
400            total_stored = guard.len(),
401            "Stored compact summaries for tool results"
402        );
403    }
404
405    /// Performs compaction if needed based on context usage.
406    /// Should be called before each LLM request.
407    async fn maybe_compact(&self) {
408        let context_used = self.current_input_tokens.load(Ordering::SeqCst);
409        let context_limit = self.context_limit.load(Ordering::SeqCst);
410        let conversation_len = self.conversation.read().await.len();
411        let summaries_count = self.compact_summaries.read().await.len();
412
413        let utilization = if context_limit > 0 {
414            context_used as f64 / context_limit as f64
415        } else {
416            0.0
417        };
418
419        tracing::debug!(
420            session_id = self.id(),
421            context_used,
422            context_limit,
423            utilization = format!("{:.2}%", utilization * 100.0),
424            conversation_len,
425            summaries_available = summaries_count,
426            "Checking if compaction needed"
427        );
428
429        // Check for LLM compactor first (async compaction)
430        if let Some(ref llm_compactor) = self.llm_compactor {
431            if !llm_compactor.should_compact(context_used, context_limit) {
432                tracing::debug!(
433                    session_id = self.id(),
434                    "LLM compaction not triggered"
435                );
436                return;
437            }
438
439            // Get conversation and summaries for async compaction
440            let summaries = self.compact_summaries.read().await.clone();
441            let conversation_arc = {
442                let guard = self.conversation.read().await;
443                Arc::clone(&*guard) // O(1)
444            };
445            let conversation = Arc::try_unwrap(conversation_arc)
446                .unwrap_or_else(|arc| (*arc).clone());
447
448            tracing::info!(
449                session_id = self.id(),
450                conversation_len = conversation.len(),
451                summaries_count = summaries.len(),
452                "Starting LLM compaction"
453            );
454
455            // Perform async LLM compaction
456            match llm_compactor.compact_async(conversation, &summaries).await {
457                Ok((new_conversation, result)) => {
458                    // Replace conversation with compacted version
459                    *self.conversation.write().await = Arc::new(new_conversation);
460
461                    if result.turns_compacted > 0 {
462                        tracing::info!(
463                            session_id = self.id(),
464                            turns_compacted = result.turns_compacted,
465                            "LLM compaction completed"
466                        );
467                    }
468                }
469                Err(e) => {
470                    tracing::error!(
471                        session_id = self.id(),
472                        error = %e,
473                        "LLM compaction failed"
474                    );
475                }
476            }
477            return;
478        }
479
480        // Fall back to sync compactor (ThresholdCompactor)
481        let compactor = match &self.compactor {
482            Some(c) => c,
483            None => {
484                tracing::debug!(
485                    session_id = self.id(),
486                    "No compactor configured"
487                );
488                return;
489            }
490        };
491
492        if !compactor.should_compact(context_used, context_limit) {
493            tracing::debug!(
494                session_id = self.id(),
495                "Threshold compaction not triggered"
496            );
497            return;
498        }
499
500        // Perform sync compaction
501        let summaries = self.compact_summaries.read().await.clone();
502        let mut guard = self.conversation.write().await;
503
504        tracing::info!(
505            session_id = self.id(),
506            conversation_len = guard.len(),
507            summaries_count = summaries.len(),
508            "Starting threshold compaction"
509        );
510
511        let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
512
513        tracing::info!(
514            session_id = self.id(),
515            tool_results_summarized = result.tool_results_summarized,
516            tool_results_redacted = result.tool_results_redacted,
517            turns_compacted = result.turns_compacted,
518            conversation_len_after = guard.len(),
519            "Threshold compaction completed"
520        );
521    }
522
523    // ---- Session Control ----
524
525    /// Clears the conversation history and compact summaries.
526    pub async fn clear_conversation(&self) {
527        let mut guard = self.conversation.write().await;
528        Arc::make_mut(&mut *guard).clear();
529
530        let mut summaries = self.compact_summaries.write().await;
531        summaries.clear();
532
533        // Reset token counters
534        self.current_input_tokens.store(0, Ordering::SeqCst);
535        self.current_output_tokens.store(0, Ordering::SeqCst);
536
537        tracing::info!(session_id = self.id(), "Conversation cleared");
538    }
539
540    /// Forces compaction to run immediately, regardless of threshold.
541    /// Returns a `CompactResult` with details about what happened.
542    pub async fn force_compact(&self) -> CompactResult {
543        // Check for LLM compactor first (async compaction)
544        if let Some(ref llm_compactor) = self.llm_compactor {
545            let summaries = self.compact_summaries.read().await.clone();
546            let conversation_arc = {
547                let guard = self.conversation.read().await;
548                Arc::clone(&*guard) // O(1)
549            };
550            let conversation = Arc::try_unwrap(conversation_arc)
551                .unwrap_or_else(|arc| (*arc).clone());
552            let messages_before = conversation.len();
553            let turns_before = self.count_unique_turns(&conversation);
554
555            match llm_compactor.compact_async(conversation, &summaries).await {
556                Ok((new_conversation, result)) => {
557                    let messages_after = new_conversation.len();
558                    let turns_after = self.count_unique_turns(&new_conversation);
559                    let compacted = messages_after < messages_before;
560
561                    // Extract summary length if compaction happened
562                    let summary_length = if compacted && !new_conversation.is_empty() {
563                        self.extract_summary_length(&new_conversation[0])
564                    } else {
565                        0
566                    };
567
568                    *self.conversation.write().await = Arc::new(new_conversation);
569
570                    if result.turns_compacted > 0 {
571                        tracing::info!(
572                            session_id = self.id(),
573                            turns_compacted = result.turns_compacted,
574                            messages_before,
575                            messages_after,
576                            "Forced LLM compaction completed"
577                        );
578                    }
579
580                    return CompactResult {
581                        compacted,
582                        messages_before,
583                        messages_after,
584                        turns_compacted: turns_before.saturating_sub(turns_after),
585                        turns_kept: turns_after,
586                        summary_length,
587                        error: None,
588                    };
589                }
590                Err(e) => {
591                    tracing::error!(
592                        session_id = self.id(),
593                        error = %e,
594                        "Forced LLM compaction failed"
595                    );
596                    return CompactResult {
597                        compacted: false,
598                        messages_before,
599                        messages_after: messages_before,
600                        turns_compacted: 0,
601                        turns_kept: turns_before,
602                        summary_length: 0,
603                        error: Some(format!("Compaction failed: {}", e)),
604                    };
605                }
606            }
607        }
608
609        // Fall back to sync compactor (ThresholdCompactor)
610        if let Some(ref compactor) = self.compactor {
611            let summaries = self.compact_summaries.read().await.clone();
612            let mut guard = self.conversation.write().await;
613            let messages_before = guard.len();
614            let turns_before = self.count_unique_turns(&guard);
615
616            let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
617
618            let messages_after = guard.len();
619            let turns_after = self.count_unique_turns(&guard);
620            let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
621
622            if result.total_compacted() > 0 {
623                tracing::info!(
624                    session_id = self.id(),
625                    tool_results_summarized = result.tool_results_summarized,
626                    tool_results_redacted = result.tool_results_redacted,
627                    turns_compacted = result.turns_compacted,
628                    "Forced threshold compaction completed"
629                );
630            }
631
632            return CompactResult {
633                compacted,
634                messages_before,
635                messages_after,
636                turns_compacted: turns_before.saturating_sub(turns_after),
637                turns_kept: turns_after,
638                summary_length: 0,
639                error: None,
640            };
641        }
642
643        // No compactor configured
644        CompactResult {
645            compacted: false,
646            error: Some("No compactor configured".to_string()),
647            ..Default::default()
648        }
649    }
650
651    /// Count unique turn IDs in a conversation.
652    fn count_unique_turns(&self, conversation: &[Message]) -> usize {
653        use std::collections::HashSet;
654        let mut turn_ids = HashSet::new();
655        for msg in conversation {
656            turn_ids.insert(msg.turn_id().clone());
657        }
658        turn_ids.len()
659    }
660
661    /// Extract the summary length from a summary message.
662    fn extract_summary_length(&self, message: &Message) -> usize {
663        if let Message::User(user_msg) = message {
664            for block in &user_msg.content {
665                if let ContentBlock::Text(text_block) = block {
666                    if text_block.text.starts_with("[Previous conversation summary]") {
667                        return text_block.text.len();
668                    }
669                }
670            }
671        }
672        0
673    }
674
675    /// Sends a message to the LLM session for processing.
676    /// Returns false if the session is shutdown or the channel is closed.
677    pub async fn send(&self, msg: ToLLMPayload) -> bool {
678        if self.shutdown.load(Ordering::SeqCst) {
679            return false;
680        }
681        self.to_llm_tx.send(msg).await.is_ok()
682    }
683
684    /// Interrupts the currently executing LLM request.
685    /// This cancels any in-flight request and removes all messages from the
686    /// current turn from conversation history. Does not shutdown the session.
687    pub async fn interrupt(&self) {
688        let guard = self.current_cancel.lock().await;
689        if let Some(token) = guard.as_ref() {
690            token.cancel();
691
692            // Remove all messages from the current turn from conversation history.
693            // This prevents any messages from the cancelled turn (user message,
694            // assistant responses, etc.) from being included in subsequent API calls.
695            let turn_id = self.current_turn_id.read().await.clone();
696            if let Some(turn_id) = turn_id {
697                let mut guard = self.conversation.write().await;
698                let original_len = guard.len();
699                Arc::make_mut(&mut *guard).retain(|msg| msg.turn_id() != &turn_id);
700                let removed = original_len - guard.len();
701                tracing::debug!(
702                    session_id = self.id(),
703                    turn_id = %turn_id,
704                    messages_removed = removed,
705                    conversation_length = guard.len(),
706                    "Removed messages from cancelled turn"
707                );
708            }
709        }
710    }
711
712    /// Gracefully shuts down the session.
713    /// After calling this, the session will not accept new messages.
714    pub fn shutdown(&self) {
715        // Mark as shutdown to prevent new messages
716        self.shutdown.store(true, Ordering::SeqCst);
717        // Cancel the session's main loop
718        self.cancel_token.cancel();
719    }
720
721    /// Returns true if the session has been shutdown
722    pub fn is_shutdown(&self) -> bool {
723        self.shutdown.load(Ordering::SeqCst)
724    }
725
726    // ---- Main Processing Loop ----
727
728    /// Starts the session's main processing loop.
729    /// This method processes requests from the ToLLM channel until shutdown.
730    /// Should be spawned as a tokio task.
731    pub async fn start(&self) {
732        tracing::info!(session_id = self.id(), "Session starting");
733
734        loop {
735            let mut rx_guard = self.to_llm_rx.lock().await;
736
737            tokio::select! {
738                _ = self.cancel_token.cancelled() => {
739                    tracing::info!(session_id = self.id(), "Session cancelled");
740                    break;
741                }
742                msg = rx_guard.recv() => {
743                    match msg {
744                        Some(request) => {
745                            // Drop the lock before handling the request
746                            drop(rx_guard);
747                            self.handle_request(request).await;
748                        }
749                        None => {
750                            // Channel closed
751                            tracing::info!(session_id = self.id(), "Session channel closed");
752                            break;
753                        }
754                    }
755                }
756            }
757        }
758
759        tracing::info!(session_id = self.id(), "Session stopped");
760    }
761
762    // ---- Request Helper Methods ----
763
764    /// Returns the current timestamp in milliseconds.
765    fn current_timestamp_millis() -> i64 {
766        std::time::SystemTime::now()
767            .duration_since(std::time::UNIX_EPOCH)
768            .map(|d| d.as_millis() as i64)
769            .unwrap_or(0)
770    }
771
772    /// Prepares the request context by setting up cancellation token and turn ID.
773    /// Returns the request token and effective turn ID.
774    async fn prepare_request(&self, request: &ToLLMPayload) -> (CancellationToken, TurnId) {
775        let request_token = CancellationToken::new();
776        {
777            let mut guard = self.current_cancel.lock().await;
778            *guard = Some(request_token.clone());
779        }
780
781        let effective_turn_id = request
782            .turn_id
783            .clone()
784            .unwrap_or_else(|| TurnId::new_user_turn(0));
785        {
786            let mut guard = self.current_turn_id.write().await;
787            *guard = Some(effective_turn_id.clone());
788        }
789
790        (request_token, effective_turn_id)
791    }
792
793    /// Builds the message options with tools for the LLM request.
794    async fn build_message_options(&self) -> crate::client::models::MessageOptions {
795        use crate::client::models::MessageOptions;
796
797        let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
798        let tools = self.tool_definitions.read().await.clone();
799        let tools_option = if tools.is_empty() { None } else { Some(tools) };
800
801        MessageOptions {
802            max_tokens: Some(max_tokens),
803            temperature: self.config.temperature,
804            tools: tools_option,
805            ..Default::default()
806        }
807    }
808
809    /// Clears the request cancellation token and turn ID after request completion.
810    async fn cleanup_request(&self) {
811        {
812            let mut guard = self.current_cancel.lock().await;
813            *guard = None;
814        }
815        {
816            let mut guard = self.current_turn_id.write().await;
817            *guard = None;
818        }
819    }
820
821    /// Handles a single request from the ToLLM channel.
822    async fn handle_request(&self, request: ToLLMPayload) {
823        if self.config.streaming {
824            self.handle_streaming_request(request).await;
825        } else {
826            self.handle_non_streaming_request(request).await;
827        }
828    }
829
830    /// Handles a non-streaming request.
831    async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
832        use super::convert::{from_llm_message, to_llm_messages};
833        use crate::controller::types::{LLMRequestType, LLMResponseType};
834        use crate::client::models::Message as LLMMessage;
835
836        // Prepare request context
837        let (_request_token, effective_turn_id) = self.prepare_request(&request).await;
838
839        let session_id = self.id();
840        tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
841
842        // Build the conversation messages
843        let mut llm_messages: Vec<LLMMessage> = Vec::new();
844
845        // Add system prompt if set
846        if let Some(prompt) = self.system_prompt.read().await.as_ref() {
847            llm_messages.push(LLMMessage::system(prompt.clone()));
848        }
849
850        // Add conversation history
851        let conversation = self.conversation.read().await;
852        llm_messages.extend(to_llm_messages(&conversation));
853        drop(conversation);
854
855        // Add the new message based on request type
856        match request.request_type {
857            LLMRequestType::UserMessage => {
858                if !request.content.is_empty() {
859                    llm_messages.push(LLMMessage::user(&request.content));
860
861                    // Add user message to conversation history
862                    let user_msg = Message::User(UserMessage {
863                        id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
864                        session_id: session_id.to_string(),
865                        turn_id: effective_turn_id.clone(),
866                        created_at: Self::current_timestamp_millis(),
867                        content: vec![ContentBlock::text(&request.content)],
868                    });
869                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
870                }
871            }
872            LLMRequestType::ToolResult => {
873                // Store compact summaries for later compaction
874                self.store_compact_summaries(&request.compact_summaries).await;
875
876                // Add tool result messages using LLM client's proper format
877                for tool_result in &request.tool_results {
878                    llm_messages.push(LLMMessage::tool_result(
879                        &tool_result.tool_use_id,
880                        &tool_result.content,
881                        tool_result.is_error,
882                    ));
883
884                    // Get compact summary if available
885                    let compact_summary = request
886                        .compact_summaries
887                        .get(&tool_result.tool_use_id)
888                        .cloned();
889
890                    // Add tool result to conversation history
891                    let user_msg = Message::User(UserMessage {
892                        id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
893                        session_id: session_id.to_string(),
894                        turn_id: effective_turn_id.clone(),
895                        created_at: Self::current_timestamp_millis(),
896                        content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
897                            tool_use_id: tool_result.tool_use_id.clone(),
898                            content: tool_result.content.clone(),
899                            is_error: tool_result.is_error,
900                            compact_summary,
901                        })],
902                    });
903                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
904                }
905            }
906        }
907
908        // Perform compaction if needed before LLM call
909        self.maybe_compact().await;
910
911        // Build message options with tools
912        let options = self.build_message_options().await;
913
914        // Call the LLM
915        let result = self.client.send_message(&llm_messages, &options).await;
916
917        match result {
918            Ok(response) => {
919                // Convert response to our types
920                let content_blocks = from_llm_message(&response);
921
922                // Extract text for the text chunk response
923                let text: String = content_blocks
924                    .iter()
925                    .filter_map(|block| {
926                        if let ContentBlock::Text(t) = block {
927                            Some(t.text.clone())
928                        } else {
929                            None
930                        }
931                    })
932                    .collect::<Vec<_>>()
933                    .join("");
934
935                // Send text chunk response
936                if !text.is_empty() {
937                    let payload = FromLLMPayload {
938                        session_id,
939                        response_type: LLMResponseType::TextChunk,
940                        text: text.clone(),
941                        turn_id: request.turn_id.clone(),
942                        ..Default::default()
943                    };
944                    let _ = self.from_llm.send(payload).await;
945                }
946
947                // Check for tool use
948                for block in &content_blocks {
949                    if let ContentBlock::ToolUse(tool_use) = block {
950                        let payload = FromLLMPayload {
951                            session_id,
952                            response_type: LLMResponseType::ToolUse,
953                            tool_use: Some(crate::controller::types::ToolUseInfo {
954                                id: tool_use.id.clone(),
955                                name: tool_use.name.clone(),
956                                input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
957                            }),
958                            turn_id: request.turn_id.clone(),
959                            ..Default::default()
960                        };
961                        let _ = self.from_llm.send(payload).await;
962                    }
963                }
964
965                // Add assistant message to conversation history
966                let now = Self::current_timestamp_millis();
967                let asst_msg = Message::Assistant(AssistantMessage {
968                    id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
969                    session_id: session_id.to_string(),
970                    turn_id: effective_turn_id.clone(),
971                    parent_id: String::new(),
972                    created_at: now,
973                    completed_at: Some(now),
974                    model_id: self.config.model.clone(),
975                    provider_id: String::new(),
976                    input_tokens: 0,
977                    output_tokens: 0,
978                    cache_read_tokens: 0,
979                    cache_write_tokens: 0,
980                    finish_reason: None,
981                    error: None,
982                    content: content_blocks,
983                });
984                Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
985
986                // Send completion
987                let payload = FromLLMPayload {
988                    session_id,
989                    response_type: LLMResponseType::Complete,
990                    is_complete: true,
991                    turn_id: request.turn_id.clone(),
992                    ..Default::default()
993                };
994                let _ = self.from_llm.send(payload).await;
995
996                // Update request count
997                self.request_count.fetch_add(1, Ordering::SeqCst);
998
999                tracing::debug!(session_id, "Request completed successfully");
1000            }
1001            Err(err) => {
1002                tracing::error!(session_id, error = %err, "LLM request failed");
1003
1004                let payload = FromLLMPayload {
1005                    session_id,
1006                    response_type: LLMResponseType::Error,
1007                    error: Some(err.to_string()),
1008                    turn_id: request.turn_id,
1009                    ..Default::default()
1010                };
1011                let _ = self.from_llm.send(payload).await;
1012            }
1013        }
1014
1015        // Clear the request cancellation token and turn ID when done
1016        self.cleanup_request().await;
1017    }
1018
1019    /// Handles a streaming request.
1020    async fn handle_streaming_request(&self, request: ToLLMPayload) {
1021        use super::convert::to_llm_messages;
1022        use crate::controller::types::{LLMRequestType, LLMResponseType};
1023        use futures::StreamExt;
1024        use crate::client::models::{
1025            ContentBlockType, Message as LLMMessage, StreamEvent,
1026        };
1027
1028        // Prepare request context
1029        let (request_token, effective_turn_id) = self.prepare_request(&request).await;
1030
1031        let session_id = self.id();
1032        tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
1033
1034        // Build the conversation messages
1035        let mut llm_messages: Vec<LLMMessage> = Vec::new();
1036
1037        // Add system prompt if set
1038        if let Some(prompt) = self.system_prompt.read().await.as_ref() {
1039            llm_messages.push(LLMMessage::system(prompt.clone()));
1040        }
1041
1042        // Add conversation history
1043        let conversation = self.conversation.read().await;
1044        llm_messages.extend(to_llm_messages(&conversation));
1045        drop(conversation);
1046
1047        // Add the new message based on request type
1048        match request.request_type {
1049            LLMRequestType::UserMessage => {
1050                if !request.content.is_empty() {
1051                    llm_messages.push(LLMMessage::user(&request.content));
1052
1053                    // Add user message to conversation history
1054                    let user_msg = Message::User(UserMessage {
1055                        id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
1056                        session_id: session_id.to_string(),
1057                        turn_id: effective_turn_id.clone(),
1058                        created_at: Self::current_timestamp_millis(),
1059                        content: vec![ContentBlock::text(&request.content)],
1060                    });
1061                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1062                }
1063            }
1064            LLMRequestType::ToolResult => {
1065                // Store compact summaries for later compaction
1066                self.store_compact_summaries(&request.compact_summaries).await;
1067
1068                // Log conversation state before adding tool results (streaming path)
1069                {
1070                    let conv = self.conversation.read().await;
1071                    tracing::debug!(
1072                        session_id,
1073                        conversation_len = conv.len(),
1074                        tool_result_count = request.tool_results.len(),
1075                        "STREAMING ToolResult: conversation state before adding results"
1076                    );
1077                }
1078                // Add tool result messages using LLM client's proper format
1079                for tool_result in &request.tool_results {
1080                    llm_messages.push(LLMMessage::tool_result(
1081                        &tool_result.tool_use_id,
1082                        &tool_result.content,
1083                        tool_result.is_error,
1084                    ));
1085
1086                    // Get compact summary if available
1087                    let compact_summary = request
1088                        .compact_summaries
1089                        .get(&tool_result.tool_use_id)
1090                        .cloned();
1091
1092                    // Add tool result to conversation history
1093                    let user_msg = Message::User(UserMessage {
1094                        id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1095                        session_id: session_id.to_string(),
1096                        turn_id: effective_turn_id.clone(),
1097                        created_at: Self::current_timestamp_millis(),
1098                        content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
1099                            tool_use_id: tool_result.tool_use_id.clone(),
1100                            content: tool_result.content.clone(),
1101                            is_error: tool_result.is_error,
1102                            compact_summary,
1103                        })],
1104                    });
1105                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1106                }
1107            }
1108        }
1109
1110        // Perform compaction if needed before LLM call
1111        self.maybe_compact().await;
1112
1113        // Build message options with tools
1114        let options = self.build_message_options().await;
1115
1116        // Call the streaming LLM API
1117        let stream_result = self
1118            .client
1119            .send_message_stream(&llm_messages, &options)
1120            .await;
1121
1122        match stream_result {
1123            Ok(mut stream) => {
1124                // Track current content block for tool use accumulation
1125                let mut current_tool_id: Option<String> = None;
1126                let mut current_tool_name: Option<String> = None;
1127                let mut tool_input_json = String::new();
1128                // Accumulate response text for conversation history
1129                let mut response_text = String::new();
1130                // Accumulate completed tool uses for conversation history
1131                let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> = Vec::new();
1132
1133                // Process stream events
1134                loop {
1135                    tokio::select! {
1136                        _ = request_token.cancelled() => {
1137                            tracing::info!(session_id, "Streaming request cancelled");
1138                            break;
1139                        }
1140                        event = stream.next() => {
1141                            match event {
1142                                Some(Ok(stream_event)) => {
1143                                    match stream_event {
1144                                        StreamEvent::MessageStart { message_id, model } => {
1145                                            let payload = FromLLMPayload {
1146                                                session_id,
1147                                                response_type: LLMResponseType::StreamStart,
1148                                                message_id,
1149                                                model,
1150                                                turn_id: request.turn_id.clone(),
1151                                                ..Default::default()
1152                                            };
1153                                            let _ = self.from_llm.send(payload).await;
1154                                        }
1155                                        StreamEvent::ContentBlockStart { index: _, block_type } => {
1156                                            match block_type {
1157                                                ContentBlockType::Text => {
1158                                                    // Text block starting, nothing special to do
1159                                                }
1160                                                ContentBlockType::ToolUse { id, name } => {
1161                                                    // Track tool use for later batch execution
1162                                                    // Don't emit events here - wait until execution begins
1163                                                    current_tool_id = Some(id);
1164                                                    current_tool_name = Some(name);
1165                                                    tool_input_json.clear();
1166                                                }
1167                                            }
1168                                        }
1169                                        StreamEvent::TextDelta { index, text } => {
1170                                            // Accumulate for conversation history
1171                                            response_text.push_str(&text);
1172
1173                                            let payload = FromLLMPayload {
1174                                                session_id,
1175                                                response_type: LLMResponseType::TextChunk,
1176                                                text,
1177                                                content_index: index,
1178                                                turn_id: request.turn_id.clone(),
1179                                                ..Default::default()
1180                                            };
1181                                            let _ = self.from_llm.send(payload).await;
1182                                        }
1183                                        StreamEvent::InputJsonDelta { index, json } => {
1184                                            // Accumulate JSON for tool input
1185                                            tool_input_json.push_str(&json);
1186
1187                                            let payload = FromLLMPayload {
1188                                                session_id,
1189                                                response_type: LLMResponseType::ToolInputDelta,
1190                                                text: json,
1191                                                content_index: index,
1192                                                turn_id: request.turn_id.clone(),
1193                                                ..Default::default()
1194                                            };
1195                                            let _ = self.from_llm.send(payload).await;
1196                                        }
1197                                        StreamEvent::ContentBlockStop { index: _ } => {
1198                                            // If we were accumulating a tool use, save it for later
1199                                            // Don't emit event - ToolUseStart already emitted at ContentBlockStart
1200                                            // Wait until MessageStop to emit ToolBatch for execution
1201                                            if let (Some(id), Some(name)) =
1202                                                (current_tool_id.take(), current_tool_name.take())
1203                                            {
1204                                                let input: serde_json::Value =
1205                                                    serde_json::from_str(&tool_input_json)
1206                                                        .unwrap_or(serde_json::Value::Object(
1207                                                            serde_json::Map::new(),
1208                                                        ));
1209
1210                                                // Save tool use for conversation history and batch execution
1211                                                tracing::debug!(
1212                                                    session_id,
1213                                                    tool_id = %id,
1214                                                    tool_name = %name,
1215                                                    "Saving tool use to completed_tool_uses"
1216                                                );
1217                                                completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1218                                                    id: id.clone(),
1219                                                    name: name.clone(),
1220                                                    input: input
1221                                                        .as_object()
1222                                                        .map(|obj| {
1223                                                            obj.iter()
1224                                                                .map(|(k, v)| (k.clone(), v.clone()))
1225                                                                .collect()
1226                                                        })
1227                                                        .unwrap_or_default(),
1228                                                });
1229
1230                                                tool_input_json.clear();
1231                                            }
1232                                        }
1233                                        StreamEvent::MessageDelta { stop_reason, usage } => {
1234                                            if let Some(usage) = usage {
1235                                                tracing::info!(
1236                                                    session_id,
1237                                                    input_tokens = usage.input_tokens,
1238                                                    output_tokens = usage.output_tokens,
1239                                                    "API token usage for this turn"
1240                                                );
1241                                                self.current_input_tokens
1242                                                    .store(usage.input_tokens as i64, Ordering::SeqCst);
1243                                                self.current_output_tokens
1244                                                    .store(usage.output_tokens as i64, Ordering::SeqCst);
1245
1246                                                let payload = FromLLMPayload {
1247                                                    session_id,
1248                                                    response_type: LLMResponseType::TokenUpdate,
1249                                                    input_tokens: usage.input_tokens as i64,
1250                                                    output_tokens: usage.output_tokens as i64,
1251                                                    turn_id: request.turn_id.clone(),
1252                                                    ..Default::default()
1253                                                };
1254                                                let _ = self.from_llm.send(payload).await;
1255                                            }
1256
1257                                            if stop_reason.is_some() {
1258                                                let payload = FromLLMPayload {
1259                                                    session_id,
1260                                                    response_type: LLMResponseType::Complete,
1261                                                    is_complete: true,
1262                                                    stop_reason,
1263                                                    turn_id: request.turn_id.clone(),
1264                                                    ..Default::default()
1265                                                };
1266                                                let _ = self.from_llm.send(payload).await;
1267                                            }
1268                                        }
1269                                        StreamEvent::MessageStop => {
1270                                            // Add assistant message to conversation history
1271                                            // Must save both text AND tool uses
1272                                            tracing::debug!(
1273                                                session_id,
1274                                                text_len = response_text.len(),
1275                                                tool_use_count = completed_tool_uses.len(),
1276                                                "MessageStop: saving assistant message to history"
1277                                            );
1278                                            if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1279                                                let now = Self::current_timestamp_millis();
1280
1281                                                // Build content blocks: text first, then tool uses
1282                                                let mut content_blocks = Vec::new();
1283                                                if !response_text.is_empty() {
1284                                                    content_blocks.push(ContentBlock::text(&response_text));
1285                                                }
1286                                                for tool_use in &completed_tool_uses {
1287                                                    content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1288                                                }
1289
1290                                                let content_block_count = content_blocks.len();
1291                                                let asst_msg = Message::Assistant(AssistantMessage {
1292                                                    id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1293                                                    session_id: session_id.to_string(),
1294                                                    turn_id: effective_turn_id.clone(),
1295                                                    parent_id: String::new(),
1296                                                    created_at: now,
1297                                                    completed_at: Some(now),
1298                                                    model_id: self.config.model.clone(),
1299                                                    provider_id: String::new(),
1300                                                    input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1301                                                    output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1302                                                    cache_read_tokens: 0,
1303                                                    cache_write_tokens: 0,
1304                                                    finish_reason: None,
1305                                                    error: None,
1306                                                    content: content_blocks,
1307                                                });
1308                                                Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
1309                                                tracing::debug!(
1310                                                    session_id,
1311                                                    content_block_count,
1312                                                    "MessageStop: saved assistant message with content blocks"
1313                                                );
1314                                            }
1315
1316                                            // If there are tool uses, emit them as a batch for execution
1317                                            // This ensures all tools are executed together and results sent back in one message
1318                                            if !completed_tool_uses.is_empty() {
1319                                                let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1320                                                    .iter()
1321                                                    .map(|tu| crate::controller::types::ToolUseInfo {
1322                                                        id: tu.id.clone(),
1323                                                        name: tu.name.clone(),
1324                                                        input: serde_json::Value::Object(
1325                                                            tu.input.iter()
1326                                                                .map(|(k, v)| (k.clone(), v.clone()))
1327                                                                .collect()
1328                                                        ),
1329                                                    })
1330                                                    .collect();
1331
1332                                                tracing::debug!(
1333                                                    session_id,
1334                                                    tool_count = tool_uses.len(),
1335                                                    "MessageStop: emitting ToolBatch for execution"
1336                                                );
1337
1338                                                let payload = FromLLMPayload {
1339                                                    session_id,
1340                                                    response_type: LLMResponseType::ToolBatch,
1341                                                    tool_uses,
1342                                                    turn_id: request.turn_id.clone(),
1343                                                    ..Default::default()
1344                                                };
1345                                                let _ = self.from_llm.send(payload).await;
1346                                            }
1347
1348                                            // Stream complete
1349                                            self.request_count.fetch_add(1, Ordering::SeqCst);
1350                                            tracing::debug!(session_id, "Streaming request completed");
1351                                            break;
1352                                        }
1353                                        StreamEvent::Ping => {
1354                                            // Keep-alive, ignore
1355                                        }
1356                                    }
1357                                }
1358                                Some(Err(err)) => {
1359                                    tracing::error!(session_id, error = %err, "Stream error");
1360                                    let payload = FromLLMPayload {
1361                                        session_id,
1362                                        response_type: LLMResponseType::Error,
1363                                        error: Some(err.to_string()),
1364                                        turn_id: request.turn_id.clone(),
1365                                        ..Default::default()
1366                                    };
1367                                    let _ = self.from_llm.send(payload).await;
1368                                    break;
1369                                }
1370                                None => {
1371                                    // Stream ended
1372                                    break;
1373                                }
1374                            }
1375                        }
1376                    }
1377                }
1378            }
1379            Err(err) => {
1380                tracing::error!(session_id, error = %err, "Failed to start streaming");
1381                let payload = FromLLMPayload {
1382                    session_id,
1383                    response_type: LLMResponseType::Error,
1384                    error: Some(err.to_string()),
1385                    turn_id: request.turn_id,
1386                    ..Default::default()
1387                };
1388                let _ = self.from_llm.send(payload).await;
1389            }
1390        }
1391
1392        // Clear the request cancellation token and turn ID when done
1393        self.cleanup_request().await;
1394    }
1395}