1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
5use std::time::Instant;
6
7use tokio::sync::{mpsc, Mutex, RwLock};
8use tokio_util::sync::CancellationToken;
9
10use crate::client::error::LlmError;
11use crate::client::models::Tool as LLMTool;
12use crate::client::providers::anthropic::AnthropicProvider;
13use crate::client::providers::openai::OpenAIProvider;
14use crate::client::LLMClient;
15
16use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
17use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
18use crate::controller::types::{
19 AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
20};
21
22#[derive(Debug, Clone, Default)]
24pub struct TokenUsage {
25 pub total_input_tokens: i64,
27 pub total_output_tokens: i64,
29 pub request_count: i64,
31 pub last_input_tokens: i64,
33 pub last_output_tokens: i64,
35}
36
37#[derive(Debug, Clone)]
39pub struct SessionStatus {
40 pub session_id: i64,
42 pub model: String,
44 pub created_at: Instant,
46 pub conversation_len: usize,
48 pub context_used: i64,
50 pub context_limit: i32,
52 pub utilization: f64,
54 pub total_input: i64,
56 pub total_output: i64,
58 pub request_count: i64,
60}
61
62#[derive(Debug, Clone, Default)]
65pub struct CompactResult {
66 pub compacted: bool,
69 pub messages_before: usize,
71 pub messages_after: usize,
73 pub turns_compacted: usize,
75 pub turns_kept: usize,
77 pub summary_length: usize,
79 pub error: Option<String>,
81}
82
83static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
85
86pub struct LLMSession {
88 id: AtomicI64,
90
91 client: LLMClient,
93
94 to_llm_tx: mpsc::Sender<ToLLMPayload>,
96 to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
97 from_llm: mpsc::Sender<FromLLMPayload>,
98
99 config: LLMSessionConfig,
101
102 system_prompt: RwLock<Option<String>>,
104 max_tokens: AtomicI64,
105 created_at: Instant,
106
107 conversation: RwLock<Vec<Message>>,
109
110 shutdown: AtomicBool,
112 cancel_token: CancellationToken,
113
114 current_cancel: Mutex<Option<CancellationToken>>,
116
117 current_turn_id: RwLock<Option<TurnId>>,
119
120 current_input_tokens: AtomicI64,
122 current_output_tokens: AtomicI64,
123
124 request_count: AtomicI64,
126
127 tool_definitions: RwLock<Vec<LLMTool>>,
129
130 compactor: Option<Box<dyn Compactor>>,
132 llm_compactor: Option<LLMCompactor>,
133 context_limit: AtomicI32,
134 compact_summaries: RwLock<HashMap<String, String>>,
135}
136
137impl LLMSession {
138 pub fn new(
148 config: LLMSessionConfig,
149 from_llm: mpsc::Sender<FromLLMPayload>,
150 cancel_token: CancellationToken,
151 ) -> Result<Self, LlmError> {
152 let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
153 let (to_llm_tx, to_llm_rx) = mpsc::channel(32);
154 let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
155 let system_prompt = config.system_prompt.clone();
156
157 let client = match config.provider {
159 LLMProvider::Anthropic => {
160 let provider = AnthropicProvider::new(
161 config.api_key.clone(),
162 config.model.clone(),
163 );
164 LLMClient::new(Box::new(provider))?
165 }
166 LLMProvider::OpenAI => {
167 let provider = OpenAIProvider::new(
168 config.api_key.clone(),
169 config.model.clone(),
170 );
171 LLMClient::new(Box::new(provider))?
172 }
173 };
174
175 let mut compactor: Option<Box<dyn Compactor>> = None;
177 let mut llm_compactor: Option<LLMCompactor> = None;
178
179 if let Some(ref compactor_type) = config.compaction {
180 match compactor_type {
181 CompactorType::Threshold(c) => {
182 match ThresholdCompactor::new(c.threshold, c.keep_recent_turns, c.tool_compaction) {
183 Ok(tc) => {
184 tracing::info!(
185 threshold = c.threshold,
186 keep_recent_turns = c.keep_recent_turns,
187 tool_compaction = %c.tool_compaction,
188 "Threshold compaction enabled for session"
189 );
190 compactor = Some(Box::new(tc) as Box<dyn Compactor>);
191 }
192 Err(e) => {
193 tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
194 }
195 }
196 }
197 CompactorType::LLM(c) => {
198 let llm_client = match config.provider {
200 LLMProvider::Anthropic => {
201 let provider = AnthropicProvider::new(
202 config.api_key.clone(),
203 config.model.clone(),
204 );
205 LLMClient::new(Box::new(provider))?
206 }
207 LLMProvider::OpenAI => {
208 let provider = OpenAIProvider::new(
209 config.api_key.clone(),
210 config.model.clone(),
211 );
212 LLMClient::new(Box::new(provider))?
213 }
214 };
215
216 match LLMCompactor::new(llm_client, c.clone()) {
217 Ok(lc) => {
218 tracing::info!(
219 threshold = c.threshold,
220 keep_recent_turns = c.keep_recent_turns,
221 "LLM compaction enabled for session"
222 );
223 llm_compactor = Some(lc);
224 }
225 Err(e) => {
226 tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
227 }
228 }
229 }
230 }
231 }
232
233 let context_limit = config.context_limit;
234
235 Ok(Self {
236 id: AtomicI64::new(session_id),
237 client,
238 to_llm_tx,
239 to_llm_rx: Mutex::new(to_llm_rx),
240 from_llm,
241 config,
242 system_prompt: RwLock::new(system_prompt),
243 max_tokens: AtomicI64::new(max_tokens),
244 created_at: Instant::now(),
245 conversation: RwLock::new(Vec::new()),
246 shutdown: AtomicBool::new(false),
247 cancel_token,
248 current_cancel: Mutex::new(None),
249 current_turn_id: RwLock::new(None),
250 current_input_tokens: AtomicI64::new(0),
251 current_output_tokens: AtomicI64::new(0),
252 request_count: AtomicI64::new(0),
253 tool_definitions: RwLock::new(Vec::new()),
254 compactor,
255 llm_compactor,
256 context_limit: AtomicI32::new(context_limit),
257 compact_summaries: RwLock::new(HashMap::new()),
258 })
259 }
260
261 pub fn id(&self) -> i64 {
263 self.id.load(Ordering::SeqCst)
264 }
265
266 pub fn created_at(&self) -> Instant {
268 self.created_at
269 }
270
271 pub fn model(&self) -> &str {
273 &self.config.model
274 }
275
276 pub fn set_max_tokens(&self, max_tokens: i64) {
280 self.max_tokens.store(max_tokens, Ordering::SeqCst);
281 }
282
283 pub fn max_tokens(&self) -> i64 {
285 self.max_tokens.load(Ordering::SeqCst)
286 }
287
288 pub fn context_limit(&self) -> i32 {
290 self.context_limit.load(Ordering::SeqCst)
291 }
292
293 pub async fn set_system_prompt(&self, prompt: String) {
297 let mut guard = self.system_prompt.write().await;
298 *guard = Some(prompt);
299 }
300
301 pub async fn clear_system_prompt(&self) {
303 let mut guard = self.system_prompt.write().await;
304 *guard = None;
305 }
306
307 pub async fn system_prompt(&self) -> Option<String> {
309 self.system_prompt.read().await.clone()
310 }
311
312 pub async fn set_tools(&self, tools: Vec<LLMTool>) {
317 let mut guard = self.tool_definitions.write().await;
318 *guard = tools;
319 }
320
321 pub async fn clear_tools(&self) {
323 let mut guard = self.tool_definitions.write().await;
324 guard.clear();
325 }
326
327 pub async fn tools(&self) -> Vec<LLMTool> {
329 self.tool_definitions.read().await.clone()
330 }
331
332 async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
337 if summaries.is_empty() {
338 tracing::warn!(
339 session_id = self.id(),
340 "No compact summaries provided with tool results"
341 );
342 return;
343 }
344 let mut guard = self.compact_summaries.write().await;
345 for (tool_use_id, summary) in summaries {
346 tracing::info!(
347 session_id = self.id(),
348 tool_use_id = %tool_use_id,
349 summary_len = summary.len(),
350 summary_preview = %summary.chars().take(50).collect::<String>(),
351 "Storing compact summary"
352 );
353 guard.insert(tool_use_id.clone(), summary.clone());
354 }
355 tracing::info!(
356 session_id = self.id(),
357 new_summaries = summaries.len(),
358 total_stored = guard.len(),
359 "Stored compact summaries for tool results"
360 );
361 }
362
363 async fn maybe_compact(&self) {
366 let context_used = self.current_input_tokens.load(Ordering::SeqCst);
367 let context_limit = self.context_limit.load(Ordering::SeqCst);
368 let conversation_len = self.conversation.read().await.len();
369 let summaries_count = self.compact_summaries.read().await.len();
370
371 let utilization = if context_limit > 0 {
372 context_used as f64 / context_limit as f64
373 } else {
374 0.0
375 };
376
377 tracing::debug!(
378 session_id = self.id(),
379 context_used,
380 context_limit,
381 utilization = format!("{:.2}%", utilization * 100.0),
382 conversation_len,
383 summaries_available = summaries_count,
384 "Checking if compaction needed"
385 );
386
387 if let Some(ref llm_compactor) = self.llm_compactor {
389 if !llm_compactor.should_compact(context_used, context_limit) {
390 tracing::debug!(
391 session_id = self.id(),
392 "LLM compaction not triggered"
393 );
394 return;
395 }
396
397 let summaries = self.compact_summaries.read().await.clone();
399 let conversation = self.conversation.read().await.clone();
400
401 tracing::info!(
402 session_id = self.id(),
403 conversation_len = conversation.len(),
404 summaries_count = summaries.len(),
405 "Starting LLM compaction"
406 );
407
408 match llm_compactor.compact_async(conversation, &summaries).await {
410 Ok((new_conversation, result)) => {
411 *self.conversation.write().await = new_conversation;
413
414 if result.turns_compacted > 0 {
415 tracing::info!(
416 session_id = self.id(),
417 turns_compacted = result.turns_compacted,
418 "LLM compaction completed"
419 );
420 }
421 }
422 Err(e) => {
423 tracing::error!(
424 session_id = self.id(),
425 error = %e,
426 "LLM compaction failed"
427 );
428 }
429 }
430 return;
431 }
432
433 let compactor = match &self.compactor {
435 Some(c) => c,
436 None => {
437 tracing::debug!(
438 session_id = self.id(),
439 "No compactor configured"
440 );
441 return;
442 }
443 };
444
445 if !compactor.should_compact(context_used, context_limit) {
446 tracing::debug!(
447 session_id = self.id(),
448 "Threshold compaction not triggered"
449 );
450 return;
451 }
452
453 let summaries = self.compact_summaries.read().await.clone();
455 let mut conversation = self.conversation.write().await;
456
457 tracing::info!(
458 session_id = self.id(),
459 conversation_len = conversation.len(),
460 summaries_count = summaries.len(),
461 "Starting threshold compaction"
462 );
463
464 let result = compactor.compact(&mut conversation, &summaries);
465
466 tracing::info!(
467 session_id = self.id(),
468 tool_results_summarized = result.tool_results_summarized,
469 tool_results_redacted = result.tool_results_redacted,
470 turns_compacted = result.turns_compacted,
471 conversation_len_after = conversation.len(),
472 "Threshold compaction completed"
473 );
474 }
475
476 pub async fn clear_conversation(&self) {
480 let mut conversation = self.conversation.write().await;
481 conversation.clear();
482
483 let mut summaries = self.compact_summaries.write().await;
484 summaries.clear();
485
486 self.current_input_tokens.store(0, Ordering::SeqCst);
488 self.current_output_tokens.store(0, Ordering::SeqCst);
489
490 tracing::info!(session_id = self.id(), "Conversation cleared");
491 }
492
493 pub async fn force_compact(&self) -> CompactResult {
496 if let Some(ref llm_compactor) = self.llm_compactor {
498 let summaries = self.compact_summaries.read().await.clone();
499 let conversation = self.conversation.read().await.clone();
500 let messages_before = conversation.len();
501 let turns_before = self.count_unique_turns(&conversation);
502
503 match llm_compactor.compact_async(conversation, &summaries).await {
504 Ok((new_conversation, result)) => {
505 let messages_after = new_conversation.len();
506 let turns_after = self.count_unique_turns(&new_conversation);
507 let compacted = messages_after < messages_before;
508
509 let summary_length = if compacted && !new_conversation.is_empty() {
511 self.extract_summary_length(&new_conversation[0])
512 } else {
513 0
514 };
515
516 *self.conversation.write().await = new_conversation;
517
518 if result.turns_compacted > 0 {
519 tracing::info!(
520 session_id = self.id(),
521 turns_compacted = result.turns_compacted,
522 messages_before,
523 messages_after,
524 "Forced LLM compaction completed"
525 );
526 }
527
528 return CompactResult {
529 compacted,
530 messages_before,
531 messages_after,
532 turns_compacted: turns_before.saturating_sub(turns_after),
533 turns_kept: turns_after,
534 summary_length,
535 error: None,
536 };
537 }
538 Err(e) => {
539 tracing::error!(
540 session_id = self.id(),
541 error = %e,
542 "Forced LLM compaction failed"
543 );
544 return CompactResult {
545 compacted: false,
546 messages_before,
547 messages_after: messages_before,
548 turns_compacted: 0,
549 turns_kept: turns_before,
550 summary_length: 0,
551 error: Some(format!("Compaction failed: {}", e)),
552 };
553 }
554 }
555 }
556
557 if let Some(ref compactor) = self.compactor {
559 let summaries = self.compact_summaries.read().await.clone();
560 let mut conversation = self.conversation.write().await;
561 let messages_before = conversation.len();
562 let turns_before = self.count_unique_turns(&conversation);
563
564 let result = compactor.compact(&mut conversation, &summaries);
565
566 let messages_after = conversation.len();
567 let turns_after = self.count_unique_turns(&conversation);
568 let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
569
570 if result.total_compacted() > 0 {
571 tracing::info!(
572 session_id = self.id(),
573 tool_results_summarized = result.tool_results_summarized,
574 tool_results_redacted = result.tool_results_redacted,
575 turns_compacted = result.turns_compacted,
576 "Forced threshold compaction completed"
577 );
578 }
579
580 return CompactResult {
581 compacted,
582 messages_before,
583 messages_after,
584 turns_compacted: turns_before.saturating_sub(turns_after),
585 turns_kept: turns_after,
586 summary_length: 0,
587 error: None,
588 };
589 }
590
591 CompactResult {
593 compacted: false,
594 error: Some("No compactor configured".to_string()),
595 ..Default::default()
596 }
597 }
598
599 fn count_unique_turns(&self, conversation: &[Message]) -> usize {
601 use std::collections::HashSet;
602 let mut turn_ids = HashSet::new();
603 for msg in conversation {
604 turn_ids.insert(msg.turn_id().clone());
605 }
606 turn_ids.len()
607 }
608
609 fn extract_summary_length(&self, message: &Message) -> usize {
611 if let Message::User(user_msg) = message {
612 for block in &user_msg.content {
613 if let ContentBlock::Text(text_block) = block {
614 if text_block.text.starts_with("[Previous conversation summary]") {
615 return text_block.text.len();
616 }
617 }
618 }
619 }
620 0
621 }
622
623 pub async fn send(&self, msg: ToLLMPayload) -> bool {
626 if self.shutdown.load(Ordering::SeqCst) {
627 return false;
628 }
629 self.to_llm_tx.send(msg).await.is_ok()
630 }
631
632 pub async fn interrupt(&self) {
636 let guard = self.current_cancel.lock().await;
637 if let Some(token) = guard.as_ref() {
638 token.cancel();
639
640 let turn_id = self.current_turn_id.read().await.clone();
644 if let Some(turn_id) = turn_id {
645 let mut conversation = self.conversation.write().await;
646 let original_len = conversation.len();
647 conversation.retain(|msg| msg.turn_id() != &turn_id);
648 let removed = original_len - conversation.len();
649 tracing::debug!(
650 session_id = self.id(),
651 turn_id = %turn_id,
652 messages_removed = removed,
653 conversation_length = conversation.len(),
654 "Removed messages from cancelled turn"
655 );
656 }
657 }
658 }
659
660 pub fn shutdown(&self) {
663 self.shutdown.store(true, Ordering::SeqCst);
665 self.cancel_token.cancel();
667 }
668
669 pub fn is_shutdown(&self) -> bool {
671 self.shutdown.load(Ordering::SeqCst)
672 }
673
674 pub async fn start(&self) {
680 tracing::info!(session_id = self.id(), "Session starting");
681
682 loop {
683 let mut rx_guard = self.to_llm_rx.lock().await;
684
685 tokio::select! {
686 _ = self.cancel_token.cancelled() => {
687 tracing::info!(session_id = self.id(), "Session cancelled");
688 break;
689 }
690 msg = rx_guard.recv() => {
691 match msg {
692 Some(request) => {
693 drop(rx_guard);
695 self.handle_request(request).await;
696 }
697 None => {
698 tracing::info!(session_id = self.id(), "Session channel closed");
700 break;
701 }
702 }
703 }
704 }
705 }
706
707 tracing::info!(session_id = self.id(), "Session stopped");
708 }
709
710 async fn handle_request(&self, request: ToLLMPayload) {
712 if self.config.streaming {
713 self.handle_streaming_request(request).await;
714 } else {
715 self.handle_non_streaming_request(request).await;
716 }
717 }
718
719 async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
721 use super::convert::{from_llm_message, to_llm_messages};
722 use crate::controller::types::{LLMRequestType, LLMResponseType};
723 use crate::client::models::{Message as LLMMessage, MessageOptions};
724
725 let request_token = CancellationToken::new();
727 {
728 let mut guard = self.current_cancel.lock().await;
729 *guard = Some(request_token.clone());
730 }
731
732 let effective_turn_id = request
734 .turn_id
735 .clone()
736 .unwrap_or_else(|| TurnId::new_user_turn(0));
737 {
738 let mut guard = self.current_turn_id.write().await;
739 *guard = Some(effective_turn_id.clone());
740 }
741
742 let session_id = self.id();
743 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
744
745 let mut llm_messages: Vec<LLMMessage> = Vec::new();
747
748 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
750 llm_messages.push(LLMMessage::system(prompt.clone()));
751 }
752
753 let conversation = self.conversation.read().await;
755 llm_messages.extend(to_llm_messages(&conversation));
756 drop(conversation);
757
758 match request.request_type {
760 LLMRequestType::UserMessage => {
761 if !request.content.is_empty() {
762 llm_messages.push(LLMMessage::user(&request.content));
763
764 let user_msg = Message::User(UserMessage {
766 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
767 session_id: session_id.to_string(),
768 turn_id: effective_turn_id.clone(),
769 created_at: std::time::SystemTime::now()
770 .duration_since(std::time::UNIX_EPOCH)
771 .unwrap_or_default()
772 .as_millis() as i64,
773 content: vec![ContentBlock::text(&request.content)],
774 });
775 self.conversation.write().await.push(user_msg);
776 }
777 }
778 LLMRequestType::ToolResult => {
779 self.store_compact_summaries(&request.compact_summaries).await;
781
782 for tool_result in &request.tool_results {
784 llm_messages.push(LLMMessage::tool_result(
785 &tool_result.tool_use_id,
786 &tool_result.content,
787 tool_result.is_error,
788 ));
789
790 let compact_summary = request
792 .compact_summaries
793 .get(&tool_result.tool_use_id)
794 .cloned();
795
796 let user_msg = Message::User(UserMessage {
798 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
799 session_id: session_id.to_string(),
800 turn_id: effective_turn_id.clone(),
801 created_at: std::time::SystemTime::now()
802 .duration_since(std::time::UNIX_EPOCH)
803 .unwrap_or_default()
804 .as_millis() as i64,
805 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
806 tool_use_id: tool_result.tool_use_id.clone(),
807 content: tool_result.content.clone(),
808 is_error: tool_result.is_error,
809 compact_summary,
810 })],
811 });
812 self.conversation.write().await.push(user_msg);
813 }
814 }
815 }
816
817 self.maybe_compact().await;
819
820 let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
822 let tools = self.tool_definitions.read().await.clone();
823 let tools_option = if tools.is_empty() { None } else { Some(tools) };
824 let options = MessageOptions {
825 max_tokens: Some(max_tokens),
826 temperature: self.config.temperature,
827 tools: tools_option,
828 ..Default::default()
829 };
830
831 let result = self.client.send_message(&llm_messages, &options).await;
833
834 match result {
835 Ok(response) => {
836 let content_blocks = from_llm_message(&response);
838
839 let text: String = content_blocks
841 .iter()
842 .filter_map(|block| {
843 if let ContentBlock::Text(t) = block {
844 Some(t.text.clone())
845 } else {
846 None
847 }
848 })
849 .collect::<Vec<_>>()
850 .join("");
851
852 if !text.is_empty() {
854 let payload = FromLLMPayload {
855 session_id,
856 response_type: LLMResponseType::TextChunk,
857 text: text.clone(),
858 turn_id: request.turn_id.clone(),
859 ..Default::default()
860 };
861 let _ = self.from_llm.send(payload).await;
862 }
863
864 for block in &content_blocks {
866 if let ContentBlock::ToolUse(tool_use) = block {
867 let payload = FromLLMPayload {
868 session_id,
869 response_type: LLMResponseType::ToolUse,
870 tool_use: Some(crate::controller::types::ToolUseInfo {
871 id: tool_use.id.clone(),
872 name: tool_use.name.clone(),
873 input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
874 }),
875 turn_id: request.turn_id.clone(),
876 ..Default::default()
877 };
878 let _ = self.from_llm.send(payload).await;
879 }
880 }
881
882 let now = std::time::SystemTime::now()
884 .duration_since(std::time::UNIX_EPOCH)
885 .unwrap_or_default()
886 .as_millis() as i64;
887 let asst_msg = Message::Assistant(AssistantMessage {
888 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
889 session_id: session_id.to_string(),
890 turn_id: effective_turn_id.clone(),
891 parent_id: String::new(),
892 created_at: now,
893 completed_at: Some(now),
894 model_id: self.config.model.clone(),
895 provider_id: String::new(),
896 input_tokens: 0,
897 output_tokens: 0,
898 cache_read_tokens: 0,
899 cache_write_tokens: 0,
900 finish_reason: None,
901 error: None,
902 content: content_blocks,
903 });
904 self.conversation.write().await.push(asst_msg);
905
906 let payload = FromLLMPayload {
908 session_id,
909 response_type: LLMResponseType::Complete,
910 is_complete: true,
911 turn_id: request.turn_id.clone(),
912 ..Default::default()
913 };
914 let _ = self.from_llm.send(payload).await;
915
916 self.request_count.fetch_add(1, Ordering::SeqCst);
918
919 tracing::debug!(session_id, "Request completed successfully");
920 }
921 Err(err) => {
922 tracing::error!(session_id, error = %err, "LLM request failed");
923
924 let payload = FromLLMPayload {
925 session_id,
926 response_type: LLMResponseType::Error,
927 error: Some(err.to_string()),
928 turn_id: request.turn_id,
929 ..Default::default()
930 };
931 let _ = self.from_llm.send(payload).await;
932 }
933 }
934
935 {
937 let mut guard = self.current_cancel.lock().await;
938 *guard = None;
939 }
940 {
941 let mut guard = self.current_turn_id.write().await;
942 *guard = None;
943 }
944 }
945
946 async fn handle_streaming_request(&self, request: ToLLMPayload) {
948 use super::convert::to_llm_messages;
949 use crate::controller::types::{LLMRequestType, LLMResponseType};
950 use futures::StreamExt;
951 use crate::client::models::{
952 ContentBlockType, Message as LLMMessage, MessageOptions, StreamEvent,
953 };
954
955 let request_token = CancellationToken::new();
957 {
958 let mut guard = self.current_cancel.lock().await;
959 *guard = Some(request_token.clone());
960 }
961
962 let effective_turn_id = request
964 .turn_id
965 .clone()
966 .unwrap_or_else(|| TurnId::new_user_turn(0));
967 {
968 let mut guard = self.current_turn_id.write().await;
969 *guard = Some(effective_turn_id.clone());
970 }
971
972 let session_id = self.id();
973 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
974
975 let mut llm_messages: Vec<LLMMessage> = Vec::new();
977
978 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
980 llm_messages.push(LLMMessage::system(prompt.clone()));
981 }
982
983 let conversation = self.conversation.read().await;
985 llm_messages.extend(to_llm_messages(&conversation));
986 drop(conversation);
987
988 match request.request_type {
990 LLMRequestType::UserMessage => {
991 if !request.content.is_empty() {
992 llm_messages.push(LLMMessage::user(&request.content));
993
994 let user_msg = Message::User(UserMessage {
996 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
997 session_id: session_id.to_string(),
998 turn_id: effective_turn_id.clone(),
999 created_at: std::time::SystemTime::now()
1000 .duration_since(std::time::UNIX_EPOCH)
1001 .unwrap_or_default()
1002 .as_millis() as i64,
1003 content: vec![ContentBlock::text(&request.content)],
1004 });
1005 self.conversation.write().await.push(user_msg);
1006 }
1007 }
1008 LLMRequestType::ToolResult => {
1009 self.store_compact_summaries(&request.compact_summaries).await;
1011
1012 {
1014 let conv = self.conversation.read().await;
1015 tracing::debug!(
1016 session_id,
1017 conversation_len = conv.len(),
1018 tool_result_count = request.tool_results.len(),
1019 "STREAMING ToolResult: conversation state before adding results"
1020 );
1021 }
1022 for tool_result in &request.tool_results {
1024 llm_messages.push(LLMMessage::tool_result(
1025 &tool_result.tool_use_id,
1026 &tool_result.content,
1027 tool_result.is_error,
1028 ));
1029
1030 let compact_summary = request
1032 .compact_summaries
1033 .get(&tool_result.tool_use_id)
1034 .cloned();
1035
1036 let user_msg = Message::User(UserMessage {
1038 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1039 session_id: session_id.to_string(),
1040 turn_id: effective_turn_id.clone(),
1041 created_at: std::time::SystemTime::now()
1042 .duration_since(std::time::UNIX_EPOCH)
1043 .unwrap_or_default()
1044 .as_millis() as i64,
1045 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
1046 tool_use_id: tool_result.tool_use_id.clone(),
1047 content: tool_result.content.clone(),
1048 is_error: tool_result.is_error,
1049 compact_summary,
1050 })],
1051 });
1052 self.conversation.write().await.push(user_msg);
1053 }
1054 }
1055 }
1056
1057 self.maybe_compact().await;
1059
1060 let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
1062 let tools = self.tool_definitions.read().await.clone();
1063 let tools_option = if tools.is_empty() { None } else { Some(tools) };
1064 let options = MessageOptions {
1065 max_tokens: Some(max_tokens),
1066 temperature: self.config.temperature,
1067 tools: tools_option,
1068 ..Default::default()
1069 };
1070
1071 let stream_result = self
1073 .client
1074 .send_message_stream(&llm_messages, &options)
1075 .await;
1076
1077 match stream_result {
1078 Ok(mut stream) => {
1079 let mut current_tool_id: Option<String> = None;
1081 let mut current_tool_name: Option<String> = None;
1082 let mut tool_input_json = String::new();
1083 let mut response_text = String::new();
1085 let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> = Vec::new();
1087
1088 loop {
1090 tokio::select! {
1091 _ = request_token.cancelled() => {
1092 tracing::info!(session_id, "Streaming request cancelled");
1093 break;
1094 }
1095 event = stream.next() => {
1096 match event {
1097 Some(Ok(stream_event)) => {
1098 match stream_event {
1099 StreamEvent::MessageStart { message_id, model } => {
1100 let payload = FromLLMPayload {
1101 session_id,
1102 response_type: LLMResponseType::StreamStart,
1103 message_id,
1104 model,
1105 turn_id: request.turn_id.clone(),
1106 ..Default::default()
1107 };
1108 let _ = self.from_llm.send(payload).await;
1109 }
1110 StreamEvent::ContentBlockStart { index: _, block_type } => {
1111 match block_type {
1112 ContentBlockType::Text => {
1113 }
1115 ContentBlockType::ToolUse { id, name } => {
1116 current_tool_id = Some(id);
1119 current_tool_name = Some(name);
1120 tool_input_json.clear();
1121 }
1122 }
1123 }
1124 StreamEvent::TextDelta { index, text } => {
1125 response_text.push_str(&text);
1127
1128 let payload = FromLLMPayload {
1129 session_id,
1130 response_type: LLMResponseType::TextChunk,
1131 text,
1132 content_index: index,
1133 turn_id: request.turn_id.clone(),
1134 ..Default::default()
1135 };
1136 let _ = self.from_llm.send(payload).await;
1137 }
1138 StreamEvent::InputJsonDelta { index, json } => {
1139 tool_input_json.push_str(&json);
1141
1142 let payload = FromLLMPayload {
1143 session_id,
1144 response_type: LLMResponseType::ToolInputDelta,
1145 text: json,
1146 content_index: index,
1147 turn_id: request.turn_id.clone(),
1148 ..Default::default()
1149 };
1150 let _ = self.from_llm.send(payload).await;
1151 }
1152 StreamEvent::ContentBlockStop { index: _ } => {
1153 if let (Some(id), Some(name)) =
1157 (current_tool_id.take(), current_tool_name.take())
1158 {
1159 let input: serde_json::Value =
1160 serde_json::from_str(&tool_input_json)
1161 .unwrap_or(serde_json::Value::Object(
1162 serde_json::Map::new(),
1163 ));
1164
1165 tracing::debug!(
1167 session_id,
1168 tool_id = %id,
1169 tool_name = %name,
1170 "Saving tool use to completed_tool_uses"
1171 );
1172 completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1173 id: id.clone(),
1174 name: name.clone(),
1175 input: input
1176 .as_object()
1177 .map(|obj| {
1178 obj.iter()
1179 .map(|(k, v)| (k.clone(), v.clone()))
1180 .collect()
1181 })
1182 .unwrap_or_default(),
1183 });
1184
1185 tool_input_json.clear();
1186 }
1187 }
1188 StreamEvent::MessageDelta { stop_reason, usage } => {
1189 if let Some(usage) = usage {
1190 tracing::info!(
1191 session_id,
1192 input_tokens = usage.input_tokens,
1193 output_tokens = usage.output_tokens,
1194 "API token usage for this turn"
1195 );
1196 self.current_input_tokens
1197 .store(usage.input_tokens as i64, Ordering::SeqCst);
1198 self.current_output_tokens
1199 .store(usage.output_tokens as i64, Ordering::SeqCst);
1200
1201 let payload = FromLLMPayload {
1202 session_id,
1203 response_type: LLMResponseType::TokenUpdate,
1204 input_tokens: usage.input_tokens as i64,
1205 output_tokens: usage.output_tokens as i64,
1206 turn_id: request.turn_id.clone(),
1207 ..Default::default()
1208 };
1209 let _ = self.from_llm.send(payload).await;
1210 }
1211
1212 if stop_reason.is_some() {
1213 let payload = FromLLMPayload {
1214 session_id,
1215 response_type: LLMResponseType::Complete,
1216 is_complete: true,
1217 stop_reason,
1218 turn_id: request.turn_id.clone(),
1219 ..Default::default()
1220 };
1221 let _ = self.from_llm.send(payload).await;
1222 }
1223 }
1224 StreamEvent::MessageStop => {
1225 tracing::debug!(
1228 session_id,
1229 text_len = response_text.len(),
1230 tool_use_count = completed_tool_uses.len(),
1231 "MessageStop: saving assistant message to history"
1232 );
1233 if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1234 let now = std::time::SystemTime::now()
1235 .duration_since(std::time::UNIX_EPOCH)
1236 .unwrap_or_default()
1237 .as_millis() as i64;
1238
1239 let mut content_blocks = Vec::new();
1241 if !response_text.is_empty() {
1242 content_blocks.push(ContentBlock::text(&response_text));
1243 }
1244 for tool_use in &completed_tool_uses {
1245 content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1246 }
1247
1248 let content_block_count = content_blocks.len();
1249 let asst_msg = Message::Assistant(AssistantMessage {
1250 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1251 session_id: session_id.to_string(),
1252 turn_id: effective_turn_id.clone(),
1253 parent_id: String::new(),
1254 created_at: now,
1255 completed_at: Some(now),
1256 model_id: self.config.model.clone(),
1257 provider_id: String::new(),
1258 input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1259 output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1260 cache_read_tokens: 0,
1261 cache_write_tokens: 0,
1262 finish_reason: None,
1263 error: None,
1264 content: content_blocks,
1265 });
1266 self.conversation.write().await.push(asst_msg);
1267 tracing::debug!(
1268 session_id,
1269 content_block_count,
1270 "MessageStop: saved assistant message with content blocks"
1271 );
1272 }
1273
1274 if !completed_tool_uses.is_empty() {
1277 let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1278 .iter()
1279 .map(|tu| crate::controller::types::ToolUseInfo {
1280 id: tu.id.clone(),
1281 name: tu.name.clone(),
1282 input: serde_json::Value::Object(
1283 tu.input.iter()
1284 .map(|(k, v)| (k.clone(), v.clone()))
1285 .collect()
1286 ),
1287 })
1288 .collect();
1289
1290 tracing::debug!(
1291 session_id,
1292 tool_count = tool_uses.len(),
1293 "MessageStop: emitting ToolBatch for execution"
1294 );
1295
1296 let payload = FromLLMPayload {
1297 session_id,
1298 response_type: LLMResponseType::ToolBatch,
1299 tool_uses,
1300 turn_id: request.turn_id.clone(),
1301 ..Default::default()
1302 };
1303 let _ = self.from_llm.send(payload).await;
1304 }
1305
1306 self.request_count.fetch_add(1, Ordering::SeqCst);
1308 tracing::debug!(session_id, "Streaming request completed");
1309 break;
1310 }
1311 StreamEvent::Ping => {
1312 }
1314 }
1315 }
1316 Some(Err(err)) => {
1317 tracing::error!(session_id, error = %err, "Stream error");
1318 let payload = FromLLMPayload {
1319 session_id,
1320 response_type: LLMResponseType::Error,
1321 error: Some(err.to_string()),
1322 turn_id: request.turn_id.clone(),
1323 ..Default::default()
1324 };
1325 let _ = self.from_llm.send(payload).await;
1326 break;
1327 }
1328 None => {
1329 break;
1331 }
1332 }
1333 }
1334 }
1335 }
1336 }
1337 Err(err) => {
1338 tracing::error!(session_id, error = %err, "Failed to start streaming");
1339 let payload = FromLLMPayload {
1340 session_id,
1341 response_type: LLMResponseType::Error,
1342 error: Some(err.to_string()),
1343 turn_id: request.turn_id,
1344 ..Default::default()
1345 };
1346 let _ = self.from_llm.send(payload).await;
1347 }
1348 }
1349
1350 {
1352 let mut guard = self.current_cancel.lock().await;
1353 *guard = None;
1354 }
1355 {
1356 let mut guard = self.current_turn_id.write().await;
1357 *guard = None;
1358 }
1359 }
1360}