1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, Ordering};
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::sync::{mpsc, Mutex, RwLock};
9use tokio_util::sync::CancellationToken;
10
11use crate::client::error::LlmError;
12use crate::client::models::Tool as LLMTool;
13use crate::client::providers::anthropic::AnthropicProvider;
14use crate::client::providers::openai::OpenAIProvider;
15use crate::client::LLMClient;
16
17use super::compactor::{AsyncCompactor, Compactor, LLMCompactor, ThresholdCompactor};
18use super::config::{CompactorType, LLMProvider, LLMSessionConfig};
19
20fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
22 match config.provider {
23 LLMProvider::Anthropic => {
24 let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
25 LLMClient::new(Box::new(provider))
26 }
27 LLMProvider::OpenAI => {
28 let provider = OpenAIProvider::new(config.api_key.clone(), config.model.clone());
29 LLMClient::new(Box::new(provider))
30 }
31 }
32}
33use crate::controller::types::{
34 AssistantMessage, ContentBlock, FromLLMPayload, Message, ToLLMPayload, TurnId, UserMessage,
35};
36
37#[derive(Debug, Clone, Default)]
39pub struct TokenUsage {
40 pub total_input_tokens: i64,
42 pub total_output_tokens: i64,
44 pub request_count: i64,
46 pub last_input_tokens: i64,
48 pub last_output_tokens: i64,
50}
51
52#[derive(Debug, Clone)]
54pub struct SessionStatus {
55 pub session_id: i64,
57 pub model: String,
59 pub created_at: Instant,
61 pub conversation_len: usize,
63 pub context_used: i64,
65 pub context_limit: i32,
67 pub utilization: f64,
69 pub total_input: i64,
71 pub total_output: i64,
73 pub request_count: i64,
75}
76
77#[derive(Debug, Clone, Default)]
80pub struct CompactResult {
81 pub compacted: bool,
84 pub messages_before: usize,
86 pub messages_after: usize,
88 pub turns_compacted: usize,
90 pub turns_kept: usize,
92 pub summary_length: usize,
94 pub error: Option<String>,
96}
97
98static SESSION_COUNTER: AtomicI64 = AtomicI64::new(0);
100
101pub struct LLMSession {
103 id: AtomicI64,
105
106 client: LLMClient,
108
109 to_llm_tx: mpsc::Sender<ToLLMPayload>,
111 to_llm_rx: Mutex<mpsc::Receiver<ToLLMPayload>>,
112 from_llm: mpsc::Sender<FromLLMPayload>,
113
114 config: LLMSessionConfig,
116
117 system_prompt: RwLock<Option<String>>,
119 max_tokens: AtomicI64,
120 created_at: Instant,
121
122 conversation: RwLock<Arc<Vec<Message>>>,
124
125 shutdown: AtomicBool,
127 cancel_token: CancellationToken,
128
129 current_cancel: Mutex<Option<CancellationToken>>,
131
132 current_turn_id: RwLock<Option<TurnId>>,
134
135 current_input_tokens: AtomicI64,
137 current_output_tokens: AtomicI64,
138
139 request_count: AtomicI64,
141
142 tool_definitions: RwLock<Vec<LLMTool>>,
144
145 compactor: Option<Box<dyn Compactor>>,
147 llm_compactor: Option<LLMCompactor>,
148 context_limit: AtomicI32,
149 compact_summaries: RwLock<HashMap<String, String>>,
150}
151
152impl LLMSession {
153 pub fn new(
163 config: LLMSessionConfig,
164 from_llm: mpsc::Sender<FromLLMPayload>,
165 cancel_token: CancellationToken,
166 ) -> Result<Self, LlmError> {
167 let session_id = SESSION_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
168 let (to_llm_tx, to_llm_rx) = mpsc::channel(32);
169 let max_tokens = config.max_tokens.unwrap_or(4096) as i64;
170 let system_prompt = config.system_prompt.clone();
171
172 let client = create_llm_client(&config)?;
174
175 let mut compactor: Option<Box<dyn Compactor>> = None;
177 let mut llm_compactor: Option<LLMCompactor> = None;
178
179 if let Some(ref compactor_type) = config.compaction {
180 match compactor_type {
181 CompactorType::Threshold(c) => {
182 match ThresholdCompactor::new(c.threshold, c.keep_recent_turns, c.tool_compaction) {
183 Ok(tc) => {
184 tracing::info!(
185 threshold = c.threshold,
186 keep_recent_turns = c.keep_recent_turns,
187 tool_compaction = %c.tool_compaction,
188 "Threshold compaction enabled for session"
189 );
190 compactor = Some(Box::new(tc) as Box<dyn Compactor>);
191 }
192 Err(e) => {
193 tracing::warn!(error = %e, "Failed to create threshold compactor, compaction disabled");
194 }
195 }
196 }
197 CompactorType::LLM(c) => {
198 let llm_client = create_llm_client(&config)?;
200
201 match LLMCompactor::new(llm_client, c.clone()) {
202 Ok(lc) => {
203 tracing::info!(
204 threshold = c.threshold,
205 keep_recent_turns = c.keep_recent_turns,
206 "LLM compaction enabled for session"
207 );
208 llm_compactor = Some(lc);
209 }
210 Err(e) => {
211 tracing::warn!(error = %e, "Failed to create LLM compactor, compaction disabled");
212 }
213 }
214 }
215 }
216 }
217
218 let context_limit = config.context_limit;
219
220 Ok(Self {
221 id: AtomicI64::new(session_id),
222 client,
223 to_llm_tx,
224 to_llm_rx: Mutex::new(to_llm_rx),
225 from_llm,
226 config,
227 system_prompt: RwLock::new(system_prompt),
228 max_tokens: AtomicI64::new(max_tokens),
229 created_at: Instant::now(),
230 conversation: RwLock::new(Arc::new(Vec::new())),
231 shutdown: AtomicBool::new(false),
232 cancel_token,
233 current_cancel: Mutex::new(None),
234 current_turn_id: RwLock::new(None),
235 current_input_tokens: AtomicI64::new(0),
236 current_output_tokens: AtomicI64::new(0),
237 request_count: AtomicI64::new(0),
238 tool_definitions: RwLock::new(Vec::new()),
239 compactor,
240 llm_compactor,
241 context_limit: AtomicI32::new(context_limit),
242 compact_summaries: RwLock::new(HashMap::new()),
243 })
244 }
245
246 pub fn id(&self) -> i64 {
248 self.id.load(Ordering::SeqCst)
249 }
250
251 pub fn created_at(&self) -> Instant {
253 self.created_at
254 }
255
256 pub fn model(&self) -> &str {
258 &self.config.model
259 }
260
261 pub fn set_max_tokens(&self, max_tokens: i64) {
265 self.max_tokens.store(max_tokens, Ordering::SeqCst);
266 }
267
268 pub fn max_tokens(&self) -> i64 {
270 self.max_tokens.load(Ordering::SeqCst)
271 }
272
273 pub fn context_limit(&self) -> i32 {
275 self.context_limit.load(Ordering::SeqCst)
276 }
277
278 pub async fn set_system_prompt(&self, prompt: String) {
282 let mut guard = self.system_prompt.write().await;
283 *guard = Some(prompt);
284 }
285
286 pub async fn clear_system_prompt(&self) {
288 let mut guard = self.system_prompt.write().await;
289 *guard = None;
290 }
291
292 pub async fn system_prompt(&self) -> Option<String> {
294 self.system_prompt.read().await.clone()
295 }
296
297 pub async fn set_tools(&self, tools: Vec<LLMTool>) {
302 let mut guard = self.tool_definitions.write().await;
303 *guard = tools;
304 }
305
306 pub async fn clear_tools(&self) {
308 let mut guard = self.tool_definitions.write().await;
309 guard.clear();
310 }
311
312 pub async fn tools(&self) -> Vec<LLMTool> {
314 self.tool_definitions.read().await.clone()
315 }
316
317 async fn store_compact_summaries(&self, summaries: &HashMap<String, String>) {
322 if summaries.is_empty() {
323 tracing::warn!(
324 session_id = self.id(),
325 "No compact summaries provided with tool results"
326 );
327 return;
328 }
329 let mut guard = self.compact_summaries.write().await;
330 for (tool_use_id, summary) in summaries {
331 tracing::info!(
332 session_id = self.id(),
333 tool_use_id = %tool_use_id,
334 summary_len = summary.len(),
335 summary_preview = %summary.chars().take(50).collect::<String>(),
336 "Storing compact summary"
337 );
338 guard.insert(tool_use_id.clone(), summary.clone());
339 }
340 tracing::info!(
341 session_id = self.id(),
342 new_summaries = summaries.len(),
343 total_stored = guard.len(),
344 "Stored compact summaries for tool results"
345 );
346 }
347
348 async fn maybe_compact(&self) {
351 let context_used = self.current_input_tokens.load(Ordering::SeqCst);
352 let context_limit = self.context_limit.load(Ordering::SeqCst);
353 let conversation_len = self.conversation.read().await.len();
354 let summaries_count = self.compact_summaries.read().await.len();
355
356 let utilization = if context_limit > 0 {
357 context_used as f64 / context_limit as f64
358 } else {
359 0.0
360 };
361
362 tracing::debug!(
363 session_id = self.id(),
364 context_used,
365 context_limit,
366 utilization = format!("{:.2}%", utilization * 100.0),
367 conversation_len,
368 summaries_available = summaries_count,
369 "Checking if compaction needed"
370 );
371
372 if let Some(ref llm_compactor) = self.llm_compactor {
374 if !llm_compactor.should_compact(context_used, context_limit) {
375 tracing::debug!(
376 session_id = self.id(),
377 "LLM compaction not triggered"
378 );
379 return;
380 }
381
382 let summaries = self.compact_summaries.read().await.clone();
384 let conversation_arc = {
385 let guard = self.conversation.read().await;
386 Arc::clone(&*guard) };
388 let conversation = Arc::try_unwrap(conversation_arc)
389 .unwrap_or_else(|arc| (*arc).clone());
390
391 tracing::info!(
392 session_id = self.id(),
393 conversation_len = conversation.len(),
394 summaries_count = summaries.len(),
395 "Starting LLM compaction"
396 );
397
398 match llm_compactor.compact_async(conversation, &summaries).await {
400 Ok((new_conversation, result)) => {
401 *self.conversation.write().await = Arc::new(new_conversation);
403
404 if result.turns_compacted > 0 {
405 tracing::info!(
406 session_id = self.id(),
407 turns_compacted = result.turns_compacted,
408 "LLM compaction completed"
409 );
410 }
411 }
412 Err(e) => {
413 tracing::error!(
414 session_id = self.id(),
415 error = %e,
416 "LLM compaction failed"
417 );
418 }
419 }
420 return;
421 }
422
423 let compactor = match &self.compactor {
425 Some(c) => c,
426 None => {
427 tracing::debug!(
428 session_id = self.id(),
429 "No compactor configured"
430 );
431 return;
432 }
433 };
434
435 if !compactor.should_compact(context_used, context_limit) {
436 tracing::debug!(
437 session_id = self.id(),
438 "Threshold compaction not triggered"
439 );
440 return;
441 }
442
443 let summaries = self.compact_summaries.read().await.clone();
445 let mut guard = self.conversation.write().await;
446
447 tracing::info!(
448 session_id = self.id(),
449 conversation_len = guard.len(),
450 summaries_count = summaries.len(),
451 "Starting threshold compaction"
452 );
453
454 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
455
456 tracing::info!(
457 session_id = self.id(),
458 tool_results_summarized = result.tool_results_summarized,
459 tool_results_redacted = result.tool_results_redacted,
460 turns_compacted = result.turns_compacted,
461 conversation_len_after = guard.len(),
462 "Threshold compaction completed"
463 );
464 }
465
466 pub async fn clear_conversation(&self) {
470 let mut guard = self.conversation.write().await;
471 Arc::make_mut(&mut *guard).clear();
472
473 let mut summaries = self.compact_summaries.write().await;
474 summaries.clear();
475
476 self.current_input_tokens.store(0, Ordering::SeqCst);
478 self.current_output_tokens.store(0, Ordering::SeqCst);
479
480 tracing::info!(session_id = self.id(), "Conversation cleared");
481 }
482
483 pub async fn force_compact(&self) -> CompactResult {
486 if let Some(ref llm_compactor) = self.llm_compactor {
488 let summaries = self.compact_summaries.read().await.clone();
489 let conversation_arc = {
490 let guard = self.conversation.read().await;
491 Arc::clone(&*guard) };
493 let conversation = Arc::try_unwrap(conversation_arc)
494 .unwrap_or_else(|arc| (*arc).clone());
495 let messages_before = conversation.len();
496 let turns_before = self.count_unique_turns(&conversation);
497
498 match llm_compactor.compact_async(conversation, &summaries).await {
499 Ok((new_conversation, result)) => {
500 let messages_after = new_conversation.len();
501 let turns_after = self.count_unique_turns(&new_conversation);
502 let compacted = messages_after < messages_before;
503
504 let summary_length = if compacted && !new_conversation.is_empty() {
506 self.extract_summary_length(&new_conversation[0])
507 } else {
508 0
509 };
510
511 *self.conversation.write().await = Arc::new(new_conversation);
512
513 if result.turns_compacted > 0 {
514 tracing::info!(
515 session_id = self.id(),
516 turns_compacted = result.turns_compacted,
517 messages_before,
518 messages_after,
519 "Forced LLM compaction completed"
520 );
521 }
522
523 return CompactResult {
524 compacted,
525 messages_before,
526 messages_after,
527 turns_compacted: turns_before.saturating_sub(turns_after),
528 turns_kept: turns_after,
529 summary_length,
530 error: None,
531 };
532 }
533 Err(e) => {
534 tracing::error!(
535 session_id = self.id(),
536 error = %e,
537 "Forced LLM compaction failed"
538 );
539 return CompactResult {
540 compacted: false,
541 messages_before,
542 messages_after: messages_before,
543 turns_compacted: 0,
544 turns_kept: turns_before,
545 summary_length: 0,
546 error: Some(format!("Compaction failed: {}", e)),
547 };
548 }
549 }
550 }
551
552 if let Some(ref compactor) = self.compactor {
554 let summaries = self.compact_summaries.read().await.clone();
555 let mut guard = self.conversation.write().await;
556 let messages_before = guard.len();
557 let turns_before = self.count_unique_turns(&guard);
558
559 let result = compactor.compact(Arc::make_mut(&mut *guard), &summaries);
560
561 let messages_after = guard.len();
562 let turns_after = self.count_unique_turns(&guard);
563 let compacted = result.turns_compacted > 0 || result.total_compacted() > 0;
564
565 if result.total_compacted() > 0 {
566 tracing::info!(
567 session_id = self.id(),
568 tool_results_summarized = result.tool_results_summarized,
569 tool_results_redacted = result.tool_results_redacted,
570 turns_compacted = result.turns_compacted,
571 "Forced threshold compaction completed"
572 );
573 }
574
575 return CompactResult {
576 compacted,
577 messages_before,
578 messages_after,
579 turns_compacted: turns_before.saturating_sub(turns_after),
580 turns_kept: turns_after,
581 summary_length: 0,
582 error: None,
583 };
584 }
585
586 CompactResult {
588 compacted: false,
589 error: Some("No compactor configured".to_string()),
590 ..Default::default()
591 }
592 }
593
594 fn count_unique_turns(&self, conversation: &[Message]) -> usize {
596 use std::collections::HashSet;
597 let mut turn_ids = HashSet::new();
598 for msg in conversation {
599 turn_ids.insert(msg.turn_id().clone());
600 }
601 turn_ids.len()
602 }
603
604 fn extract_summary_length(&self, message: &Message) -> usize {
606 if let Message::User(user_msg) = message {
607 for block in &user_msg.content {
608 if let ContentBlock::Text(text_block) = block {
609 if text_block.text.starts_with("[Previous conversation summary]") {
610 return text_block.text.len();
611 }
612 }
613 }
614 }
615 0
616 }
617
618 pub async fn send(&self, msg: ToLLMPayload) -> bool {
621 if self.shutdown.load(Ordering::SeqCst) {
622 return false;
623 }
624 self.to_llm_tx.send(msg).await.is_ok()
625 }
626
627 pub async fn interrupt(&self) {
631 let guard = self.current_cancel.lock().await;
632 if let Some(token) = guard.as_ref() {
633 token.cancel();
634
635 let turn_id = self.current_turn_id.read().await.clone();
639 if let Some(turn_id) = turn_id {
640 let mut guard = self.conversation.write().await;
641 let original_len = guard.len();
642 Arc::make_mut(&mut *guard).retain(|msg| msg.turn_id() != &turn_id);
643 let removed = original_len - guard.len();
644 tracing::debug!(
645 session_id = self.id(),
646 turn_id = %turn_id,
647 messages_removed = removed,
648 conversation_length = guard.len(),
649 "Removed messages from cancelled turn"
650 );
651 }
652 }
653 }
654
655 pub fn shutdown(&self) {
658 self.shutdown.store(true, Ordering::SeqCst);
660 self.cancel_token.cancel();
662 }
663
664 pub fn is_shutdown(&self) -> bool {
666 self.shutdown.load(Ordering::SeqCst)
667 }
668
669 pub async fn start(&self) {
675 tracing::info!(session_id = self.id(), "Session starting");
676
677 loop {
678 let mut rx_guard = self.to_llm_rx.lock().await;
679
680 tokio::select! {
681 _ = self.cancel_token.cancelled() => {
682 tracing::info!(session_id = self.id(), "Session cancelled");
683 break;
684 }
685 msg = rx_guard.recv() => {
686 match msg {
687 Some(request) => {
688 drop(rx_guard);
690 self.handle_request(request).await;
691 }
692 None => {
693 tracing::info!(session_id = self.id(), "Session channel closed");
695 break;
696 }
697 }
698 }
699 }
700 }
701
702 tracing::info!(session_id = self.id(), "Session stopped");
703 }
704
705 fn current_timestamp_millis() -> i64 {
709 std::time::SystemTime::now()
710 .duration_since(std::time::UNIX_EPOCH)
711 .map(|d| d.as_millis() as i64)
712 .unwrap_or(0)
713 }
714
715 async fn prepare_request(&self, request: &ToLLMPayload) -> (CancellationToken, TurnId) {
718 let request_token = CancellationToken::new();
719 {
720 let mut guard = self.current_cancel.lock().await;
721 *guard = Some(request_token.clone());
722 }
723
724 let effective_turn_id = request
725 .turn_id
726 .clone()
727 .unwrap_or_else(|| TurnId::new_user_turn(0));
728 {
729 let mut guard = self.current_turn_id.write().await;
730 *guard = Some(effective_turn_id.clone());
731 }
732
733 (request_token, effective_turn_id)
734 }
735
736 async fn build_message_options(&self) -> crate::client::models::MessageOptions {
738 use crate::client::models::MessageOptions;
739
740 let max_tokens = self.max_tokens.load(Ordering::SeqCst) as u32;
741 let tools = self.tool_definitions.read().await.clone();
742 let tools_option = if tools.is_empty() { None } else { Some(tools) };
743
744 MessageOptions {
745 max_tokens: Some(max_tokens),
746 temperature: self.config.temperature,
747 tools: tools_option,
748 ..Default::default()
749 }
750 }
751
752 async fn cleanup_request(&self) {
754 {
755 let mut guard = self.current_cancel.lock().await;
756 *guard = None;
757 }
758 {
759 let mut guard = self.current_turn_id.write().await;
760 *guard = None;
761 }
762 }
763
764 async fn handle_request(&self, request: ToLLMPayload) {
766 if self.config.streaming {
767 self.handle_streaming_request(request).await;
768 } else {
769 self.handle_non_streaming_request(request).await;
770 }
771 }
772
773 async fn handle_non_streaming_request(&self, request: ToLLMPayload) {
775 use super::convert::{from_llm_message, to_llm_messages};
776 use crate::controller::types::{LLMRequestType, LLMResponseType};
777 use crate::client::models::Message as LLMMessage;
778
779 let (_request_token, effective_turn_id) = self.prepare_request(&request).await;
781
782 let session_id = self.id();
783 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling request");
784
785 let mut llm_messages: Vec<LLMMessage> = Vec::new();
787
788 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
790 llm_messages.push(LLMMessage::system(prompt.clone()));
791 }
792
793 let conversation = self.conversation.read().await;
795 llm_messages.extend(to_llm_messages(&conversation));
796 drop(conversation);
797
798 match request.request_type {
800 LLMRequestType::UserMessage => {
801 if !request.content.is_empty() {
802 llm_messages.push(LLMMessage::user(&request.content));
803
804 let user_msg = Message::User(UserMessage {
806 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
807 session_id: session_id.to_string(),
808 turn_id: effective_turn_id.clone(),
809 created_at: Self::current_timestamp_millis(),
810 content: vec![ContentBlock::text(&request.content)],
811 });
812 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
813 }
814 }
815 LLMRequestType::ToolResult => {
816 self.store_compact_summaries(&request.compact_summaries).await;
818
819 for tool_result in &request.tool_results {
821 llm_messages.push(LLMMessage::tool_result(
822 &tool_result.tool_use_id,
823 &tool_result.content,
824 tool_result.is_error,
825 ));
826
827 let compact_summary = request
829 .compact_summaries
830 .get(&tool_result.tool_use_id)
831 .cloned();
832
833 let user_msg = Message::User(UserMessage {
835 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
836 session_id: session_id.to_string(),
837 turn_id: effective_turn_id.clone(),
838 created_at: Self::current_timestamp_millis(),
839 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
840 tool_use_id: tool_result.tool_use_id.clone(),
841 content: tool_result.content.clone(),
842 is_error: tool_result.is_error,
843 compact_summary,
844 })],
845 });
846 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
847 }
848 }
849 }
850
851 self.maybe_compact().await;
853
854 let options = self.build_message_options().await;
856
857 let result = self.client.send_message(&llm_messages, &options).await;
859
860 match result {
861 Ok(response) => {
862 let content_blocks = from_llm_message(&response);
864
865 let text: String = content_blocks
867 .iter()
868 .filter_map(|block| {
869 if let ContentBlock::Text(t) = block {
870 Some(t.text.clone())
871 } else {
872 None
873 }
874 })
875 .collect::<Vec<_>>()
876 .join("");
877
878 if !text.is_empty() {
880 let payload = FromLLMPayload {
881 session_id,
882 response_type: LLMResponseType::TextChunk,
883 text: text.clone(),
884 turn_id: request.turn_id.clone(),
885 ..Default::default()
886 };
887 let _ = self.from_llm.send(payload).await;
888 }
889
890 for block in &content_blocks {
892 if let ContentBlock::ToolUse(tool_use) = block {
893 let payload = FromLLMPayload {
894 session_id,
895 response_type: LLMResponseType::ToolUse,
896 tool_use: Some(crate::controller::types::ToolUseInfo {
897 id: tool_use.id.clone(),
898 name: tool_use.name.clone(),
899 input: serde_json::to_value(&tool_use.input).unwrap_or_default(),
900 }),
901 turn_id: request.turn_id.clone(),
902 ..Default::default()
903 };
904 let _ = self.from_llm.send(payload).await;
905 }
906 }
907
908 let now = Self::current_timestamp_millis();
910 let asst_msg = Message::Assistant(AssistantMessage {
911 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
912 session_id: session_id.to_string(),
913 turn_id: effective_turn_id.clone(),
914 parent_id: String::new(),
915 created_at: now,
916 completed_at: Some(now),
917 model_id: self.config.model.clone(),
918 provider_id: String::new(),
919 input_tokens: 0,
920 output_tokens: 0,
921 cache_read_tokens: 0,
922 cache_write_tokens: 0,
923 finish_reason: None,
924 error: None,
925 content: content_blocks,
926 });
927 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
928
929 let payload = FromLLMPayload {
931 session_id,
932 response_type: LLMResponseType::Complete,
933 is_complete: true,
934 turn_id: request.turn_id.clone(),
935 ..Default::default()
936 };
937 let _ = self.from_llm.send(payload).await;
938
939 self.request_count.fetch_add(1, Ordering::SeqCst);
941
942 tracing::debug!(session_id, "Request completed successfully");
943 }
944 Err(err) => {
945 tracing::error!(session_id, error = %err, "LLM request failed");
946
947 let payload = FromLLMPayload {
948 session_id,
949 response_type: LLMResponseType::Error,
950 error: Some(err.to_string()),
951 turn_id: request.turn_id,
952 ..Default::default()
953 };
954 let _ = self.from_llm.send(payload).await;
955 }
956 }
957
958 self.cleanup_request().await;
960 }
961
962 async fn handle_streaming_request(&self, request: ToLLMPayload) {
964 use super::convert::to_llm_messages;
965 use crate::controller::types::{LLMRequestType, LLMResponseType};
966 use futures::StreamExt;
967 use crate::client::models::{
968 ContentBlockType, Message as LLMMessage, StreamEvent,
969 };
970
971 let (request_token, effective_turn_id) = self.prepare_request(&request).await;
973
974 let session_id = self.id();
975 tracing::debug!(session_id, turn_id = %effective_turn_id, "Handling streaming request");
976
977 let mut llm_messages: Vec<LLMMessage> = Vec::new();
979
980 if let Some(prompt) = self.system_prompt.read().await.as_ref() {
982 llm_messages.push(LLMMessage::system(prompt.clone()));
983 }
984
985 let conversation = self.conversation.read().await;
987 llm_messages.extend(to_llm_messages(&conversation));
988 drop(conversation);
989
990 match request.request_type {
992 LLMRequestType::UserMessage => {
993 if !request.content.is_empty() {
994 llm_messages.push(LLMMessage::user(&request.content));
995
996 let user_msg = Message::User(UserMessage {
998 id: format!("user_{}", self.request_count.load(Ordering::SeqCst)),
999 session_id: session_id.to_string(),
1000 turn_id: effective_turn_id.clone(),
1001 created_at: Self::current_timestamp_millis(),
1002 content: vec![ContentBlock::text(&request.content)],
1003 });
1004 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1005 }
1006 }
1007 LLMRequestType::ToolResult => {
1008 self.store_compact_summaries(&request.compact_summaries).await;
1010
1011 {
1013 let conv = self.conversation.read().await;
1014 tracing::debug!(
1015 session_id,
1016 conversation_len = conv.len(),
1017 tool_result_count = request.tool_results.len(),
1018 "STREAMING ToolResult: conversation state before adding results"
1019 );
1020 }
1021 for tool_result in &request.tool_results {
1023 llm_messages.push(LLMMessage::tool_result(
1024 &tool_result.tool_use_id,
1025 &tool_result.content,
1026 tool_result.is_error,
1027 ));
1028
1029 let compact_summary = request
1031 .compact_summaries
1032 .get(&tool_result.tool_use_id)
1033 .cloned();
1034
1035 let user_msg = Message::User(UserMessage {
1037 id: format!("tool_result_{}", self.request_count.load(Ordering::SeqCst)),
1038 session_id: session_id.to_string(),
1039 turn_id: effective_turn_id.clone(),
1040 created_at: Self::current_timestamp_millis(),
1041 content: vec![ContentBlock::ToolResult(crate::controller::types::ToolResultBlock {
1042 tool_use_id: tool_result.tool_use_id.clone(),
1043 content: tool_result.content.clone(),
1044 is_error: tool_result.is_error,
1045 compact_summary,
1046 })],
1047 });
1048 Arc::make_mut(&mut *self.conversation.write().await).push(user_msg);
1049 }
1050 }
1051 }
1052
1053 self.maybe_compact().await;
1055
1056 let options = self.build_message_options().await;
1058
1059 let stream_result = self
1061 .client
1062 .send_message_stream(&llm_messages, &options)
1063 .await;
1064
1065 match stream_result {
1066 Ok(mut stream) => {
1067 let mut current_tool_id: Option<String> = None;
1069 let mut current_tool_name: Option<String> = None;
1070 let mut tool_input_json = String::new();
1071 let mut response_text = String::new();
1073 let mut completed_tool_uses: Vec<crate::controller::types::ToolUseBlock> = Vec::new();
1075
1076 loop {
1078 tokio::select! {
1079 _ = request_token.cancelled() => {
1080 tracing::info!(session_id, "Streaming request cancelled");
1081 break;
1082 }
1083 event = stream.next() => {
1084 match event {
1085 Some(Ok(stream_event)) => {
1086 match stream_event {
1087 StreamEvent::MessageStart { message_id, model } => {
1088 let payload = FromLLMPayload {
1089 session_id,
1090 response_type: LLMResponseType::StreamStart,
1091 message_id,
1092 model,
1093 turn_id: request.turn_id.clone(),
1094 ..Default::default()
1095 };
1096 let _ = self.from_llm.send(payload).await;
1097 }
1098 StreamEvent::ContentBlockStart { index: _, block_type } => {
1099 match block_type {
1100 ContentBlockType::Text => {
1101 }
1103 ContentBlockType::ToolUse { id, name } => {
1104 current_tool_id = Some(id);
1107 current_tool_name = Some(name);
1108 tool_input_json.clear();
1109 }
1110 }
1111 }
1112 StreamEvent::TextDelta { index, text } => {
1113 response_text.push_str(&text);
1115
1116 let payload = FromLLMPayload {
1117 session_id,
1118 response_type: LLMResponseType::TextChunk,
1119 text,
1120 content_index: index,
1121 turn_id: request.turn_id.clone(),
1122 ..Default::default()
1123 };
1124 let _ = self.from_llm.send(payload).await;
1125 }
1126 StreamEvent::InputJsonDelta { index, json } => {
1127 tool_input_json.push_str(&json);
1129
1130 let payload = FromLLMPayload {
1131 session_id,
1132 response_type: LLMResponseType::ToolInputDelta,
1133 text: json,
1134 content_index: index,
1135 turn_id: request.turn_id.clone(),
1136 ..Default::default()
1137 };
1138 let _ = self.from_llm.send(payload).await;
1139 }
1140 StreamEvent::ContentBlockStop { index: _ } => {
1141 if let (Some(id), Some(name)) =
1145 (current_tool_id.take(), current_tool_name.take())
1146 {
1147 let input: serde_json::Value =
1148 serde_json::from_str(&tool_input_json)
1149 .unwrap_or(serde_json::Value::Object(
1150 serde_json::Map::new(),
1151 ));
1152
1153 tracing::debug!(
1155 session_id,
1156 tool_id = %id,
1157 tool_name = %name,
1158 "Saving tool use to completed_tool_uses"
1159 );
1160 completed_tool_uses.push(crate::controller::types::ToolUseBlock {
1161 id: id.clone(),
1162 name: name.clone(),
1163 input: input
1164 .as_object()
1165 .map(|obj| {
1166 obj.iter()
1167 .map(|(k, v)| (k.clone(), v.clone()))
1168 .collect()
1169 })
1170 .unwrap_or_default(),
1171 });
1172
1173 tool_input_json.clear();
1174 }
1175 }
1176 StreamEvent::MessageDelta { stop_reason, usage } => {
1177 if let Some(usage) = usage {
1178 tracing::info!(
1179 session_id,
1180 input_tokens = usage.input_tokens,
1181 output_tokens = usage.output_tokens,
1182 "API token usage for this turn"
1183 );
1184 self.current_input_tokens
1185 .store(usage.input_tokens as i64, Ordering::SeqCst);
1186 self.current_output_tokens
1187 .store(usage.output_tokens as i64, Ordering::SeqCst);
1188
1189 let payload = FromLLMPayload {
1190 session_id,
1191 response_type: LLMResponseType::TokenUpdate,
1192 input_tokens: usage.input_tokens as i64,
1193 output_tokens: usage.output_tokens as i64,
1194 turn_id: request.turn_id.clone(),
1195 ..Default::default()
1196 };
1197 let _ = self.from_llm.send(payload).await;
1198 }
1199
1200 if stop_reason.is_some() {
1201 let payload = FromLLMPayload {
1202 session_id,
1203 response_type: LLMResponseType::Complete,
1204 is_complete: true,
1205 stop_reason,
1206 turn_id: request.turn_id.clone(),
1207 ..Default::default()
1208 };
1209 let _ = self.from_llm.send(payload).await;
1210 }
1211 }
1212 StreamEvent::MessageStop => {
1213 tracing::debug!(
1216 session_id,
1217 text_len = response_text.len(),
1218 tool_use_count = completed_tool_uses.len(),
1219 "MessageStop: saving assistant message to history"
1220 );
1221 if !response_text.is_empty() || !completed_tool_uses.is_empty() {
1222 let now = Self::current_timestamp_millis();
1223
1224 let mut content_blocks = Vec::new();
1226 if !response_text.is_empty() {
1227 content_blocks.push(ContentBlock::text(&response_text));
1228 }
1229 for tool_use in &completed_tool_uses {
1230 content_blocks.push(ContentBlock::ToolUse(tool_use.clone()));
1231 }
1232
1233 let content_block_count = content_blocks.len();
1234 let asst_msg = Message::Assistant(AssistantMessage {
1235 id: format!("asst_{}", self.request_count.load(Ordering::SeqCst)),
1236 session_id: session_id.to_string(),
1237 turn_id: effective_turn_id.clone(),
1238 parent_id: String::new(),
1239 created_at: now,
1240 completed_at: Some(now),
1241 model_id: self.config.model.clone(),
1242 provider_id: String::new(),
1243 input_tokens: self.current_input_tokens.load(Ordering::SeqCst),
1244 output_tokens: self.current_output_tokens.load(Ordering::SeqCst),
1245 cache_read_tokens: 0,
1246 cache_write_tokens: 0,
1247 finish_reason: None,
1248 error: None,
1249 content: content_blocks,
1250 });
1251 Arc::make_mut(&mut *self.conversation.write().await).push(asst_msg);
1252 tracing::debug!(
1253 session_id,
1254 content_block_count,
1255 "MessageStop: saved assistant message with content blocks"
1256 );
1257 }
1258
1259 if !completed_tool_uses.is_empty() {
1262 let tool_uses: Vec<crate::controller::types::ToolUseInfo> = completed_tool_uses
1263 .iter()
1264 .map(|tu| crate::controller::types::ToolUseInfo {
1265 id: tu.id.clone(),
1266 name: tu.name.clone(),
1267 input: serde_json::Value::Object(
1268 tu.input.iter()
1269 .map(|(k, v)| (k.clone(), v.clone()))
1270 .collect()
1271 ),
1272 })
1273 .collect();
1274
1275 tracing::debug!(
1276 session_id,
1277 tool_count = tool_uses.len(),
1278 "MessageStop: emitting ToolBatch for execution"
1279 );
1280
1281 let payload = FromLLMPayload {
1282 session_id,
1283 response_type: LLMResponseType::ToolBatch,
1284 tool_uses,
1285 turn_id: request.turn_id.clone(),
1286 ..Default::default()
1287 };
1288 let _ = self.from_llm.send(payload).await;
1289 }
1290
1291 self.request_count.fetch_add(1, Ordering::SeqCst);
1293 tracing::debug!(session_id, "Streaming request completed");
1294 break;
1295 }
1296 StreamEvent::Ping => {
1297 }
1299 }
1300 }
1301 Some(Err(err)) => {
1302 tracing::error!(session_id, error = %err, "Stream error");
1303 let payload = FromLLMPayload {
1304 session_id,
1305 response_type: LLMResponseType::Error,
1306 error: Some(err.to_string()),
1307 turn_id: request.turn_id.clone(),
1308 ..Default::default()
1309 };
1310 let _ = self.from_llm.send(payload).await;
1311 break;
1312 }
1313 None => {
1314 break;
1316 }
1317 }
1318 }
1319 }
1320 }
1321 }
1322 Err(err) => {
1323 tracing::error!(session_id, error = %err, "Failed to start streaming");
1324 let payload = FromLLMPayload {
1325 session_id,
1326 response_type: LLMResponseType::Error,
1327 error: Some(err.to_string()),
1328 turn_id: request.turn_id,
1329 ..Default::default()
1330 };
1331 let _ = self.from_llm.send(payload).await;
1332 }
1333 }
1334
1335 self.cleanup_request().await;
1337 }
1338}