1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::sync::{mpsc, Mutex, RwLock};
9use tokio_util::sync::CancellationToken;
10
11use crate::client::error::LlmError;
12use crate::client::models::Tool as LLMTool;
13use crate::client::providers::anthropic::AnthropicProvider;
14use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
15use crate::client::providers::cohere::CohereProvider;
16use crate::client::providers::gemini::GeminiProvider;
17use crate::client::providers::openai::OpenAIProvider;
18use crate::client::LLMClient;
19
20use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
21use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
22
23fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
25 match config.provider {
26 LLMProvider::Anthropic => {
27 let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
28 LLMClient::new(Box::new(provider))
29 }
30 LLMProvider::OpenAI => {
31 let provider = if let (Some(resource), Some(deployment)) =
33 (&config.azure_resource, &config.azure_deployment)
34 {
35 let api_version = config
36 .azure_api_version
37 .clone()
38 .unwrap_or_else(|| "2024-10-21".to_string());
39 OpenAIProvider::azure(
40 config.api_key.clone(),
41 resource.clone(),
42 deployment.clone(),
43 api_version,
44 )
45 } else if let Some(base_url) = &config.base_url {
46 OpenAIProvider::with_base_url(
47 config.api_key.clone(),
48 config.model.clone(),
49 base_url.clone(),
50 )
51 } else {
52 OpenAIProvider::new(config.api_key.clone(), config.model.clone())
53 };
54 LLMClient::new(Box::new(provider))
55 }
56 LLMProvider::Google => {
57 let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
58 LLMClient::new(Box::new(provider))
59 }
60 LLMProvider::Cohere => {
61 let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
62 LLMClient::new(Box::new(provider))
63 }
64 LLMProvider::Bedrock => {
65 let region = config.bedrock_region.clone().ok_or_else(|| {
67 LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_region")
68 })?;
69 let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
70 LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_access_key_id")
71 })?;
72 let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
73 LlmError::new("MISSING_CONFIG", "Bedrock requires bedrock_secret_access_key")
74 })?;
75
76 let credentials = match &config.bedrock_session_token {
77 Some(token) => {
78 BedrockCredentials::with_session_token(access_key_id, secret_access_key, token.clone())
79 }
80 None => BedrockCredentials::new(access_key_id, secret_access_key),
81 };
82
83 let provider = BedrockProvider::new(credentials, region, config.model.clone());
84 LLMClient::new(Box::new(provider))
85 }
86 }
87}
88use crate::controller::types::{
89 AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
90};
91
92#[derive(Debug, Clone, Default)]
94pub struct TokenUsage {
95 pub total_input_tokens: i64,
97 pub total_output_tokens: i64,
99 pub request_count: i64,
101 pub last_input_tokens: i64,
103 pub last_output_tokens: i64,
105}
106
107#[derive(Debug, Clone)]
109pub struct SessionStatus {
110 pub session_id: i64,
112 pub model: String,
114 pub created_at: Instant,
116 pub conversation_len: usize,
118 pub context_used: i64,
120 pub context_limit: i32,
122 pub utilization: f64,
124 pub total_input: i64,
126 pub total_output: i64,
128 pub request_count: i64,
130}
131
132#[derive(Debug, Clone, Default)]
135pub struct CompactResult {
136 pub compacted: bool,
139 pub messages_before: usize,
141 pub messages_after: usize,
143 pub turns_compacted: usize,
145 pub turns_kept: usize,
147 pub summary_length: usize,
149 pub error: Option<String>,
151}
152
153static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
155
156pub struct LLMSession {
158 id: AtomicI64,
160
161 client: LLMClient,
163
164 to_llm_tx: mpsc::Sender<ToLLMPayload>,
166 to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
167 from_llm: mpsc::Sender<FromLLMPayload>,
168
169 config: LLMSessionConfig,
171
172 system_prompt: RwLock<Option<String>>,
174 max_tokens: AtomicI64,
175 created_at: Instant,
176
177 conversation: RwLock<Arc<Vec<Message>>>,
179
180 shutdown: AtomicBool,
182 cancel_token: CancellationToken,
183
184 current_cancel: Mutex<Option<CancellationToken>>,
186
187 current_turn_id: RwLock<Option<TurnId>>,
189
190 current_input_tokens: AtomicI64,
192 current_output_tokens: AtomicI64,
193
194 request_count: AtomicI64,
196
197 tool_definitions: RwLock<Vec<LLMTool>>,
199
200 compactor: Option<Box<dyn Compactor>>,
202 llm_compactor: Option<LLMCompactor>,
203 context_limit: AtomicI32,
204 compact_summaries: RwLock<HashMap<String, String>>,
205}
206
207impl LLMSession {
208 pub fn new(
219 config: LLMSessionConfig,
220 from_llm: mpsc::Sender<FromLLMPayload>,
221 cancel_token: CancellationToken,
222 channel_size: usize,
223 ) -> Result<Self, LlmError> {
224 let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
225 let (to_llm_tx, to_llm_rx) = mpsc::channel(channel_size);
226 let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
227 let system_prompt = config.system_prompt.clone();
228
229 let client = create_llm_client(&config)?;
231
232 let mut compactor: Option<Box<dyn Compactor>> = None;
234 let mut llm_compactor: Option<LLMCompactor> = None;
235
236 if let Some(ref compactor_type) = config.compaction {
237 match compactor_type {
238 CompactorType::Threshold(c) => {
239 match ThresholdCompactor::new(c.threshold, c.keep_recent_turns, c.tool_compaction) {
240 Ok(tc) => {
241 tracing::info!(
242 threshold = c.threshold,
243 keep_recent_turns = c.keep_recent_turns,
244 tool_compaction = %c.tool_compaction,
245 "Threshold compaction enabled for session"
246 );
247 compactor = Some(Box::new(tc) as Box<dyn Compactor>);
248 }
249 Err(e) => {
250 tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
251 }
252 }
253 }
254 CompactorType::LLM(c) => {
255 let llm_client = create_llm_client(&config)?;
257
258 match LLMCompactor::new(llm_client, c.clone()) {
259 Ok(lc) => {
260 tracing::info!(
261 threshold = c.threshold,
262 keep_recent_turns = c.keep_recent_turns,
263 "LLM compaction enabled for session"
264 );
265 llm_compactor = Some(lc);
266 }
267 Err(e) => {
268 tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
269 }
270 }
271 }
272 }
273 }
274
275 let context_limit = config.context_limit;
276
277 Ok(Self {
278 id: AtomicI64::new(session_id),
279 client,
280 to_llm_tx,
281 to_llm_rx: Mutex::new(to_llm_rx),
282 from_llm,
283 config,
284 system_prompt: RwLock::new(system_prompt),
285 max_tokens: AtomicI64::new(max_tokens),
286 created_at: Instant::now(),
287 conversation: RwLock::new(Arc::new(Vec::new())),
288 shutdown: AtomicBool::new(false),
289 cancel_token,
290 current_cancel: Mutex::new(None),
291 current_turn_id: RwLock::new(None),
292 current_input_tokens: AtomicI64::new(0),
293 current_output_tokens: AtomicI64::new(0),
294 request_count: AtomicI64::new(0),
295 tool_definitions: RwLock::new(Vec::new()),
296 compactor,
297 llm_compactor,
298 context_limit: AtomicI32::new(context_limit),
299 compact_summaries: RwLock::new(HashMap::new()),
300 })
301 }
302
303 pub fn id(&self) -> i64 {
305 self.id.load(Ordering::SeqCst)
306 }
307
308 pub fn created_at(&self) -> Instant {
310 self.created_at
311 }
312
313 pub fn model(&self) -> &str {
315 &self.config.model
316 }
317
318 pub fn set_max_tokens(&self, max_tokens: i64) {
322 self.max_tokens.store(max_tokens, Ordering::SeqCst);
323 }
324
325 pub fn max_tokens(&self) -> i64 {
327 self.max_tokens.load(Ordering::SeqCst)
328 }
329
330 pub fn context_limit(&self) -> i32 {
332 self.context_limit.load(Ordering::SeqCst)
333 }
334
335 pub async fn set_system_prompt(&self, prompt: String) {
339 let mut guard = self.system_prompt.write().await;
340 *guard = Some(prompt);
341 }
342
343 pub async fn clear_system_prompt(&self) {
345 let mut guard = self.system_prompt.write().await;
346 *guard = None;
347 }
348
349 pub async fn system_prompt(&self) -> Option<String> {
351 self.system_prompt.read().await.clone()
352 }
353
354 pub async fn set_tools(&self, tools: Vec<LLMTool>) {
359 let mut guard = self.tool_definitions.write().await;
360 *guard = tools;
361 }
362
363 pub async fn clear_tools(&self) {
365 let mut guard = self.tool_definitions.write().await;
366 guard.clear();
367 }
368
369 pub async fn tools(&self) -> Vec<LLMTool> {
371 self.tool_definitions.read().await.clone()
372 }
373
374 async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
379 if summaries.is_empty() {
380 tracing::warn!(
381 session_id = self.id(),
382 "No compact summaries provided with tool results"
383 );
384 return;
385 }
386 let mut guard = self.compact_summaries.write().await;
387 for (tool_use_id, summary) in summaries {
388 tracing::info!(
389 session_id = self.id(),
390 tool_use_id = %tool_use_id,
391 summary_len = summary.len(),
392 summary_preview = %summary.chars().take(50).collect::<String>(),
393 "Storing compact summary"
394 );
395 guard.insert(tool_use_id.clone(), summary.clone());
396 }
397 tracing::info!(
398 session_id = self.id(),
399 new_summaries = summaries.len(),
400 total_stored = guard.len(),
401 "Stored compact summaries for tool results"
402 );
403 }
404
405 async fn maybe_compact(&self) {
408 let context_used = self.current_input_tokens.load(Ordering::SeqCst);
409 let context_limit = self.context_limit.load(Ordering::SeqCst);
410 let conversation_len = self.conversation.read().await.len();
411 let summaries_count = self.compact_summaries.read().await.len();
412
413 let utilization = if context_limit > 0 {
414 context_used as f64 / context_limit as f64
415 } else {
416 0.0
417 };
418
419 tracing::debug!(
420 session_id = self.id(),
421 context_used,
422 context_limit,
423 utilization = format!("{:.2}%", utilization * 100.0),
424 conversation_len,
425 summaries_available = summaries_count,
426 "Checking if compaction needed"
427 );
428
429 if let Some(ref llm_compactor) = self.llm_compactor {
431 if !llm_compactor.should_compact(context_used, context_limit) {
432 tracing::debug!(
433 session_id = self.id(),
434 "LLM compaction not triggered"
435 );
436 return;
437 }
438
439 let summaries = self.compact_summaries.read().await.clone();
441 let conversation_arc = {
442 let guard = self.conversation.read().await;
443 Arc::clone(&*guard) };
445 let conversation = Arc::try_unwrap(conversation_arc)
446 .unwrap_or_else(|arc| (*arc).clone());
447
448 tracing::info!(
449 session_id = self.id(),
450 conversation_len = conversation.len(),
451 summaries_count = summaries.len(),
452 "Starting LLM compaction"
453 );
454
455 match llm_compactor.compact_async(conversation, &summaries).await {
457 Ok((new_conversation, result)) => {
458 *self.conversation.write().await = Arc::new(new_conversation);
460
461 if result.turns_compacted > 0 {
462 tracing::info!(
463 session_id = self.id(),
464 turns_compacted = result.turns_compacted,
465 "LLM compaction completed"
466 );
467 }
468 }
469 Err(e) => {
470 tracing::error!(
471 session_id = self.id(),
472 error = %e,
473 "LLM compaction failed"
474 );
475 }
476 }
477 return;
478 }
479
480 let compactor = match &self.compactor {
482 Some(c) => c,
483 None => {
484 tracing::debug!(
485 session_id = self.id(),
486 "No compactor configured"
487 );
488 return;
489 }
490 };
491
492 if !compactor.should_compact(context_used, context_limit) {
493 tracing::debug!(
494 session_id = self.id(),
495 "Threshold compaction not triggered"
496 );
497 return;
498 }
499
500 let summaries = self.compact_summaries.read().await.clone();
502 let mut guard = self.conversation.write().await;
503
504 tracing::info!(
505 session_id = self.id(),
506 conversation_len = guard.len(),
507 summaries_count = summaries.len(),
508 "Starting threshold compaction"
509 );
510
511 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
512
513 tracing::info!(
514 session_id = self.id(),
515 tool_results_summarized = result.tool_results_summarized,
516 tool_results_redacted = result.tool_results_redacted,
517 turns_compacted = result.turns_compacted,
518 conversation_len_after = guard.len(),
519 "Threshold compaction completed"
520 );
521 }
522
523 pub async fn clear_conversation(&self) {
527 let mut guard = self.conversation.write().await;
528 Arc::make_mut(&mut *guard).clear();
529
530 let mut summaries = self.compact_summaries.write().await;
531 summaries.clear();
532
533 self.current_input_tokens.store(0, Ordering::SeqCst);
535 self.current_output_tokens.store(0, Ordering::SeqCst);
536
537 tracing::info!(session_id = self.id(), "Conversation cleared");
538 }
539
540 pub async fn force_compact(&self) -> CompactResult {
543 if let Some(ref llm_compactor) = self.llm_compactor {
545 let summaries = self.compact_summaries.read().await.clone();
546 let conversation_arc = {
547 let guard = self.conversation.read().await;
548 Arc::clone(&*guard) };
550 let conversation = Arc::try_unwrap(conversation_arc)
551 .unwrap_or_else(|arc| (*arc).clone());
552 let messages_before = conversation.len();
553 let turns_before = self.count_unique_turns(&conversation);
554
555 match llm_compactor.compact_async(conversation, &summaries).await {
556 Ok((new_conversation, result)) => {
557 let messages_after = new_conversation.len();
558 let turns_after = self.count_unique_turns(&new_conversation);
559 let compacted = messages_after < messages_before;
560
561 let summary_length = if compacted && !new_conversation.is_empty() {
563 self.extract_summary_length(&new_conversation[0])
564 } else {
565 0
566 };
567
568 *self.conversation.write().await = Arc::new(new_conversation);
569
570 if result.turns_compacted > 0 {
571 tracing::info!(
572 session_id = self.id(),
573 turns_compacted = result.turns_compacted,
574 messages_before,
575 messages_after,
576 "Forced LLM compaction completed"
577 );
578 }
579
580 return CompactResult {
581 compacted,
582 messages_before,
583 messages_after,
584 turns_compacted: turns_before.saturating_sub(turns_after),
585 turns_kept: turns_after,
586 summary_length,
587 error: None,
588 };
589 }
590 Err(e) => {
591 tracing::error!(
592 session_id = self.id(),
593 error = %e,
594 "Forced LLM compaction failed"
595 );
596 return CompactResult {
597 compacted: false,
598 messages_before,
599 messages_after: messages_before,
600 turns_compacted: 0,
601 turns_kept: turns_before,
602 summary_length: 0,
603 error: Some(format!("Compaction failed: {}", e)),
604 };
605 }
606 }
607 }
608
609 if let Some(ref compactor) = self.compactor {
611 let summaries = self.compact_summaries.read().await.clone();
612 let mut guard = self.conversation.write().await;
613 let messages_before = guard.len();
614 let turns_before = self.count_unique_turns(&guard);
615
616 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
617
618 let messages_after = guard.len();
619 let turns_after = self.count_unique_turns(&guard);
620 let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
621
622 if result.total_compacted() > 0 {
623 tracing::info!(
624 session_id = self.id(),
625 tool_results_summarized = result.tool_results_summarized,
626 tool_results_redacted = result.tool_results_redacted,
627 turns_compacted = result.turns_compacted,
628 "Forced threshold compaction completed"
629 );
630 }
631
632 return CompactResult {
633 compacted,
634 messages_before,
635 messages_after,
636 turns_compacted: turns_before.saturating_sub(turns_after),
637 turns_kept: turns_after,
638 summary_length: 0,
639 error: None,
640 };
641 }
642
643 CompactResult {
645 compacted: false,
646 error: Some("No compactor configured".to_string()),
647 ..Default::default()
648 }
649 }
650
651 fn count_unique_turns(&self, conversation: &[Message]) -> usize {
653 use std::collections::HashSet;
654 let mut turn_ids = HashSet::new();
655 for msg in conversation {
656 turn_ids.insert(msg.turn_id().clone());
657 }
658 turn_ids.len()
659 }
660
661 fn extract_summary_length(&self, message: &Message) -> usize {
663 if let Message::User(user_msg) = message {
664 for block in &user_msg.content {
665 if let ContentBlock::Text(text_block) = block {
666 if text_block.text.starts_with("[Previous conversation summary]") {
667 return text_block.text.len();
668 }
669 }
670 }
671 }
672 0
673 }
674
675 pub async fn send(&self, msg: ToLLMPayload) -> bool {
678 if self.shutdown.load(Ordering::SeqCst) {
679 return false;
680 }
681 self.to_llm_tx.send(msg).await.is_ok()
682 }
683
684 pub async fn interrupt(&self) {
688 let guard = self.current_cancel.lock().await;
689 if let Some(token) = guard.as_ref() {
690 token.cancel();
691
692 let turn_id = self.current_turn_id.read().await.clone();
696 if let Some(turn_id) = turn_id {
697 let mut guard = self.conversation.write().await;
698 let original_len = guard.len();
699 Arc::make_mut(&mut *guard).retain(|msg| msg.turn_id() != &turn_id);
700 let removed = original_len - guard.len();
701 tracing::debug!(
702 session_id = self.id(),
703 turn_id = %turn_id,
704 messages_removed = removed,
705 conversation_length = guard.len(),
706 "Removed messages from cancelled turn"
707 );
708 }
709 }
710 }
711
712 pub fn shutdown(&self) {
715 self.shutdown.store(true, Ordering::SeqCst);
717 self.cancel_token.cancel();
719 }
720
721 pub fn is_shutdown(&self) -> bool {
723 self.shutdown.load(Ordering::SeqCst)
724 }
725
726 pub async fn start(&self) {
732 tracing::info!(session_id = self.id(), "Session starting");
733
734 loop {
735 let mut rx_guard = self.to_llm_rx.lock().await;
736
737 tokio::select! {
738 _ = self.cancel_token.cancelled() => {
739 tracing::info!(session_id = self.id(), "Session cancelled");
740 break;
741 }
742 msg = rx_guard.recv() => {
743 match msg {
744 Some(request) => {
745 drop(rx_guard);
747 self.handle_request(request).await;
748 }
749 None => {
750 tracing::info!(session_id = self.id(), "Session channel closed");
752 break;
753 }
754 }
755 }
756 }
757 }
758
759 tracing::info!(session_id = self.id(), "Session stopped");
760 }
761
762 fn current_timestamp_millis() -> i64 {
766 std::time::SystemTime::now()
767 .duration_since(std::time::UNIX_EPOCH)
768 .map(|d| d.as_millis() as i64)
769 .unwrap_or(0)
770 }
771
772 async fn prepare_request(&self, request: &ToLLMPayload) -> (CancellationToken, TurnId) {
775 let request_token = CancellationToken::new();
776 {
777 let mut guard = self.current_cancel.lock().await;
778 *guard = Some(request_token.clone());
779 }
780
781 let effective_turn_id = request
782 .turn_id
783 .clone()
784 .unwrap_or_else(|| TurnId::new_user_turn(0));
785 {
786 let mut guard = self.current_turn_id.write().await;
787 *guard = Some(effective_turn_id.clone());
788 }
789
790 (request_token, effective_turn_id)
791 }
792
793 async fn build_message_options(&self) -> crate::client::models::MessageOptions {
795 use crate::client::models::MessageOptions;
796
797 let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
798 let tools = self.tool_definitions.read().await.clone();
799 let tools_option = if tools.is_empty() { None } else { Some(tools) };
800
801 MessageOptions {
802 max_tokens: Some(max_tokens),
803 temperature: self.config.temperature,
804 tools: tools_option,
805 ..Default::default()
806 }
807 }
808
809 async fn cleanup_request(&self) {
811 {
812 let mut guard = self.current_cancel.lock().await;
813 *guard = None;
814 }
815 {
816 let mut guard = self.current_turn_id.write().await;
817 *guard = None;
818 }
819 }
820
821 async fn handle_request(&self, request: ToLLMPayload) {
823 if self.config.streaming {
824 self.handle_streaming_request(request).await;
825 } else {
826 self.handle_non_streaming_request(request).await;
827 }
828 }
829
830 async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
832 use super::convert::{from_llm_message, to_llm_messages};
833 use crate::controller::types::{LLMRequestType, LLMResponseType};
834 use crate::client::models::Message as LLMMessage;
835
836 let (_request_token, effective_turn_id) = self.prepare_request(&request).await;
838
839 let session_id = self.id();
840 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
841
842 let mut llm_messages: Vec<LLMMessage> = Vec::new();
844
845 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
847 llm_messages.push(LLMMessage::system(prompt.clone()));
848 }
849
850 let conversation = self.conversation.read().await;
852 llm_messages.extend(to_llm_messages(&conversation));
853 drop(conversation);
854
855 match request.request_type {
857 LLMRequestType::UserMessage => {
858 if !request.content.is_empty() {
859 llm_messages.push(LLMMessage::user(&request.content));
860
861 let user_msg = Message::User(UserMessage {
863 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
864 session_id: session_id.to_string(),
865 turn_id: effective_turn_id.clone(),
866 created_at: Self::current_timestamp_millis(),
867 content: vec![ContentBlock::text(&request.content)],
868 });
869 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
870 }
871 }
872 LLMRequestType::ToolResult => {
873 self.store_compact_summaries(&request.compact_summaries).await;
875
876 for tool_result in &request.tool_results {
878 llm_messages.push(LLMMessage::tool_result(
879 &tool_result.tool_use_id,
880 &tool_result.content,
881 tool_result.is_error,
882 ));
883
884 let compact_summary = request
886 .compact_summaries
887 .get(&tool_result.tool_use_id)
888 .cloned();
889
890 let user_msg = Message::User(UserMessage {
892 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
893 session_id: session_id.to_string(),
894 turn_id: effective_turn_id.clone(),
895 created_at: Self::current_timestamp_millis(),
896 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
897 tool_use_id: tool_result.tool_use_id.clone(),
898 content: tool_result.content.clone(),
899 is_error: tool_result.is_error,
900 compact_summary,
901 })],
902 });
903 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
904 }
905 }
906 }
907
908 self.maybe_compact().await;
910
911 let options = self.build_message_options().await;
913
914 let result = self.client.send_message(&llm_messages, &options).await;
916
917 match result {
918 Ok(response) => {
919 let content_blocks = from_llm_message(&response);
921
922 let text: String = content_blocks
924 .iter()
925 .filter_map(|block| {
926 if let ContentBlock::Text(t) = block {
927 Some(t.text.clone())
928 } else {
929 None
930 }
931 })
932 .collect::<Vec<_>>()
933 .join("");
934
935 if !text.is_empty() {
937 let payload = FromLLMPayload {
938 session_id,
939 response_type: LLMResponseType::TextChunk,
940 text: text.clone(),
941 turn_id: request.turn_id.clone(),
942 ..Default::default()
943 };
944 let _ = self.from_llm.send(payload).await;
945 }
946
947 for block in &content_blocks {
949 if let ContentBlock::ToolUse(tool_use) = block {
950 let payload = FromLLMPayload {
951 session_id,
952 response_type: LLMResponseType::ToolUse,
953 tool_use: Some(crate::controller::types::ToolUseInfo {
954 id: tool_use.id.clone(),
955 name: tool_use.name.clone(),
956 input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
957 }),
958 turn_id: request.turn_id.clone(),
959 ..Default::default()
960 };
961 let _ = self.from_llm.send(payload).await;
962 }
963 }
964
965 let now = Self::current_timestamp_millis();
967 let asst_msg = Message::Assistant(AssistantMessage {
968 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
969 session_id: session_id.to_string(),
970 turn_id: effective_turn_id.clone(),
971 parent_id: String::new(),
972 created_at: now,
973 completed_at: Some(now),
974 model_id: self.config.model.clone(),
975 provider_id: String::new(),
976 input_tokens: 0,
977 output_tokens: 0,
978 cache_read_tokens: 0,
979 cache_write_tokens: 0,
980 finish_reason: None,
981 error: None,
982 content: content_blocks,
983 });
984 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
985
986 let payload = FromLLMPayload {
988 session_id,
989 response_type: LLMResponseType::Complete,
990 is_complete: true,
991 turn_id: request.turn_id.clone(),
992 ..Default::default()
993 };
994 let _ = self.from_llm.send(payload).await;
995
996 self.request_count.fetch_add(1, Ordering::SeqCst);
998
999 tracing::debug!(session_id, "Request completed successfully");
1000 }
1001 Err(err) => {
1002 tracing::error!(session_id, error = %err, "LLM request failed");
1003
1004 let payload = FromLLMPayload {
1005 session_id,
1006 response_type: LLMResponseType::Error,
1007 error: Some(err.to_string()),
1008 turn_id: request.turn_id,
1009 ..Default::default()
1010 };
1011 let _ = self.from_llm.send(payload).await;
1012 }
1013 }
1014
1015 self.cleanup_request().await;
1017 }
1018
1019 async fn handle_streaming_request(&self, request: ToLLMPayload) {
1021 use super::convert::to_llm_messages;
1022 use crate::controller::types::{LLMRequestType, LLMResponseType};
1023 use futures::StreamExt;
1024 use crate::client::models::{
1025 ContentBlockType, Message as LLMMessage, StreamEvent,
1026 };
1027
1028 let (request_token, effective_turn_id) = self.prepare_request(&request).await;
1030
1031 let session_id = self.id();
1032 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
1033
1034 let mut llm_messages: Vec<LLMMessage> = Vec::new();
1036
1037 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
1039 llm_messages.push(LLMMessage::system(prompt.clone()));
1040 }
1041
1042 let conversation = self.conversation.read().await;
1044 llm_messages.extend(to_llm_messages(&conversation));
1045 drop(conversation);
1046
1047 match request.request_type {
1049 LLMRequestType::UserMessage => {
1050 if !request.content.is_empty() {
1051 llm_messages.push(LLMMessage::user(&request.content));
1052
1053 let user_msg = Message::User(UserMessage {
1055 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
1056 session_id: session_id.to_string(),
1057 turn_id: effective_turn_id.clone(),
1058 created_at: Self::current_timestamp_millis(),
1059 content: vec![ContentBlock::text(&request.content)],
1060 });
1061 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1062 }
1063 }
1064 LLMRequestType::ToolResult => {
1065 self.store_compact_summaries(&request.compact_summaries).await;
1067
1068 {
1070 let conv = self.conversation.read().await;
1071 tracing::debug!(
1072 session_id,
1073 conversation_len = conv.len(),
1074 tool_result_count = request.tool_results.len(),
1075 "STREAMING ToolResult: conversation state before adding results"
1076 );
1077 }
1078 for tool_result in &request.tool_results {
1080 llm_messages.push(LLMMessage::tool_result(
1081 &tool_result.tool_use_id,
1082 &tool_result.content,
1083 tool_result.is_error,
1084 ));
1085
1086 let compact_summary = request
1088 .compact_summaries
1089 .get(&tool_result.tool_use_id)
1090 .cloned();
1091
1092 let user_msg = Message::User(UserMessage {
1094 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1095 session_id: session_id.to_string(),
1096 turn_id: effective_turn_id.clone(),
1097 created_at: Self::current_timestamp_millis(),
1098 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
1099 tool_use_id: tool_result.tool_use_id.clone(),
1100 content: tool_result.content.clone(),
1101 is_error: tool_result.is_error,
1102 compact_summary,
1103 })],
1104 });
1105 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1106 }
1107 }
1108 }
1109
1110 self.maybe_compact().await;
1112
1113 let options = self.build_message_options().await;
1115
1116 let stream_result = self
1118 .client
1119 .send_message_stream(&llm_messages, &options)
1120 .await;
1121
1122 match stream_result {
1123 Ok(mut stream) => {
1124 let mut current_tool_id: Option<String> = None;
1126 let mut current_tool_name: Option<String> = None;
1127 let mut tool_input_json = String::new();
1128 let mut response_text = String::new();
1130 let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> = Vec::new();
1132
1133 loop {
1135 tokio::select! {
1136 _ = request_token.cancelled() => {
1137 tracing::info!(session_id, "Streaming request cancelled");
1138 break;
1139 }
1140 event = stream.next() => {
1141 match event {
1142 Some(Ok(stream_event)) => {
1143 match stream_event {
1144 StreamEvent::MessageStart { message_id, model } => {
1145 let payload = FromLLMPayload {
1146 session_id,
1147 response_type: LLMResponseType::StreamStart,
1148 message_id,
1149 model,
1150 turn_id: request.turn_id.clone(),
1151 ..Default::default()
1152 };
1153 let _ = self.from_llm.send(payload).await;
1154 }
1155 StreamEvent::ContentBlockStart { index: _, block_type } => {
1156 match block_type {
1157 ContentBlockType::Text => {
1158 }
1160 ContentBlockType::ToolUse { id, name } => {
1161 current_tool_id = Some(id);
1164 current_tool_name = Some(name);
1165 tool_input_json.clear();
1166 }
1167 }
1168 }
1169 StreamEvent::TextDelta { index, text } => {
1170 response_text.push_str(&text);
1172
1173 let payload = FromLLMPayload {
1174 session_id,
1175 response_type: LLMResponseType::TextChunk,
1176 text,
1177 content_index: index,
1178 turn_id: request.turn_id.clone(),
1179 ..Default::default()
1180 };
1181 let _ = self.from_llm.send(payload).await;
1182 }
1183 StreamEvent::InputJsonDelta { index, json } => {
1184 tool_input_json.push_str(&json);
1186
1187 let payload = FromLLMPayload {
1188 session_id,
1189 response_type: LLMResponseType::ToolInputDelta,
1190 text: json,
1191 content_index: index,
1192 turn_id: request.turn_id.clone(),
1193 ..Default::default()
1194 };
1195 let _ = self.from_llm.send(payload).await;
1196 }
1197 StreamEvent::ContentBlockStop { index: _ } => {
1198 if let (Some(id), Some(name)) =
1202 (current_tool_id.take(), current_tool_name.take())
1203 {
1204 let input: serde_json::Value =
1205 serde_json::from_str(&tool_input_json)
1206 .unwrap_or(serde_json::Value::Object(
1207 serde_json::Map::new(),
1208 ));
1209
1210 tracing::debug!(
1212 session_id,
1213 tool_id = %id,
1214 tool_name = %name,
1215 "Saving tool use to completed_tool_uses"
1216 );
1217 completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1218 id: id.clone(),
1219 name: name.clone(),
1220 input: input
1221 .as_object()
1222 .map(|obj| {
1223 obj.iter()
1224 .map(|(k, v)| (k.clone(), v.clone()))
1225 .collect()
1226 })
1227 .unwrap_or_default(),
1228 });
1229
1230 tool_input_json.clear();
1231 }
1232 }
1233 StreamEvent::MessageDelta { stop_reason, usage } => {
1234 if let Some(usage) = usage {
1235 tracing::info!(
1236 session_id,
1237 input_tokens = usage.input_tokens,
1238 output_tokens = usage.output_tokens,
1239 "API token usage for this turn"
1240 );
1241 self.current_input_tokens
1242 .store(usage.input_tokens as i64, Ordering::SeqCst);
1243 self.current_output_tokens
1244 .store(usage.output_tokens as i64, Ordering::SeqCst);
1245
1246 let payload = FromLLMPayload {
1247 session_id,
1248 response_type: LLMResponseType::TokenUpdate,
1249 input_tokens: usage.input_tokens as i64,
1250 output_tokens: usage.output_tokens as i64,
1251 turn_id: request.turn_id.clone(),
1252 ..Default::default()
1253 };
1254 let _ = self.from_llm.send(payload).await;
1255 }
1256
1257 if stop_reason.is_some() {
1258 let payload = FromLLMPayload {
1259 session_id,
1260 response_type: LLMResponseType::Complete,
1261 is_complete: true,
1262 stop_reason,
1263 turn_id: request.turn_id.clone(),
1264 ..Default::default()
1265 };
1266 let _ = self.from_llm.send(payload).await;
1267 }
1268 }
1269 StreamEvent::MessageStop => {
1270 tracing::debug!(
1273 session_id,
1274 text_len = response_text.len(),
1275 tool_use_count = completed_tool_uses.len(),
1276 "MessageStop: saving assistant message to history"
1277 );
1278 if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1279 let now = Self::current_timestamp_millis();
1280
1281 let mut content_blocks = Vec::new();
1283 if !response_text.is_empty() {
1284 content_blocks.push(ContentBlock::text(&response_text));
1285 }
1286 for tool_use in &completed_tool_uses {
1287 content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1288 }
1289
1290 let content_block_count = content_blocks.len();
1291 let asst_msg = Message::Assistant(AssistantMessage {
1292 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1293 session_id: session_id.to_string(),
1294 turn_id: effective_turn_id.clone(),
1295 parent_id: String::new(),
1296 created_at: now,
1297 completed_at: Some(now),
1298 model_id: self.config.model.clone(),
1299 provider_id: String::new(),
1300 input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1301 output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1302 cache_read_tokens: 0,
1303 cache_write_tokens: 0,
1304 finish_reason: None,
1305 error: None,
1306 content: content_blocks,
1307 });
1308 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
1309 tracing::debug!(
1310 session_id,
1311 content_block_count,
1312 "MessageStop: saved assistant message with content blocks"
1313 );
1314 }
1315
1316 if !completed_tool_uses.is_empty() {
1319 let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1320 .iter()
1321 .map(|tu| crate::controller::types::ToolUseInfo {
1322 id: tu.id.clone(),
1323 name: tu.name.clone(),
1324 input: serde_json::Value::Object(
1325 tu.input.iter()
1326 .map(|(k, v)| (k.clone(), v.clone()))
1327 .collect()
1328 ),
1329 })
1330 .collect();
1331
1332 tracing::debug!(
1333 session_id,
1334 tool_count = tool_uses.len(),
1335 "MessageStop: emitting ToolBatch for execution"
1336 );
1337
1338 let payload = FromLLMPayload {
1339 session_id,
1340 response_type: LLMResponseType::ToolBatch,
1341 tool_uses,
1342 turn_id: request.turn_id.clone(),
1343 ..Default::default()
1344 };
1345 let _ = self.from_llm.send(payload).await;
1346 }
1347
1348 self.request_count.fetch_add(1, Ordering::SeqCst);
1350 tracing::debug!(session_id, "Streaming request completed");
1351 break;
1352 }
1353 StreamEvent::Ping => {
1354 }
1356 }
1357 }
1358 Some(Err(err)) => {
1359 tracing::error!(session_id, error = %err, "Stream error");
1360 let payload = FromLLMPayload {
1361 session_id,
1362 response_type: LLMResponseType::Error,
1363 error: Some(err.to_string()),
1364 turn_id: request.turn_id.clone(),
1365 ..Default::default()
1366 };
1367 let _ = self.from_llm.send(payload).await;
1368 break;
1369 }
1370 None => {
1371 break;
1373 }
1374 }
1375 }
1376 }
1377 }
1378 }
1379 Err(err) => {
1380 tracing::error!(session_id, error = %err, "Failed to start streaming");
1381 let payload = FromLLMPayload {
1382 session_id,
1383 response_type: LLMResponseType::Error,
1384 error: Some(err.to_string()),
1385 turn_id: request.turn_id,
1386 ..Default::default()
1387 };
1388 let _ = self.from_llm.send(payload).await;
1389 }
1390 }
1391
1392 self.cleanup_request().await;
1394 }
1395}