Skip to main content

agent_air_runtime/controller/session/
llm_session.rs

1// This implements a single session with an LLM
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
6use std::time::Instant;
7
8use tokio::sync::{Mutex, RwLock, mpsc};
9use tokio_util::sync::CancellationToken;
10
11use crate::client::LLMClient;
12use crate::client::error::LlmError;
13use crate::client::models::Tool as LLMTool;
14use crate::client::providers::anthropic::AnthropicProvider;
15use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
16use crate::client::providers::cohere::CohereProvider;
17use crate::client::providers::gemini::GeminiProvider;
18use crate::client::providers::openai::OpenAIProvider;
19
20use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
21use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
22
23/// Creates an LLMClient from the session configuration.
24fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
25    match config.provider {
26        LLMProvider::Anthropic => {
27            let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
28            LLMClient::new(Box::new(provider))
29        }
30        LLMProvider::OpenAI => {
31            // Check for Azure configuration first
32            let provider = if let (Some(resource), Some(deployment)) =
33                (&config.azure_resource, &config.azure_deployment)
34            {
35                let api_version = config
36                    .azure_api_version
37                    .clone()
38                    .unwrap_or_else(|| "2024-10-21".to_string());
39                OpenAIProvider::azure(
40                    config.api_key.clone(),
41                    resource.clone(),
42                    deployment.clone(),
43                    api_version,
44                )
45            } else if let Some(base_url) = &config.base_url {
46                OpenAIProvider::with_base_url(
47                    config.api_key.clone(),
48                    config.model.clone(),
49                    base_url.clone(),
50                )
51            } else {
52                OpenAIProvider::new(config.api_key.clone(), config.model.clone())
53            };
54            LLMClient::new(Box::new(provider))
55        }
56        LLMProvider::Google => {
57            let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
58            LLMClient::new(Box::new(provider))
59        }
60        LLMProvider::Cohere => {
61            let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
62            LLMClient::new(Box::new(provider))
63        }
64        LLMProvider::Bedrock => {
65            // Bedrock requires all four credential/region fields
66            let region = config.bedrock_region.clone().ok_or_else(|| {
67                LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_region")
68            })?;
69            let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
70                LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_access_key_id")
71            })?;
72            let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
73                LlmError::new(
74                    "MISSING_CONFIG",
75                    "Bedrock requires bedrock_secret_access_key",
76                )
77            })?;
78
79            let credentials = match &config.bedrock_session_token {
80                Some(token) => BedrockCredentials::with_session_token(
81                    access_key_id,
82                    secret_access_key,
83                    token.clone(),
84                ),
85                None => BedrockCredentials::new(access_key_id, secret_access_key),
86            };
87
88            let provider = BedrockProvider::new(credentials, region, config.model.clone());
89            LLMClient::new(Box::new(provider))
90        }
91    }
92}
93use crate::controller::types::{
94    AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
95};
96
97/// Token usage statistics for the session
98#[derive(Debug, Clone, Default)]
99pub struct TokenUsage {
100    /// Total input tokens across all requests
101    pub total_input_tokens: i64,
102    /// Total output tokens across all requests
103    pub total_output_tokens: i64,
104    /// Number of completed LLM requests
105    pub request_count: i64,
106    /// Input tokens from the most recent request
107    pub last_input_tokens: i64,
108    /// Output tokens from the most recent request
109    pub last_output_tokens: i64,
110}
111
112/// Current status of an LLM session
113#[derive(Debug, Clone)]
114pub struct SessionStatus {
115    /// Session identifier
116    pub session_id: i64,
117    /// Model name
118    pub model: String,
119    /// When session was created
120    pub created_at: Instant,
121    /// Number of messages in conversation
122    pub conversation_len: usize,
123    /// Current input tokens (context size)
124    pub context_used: i64,
125    /// Model's context window limit
126    pub context_limit: i32,
127    /// Context utilization percentage (0-100)
128    pub utilization: f64,
129    /// Cumulative input tokens
130    pub total_input: i64,
131    /// Cumulative output tokens
132    pub total_output: i64,
133    /// Number of API calls
134    pub request_count: i64,
135}
136
137/// Result of a manual compaction operation.
138/// Used by `force_compact()` to report what happened during compaction.
139#[derive(Debug, Clone, Default)]
140pub struct CompactResult {
141    /// Whether compaction was actually performed.
142    /// False if there weren't enough turns to compact.
143    pub compacted: bool,
144    /// Number of messages before compaction.
145    pub messages_before: usize,
146    /// Number of messages after compaction.
147    pub messages_after: usize,
148    /// Number of turns that were summarized.
149    pub turns_compacted: usize,
150    /// Number of recent turns that were preserved.
151    pub turns_kept: usize,
152    /// Character length of the generated summary (for LLM compaction).
153    pub summary_length: usize,
154    /// Error message if compaction failed.
155    pub error: Option<String>,
156}
157
158/// Global counter for generating unique session IDs
159static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
160
161/// A session that manages communication with an LLM
162pub struct LLMSession {
163    // Session identification
164    id: AtomicI64,
165
166    // LLM client
167    client: LLMClient,
168
169    // Channels for communication
170    to_llm_tx: mpsc::Sender<ToLLMPayload>,
171    to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
172    from_llm: mpsc::Sender<FromLLMPayload>,
173
174    // Session configuration
175    config: LLMSessionConfig,
176
177    // Runtime overrides for LLM options
178    system_prompt: RwLock<Option<String>>,
179    max_tokens: AtomicI64,
180    created_at: Instant,
181
182    // Conversation state
183    conversation: RwLock<Arc<Vec<Message>>>,
184
185    // Shutdown management
186    shutdown: AtomicBool,
187    cancel_token: CancellationToken,
188
189    // Per-request cancellation
190    current_cancel: Mutex<Option<CancellationToken>>,
191
192    // Current turn ID for the active request (used for filtering on interrupt)
193    current_turn_id: RwLock<Option<TurnId>>,
194
195    // Token tracking for current request
196    current_input_tokens: AtomicI64,
197    current_output_tokens: AtomicI64,
198
199    // Cumulative token tracking
200    request_count: AtomicI64,
201
202    // Tool definitions for LLM API calls
203    tool_definitions: RwLock<Vec<LLMTool>>,
204
205    // Compaction support
206    compactor: Option<Box<dyn Compactor>>,
207    llm_compactor: Option<LLMCompactor>,
208    context_limit: AtomicI32,
209    compact_summaries: RwLock<HashMap<String, String>>,
210}
211
212impl LLMSession {
213    /// Creates a new LLM session
214    ///
215    /// # Arguments
216    /// * `config` - Session configuration
217    /// * `from_llm` - Sender for outgoing responses
218    /// * `cancel_token` - Token for session cancellation
219    /// * `channel_size` - Buffer size for the session's input channel
220    ///
221    /// # Errors
222    /// Returns an error if the LLM client fails to initialize (e.g., TLS setup failure)
223    pub fn new(
224        config: LLMSessionConfig,
225        from_llm: mpsc::Sender<FromLLMPayload>,
226        cancel_token: CancellationToken,
227        channel_size: usize,
228    ) -> Result<Self, LlmError> {
229        let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
230        let (to_llm_tx, to_llm_rx) = mpsc::channel(channel_size);
231        let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
232        let system_prompt = config.system_prompt.clone();
233
234        // Create the LLMClient client based on the provider
235        let client = create_llm_client(&config)?;
236
237        // Create compactor if configured
238        let mut compactor: Option<Box<dyn Compactor>> = None;
239        let mut llm_compactor: Option<LLMCompactor> = None;
240
241        if let Some(ref compactor_type) = config.compaction {
242            match compactor_type {
243                CompactorType::Threshold(c) => {
244                    match ThresholdCompactor::new(
245                        c.threshold,
246                        c.keep_recent_turns,
247                        c.tool_compaction,
248                    ) {
249                        Ok(tc) => {
250                            tracing::info!(
251                                threshold = c.threshold,
252                                keep_recent_turns = c.keep_recent_turns,
253                                tool_compaction = %c.tool_compaction,
254                                "Threshold compaction enabled for session"
255                            );
256                            compactor = Some(Box::new(tc) as Box<dyn Compactor>);
257                        }
258                        Err(e) => {
259                            tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
260                        }
261                    }
262                }
263                CompactorType::LLM(c) => {
264                    // Create a separate LLMClient client for LLM compaction
265                    let llm_client = create_llm_client(&config)?;
266
267                    match LLMCompactor::new(llm_client, c.clone()) {
268                        Ok(lc) => {
269                            tracing::info!(
270                                threshold = c.threshold,
271                                keep_recent_turns = c.keep_recent_turns,
272                                "LLM compaction enabled for session"
273                            );
274                            llm_compactor = Some(lc);
275                        }
276                        Err(e) => {
277                            tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
278                        }
279                    }
280                }
281            }
282        }
283
284        let context_limit = config.context_limit;
285
286        Ok(Self {
287            id: AtomicI64::new(session_id),
288            client,
289            to_llm_tx,
290            to_llm_rx: Mutex::new(to_llm_rx),
291            from_llm,
292            config,
293            system_prompt: RwLock::new(system_prompt),
294            max_tokens: AtomicI64::new(max_tokens),
295            created_at: Instant::now(),
296            conversation: RwLock::new(Arc::new(Vec::new())),
297            shutdown: AtomicBool::new(false),
298            cancel_token,
299            current_cancel: Mutex::new(None),
300            current_turn_id: RwLock::new(None),
301            current_input_tokens: AtomicI64::new(0),
302            current_output_tokens: AtomicI64::new(0),
303            request_count: AtomicI64::new(0),
304            tool_definitions: RwLock::new(Vec::new()),
305            compactor,
306            llm_compactor,
307            context_limit: AtomicI32::new(context_limit),
308            compact_summaries: RwLock::new(HashMap::new()),
309        })
310    }
311
312    /// Returns the session ID
313    pub fn id(&self) -> i64 {
314        self.id.load(Ordering::SeqCst)
315    }
316
317    /// Returns when the session was created
318    pub fn created_at(&self) -> Instant {
319        self.created_at
320    }
321
322    /// Returns the model for this session
323    pub fn model(&self) -> &str {
324        &self.config.model
325    }
326
327    // ---- Max Tokens ----
328
329    /// Sets the default maximum tokens for the session
330    pub fn set_max_tokens(&self, max_tokens: i64) {
331        self.max_tokens.store(max_tokens, Ordering::SeqCst);
332    }
333
334    /// Returns the current max tokens setting
335    pub fn max_tokens(&self) -> i64 {
336        self.max_tokens.load(Ordering::SeqCst)
337    }
338
339    /// Returns the context limit for this session's model
340    pub fn context_limit(&self) -> i32 {
341        self.context_limit.load(Ordering::SeqCst)
342    }
343
344    // ---- System Prompt ----
345
346    /// Sets the default system prompt for the session
347    pub async fn set_system_prompt(&self, prompt: String) {
348        let mut guard = self.system_prompt.write().await;
349        *guard = Some(prompt);
350    }
351
352    /// Clears the default system prompt
353    pub async fn clear_system_prompt(&self) {
354        let mut guard = self.system_prompt.write().await;
355        *guard = None;
356    }
357
358    /// Returns the current system prompt
359    pub async fn system_prompt(&self) -> Option<String> {
360        self.system_prompt.read().await.clone()
361    }
362
363    // ---- Tools ----
364
365    /// Sets the tool definitions for this session.
366    /// Tools will be included in all subsequent LLM API calls.
367    pub async fn set_tools(&self, tools: Vec<LLMTool>) {
368        let mut guard = self.tool_definitions.write().await;
369        *guard = tools;
370    }
371
372    /// Clears all tool definitions for this session.
373    pub async fn clear_tools(&self) {
374        let mut guard = self.tool_definitions.write().await;
375        guard.clear();
376    }
377
378    /// Returns a copy of the current tool definitions.
379    pub async fn tools(&self) -> Vec<LLMTool> {
380        self.tool_definitions.read().await.clone()
381    }
382
383    // ---- Compaction ----
384
385    /// Stores compact summaries for later use during compaction.
386    /// Called when tool results are received.
387    async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
388        if summaries.is_empty() {
389            tracing::warn!(
390                session_id = self.id(),
391                "No compact summaries provided with tool results"
392            );
393            return;
394        }
395        let mut guard = self.compact_summaries.write().await;
396        for (tool_use_id, summary) in summaries {
397            tracing::info!(
398                session_id = self.id(),
399                tool_use_id = %tool_use_id,
400                summary_len = summary.len(),
401                summary_preview = %summary.chars().take(50).collect::<String>(),
402                "Storing compact summary"
403            );
404            guard.insert(tool_use_id.clone(), summary.clone());
405        }
406        tracing::info!(
407            session_id = self.id(),
408            new_summaries = summaries.len(),
409            total_stored = guard.len(),
410            "Stored compact summaries for tool results"
411        );
412    }
413
414    /// Performs compaction if needed based on context usage.
415    /// Should be called before each LLM request.
416    async fn maybe_compact(&self) {
417        let context_used = self.current_input_tokens.load(Ordering::SeqCst);
418        let context_limit = self.context_limit.load(Ordering::SeqCst);
419        let conversation_len = self.conversation.read().await.len();
420        let summaries_count = self.compact_summaries.read().await.len();
421
422        let utilization = if context_limit > 0 {
423            context_used as f64 / context_limit as f64
424        } else {
425            0.0
426        };
427
428        tracing::debug!(
429            session_id = self.id(),
430            context_used,
431            context_limit,
432            utilization = format!("{:.2}%", utilization * 100.0),
433            conversation_len,
434            summaries_available = summaries_count,
435            "Checking if compaction needed"
436        );
437
438        // Check for LLM compactor first (async compaction)
439        if let Some(ref llm_compactor) = self.llm_compactor {
440            if !llm_compactor.should_compact(context_used, context_limit) {
441                tracing::debug!(session_id = self.id(), "LLM compaction not triggered");
442                return;
443            }
444
445            // Get conversation and summaries for async compaction
446            let summaries = self.compact_summaries.read().await.clone();
447            let conversation_arc = {
448                let guard = self.conversation.read().await;
449                Arc::clone(&*guard) // O(1)
450            };
451            let conversation =
452                Arc::try_unwrap(conversation_arc).unwrap_or_else(|arc| (*arc).clone());
453
454            tracing::info!(
455                session_id = self.id(),
456                conversation_len = conversation.len(),
457                summaries_count = summaries.len(),
458                "Starting LLM compaction"
459            );
460
461            // Perform async LLM compaction
462            match llm_compactor.compact_async(conversation, &summaries).await {
463                Ok((new_conversation, result)) => {
464                    // Replace conversation with compacted version
465                    *self.conversation.write().await = Arc::new(new_conversation);
466
467                    if result.turns_compacted > 0 {
468                        tracing::info!(
469                            session_id = self.id(),
470                            turns_compacted = result.turns_compacted,
471                            "LLM compaction completed"
472                        );
473                    }
474                }
475                Err(e) => {
476                    tracing::error!(
477                        session_id = self.id(),
478                        error = %e,
479                        "LLM compaction failed"
480                    );
481                }
482            }
483            return;
484        }
485
486        // Fall back to sync compactor (ThresholdCompactor)
487        let compactor = match &self.compactor {
488            Some(c) => c,
489            None => {
490                tracing::debug!(session_id = self.id(), "No compactor configured");
491                return;
492            }
493        };
494
495        if !compactor.should_compact(context_used, context_limit) {
496            tracing::debug!(session_id = self.id(), "Threshold compaction not triggered");
497            return;
498        }
499
500        // Perform sync compaction
501        let summaries = self.compact_summaries.read().await.clone();
502        let mut guard = self.conversation.write().await;
503
504        tracing::info!(
505            session_id = self.id(),
506            conversation_len = guard.len(),
507            summaries_count = summaries.len(),
508            "Starting threshold compaction"
509        );
510
511        let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
512
513        tracing::info!(
514            session_id = self.id(),
515            tool_results_summarized = result.tool_results_summarized,
516            tool_results_redacted = result.tool_results_redacted,
517            turns_compacted = result.turns_compacted,
518            conversation_len_after = guard.len(),
519            "Threshold compaction completed"
520        );
521    }
522
523    // ---- Session Control ----
524
525    /// Clears the conversation history and compact summaries.
526    pub async fn clear_conversation(&self) {
527        let mut guard = self.conversation.write().await;
528        Arc::make_mut(&mut *guard).clear();
529
530        let mut summaries = self.compact_summaries.write().await;
531        summaries.clear();
532
533        // Reset token counters
534        self.current_input_tokens.store(0, Ordering::SeqCst);
535        self.current_output_tokens.store(0, Ordering::SeqCst);
536
537        tracing::info!(session_id = self.id(), "Conversation cleared");
538    }
539
540    /// Forces compaction to run immediately, regardless of threshold.
541    /// Returns a `CompactResult` with details about what happened.
542    pub async fn force_compact(&self) -> CompactResult {
543        // Check for LLM compactor first (async compaction)
544        if let Some(ref llm_compactor) = self.llm_compactor {
545            let summaries = self.compact_summaries.read().await.clone();
546            let conversation_arc = {
547                let guard = self.conversation.read().await;
548                Arc::clone(&*guard) // O(1)
549            };
550            let conversation =
551                Arc::try_unwrap(conversation_arc).unwrap_or_else(|arc| (*arc).clone());
552            let messages_before = conversation.len();
553            let turns_before = self.count_unique_turns(&conversation);
554
555            match llm_compactor.compact_async(conversation, &summaries).await {
556                Ok((new_conversation, result)) => {
557                    let messages_after = new_conversation.len();
558                    let turns_after = self.count_unique_turns(&new_conversation);
559                    let compacted = messages_after < messages_before;
560
561                    // Extract summary length if compaction happened
562                    let summary_length = if compacted && !new_conversation.is_empty() {
563                        self.extract_summary_length(&new_conversation[0])
564                    } else {
565                        0
566                    };
567
568                    *self.conversation.write().await = Arc::new(new_conversation);
569
570                    if result.turns_compacted > 0 {
571                        tracing::info!(
572                            session_id = self.id(),
573                            turns_compacted = result.turns_compacted,
574                            messages_before,
575                            messages_after,
576                            "Forced LLM compaction completed"
577                        );
578                    }
579
580                    return CompactResult {
581                        compacted,
582                        messages_before,
583                        messages_after,
584                        turns_compacted: turns_before.saturating_sub(turns_after),
585                        turns_kept: turns_after,
586                        summary_length,
587                        error: None,
588                    };
589                }
590                Err(e) => {
591                    tracing::error!(
592                        session_id = self.id(),
593                        error = %e,
594                        "Forced LLM compaction failed"
595                    );
596                    return CompactResult {
597                        compacted: false,
598                        messages_before,
599                        messages_after: messages_before,
600                        turns_compacted: 0,
601                        turns_kept: turns_before,
602                        summary_length: 0,
603                        error: Some(format!("Compaction failed: {}", e)),
604                    };
605                }
606            }
607        }
608
609        // Fall back to sync compactor (ThresholdCompactor)
610        if let Some(ref compactor) = self.compactor {
611            let summaries = self.compact_summaries.read().await.clone();
612            let mut guard = self.conversation.write().await;
613            let messages_before = guard.len();
614            let turns_before = self.count_unique_turns(&guard);
615
616            let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
617
618            let messages_after = guard.len();
619            let turns_after = self.count_unique_turns(&guard);
620            let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
621
622            if result.total_compacted() > 0 {
623                tracing::info!(
624                    session_id = self.id(),
625                    tool_results_summarized = result.tool_results_summarized,
626                    tool_results_redacted = result.tool_results_redacted,
627                    turns_compacted = result.turns_compacted,
628                    "Forced threshold compaction completed"
629                );
630            }
631
632            return CompactResult {
633                compacted,
634                messages_before,
635                messages_after,
636                turns_compacted: turns_before.saturating_sub(turns_after),
637                turns_kept: turns_after,
638                summary_length: 0,
639                error: None,
640            };
641        }
642
643        // No compactor configured
644        CompactResult {
645            compacted: false,
646            error: Some("No compactor configured".to_string()),
647            ..Default::default()
648        }
649    }
650
651    /// Count unique turn IDs in a conversation.
652    fn count_unique_turns(&self, conversation: &[Message]) -> usize {
653        use std::collections::HashSet;
654        let mut turn_ids = HashSet::new();
655        for msg in conversation {
656            turn_ids.insert(msg.turn_id().clone());
657        }
658        turn_ids.len()
659    }
660
661    /// Extract the summary length from a summary message.
662    fn extract_summary_length(&self, message: &Message) -> usize {
663        if let Message::User(user_msg) = message {
664            for block in &user_msg.content {
665                if let ContentBlock::Text(text_block) = block
666                    && text_block
667                        .text
668                        .starts_with("[Previous conversation summary]")
669                {
670                    return text_block.text.len();
671                }
672            }
673        }
674        0
675    }
676
677    /// Sends a message to the LLM session for processing.
678    /// Returns false if the session is shutdown or the channel is closed.
679    pub async fn send(&self, msg: ToLLMPayload) -> bool {
680        if self.shutdown.load(Ordering::SeqCst) {
681            return false;
682        }
683        self.to_llm_tx.send(msg).await.is_ok()
684    }
685
686    /// Interrupts the currently executing LLM request.
687    /// This cancels any in-flight request and removes all messages from the
688    /// current turn from conversation history. Does not shutdown the session.
689    pub async fn interrupt(&self) {
690        let guard = self.current_cancel.lock().await;
691        if let Some(token) = guard.as_ref() {
692            token.cancel();
693
694            // Remove all messages from the current turn from conversation history.
695            // This prevents any messages from the cancelled turn (user message,
696            // assistant responses, etc.) from being included in subsequent API calls.
697            let turn_id = self.current_turn_id.read().await.clone();
698            if let Some(turn_id) = turn_id {
699                let mut guard = self.conversation.write().await;
700                let original_len = guard.len();
701                Arc::make_mut(&mut *guard).retain(|msg| msg.turn_id() != &turn_id);
702                let removed = original_len - guard.len();
703                tracing::debug!(
704                    session_id = self.id(),
705                    turn_id = %turn_id,
706                    messages_removed = removed,
707                    conversation_length = guard.len(),
708                    "Removed messages from cancelled turn"
709                );
710            }
711        }
712    }
713
714    /// Gracefully shuts down the session.
715    /// After calling this, the session will not accept new messages.
716    pub fn shutdown(&self) {
717        // Mark as shutdown to prevent new messages
718        self.shutdown.store(true, Ordering::SeqCst);
719        // Cancel the session's main loop
720        self.cancel_token.cancel();
721    }
722
723    /// Returns true if the session has been shutdown
724    pub fn is_shutdown(&self) -> bool {
725        self.shutdown.load(Ordering::SeqCst)
726    }
727
728    // ---- Main Processing Loop ----
729
730    /// Starts the session's main processing loop.
731    /// This method processes requests from the ToLLM channel until shutdown.
732    /// Should be spawned as a tokio task.
733    pub async fn start(&self) {
734        tracing::info!(session_id = self.id(), "Session starting");
735
736        loop {
737            let mut rx_guard = self.to_llm_rx.lock().await;
738
739            tokio::select! {
740                _ = self.cancel_token.cancelled() => {
741                    tracing::info!(session_id = self.id(), "Session cancelled");
742                    break;
743                }
744                msg = rx_guard.recv() => {
745                    match msg {
746                        Some(request) => {
747                            // Drop the lock before handling the request
748                            drop(rx_guard);
749                            self.handle_request(request).await;
750                        }
751                        None => {
752                            // Channel closed
753                            tracing::info!(session_id = self.id(), "Session channel closed");
754                            break;
755                        }
756                    }
757                }
758            }
759        }
760
761        tracing::info!(session_id = self.id(), "Session stopped");
762    }
763
764    // ---- Request Helper Methods ----
765
766    /// Returns the current timestamp in milliseconds.
767    fn current_timestamp_millis() -> i64 {
768        std::time::SystemTime::now()
769            .duration_since(std::time::UNIX_EPOCH)
770            .map(|d| d.as_millis() as i64)
771            .unwrap_or(0)
772    }
773
774    /// Prepares the request context by setting up cancellation token and turn ID.
775    /// Returns the request token and effective turn ID.
776    async fn prepare_request(&self, request: &ToLLMPayload) -> (CancellationToken, TurnId) {
777        let request_token = CancellationToken::new();
778        {
779            let mut guard = self.current_cancel.lock().await;
780            *guard = Some(request_token.clone());
781        }
782
783        let effective_turn_id = request
784            .turn_id
785            .clone()
786            .unwrap_or_else(|| TurnId::new_user_turn(0));
787        {
788            let mut guard = self.current_turn_id.write().await;
789            *guard = Some(effective_turn_id.clone());
790        }
791
792        (request_token, effective_turn_id)
793    }
794
795    /// Builds the message options with tools for the LLM request.
796    async fn build_message_options(&self) -> crate::client::models::MessageOptions {
797        use crate::client::models::MessageOptions;
798
799        let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
800        let tools = self.tool_definitions.read().await.clone();
801        let tools_option = if tools.is_empty() { None } else { Some(tools) };
802
803        MessageOptions {
804            max_tokens: Some(max_tokens),
805            temperature: self.config.temperature,
806            tools: tools_option,
807            ..Default::default()
808        }
809    }
810
811    /// Clears the request cancellation token and turn ID after request completion.
812    async fn cleanup_request(&self) {
813        {
814            let mut guard = self.current_cancel.lock().await;
815            *guard = None;
816        }
817        {
818            let mut guard = self.current_turn_id.write().await;
819            *guard = None;
820        }
821    }
822
823    /// Handles a single request from the ToLLM channel.
824    async fn handle_request(&self, request: ToLLMPayload) {
825        if self.config.streaming {
826            self.handle_streaming_request(request).await;
827        } else {
828            self.handle_non_streaming_request(request).await;
829        }
830    }
831
832    /// Handles a non-streaming request.
833    async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
834        use super::convert::{from_llm_message, to_llm_messages};
835        use crate::client::models::Message as LLMMessage;
836        use crate::controller::types::{LLMRequestType, LLMResponseType};
837
838        // Prepare request context
839        let (_request_token, effective_turn_id) = self.prepare_request(&request).await;
840
841        let session_id = self.id();
842        tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
843
844        // Build the conversation messages
845        let mut llm_messages: Vec<LLMMessage> = Vec::new();
846
847        // Add system prompt if set
848        if let Some(prompt) = self.system_prompt.read().await.as_ref() {
849            llm_messages.push(LLMMessage::system(prompt.clone()));
850        }
851
852        // Add conversation history
853        let conversation = self.conversation.read().await;
854        llm_messages.extend(to_llm_messages(&conversation));
855        drop(conversation);
856
857        // Add the new message based on request type
858        match request.request_type {
859            LLMRequestType::UserMessage => {
860                if !request.content.is_empty() {
861                    llm_messages.push(LLMMessage::user(&request.content));
862
863                    // Add user message to conversation history
864                    let user_msg = Message::User(UserMessage {
865                        id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
866                        session_id: session_id.to_string(),
867                        turn_id: effective_turn_id.clone(),
868                        created_at: Self::current_timestamp_millis(),
869                        content: vec![ContentBlock::text(&request.content)],
870                    });
871                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
872                }
873            }
874            LLMRequestType::ToolResult => {
875                // Store compact summaries for later compaction
876                self.store_compact_summaries(&request.compact_summaries)
877                    .await;
878
879                // Add tool result messages using LLM client's proper format
880                for tool_result in &request.tool_results {
881                    llm_messages.push(LLMMessage::tool_result(
882                        &tool_result.tool_use_id,
883                        &tool_result.content,
884                        tool_result.is_error,
885                    ));
886
887                    // Get compact summary if available
888                    let compact_summary = request
889                        .compact_summaries
890                        .get(&tool_result.tool_use_id)
891                        .cloned();
892
893                    // Add tool result to conversation history
894                    let user_msg = Message::User(UserMessage {
895                        id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
896                        session_id: session_id.to_string(),
897                        turn_id: effective_turn_id.clone(),
898                        created_at: Self::current_timestamp_millis(),
899                        content: vec![ContentBlock::ToolResult(
900                            crate::controller::types::ToolResultBlock {
901                                tool_use_id: tool_result.tool_use_id.clone(),
902                                content: tool_result.content.clone(),
903                                is_error: tool_result.is_error,
904                                compact_summary,
905                            },
906                        )],
907                    });
908                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
909                }
910            }
911        }
912
913        // Perform compaction if needed before LLM call
914        self.maybe_compact().await;
915
916        // Build message options with tools
917        let options = self.build_message_options().await;
918
919        // Call the LLM
920        let result = self.client.send_message(&llm_messages, &options).await;
921
922        match result {
923            Ok(response) => {
924                // Convert response to our types
925                let content_blocks = from_llm_message(&response);
926
927                // Extract text for the text chunk response
928                let text: String = content_blocks
929                    .iter()
930                    .filter_map(|block| {
931                        if let ContentBlock::Text(t) = block {
932                            Some(t.text.clone())
933                        } else {
934                            None
935                        }
936                    })
937                    .collect::<Vec<_>>()
938                    .join("");
939
940                // Send text chunk response
941                if !text.is_empty() {
942                    let payload = FromLLMPayload {
943                        session_id,
944                        response_type: LLMResponseType::TextChunk,
945                        text: text.clone(),
946                        turn_id: request.turn_id.clone(),
947                        ..Default::default()
948                    };
949                    let _ = self.from_llm.send(payload).await;
950                }
951
952                // Check for tool use
953                for block in &content_blocks {
954                    if let ContentBlock::ToolUse(tool_use) = block {
955                        let payload = FromLLMPayload {
956                            session_id,
957                            response_type: LLMResponseType::ToolUse,
958                            tool_use: Some(crate::controller::types::ToolUseInfo {
959                                id: tool_use.id.clone(),
960                                name: tool_use.name.clone(),
961                                input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
962                            }),
963                            turn_id: request.turn_id.clone(),
964                            ..Default::default()
965                        };
966                        let _ = self.from_llm.send(payload).await;
967                    }
968                }
969
970                // Add assistant message to conversation history
971                let now = Self::current_timestamp_millis();
972                let asst_msg = Message::Assistant(AssistantMessage {
973                    id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
974                    session_id: session_id.to_string(),
975                    turn_id: effective_turn_id.clone(),
976                    parent_id: String::new(),
977                    created_at: now,
978                    completed_at: Some(now),
979                    model_id: self.config.model.clone(),
980                    provider_id: String::new(),
981                    input_tokens: 0,
982                    output_tokens: 0,
983                    cache_read_tokens: 0,
984                    cache_write_tokens: 0,
985                    finish_reason: None,
986                    error: None,
987                    content: content_blocks,
988                });
989                Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
990
991                // Send completion
992                let payload = FromLLMPayload {
993                    session_id,
994                    response_type: LLMResponseType::Complete,
995                    is_complete: true,
996                    turn_id: request.turn_id.clone(),
997                    ..Default::default()
998                };
999                let _ = self.from_llm.send(payload).await;
1000
1001                // Update request count
1002                self.request_count.fetch_add(1, Ordering::SeqCst);
1003
1004                tracing::debug!(session_id, "Request completed successfully");
1005            }
1006            Err(err) => {
1007                tracing::error!(session_id, error = %err, "LLM request failed");
1008
1009                let payload = FromLLMPayload {
1010                    session_id,
1011                    response_type: LLMResponseType::Error,
1012                    error: Some(err.to_string()),
1013                    turn_id: request.turn_id,
1014                    ..Default::default()
1015                };
1016                let _ = self.from_llm.send(payload).await;
1017            }
1018        }
1019
1020        // Clear the request cancellation token and turn ID when done
1021        self.cleanup_request().await;
1022    }
1023
1024    /// Handles a streaming request.
1025    async fn handle_streaming_request(&self, request: ToLLMPayload) {
1026        use super::convert::to_llm_messages;
1027        use crate::client::models::{ContentBlockType, Message as LLMMessage, StreamEvent};
1028        use crate::controller::types::{LLMRequestType, LLMResponseType};
1029        use futures::StreamExt;
1030
1031        // Prepare request context
1032        let (request_token, effective_turn_id) = self.prepare_request(&request).await;
1033
1034        let session_id = self.id();
1035        tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
1036
1037        // Build the conversation messages
1038        let mut llm_messages: Vec<LLMMessage> = Vec::new();
1039
1040        // Add system prompt if set
1041        if let Some(prompt) = self.system_prompt.read().await.as_ref() {
1042            llm_messages.push(LLMMessage::system(prompt.clone()));
1043        }
1044
1045        // Add conversation history
1046        let conversation = self.conversation.read().await;
1047        llm_messages.extend(to_llm_messages(&conversation));
1048        drop(conversation);
1049
1050        // Add the new message based on request type
1051        match request.request_type {
1052            LLMRequestType::UserMessage => {
1053                if !request.content.is_empty() {
1054                    llm_messages.push(LLMMessage::user(&request.content));
1055
1056                    // Add user message to conversation history
1057                    let user_msg = Message::User(UserMessage {
1058                        id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
1059                        session_id: session_id.to_string(),
1060                        turn_id: effective_turn_id.clone(),
1061                        created_at: Self::current_timestamp_millis(),
1062                        content: vec![ContentBlock::text(&request.content)],
1063                    });
1064                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1065                }
1066            }
1067            LLMRequestType::ToolResult => {
1068                // Store compact summaries for later compaction
1069                self.store_compact_summaries(&request.compact_summaries)
1070                    .await;
1071
1072                // Log conversation state before adding tool results (streaming path)
1073                {
1074                    let conv = self.conversation.read().await;
1075                    tracing::debug!(
1076                        session_id,
1077                        conversation_len = conv.len(),
1078                        tool_result_count = request.tool_results.len(),
1079                        "STREAMING ToolResult: conversation state before adding results"
1080                    );
1081                }
1082                // Add tool result messages using LLM client's proper format
1083                for tool_result in &request.tool_results {
1084                    llm_messages.push(LLMMessage::tool_result(
1085                        &tool_result.tool_use_id,
1086                        &tool_result.content,
1087                        tool_result.is_error,
1088                    ));
1089
1090                    // Get compact summary if available
1091                    let compact_summary = request
1092                        .compact_summaries
1093                        .get(&tool_result.tool_use_id)
1094                        .cloned();
1095
1096                    // Add tool result to conversation history
1097                    let user_msg = Message::User(UserMessage {
1098                        id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1099                        session_id: session_id.to_string(),
1100                        turn_id: effective_turn_id.clone(),
1101                        created_at: Self::current_timestamp_millis(),
1102                        content: vec![ContentBlock::ToolResult(
1103                            crate::controller::types::ToolResultBlock {
1104                                tool_use_id: tool_result.tool_use_id.clone(),
1105                                content: tool_result.content.clone(),
1106                                is_error: tool_result.is_error,
1107                                compact_summary,
1108                            },
1109                        )],
1110                    });
1111                    Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1112                }
1113            }
1114        }
1115
1116        // Perform compaction if needed before LLM call
1117        self.maybe_compact().await;
1118
1119        // Build message options with tools
1120        let options = self.build_message_options().await;
1121
1122        // Call the streaming LLM API
1123        let stream_result = self
1124            .client
1125            .send_message_stream(&llm_messages, &options)
1126            .await;
1127
1128        match stream_result {
1129            Ok(mut stream) => {
1130                // Track current content block for tool use accumulation
1131                let mut current_tool_id: Option<String> = None;
1132                let mut current_tool_name: Option<String> = None;
1133                let mut tool_input_json = String::new();
1134                // Accumulate response text for conversation history
1135                let mut response_text = String::new();
1136                // Accumulate completed tool uses for conversation history
1137                let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> =
1138                    Vec::new();
1139
1140                // Process stream events
1141                loop {
1142                    tokio::select! {
1143                        _ = request_token.cancelled() => {
1144                            tracing::info!(session_id, "Streaming request cancelled");
1145                            break;
1146                        }
1147                        event = stream.next() => {
1148                            match event {
1149                                Some(Ok(stream_event)) => {
1150                                    match stream_event {
1151                                        StreamEvent::MessageStart { message_id, model } => {
1152                                            let payload = FromLLMPayload {
1153                                                session_id,
1154                                                response_type: LLMResponseType::StreamStart,
1155                                                message_id,
1156                                                model,
1157                                                turn_id: request.turn_id.clone(),
1158                                                ..Default::default()
1159                                            };
1160                                            let _ = self.from_llm.send(payload).await;
1161                                        }
1162                                        StreamEvent::ContentBlockStart { index: _, block_type } => {
1163                                            match block_type {
1164                                                ContentBlockType::Text => {
1165                                                    // Text block starting, nothing special to do
1166                                                }
1167                                                ContentBlockType::ToolUse { id, name } => {
1168                                                    // Track tool use for later batch execution
1169                                                    // Don't emit events here - wait until execution begins
1170                                                    current_tool_id = Some(id);
1171                                                    current_tool_name = Some(name);
1172                                                    tool_input_json.clear();
1173                                                }
1174                                            }
1175                                        }
1176                                        StreamEvent::TextDelta { index, text } => {
1177                                            // Accumulate for conversation history
1178                                            response_text.push_str(&text);
1179
1180                                            let payload = FromLLMPayload {
1181                                                session_id,
1182                                                response_type: LLMResponseType::TextChunk,
1183                                                text,
1184                                                content_index: index,
1185                                                turn_id: request.turn_id.clone(),
1186                                                ..Default::default()
1187                                            };
1188                                            let _ = self.from_llm.send(payload).await;
1189                                        }
1190                                        StreamEvent::InputJsonDelta { index, json } => {
1191                                            // Accumulate JSON for tool input
1192                                            tool_input_json.push_str(&json);
1193
1194                                            let payload = FromLLMPayload {
1195                                                session_id,
1196                                                response_type: LLMResponseType::ToolInputDelta,
1197                                                text: json,
1198                                                content_index: index,
1199                                                turn_id: request.turn_id.clone(),
1200                                                ..Default::default()
1201                                            };
1202                                            let _ = self.from_llm.send(payload).await;
1203                                        }
1204                                        StreamEvent::ContentBlockStop { index: _ } => {
1205                                            // If we were accumulating a tool use, save it for later
1206                                            // Don't emit event - ToolUseStart already emitted at ContentBlockStart
1207                                            // Wait until MessageStop to emit ToolBatch for execution
1208                                            if let (Some(id), Some(name)) =
1209                                                (current_tool_id.take(), current_tool_name.take())
1210                                            {
1211                                                let input: serde_json::Value =
1212                                                    serde_json::from_str(&tool_input_json)
1213                                                        .unwrap_or(serde_json::Value::Object(
1214                                                            serde_json::Map::new(),
1215                                                        ));
1216
1217                                                // Save tool use for conversation history and batch execution
1218                                                tracing::debug!(
1219                                                    session_id,
1220                                                    tool_id = %id,
1221                                                    tool_name = %name,
1222                                                    "Saving tool use to completed_tool_uses"
1223                                                );
1224                                                completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1225                                                    id: id.clone(),
1226                                                    name: name.clone(),
1227                                                    input: input
1228                                                        .as_object()
1229                                                        .map(|obj| {
1230                                                            obj.iter()
1231                                                                .map(|(k, v)| (k.clone(), v.clone()))
1232                                                                .collect()
1233                                                        })
1234                                                        .unwrap_or_default(),
1235                                                });
1236
1237                                                tool_input_json.clear();
1238                                            }
1239                                        }
1240                                        StreamEvent::MessageDelta { stop_reason, usage } => {
1241                                            if let Some(usage) = usage {
1242                                                tracing::info!(
1243                                                    session_id,
1244                                                    input_tokens = usage.input_tokens,
1245                                                    output_tokens = usage.output_tokens,
1246                                                    "API token usage for this turn"
1247                                                );
1248                                                self.current_input_tokens
1249                                                    .store(usage.input_tokens as i64, Ordering::SeqCst);
1250                                                self.current_output_tokens
1251                                                    .store(usage.output_tokens as i64, Ordering::SeqCst);
1252
1253                                                let payload = FromLLMPayload {
1254                                                    session_id,
1255                                                    response_type: LLMResponseType::TokenUpdate,
1256                                                    input_tokens: usage.input_tokens as i64,
1257                                                    output_tokens: usage.output_tokens as i64,
1258                                                    turn_id: request.turn_id.clone(),
1259                                                    ..Default::default()
1260                                                };
1261                                                let _ = self.from_llm.send(payload).await;
1262                                            }
1263
1264                                            if stop_reason.is_some() {
1265                                                let payload = FromLLMPayload {
1266                                                    session_id,
1267                                                    response_type: LLMResponseType::Complete,
1268                                                    is_complete: true,
1269                                                    stop_reason,
1270                                                    turn_id: request.turn_id.clone(),
1271                                                    ..Default::default()
1272                                                };
1273                                                let _ = self.from_llm.send(payload).await;
1274                                            }
1275                                        }
1276                                        StreamEvent::MessageStop => {
1277                                            // Add assistant message to conversation history
1278                                            // Must save both text AND tool uses
1279                                            tracing::debug!(
1280                                                session_id,
1281                                                text_len = response_text.len(),
1282                                                tool_use_count = completed_tool_uses.len(),
1283                                                "MessageStop: saving assistant message to history"
1284                                            );
1285                                            if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1286                                                let now = Self::current_timestamp_millis();
1287
1288                                                // Build content blocks: text first, then tool uses
1289                                                let mut content_blocks = Vec::new();
1290                                                if !response_text.is_empty() {
1291                                                    content_blocks.push(ContentBlock::text(&response_text));
1292                                                }
1293                                                for tool_use in &completed_tool_uses {
1294                                                    content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1295                                                }
1296
1297                                                let content_block_count = content_blocks.len();
1298                                                let asst_msg = Message::Assistant(AssistantMessage {
1299                                                    id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1300                                                    session_id: session_id.to_string(),
1301                                                    turn_id: effective_turn_id.clone(),
1302                                                    parent_id: String::new(),
1303                                                    created_at: now,
1304                                                    completed_at: Some(now),
1305                                                    model_id: self.config.model.clone(),
1306                                                    provider_id: String::new(),
1307                                                    input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1308                                                    output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1309                                                    cache_read_tokens: 0,
1310                                                    cache_write_tokens: 0,
1311                                                    finish_reason: None,
1312                                                    error: None,
1313                                                    content: content_blocks,
1314                                                });
1315                                                Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
1316                                                tracing::debug!(
1317                                                    session_id,
1318                                                    content_block_count,
1319                                                    "MessageStop: saved assistant message with content blocks"
1320                                                );
1321                                            }
1322
1323                                            // If there are tool uses, emit them as a batch for execution
1324                                            // This ensures all tools are executed together and results sent back in one message
1325                                            if !completed_tool_uses.is_empty() {
1326                                                let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1327                                                    .iter()
1328                                                    .map(|tu| crate::controller::types::ToolUseInfo {
1329                                                        id: tu.id.clone(),
1330                                                        name: tu.name.clone(),
1331                                                        input: serde_json::Value::Object(
1332                                                            tu.input.iter()
1333                                                                .map(|(k, v)| (k.clone(), v.clone()))
1334                                                                .collect()
1335                                                        ),
1336                                                    })
1337                                                    .collect();
1338
1339                                                tracing::debug!(
1340                                                    session_id,
1341                                                    tool_count = tool_uses.len(),
1342                                                    "MessageStop: emitting ToolBatch for execution"
1343                                                );
1344
1345                                                let payload = FromLLMPayload {
1346                                                    session_id,
1347                                                    response_type: LLMResponseType::ToolBatch,
1348                                                    tool_uses,
1349                                                    turn_id: request.turn_id.clone(),
1350                                                    ..Default::default()
1351                                                };
1352                                                let _ = self.from_llm.send(payload).await;
1353                                            }
1354
1355                                            // Stream complete
1356                                            self.request_count.fetch_add(1, Ordering::SeqCst);
1357                                            tracing::debug!(session_id, "Streaming request completed");
1358                                            break;
1359                                        }
1360                                        StreamEvent::Ping => {
1361                                            // Keep-alive, ignore
1362                                        }
1363                                    }
1364                                }
1365                                Some(Err(err)) => {
1366                                    tracing::error!(session_id, error = %err, "Stream error");
1367                                    let payload = FromLLMPayload {
1368                                        session_id,
1369                                        response_type: LLMResponseType::Error,
1370                                        error: Some(err.to_string()),
1371                                        turn_id: request.turn_id.clone(),
1372                                        ..Default::default()
1373                                    };
1374                                    let _ = self.from_llm.send(payload).await;
1375                                    break;
1376                                }
1377                                None => {
1378                                    // Stream ended
1379                                    break;
1380                                }
1381                            }
1382                        }
1383                    }
1384                }
1385            }
1386            Err(err) => {
1387                tracing::error!(session_id, error = %err, "Failed to start streaming");
1388                let payload = FromLLMPayload {
1389                    session_id,
1390                    response_type: LLMResponseType::Error,
1391                    error: Some(err.to_string()),
1392                    turn_id: request.turn_id,
1393                    ..Default::default()
1394                };
1395                let _ = self.from_llm.send(payload).await;
1396            }
1397        }
1398
1399        // Clear the request cancellation token and turn ID when done
1400        self.cleanup_request().await;
1401    }
1402}