1use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
6use std::time::Instant;
7
8use tokio::sync::{Mutex, RwLock, mpsc};
9use tokio_util::sync::CancellationToken;
10
11use crate::client::LLMClient;
12use crate::client::error::LlmError;
13use crate::client::models::Tool as LLMTool;
14use crate::client::providers::anthropic::AnthropicProvider;
15use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
16use crate::client::providers::cohere::CohereProvider;
17use crate::client::providers::gemini::GeminiProvider;
18use crate::client::providers::openai::OpenAIProvider;
19
20use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
21use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
22
23fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
25 match config.provider {
26 LLMProvider::Anthropic => {
27 let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
28 LLMClient::new(Box::new(provider))
29 }
30 LLMProvider::OpenAI => {
31 let provider = if let (Some(resource), Some(deployment)) =
33 (&config.azure_resource, &config.azure_deployment)
34 {
35 let api_version = config
36 .azure_api_version
37 .clone()
38 .unwrap_or_else(|| "2024-10-21".to_string());
39 OpenAIProvider::azure(
40 config.api_key.clone(),
41 resource.clone(),
42 deployment.clone(),
43 api_version,
44 )
45 } else if let Some(base_url) = &config.base_url {
46 OpenAIProvider::with_base_url(
47 config.api_key.clone(),
48 config.model.clone(),
49 base_url.clone(),
50 )
51 } else {
52 OpenAIProvider::new(config.api_key.clone(), config.model.clone())
53 };
54 LLMClient::new(Box::new(provider))
55 }
56 LLMProvider::Google => {
57 let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
58 LLMClient::new(Box::new(provider))
59 }
60 LLMProvider::Cohere => {
61 let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
62 LLMClient::new(Box::new(provider))
63 }
64 LLMProvider::Bedrock => {
65 let region = config.bedrock_region.clone().ok_or_else(|| {
67 LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_region")
68 })?;
69 let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
70 LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_access_key_id")
71 })?;
72 let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
73 LlmError::new(
74 "MISSING_CONFIG",
75 "Bedrock requires bedrock_secret_access_key",
76 )
77 })?;
78
79 let credentials = match &config.bedrock_session_token {
80 Some(token) => BedrockCredentials::with_session_token(
81 access_key_id,
82 secret_access_key,
83 token.clone(),
84 ),
85 None => BedrockCredentials::new(access_key_id, secret_access_key),
86 };
87
88 let provider = BedrockProvider::new(credentials, region, config.model.clone());
89 LLMClient::new(Box::new(provider))
90 }
91 }
92}
93use crate::controller::types::{
94 AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
95};
96
97#[derive(Debug, Clone, Default)]
99pub struct TokenUsage {
100 pub total_input_tokens: i64,
102 pub total_output_tokens: i64,
104 pub request_count: i64,
106 pub last_input_tokens: i64,
108 pub last_output_tokens: i64,
110}
111
112#[derive(Debug, Clone)]
114pub struct SessionStatus {
115 pub session_id: i64,
117 pub model: String,
119 pub created_at: Instant,
121 pub conversation_len: usize,
123 pub context_used: i64,
125 pub context_limit: i32,
127 pub utilization: f64,
129 pub total_input: i64,
131 pub total_output: i64,
133 pub request_count: i64,
135}
136
137#[derive(Debug, Clone, Default)]
140pub struct CompactResult {
141 pub compacted: bool,
144 pub messages_before: usize,
146 pub messages_after: usize,
148 pub turns_compacted: usize,
150 pub turns_kept: usize,
152 pub summary_length: usize,
154 pub error: Option<String>,
156}
157
158static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
160
161pub struct LLMSession {
163 id: AtomicI64,
165
166 client: LLMClient,
168
169 to_llm_tx: mpsc::Sender<ToLLMPayload>,
171 to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
172 from_llm: mpsc::Sender<FromLLMPayload>,
173
174 config: LLMSessionConfig,
176
177 system_prompt: RwLock<Option<String>>,
179 max_tokens: AtomicI64,
180 created_at: Instant,
181
182 conversation: RwLock<Arc<Vec<Message>>>,
184
185 shutdown: AtomicBool,
187 cancel_token: CancellationToken,
188
189 current_cancel: Mutex<Option<CancellationToken>>,
191
192 current_turn_id: RwLock<Option<TurnId>>,
194
195 current_input_tokens: AtomicI64,
197 current_output_tokens: AtomicI64,
198
199 request_count: AtomicI64,
201
202 tool_definitions: RwLock<Vec<LLMTool>>,
204
205 compactor: Option<Box<dyn Compactor>>,
207 llm_compactor: Option<LLMCompactor>,
208 context_limit: AtomicI32,
209 compact_summaries: RwLock<HashMap<String, String>>,
210}
211
212impl LLMSession {
213 pub fn new(
224 config: LLMSessionConfig,
225 from_llm: mpsc::Sender<FromLLMPayload>,
226 cancel_token: CancellationToken,
227 channel_size: usize,
228 ) -> Result<Self, LlmError> {
229 let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
230 let (to_llm_tx, to_llm_rx) = mpsc::channel(channel_size);
231 let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
232 let system_prompt = config.system_prompt.clone();
233
234 let client = create_llm_client(&config)?;
236
237 let mut compactor: Option<Box<dyn Compactor>> = None;
239 let mut llm_compactor: Option<LLMCompactor> = None;
240
241 if let Some(ref compactor_type) = config.compaction {
242 match compactor_type {
243 CompactorType::Threshold(c) => {
244 match ThresholdCompactor::new(
245 c.threshold,
246 c.keep_recent_turns,
247 c.tool_compaction,
248 ) {
249 Ok(tc) => {
250 tracing::info!(
251 threshold = c.threshold,
252 keep_recent_turns = c.keep_recent_turns,
253 tool_compaction = %c.tool_compaction,
254 "Threshold compaction enabled for session"
255 );
256 compactor = Some(Box::new(tc) as Box<dyn Compactor>);
257 }
258 Err(e) => {
259 tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
260 }
261 }
262 }
263 CompactorType::LLM(c) => {
264 let llm_client = create_llm_client(&config)?;
266
267 match LLMCompactor::new(llm_client, c.clone()) {
268 Ok(lc) => {
269 tracing::info!(
270 threshold = c.threshold,
271 keep_recent_turns = c.keep_recent_turns,
272 "LLM compaction enabled for session"
273 );
274 llm_compactor = Some(lc);
275 }
276 Err(e) => {
277 tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
278 }
279 }
280 }
281 }
282 }
283
284 let context_limit = config.context_limit;
285
286 Ok(Self {
287 id: AtomicI64::new(session_id),
288 client,
289 to_llm_tx,
290 to_llm_rx: Mutex::new(to_llm_rx),
291 from_llm,
292 config,
293 system_prompt: RwLock::new(system_prompt),
294 max_tokens: AtomicI64::new(max_tokens),
295 created_at: Instant::now(),
296 conversation: RwLock::new(Arc::new(Vec::new())),
297 shutdown: AtomicBool::new(false),
298 cancel_token,
299 current_cancel: Mutex::new(None),
300 current_turn_id: RwLock::new(None),
301 current_input_tokens: AtomicI64::new(0),
302 current_output_tokens: AtomicI64::new(0),
303 request_count: AtomicI64::new(0),
304 tool_definitions: RwLock::new(Vec::new()),
305 compactor,
306 llm_compactor,
307 context_limit: AtomicI32::new(context_limit),
308 compact_summaries: RwLock::new(HashMap::new()),
309 })
310 }
311
312 pub fn id(&self) -> i64 {
314 self.id.load(Ordering::SeqCst)
315 }
316
317 pub fn created_at(&self) -> Instant {
319 self.created_at
320 }
321
322 pub fn model(&self) -> &str {
324 &self.config.model
325 }
326
327 pub fn set_max_tokens(&self, max_tokens: i64) {
331 self.max_tokens.store(max_tokens, Ordering::SeqCst);
332 }
333
334 pub fn max_tokens(&self) -> i64 {
336 self.max_tokens.load(Ordering::SeqCst)
337 }
338
339 pub fn context_limit(&self) -> i32 {
341 self.context_limit.load(Ordering::SeqCst)
342 }
343
344 pub async fn set_system_prompt(&self, prompt: String) {
348 let mut guard = self.system_prompt.write().await;
349 *guard = Some(prompt);
350 }
351
352 pub async fn clear_system_prompt(&self) {
354 let mut guard = self.system_prompt.write().await;
355 *guard = None;
356 }
357
358 pub async fn system_prompt(&self) -> Option<String> {
360 self.system_prompt.read().await.clone()
361 }
362
363 pub async fn set_tools(&self, tools: Vec<LLMTool>) {
368 let mut guard = self.tool_definitions.write().await;
369 *guard = tools;
370 }
371
372 pub async fn clear_tools(&self) {
374 let mut guard = self.tool_definitions.write().await;
375 guard.clear();
376 }
377
378 pub async fn tools(&self) -> Vec<LLMTool> {
380 self.tool_definitions.read().await.clone()
381 }
382
383 async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
388 if summaries.is_empty() {
389 tracing::warn!(
390 session_id = self.id(),
391 "No compact summaries provided with tool results"
392 );
393 return;
394 }
395 let mut guard = self.compact_summaries.write().await;
396 for (tool_use_id, summary) in summaries {
397 tracing::info!(
398 session_id = self.id(),
399 tool_use_id = %tool_use_id,
400 summary_len = summary.len(),
401 summary_preview = %summary.chars().take(50).collect::<String>(),
402 "Storing compact summary"
403 );
404 guard.insert(tool_use_id.clone(), summary.clone());
405 }
406 tracing::info!(
407 session_id = self.id(),
408 new_summaries = summaries.len(),
409 total_stored = guard.len(),
410 "Stored compact summaries for tool results"
411 );
412 }
413
414 async fn maybe_compact(&self) {
417 let context_used = self.current_input_tokens.load(Ordering::SeqCst);
418 let context_limit = self.context_limit.load(Ordering::SeqCst);
419 let conversation_len = self.conversation.read().await.len();
420 let summaries_count = self.compact_summaries.read().await.len();
421
422 let utilization = if context_limit > 0 {
423 context_used as f64 / context_limit as f64
424 } else {
425 0.0
426 };
427
428 tracing::debug!(
429 session_id = self.id(),
430 context_used,
431 context_limit,
432 utilization = format!("{:.2}%", utilization * 100.0),
433 conversation_len,
434 summaries_available = summaries_count,
435 "Checking if compaction needed"
436 );
437
438 if let Some(ref llm_compactor) = self.llm_compactor {
440 if !llm_compactor.should_compact(context_used, context_limit) {
441 tracing::debug!(session_id = self.id(), "LLM compaction not triggered");
442 return;
443 }
444
445 let summaries = self.compact_summaries.read().await.clone();
447 let conversation_arc = {
448 let guard = self.conversation.read().await;
449 Arc::clone(&*guard) };
451 let conversation =
452 Arc::try_unwrap(conversation_arc).unwrap_or_else(|arc| (*arc).clone());
453
454 tracing::info!(
455 session_id = self.id(),
456 conversation_len = conversation.len(),
457 summaries_count = summaries.len(),
458 "Starting LLM compaction"
459 );
460
461 match llm_compactor.compact_async(conversation, &summaries).await {
463 Ok((new_conversation, result)) => {
464 *self.conversation.write().await = Arc::new(new_conversation);
466
467 if result.turns_compacted > 0 {
468 tracing::info!(
469 session_id = self.id(),
470 turns_compacted = result.turns_compacted,
471 "LLM compaction completed"
472 );
473 }
474 }
475 Err(e) => {
476 tracing::error!(
477 session_id = self.id(),
478 error = %e,
479 "LLM compaction failed"
480 );
481 }
482 }
483 return;
484 }
485
486 let compactor = match &self.compactor {
488 Some(c) => c,
489 None => {
490 tracing::debug!(session_id = self.id(), "No compactor configured");
491 return;
492 }
493 };
494
495 if !compactor.should_compact(context_used, context_limit) {
496 tracing::debug!(session_id = self.id(), "Threshold compaction not triggered");
497 return;
498 }
499
500 let summaries = self.compact_summaries.read().await.clone();
502 let mut guard = self.conversation.write().await;
503
504 tracing::info!(
505 session_id = self.id(),
506 conversation_len = guard.len(),
507 summaries_count = summaries.len(),
508 "Starting threshold compaction"
509 );
510
511 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
512
513 tracing::info!(
514 session_id = self.id(),
515 tool_results_summarized = result.tool_results_summarized,
516 tool_results_redacted = result.tool_results_redacted,
517 turns_compacted = result.turns_compacted,
518 conversation_len_after = guard.len(),
519 "Threshold compaction completed"
520 );
521 }
522
523 pub async fn clear_conversation(&self) {
527 let mut guard = self.conversation.write().await;
528 Arc::make_mut(&mut *guard).clear();
529
530 let mut summaries = self.compact_summaries.write().await;
531 summaries.clear();
532
533 self.current_input_tokens.store(0, Ordering::SeqCst);
535 self.current_output_tokens.store(0, Ordering::SeqCst);
536
537 tracing::info!(session_id = self.id(), "Conversation cleared");
538 }
539
540 pub async fn force_compact(&self) -> CompactResult {
543 if let Some(ref llm_compactor) = self.llm_compactor {
545 let summaries = self.compact_summaries.read().await.clone();
546 let conversation_arc = {
547 let guard = self.conversation.read().await;
548 Arc::clone(&*guard) };
550 let conversation =
551 Arc::try_unwrap(conversation_arc).unwrap_or_else(|arc| (*arc).clone());
552 let messages_before = conversation.len();
553 let turns_before = self.count_unique_turns(&conversation);
554
555 match llm_compactor.compact_async(conversation, &summaries).await {
556 Ok((new_conversation, result)) => {
557 let messages_after = new_conversation.len();
558 let turns_after = self.count_unique_turns(&new_conversation);
559 let compacted = messages_after < messages_before;
560
561 let summary_length = if compacted && !new_conversation.is_empty() {
563 self.extract_summary_length(&new_conversation[0])
564 } else {
565 0
566 };
567
568 *self.conversation.write().await = Arc::new(new_conversation);
569
570 if result.turns_compacted > 0 {
571 tracing::info!(
572 session_id = self.id(),
573 turns_compacted = result.turns_compacted,
574 messages_before,
575 messages_after,
576 "Forced LLM compaction completed"
577 );
578 }
579
580 return CompactResult {
581 compacted,
582 messages_before,
583 messages_after,
584 turns_compacted: turns_before.saturating_sub(turns_after),
585 turns_kept: turns_after,
586 summary_length,
587 error: None,
588 };
589 }
590 Err(e) => {
591 tracing::error!(
592 session_id = self.id(),
593 error = %e,
594 "Forced LLM compaction failed"
595 );
596 return CompactResult {
597 compacted: false,
598 messages_before,
599 messages_after: messages_before,
600 turns_compacted: 0,
601 turns_kept: turns_before,
602 summary_length: 0,
603 error: Some(format!("Compaction failed: {}", e)),
604 };
605 }
606 }
607 }
608
609 if let Some(ref compactor) = self.compactor {
611 let summaries = self.compact_summaries.read().await.clone();
612 let mut guard = self.conversation.write().await;
613 let messages_before = guard.len();
614 let turns_before = self.count_unique_turns(&guard);
615
616 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
617
618 let messages_after = guard.len();
619 let turns_after = self.count_unique_turns(&guard);
620 let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
621
622 if result.total_compacted() > 0 {
623 tracing::info!(
624 session_id = self.id(),
625 tool_results_summarized = result.tool_results_summarized,
626 tool_results_redacted = result.tool_results_redacted,
627 turns_compacted = result.turns_compacted,
628 "Forced threshold compaction completed"
629 );
630 }
631
632 return CompactResult {
633 compacted,
634 messages_before,
635 messages_after,
636 turns_compacted: turns_before.saturating_sub(turns_after),
637 turns_kept: turns_after,
638 summary_length: 0,
639 error: None,
640 };
641 }
642
643 CompactResult {
645 compacted: false,
646 error: Some("No compactor configured".to_string()),
647 ..Default::default()
648 }
649 }
650
651 fn count_unique_turns(&self, conversation: &[Message]) -> usize {
653 use std::collections::HashSet;
654 let mut turn_ids = HashSet::new();
655 for msg in conversation {
656 turn_ids.insert(msg.turn_id().clone());
657 }
658 turn_ids.len()
659 }
660
661 fn extract_summary_length(&self, message: &Message) -> usize {
663 if let Message::User(user_msg) = message {
664 for block in &user_msg.content {
665 if let ContentBlock::Text(text_block) = block
666 && text_block
667 .text
668 .starts_with("[Previous conversation summary]")
669 {
670 return text_block.text.len();
671 }
672 }
673 }
674 0
675 }
676
677 pub async fn send(&self, msg: ToLLMPayload) -> bool {
680 if self.shutdown.load(Ordering::SeqCst) {
681 return false;
682 }
683 self.to_llm_tx.send(msg).await.is_ok()
684 }
685
686 pub async fn interrupt(&self) {
690 let guard = self.current_cancel.lock().await;
691 if let Some(token) = guard.as_ref() {
692 token.cancel();
693
694 let turn_id = self.current_turn_id.read().await.clone();
698 if let Some(turn_id) = turn_id {
699 let mut guard = self.conversation.write().await;
700 let original_len = guard.len();
701 Arc::make_mut(&mut *guard).retain(|msg| msg.turn_id() != &turn_id);
702 let removed = original_len - guard.len();
703 tracing::debug!(
704 session_id = self.id(),
705 turn_id = %turn_id,
706 messages_removed = removed,
707 conversation_length = guard.len(),
708 "Removed messages from cancelled turn"
709 );
710 }
711 }
712 }
713
714 pub fn shutdown(&self) {
717 self.shutdown.store(true, Ordering::SeqCst);
719 self.cancel_token.cancel();
721 }
722
723 pub fn is_shutdown(&self) -> bool {
725 self.shutdown.load(Ordering::SeqCst)
726 }
727
728 pub async fn start(&self) {
734 tracing::info!(session_id = self.id(), "Session starting");
735
736 loop {
737 let mut rx_guard = self.to_llm_rx.lock().await;
738
739 tokio::select! {
740 _ = self.cancel_token.cancelled() => {
741 tracing::info!(session_id = self.id(), "Session cancelled");
742 break;
743 }
744 msg = rx_guard.recv() => {
745 match msg {
746 Some(request) => {
747 drop(rx_guard);
749 self.handle_request(request).await;
750 }
751 None => {
752 tracing::info!(session_id = self.id(), "Session channel closed");
754 break;
755 }
756 }
757 }
758 }
759 }
760
761 tracing::info!(session_id = self.id(), "Session stopped");
762 }
763
764 fn current_timestamp_millis() -> i64 {
768 std::time::SystemTime::now()
769 .duration_since(std::time::UNIX_EPOCH)
770 .map(|d| d.as_millis() as i64)
771 .unwrap_or(0)
772 }
773
774 async fn prepare_request(&self, request: &ToLLMPayload) -> (CancellationToken, TurnId) {
777 let request_token = CancellationToken::new();
778 {
779 let mut guard = self.current_cancel.lock().await;
780 *guard = Some(request_token.clone());
781 }
782
783 let effective_turn_id = request
784 .turn_id
785 .clone()
786 .unwrap_or_else(|| TurnId::new_user_turn(0));
787 {
788 let mut guard = self.current_turn_id.write().await;
789 *guard = Some(effective_turn_id.clone());
790 }
791
792 (request_token, effective_turn_id)
793 }
794
795 async fn build_message_options(&self) -> crate::client::models::MessageOptions {
797 use crate::client::models::MessageOptions;
798
799 let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
800 let tools = self.tool_definitions.read().await.clone();
801 let tools_option = if tools.is_empty() { None } else { Some(tools) };
802
803 MessageOptions {
804 max_tokens: Some(max_tokens),
805 temperature: self.config.temperature,
806 tools: tools_option,
807 ..Default::default()
808 }
809 }
810
811 async fn cleanup_request(&self) {
813 {
814 let mut guard = self.current_cancel.lock().await;
815 *guard = None;
816 }
817 {
818 let mut guard = self.current_turn_id.write().await;
819 *guard = None;
820 }
821 }
822
823 async fn handle_request(&self, request: ToLLMPayload) {
825 if self.config.streaming {
826 self.handle_streaming_request(request).await;
827 } else {
828 self.handle_non_streaming_request(request).await;
829 }
830 }
831
832 async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
834 use super::convert::{from_llm_message, to_llm_messages};
835 use crate::client::models::Message as LLMMessage;
836 use crate::controller::types::{LLMRequestType, LLMResponseType};
837
838 let (_request_token, effective_turn_id) = self.prepare_request(&request).await;
840
841 let session_id = self.id();
842 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
843
844 let mut llm_messages: Vec<LLMMessage> = Vec::new();
846
847 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
849 llm_messages.push(LLMMessage::system(prompt.clone()));
850 }
851
852 let conversation = self.conversation.read().await;
854 llm_messages.extend(to_llm_messages(&conversation));
855 drop(conversation);
856
857 match request.request_type {
859 LLMRequestType::UserMessage => {
860 if !request.content.is_empty() {
861 llm_messages.push(LLMMessage::user(&request.content));
862
863 let user_msg = Message::User(UserMessage {
865 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
866 session_id: session_id.to_string(),
867 turn_id: effective_turn_id.clone(),
868 created_at: Self::current_timestamp_millis(),
869 content: vec![ContentBlock::text(&request.content)],
870 });
871 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
872 }
873 }
874 LLMRequestType::ToolResult => {
875 self.store_compact_summaries(&request.compact_summaries)
877 .await;
878
879 for tool_result in &request.tool_results {
881 llm_messages.push(LLMMessage::tool_result(
882 &tool_result.tool_use_id,
883 &tool_result.content,
884 tool_result.is_error,
885 ));
886
887 let compact_summary = request
889 .compact_summaries
890 .get(&tool_result.tool_use_id)
891 .cloned();
892
893 let user_msg = Message::User(UserMessage {
895 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
896 session_id: session_id.to_string(),
897 turn_id: effective_turn_id.clone(),
898 created_at: Self::current_timestamp_millis(),
899 content: vec![ContentBlock::ToolResult(
900 crate::controller::types::ToolResultBlock {
901 tool_use_id: tool_result.tool_use_id.clone(),
902 content: tool_result.content.clone(),
903 is_error: tool_result.is_error,
904 compact_summary,
905 },
906 )],
907 });
908 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
909 }
910 }
911 }
912
913 self.maybe_compact().await;
915
916 let options = self.build_message_options().await;
918
919 let result = self.client.send_message(&llm_messages, &options).await;
921
922 match result {
923 Ok(response) => {
924 let content_blocks = from_llm_message(&response);
926
927 let text: String = content_blocks
929 .iter()
930 .filter_map(|block| {
931 if let ContentBlock::Text(t) = block {
932 Some(t.text.clone())
933 } else {
934 None
935 }
936 })
937 .collect::<Vec<_>>()
938 .join("");
939
940 if !text.is_empty() {
942 let payload = FromLLMPayload {
943 session_id,
944 response_type: LLMResponseType::TextChunk,
945 text: text.clone(),
946 turn_id: request.turn_id.clone(),
947 ..Default::default()
948 };
949 let _ = self.from_llm.send(payload).await;
950 }
951
952 for block in &content_blocks {
954 if let ContentBlock::ToolUse(tool_use) = block {
955 let payload = FromLLMPayload {
956 session_id,
957 response_type: LLMResponseType::ToolUse,
958 tool_use: Some(crate::controller::types::ToolUseInfo {
959 id: tool_use.id.clone(),
960 name: tool_use.name.clone(),
961 input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
962 }),
963 turn_id: request.turn_id.clone(),
964 ..Default::default()
965 };
966 let _ = self.from_llm.send(payload).await;
967 }
968 }
969
970 let now = Self::current_timestamp_millis();
972 let asst_msg = Message::Assistant(AssistantMessage {
973 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
974 session_id: session_id.to_string(),
975 turn_id: effective_turn_id.clone(),
976 parent_id: String::new(),
977 created_at: now,
978 completed_at: Some(now),
979 model_id: self.config.model.clone(),
980 provider_id: String::new(),
981 input_tokens: 0,
982 output_tokens: 0,
983 cache_read_tokens: 0,
984 cache_write_tokens: 0,
985 finish_reason: None,
986 error: None,
987 content: content_blocks,
988 });
989 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
990
991 let payload = FromLLMPayload {
993 session_id,
994 response_type: LLMResponseType::Complete,
995 is_complete: true,
996 turn_id: request.turn_id.clone(),
997 ..Default::default()
998 };
999 let _ = self.from_llm.send(payload).await;
1000
1001 self.request_count.fetch_add(1, Ordering::SeqCst);
1003
1004 tracing::debug!(session_id, "Request completed successfully");
1005 }
1006 Err(err) => {
1007 tracing::error!(session_id, error = %err, "LLM request failed");
1008
1009 let payload = FromLLMPayload {
1010 session_id,
1011 response_type: LLMResponseType::Error,
1012 error: Some(err.to_string()),
1013 turn_id: request.turn_id,
1014 ..Default::default()
1015 };
1016 let _ = self.from_llm.send(payload).await;
1017 }
1018 }
1019
1020 self.cleanup_request().await;
1022 }
1023
1024 async fn handle_streaming_request(&self, request: ToLLMPayload) {
1026 use super::convert::to_llm_messages;
1027 use crate::client::models::{ContentBlockType, Message as LLMMessage, StreamEvent};
1028 use crate::controller::types::{LLMRequestType, LLMResponseType};
1029 use futures::StreamExt;
1030
1031 let (request_token, effective_turn_id) = self.prepare_request(&request).await;
1033
1034 let session_id = self.id();
1035 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
1036
1037 let mut llm_messages: Vec<LLMMessage> = Vec::new();
1039
1040 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
1042 llm_messages.push(LLMMessage::system(prompt.clone()));
1043 }
1044
1045 let conversation = self.conversation.read().await;
1047 llm_messages.extend(to_llm_messages(&conversation));
1048 drop(conversation);
1049
1050 match request.request_type {
1052 LLMRequestType::UserMessage => {
1053 if !request.content.is_empty() {
1054 llm_messages.push(LLMMessage::user(&request.content));
1055
1056 let user_msg = Message::User(UserMessage {
1058 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
1059 session_id: session_id.to_string(),
1060 turn_id: effective_turn_id.clone(),
1061 created_at: Self::current_timestamp_millis(),
1062 content: vec![ContentBlock::text(&request.content)],
1063 });
1064 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1065 }
1066 }
1067 LLMRequestType::ToolResult => {
1068 self.store_compact_summaries(&request.compact_summaries)
1070 .await;
1071
1072 {
1074 let conv = self.conversation.read().await;
1075 tracing::debug!(
1076 session_id,
1077 conversation_len = conv.len(),
1078 tool_result_count = request.tool_results.len(),
1079 "STREAMING ToolResult: conversation state before adding results"
1080 );
1081 }
1082 for tool_result in &request.tool_results {
1084 llm_messages.push(LLMMessage::tool_result(
1085 &tool_result.tool_use_id,
1086 &tool_result.content,
1087 tool_result.is_error,
1088 ));
1089
1090 let compact_summary = request
1092 .compact_summaries
1093 .get(&tool_result.tool_use_id)
1094 .cloned();
1095
1096 let user_msg = Message::User(UserMessage {
1098 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1099 session_id: session_id.to_string(),
1100 turn_id: effective_turn_id.clone(),
1101 created_at: Self::current_timestamp_millis(),
1102 content: vec![ContentBlock::ToolResult(
1103 crate::controller::types::ToolResultBlock {
1104 tool_use_id: tool_result.tool_use_id.clone(),
1105 content: tool_result.content.clone(),
1106 is_error: tool_result.is_error,
1107 compact_summary,
1108 },
1109 )],
1110 });
1111 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1112 }
1113 }
1114 }
1115
1116 self.maybe_compact().await;
1118
1119 let options = self.build_message_options().await;
1121
1122 let stream_result = self
1124 .client
1125 .send_message_stream(&llm_messages, &options)
1126 .await;
1127
1128 match stream_result {
1129 Ok(mut stream) => {
1130 let mut current_tool_id: Option<String> = None;
1132 let mut current_tool_name: Option<String> = None;
1133 let mut tool_input_json = String::new();
1134 let mut response_text = String::new();
1136 let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> =
1138 Vec::new();
1139
1140 loop {
1142 tokio::select! {
1143 _ = request_token.cancelled() => {
1144 tracing::info!(session_id, "Streaming request cancelled");
1145 break;
1146 }
1147 event = stream.next() => {
1148 match event {
1149 Some(Ok(stream_event)) => {
1150 match stream_event {
1151 StreamEvent::MessageStart { message_id, model } => {
1152 let payload = FromLLMPayload {
1153 session_id,
1154 response_type: LLMResponseType::StreamStart,
1155 message_id,
1156 model,
1157 turn_id: request.turn_id.clone(),
1158 ..Default::default()
1159 };
1160 let _ = self.from_llm.send(payload).await;
1161 }
1162 StreamEvent::ContentBlockStart { index: _, block_type } => {
1163 match block_type {
1164 ContentBlockType::Text => {
1165 }
1167 ContentBlockType::ToolUse { id, name } => {
1168 current_tool_id = Some(id);
1171 current_tool_name = Some(name);
1172 tool_input_json.clear();
1173 }
1174 }
1175 }
1176 StreamEvent::TextDelta { index, text } => {
1177 response_text.push_str(&text);
1179
1180 let payload = FromLLMPayload {
1181 session_id,
1182 response_type: LLMResponseType::TextChunk,
1183 text,
1184 content_index: index,
1185 turn_id: request.turn_id.clone(),
1186 ..Default::default()
1187 };
1188 let _ = self.from_llm.send(payload).await;
1189 }
1190 StreamEvent::InputJsonDelta { index, json } => {
1191 tool_input_json.push_str(&json);
1193
1194 let payload = FromLLMPayload {
1195 session_id,
1196 response_type: LLMResponseType::ToolInputDelta,
1197 text: json,
1198 content_index: index,
1199 turn_id: request.turn_id.clone(),
1200 ..Default::default()
1201 };
1202 let _ = self.from_llm.send(payload).await;
1203 }
1204 StreamEvent::ContentBlockStop { index: _ } => {
1205 if let (Some(id), Some(name)) =
1209 (current_tool_id.take(), current_tool_name.take())
1210 {
1211 let input: serde_json::Value =
1212 serde_json::from_str(&tool_input_json)
1213 .unwrap_or(serde_json::Value::Object(
1214 serde_json::Map::new(),
1215 ));
1216
1217 tracing::debug!(
1219 session_id,
1220 tool_id = %id,
1221 tool_name = %name,
1222 "Saving tool use to completed_tool_uses"
1223 );
1224 completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1225 id: id.clone(),
1226 name: name.clone(),
1227 input: input
1228 .as_object()
1229 .map(|obj| {
1230 obj.iter()
1231 .map(|(k, v)| (k.clone(), v.clone()))
1232 .collect()
1233 })
1234 .unwrap_or_default(),
1235 });
1236
1237 tool_input_json.clear();
1238 }
1239 }
1240 StreamEvent::MessageDelta { stop_reason, usage } => {
1241 if let Some(usage) = usage {
1242 tracing::info!(
1243 session_id,
1244 input_tokens = usage.input_tokens,
1245 output_tokens = usage.output_tokens,
1246 "API token usage for this turn"
1247 );
1248 self.current_input_tokens
1249 .store(usage.input_tokens as i64, Ordering::SeqCst);
1250 self.current_output_tokens
1251 .store(usage.output_tokens as i64, Ordering::SeqCst);
1252
1253 let payload = FromLLMPayload {
1254 session_id,
1255 response_type: LLMResponseType::TokenUpdate,
1256 input_tokens: usage.input_tokens as i64,
1257 output_tokens: usage.output_tokens as i64,
1258 turn_id: request.turn_id.clone(),
1259 ..Default::default()
1260 };
1261 let _ = self.from_llm.send(payload).await;
1262 }
1263
1264 if stop_reason.is_some() {
1265 let payload = FromLLMPayload {
1266 session_id,
1267 response_type: LLMResponseType::Complete,
1268 is_complete: true,
1269 stop_reason,
1270 turn_id: request.turn_id.clone(),
1271 ..Default::default()
1272 };
1273 let _ = self.from_llm.send(payload).await;
1274 }
1275 }
1276 StreamEvent::MessageStop => {
1277 tracing::debug!(
1280 session_id,
1281 text_len = response_text.len(),
1282 tool_use_count = completed_tool_uses.len(),
1283 "MessageStop: saving assistant message to history"
1284 );
1285 if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1286 let now = Self::current_timestamp_millis();
1287
1288 let mut content_blocks = Vec::new();
1290 if !response_text.is_empty() {
1291 content_blocks.push(ContentBlock::text(&response_text));
1292 }
1293 for tool_use in &completed_tool_uses {
1294 content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1295 }
1296
1297 let content_block_count = content_blocks.len();
1298 let asst_msg = Message::Assistant(AssistantMessage {
1299 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1300 session_id: session_id.to_string(),
1301 turn_id: effective_turn_id.clone(),
1302 parent_id: String::new(),
1303 created_at: now,
1304 completed_at: Some(now),
1305 model_id: self.config.model.clone(),
1306 provider_id: String::new(),
1307 input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1308 output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1309 cache_read_tokens: 0,
1310 cache_write_tokens: 0,
1311 finish_reason: None,
1312 error: None,
1313 content: content_blocks,
1314 });
1315 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
1316 tracing::debug!(
1317 session_id,
1318 content_block_count,
1319 "MessageStop: saved assistant message with content blocks"
1320 );
1321 }
1322
1323 if !completed_tool_uses.is_empty() {
1326 let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1327 .iter()
1328 .map(|tu| crate::controller::types::ToolUseInfo {
1329 id: tu.id.clone(),
1330 name: tu.name.clone(),
1331 input: serde_json::Value::Object(
1332 tu.input.iter()
1333 .map(|(k, v)| (k.clone(), v.clone()))
1334 .collect()
1335 ),
1336 })
1337 .collect();
1338
1339 tracing::debug!(
1340 session_id,
1341 tool_count = tool_uses.len(),
1342 "MessageStop: emitting ToolBatch for execution"
1343 );
1344
1345 let payload = FromLLMPayload {
1346 session_id,
1347 response_type: LLMResponseType::ToolBatch,
1348 tool_uses,
1349 turn_id: request.turn_id.clone(),
1350 ..Default::default()
1351 };
1352 let _ = self.from_llm.send(payload).await;
1353 }
1354
1355 self.request_count.fetch_add(1, Ordering::SeqCst);
1357 tracing::debug!(session_id, "Streaming request completed");
1358 break;
1359 }
1360 StreamEvent::Ping => {
1361 }
1363 }
1364 }
1365 Some(Err(err)) => {
1366 tracing::error!(session_id, error = %err, "Stream error");
1367 let payload = FromLLMPayload {
1368 session_id,
1369 response_type: LLMResponseType::Error,
1370 error: Some(err.to_string()),
1371 turn_id: request.turn_id.clone(),
1372 ..Default::default()
1373 };
1374 let _ = self.from_llm.send(payload).await;
1375 break;
1376 }
1377 None => {
1378 break;
1380 }
1381 }
1382 }
1383 }
1384 }
1385 }
1386 Err(err) => {
1387 tracing::error!(session_id, error = %err, "Failed to start streaming");
1388 let payload = FromLLMPayload {
1389 session_id,
1390 response_type: LLMResponseType::Error,
1391 error: Some(err.to_string()),
1392 turn_id: request.turn_id,
1393 ..Default::default()
1394 };
1395 let _ = self.from_llm.send(payload).await;
1396 }
1397 }
1398
1399 self.cleanup_request().await;
1401 }
1402}