1use std::sync::Arc;
35
36use futures::StreamExt;
37use tokio::sync::{Mutex, Notify, RwLock, broadcast, mpsc};
38use tokio_util::sync::CancellationToken;
39use tracing::{debug, info, warn};
40
41#[cfg(feature = "memory")]
42use adk_core::Memory;
43use adk_core::{Agent, Content, Event, Part};
44use adk_runner::Runner;
45use adk_session::service::SessionService;
46
47use crate::checkpoint::{CheckpointManager, RunState};
48use crate::event_mapping::{RunnerOutput, custom_tool_use_id, map_runner_output, requires_parking};
49use crate::parking::ToolParkingLot;
50use crate::sequence::SequenceCounter;
51use crate::types::{
52 ContentBlock, RuntimeError, SessionEvent, SessionStatus, StopReason, UserEvent,
53};
54use crate::usage::{SessionUsageTracker, UsageReport};
55
56pub struct SessionLoop {
91 session_id: String,
93 event_rx: mpsc::Receiver<UserEvent>,
95 event_tx: broadcast::Sender<SessionEvent>,
97 seq: SequenceCounter,
99 parking: Arc<ToolParkingLot>,
101 checkpoint: Arc<RwLock<CheckpointManager>>,
103 cancel_token: CancellationToken,
105 pause_flag: Arc<Mutex<bool>>,
107 pause_notify: Arc<Notify>,
109 status: SessionStatus,
111 agent: Arc<dyn Agent>,
113 session_service: Arc<dyn SessionService>,
115 #[cfg(feature = "memory")]
117 memory: Option<Arc<dyn Memory>>,
118 usage_tracker: SessionUsageTracker,
120}
121
122impl SessionLoop {
123 pub fn new(
135 session_id: String,
136 event_rx: mpsc::Receiver<UserEvent>,
137 event_tx: broadcast::Sender<SessionEvent>,
138 parking: Arc<ToolParkingLot>,
139 cancel_token: CancellationToken,
140 agent: Arc<dyn Agent>,
141 session_service: Arc<dyn SessionService>,
142 ) -> Self {
143 let checkpoint = Arc::new(RwLock::new(CheckpointManager::new(session_id.clone())));
144 Self {
145 session_id,
146 event_rx,
147 event_tx,
148 seq: SequenceCounter::default(),
149 parking,
150 checkpoint,
151 cancel_token,
152 pause_flag: Arc::new(Mutex::new(false)),
153 pause_notify: Arc::new(Notify::new()),
154 status: SessionStatus::Queued,
155 agent,
156 session_service,
157 #[cfg(feature = "memory")]
158 memory: None,
159 usage_tracker: SessionUsageTracker::new(),
160 }
161 }
162
163 #[cfg(feature = "memory")]
169 #[allow(clippy::too_many_arguments)]
170 pub fn with_pause_controls(
171 session_id: String,
172 event_rx: mpsc::Receiver<UserEvent>,
173 event_tx: broadcast::Sender<SessionEvent>,
174 parking: Arc<ToolParkingLot>,
175 cancel_token: CancellationToken,
176 pause_flag: Arc<Mutex<bool>>,
177 pause_notify: Arc<Notify>,
178 checkpoint: Arc<RwLock<CheckpointManager>>,
179 agent: Arc<dyn Agent>,
180 session_service: Arc<dyn SessionService>,
181 memory: Option<Arc<dyn Memory>>,
182 ) -> Self {
183 Self {
184 session_id,
185 event_rx,
186 event_tx,
187 seq: SequenceCounter::default(),
188 parking,
189 checkpoint,
190 cancel_token,
191 pause_flag,
192 pause_notify,
193 status: SessionStatus::Queued,
194 agent,
195 session_service,
196 memory,
197 usage_tracker: SessionUsageTracker::new(),
198 }
199 }
200
201 #[cfg(not(feature = "memory"))]
205 #[allow(clippy::too_many_arguments)]
206 pub fn with_pause_controls(
207 session_id: String,
208 event_rx: mpsc::Receiver<UserEvent>,
209 event_tx: broadcast::Sender<SessionEvent>,
210 parking: Arc<ToolParkingLot>,
211 cancel_token: CancellationToken,
212 pause_flag: Arc<Mutex<bool>>,
213 pause_notify: Arc<Notify>,
214 checkpoint: Arc<RwLock<CheckpointManager>>,
215 agent: Arc<dyn Agent>,
216 session_service: Arc<dyn SessionService>,
217 ) -> Self {
218 Self {
219 session_id,
220 event_rx,
221 event_tx,
222 seq: SequenceCounter::default(),
223 parking,
224 checkpoint,
225 cancel_token,
226 pause_flag,
227 pause_notify,
228 status: SessionStatus::Queued,
229 agent,
230 session_service,
231 usage_tracker: SessionUsageTracker::new(),
232 }
233 }
234
235 pub fn pause_flag(&self) -> Arc<Mutex<bool>> {
237 Arc::clone(&self.pause_flag)
238 }
239
240 pub fn pause_notify(&self) -> Arc<Notify> {
242 Arc::clone(&self.pause_notify)
243 }
244
245 pub async fn run(mut self) -> Result<(), RuntimeError> {
255 info!(session_id = %self.session_id, "session loop started");
256
257 loop {
258 if self.cancel_token.is_cancelled() {
260 debug!(session_id = %self.session_id, "interrupt detected, shutting down");
261 self.emit_idle(Some(StopReason::EndTurn), None).await;
262 break;
263 }
264
265 self.check_pause().await;
267
268 let event = tokio::select! {
270 biased;
271 _ = self.cancel_token.cancelled() => {
272 debug!(session_id = %self.session_id, "interrupted while waiting for event");
273 self.emit_idle(Some(StopReason::EndTurn), None).await;
274 break;
275 }
276 ev = self.event_rx.recv() => {
277 match ev {
278 Some(event) => event,
279 None => {
280 debug!(session_id = %self.session_id, "event channel closed, shutting down");
281 break;
282 }
283 }
284 }
285 };
286
287 match event {
289 UserEvent::Message { content } => {
290 self.process_turn(content).await?;
291 }
292 UserEvent::Interrupt {} => {
293 debug!(session_id = %self.session_id, "user.interrupt received");
294 self.emit_idle(Some(StopReason::EndTurn), None).await;
295 break;
296 }
297 UserEvent::CustomToolResult { custom_tool_use_id, content } => {
298 debug!(
299 session_id = %self.session_id,
300 tool_use_id = %custom_tool_use_id,
301 "delivering custom tool result"
302 );
303 if let Err(e) = self.parking.deliver(&custom_tool_use_id, content).await {
304 warn!(
305 session_id = %self.session_id,
306 error = %e,
307 "failed to deliver custom tool result"
308 );
309 }
310 }
311 UserEvent::ToolConfirmation { tool_use_id, result, deny_message } => {
312 debug!(
313 session_id = %self.session_id,
314 tool_use_id = %tool_use_id,
315 result = ?result,
316 "tool confirmation received, delivering to parking lot"
317 );
318 let content = match result {
323 crate::types::ConfirmationResult::Allow => {
324 vec![ContentBlock::Text {
325 text: serde_json::json!({
326 "confirmation": "approved",
327 "tool_use_id": tool_use_id
328 })
329 .to_string(),
330 }]
331 }
332 crate::types::ConfirmationResult::Deny => {
333 let message = deny_message
334 .unwrap_or_else(|| "Tool execution denied by user".to_string());
335 vec![ContentBlock::Text {
336 text: serde_json::json!({
337 "confirmation": "denied",
338 "tool_use_id": tool_use_id,
339 "reason": message
340 })
341 .to_string(),
342 }]
343 }
344 };
345 if let Err(e) = self.parking.deliver(&tool_use_id, content).await {
346 warn!(
347 session_id = %self.session_id,
348 error = %e,
349 "failed to deliver tool confirmation"
350 );
351 }
352 }
353 UserEvent::ToolResult { tool_use_id, .. } => {
354 debug!(
355 session_id = %self.session_id,
356 tool_use_id = %tool_use_id,
357 "tool result received (self-hosted only, not yet wired)"
358 );
359 }
360 UserEvent::DefineOutcome { criteria } => {
361 debug!(
362 session_id = %self.session_id,
363 criteria = %criteria,
364 "outcome criteria defined"
365 );
366 }
368 }
369 }
370
371 info!(session_id = %self.session_id, "session loop exited");
372 Ok(())
373 }
374
375 async fn process_turn(&mut self, content: Vec<ContentBlock>) -> Result<(), RuntimeError> {
377 self.status = SessionStatus::Running;
379 let running_event = SessionEvent::StatusRunning { seq: self.seq.next() };
380 self.emit_event(running_event).await;
381
382 if self.check_interrupt() {
384 self.emit_idle(Some(StopReason::EndTurn), None).await;
385 return Ok(());
386 }
387
388 let user_content = self.build_user_content(&content);
390
391 let runner = self.build_runner()?;
393
394 let event_stream = runner
395 .run_str("managed_user", &self.session_id, user_content)
396 .await
397 .map_err(|e| RuntimeError::internal(format!("runner invocation failed: {e}")))?;
398
399 let mut turn_usage = UsageReport::default();
401 let mut custom_tool_ids = Vec::new();
402
403 futures::pin_mut!(event_stream);
404
405 while let Some(event_result) = event_stream.next().await {
406 if self.check_interrupt() {
408 self.emit_idle(Some(StopReason::EndTurn), None).await;
409 return Ok(());
410 }
411
412 match event_result {
413 Ok(event) => {
414 self.process_runner_event(&event, &mut turn_usage, &mut custom_tool_ids).await;
415 }
416 Err(e) => {
417 warn!(
418 session_id = %self.session_id,
419 error = %e,
420 "runner event stream error"
421 );
422 let error_event = SessionEvent::Error {
423 code: "runner_error".to_string(),
424 message: e.to_string(),
425 seq: self.seq.next(),
426 };
427 self.emit_event(error_event).await;
428 }
429 }
430 }
431
432 let turn_usage_report = if !turn_usage.is_empty() {
435 self.usage_tracker.record_turn(turn_usage.clone());
436 Some(turn_usage)
437 } else {
438 None
439 };
440
441 let stop_reason = if custom_tool_ids.is_empty() {
443 Some(StopReason::EndTurn)
444 } else {
445 Some(StopReason::RequiresAction { event_ids: custom_tool_ids })
446 };
447
448 self.emit_idle(stop_reason, turn_usage_report).await;
450
451 Ok(())
452 }
453
454 fn build_runner(&self) -> Result<Runner, RuntimeError> {
456 #[allow(unused_mut)]
457 let mut builder = Runner::builder()
458 .app_name("managed")
459 .agent(Arc::clone(&self.agent))
460 .session_service(Arc::clone(&self.session_service))
461 .cancellation_token(self.cancel_token.clone());
462
463 #[cfg(feature = "memory")]
464 if let Some(ref memory) = self.memory {
465 builder = builder.memory_service(Arc::clone(memory));
466 }
467
468 builder.build().map_err(|e| RuntimeError::internal(format!("failed to build runner: {e}")))
469 }
470
471 fn build_user_content(&self, blocks: &[ContentBlock]) -> Content {
473 let mut parts = Vec::new();
474 for block in blocks {
475 match block {
476 ContentBlock::Text { text } => {
477 parts.push(Part::Text { text: text.clone() });
478 }
479 ContentBlock::Image { source } => {
480 if let Some(url) = source.get("url").and_then(|v| v.as_str()) {
482 parts.push(Part::FileData {
483 mime_type: source
484 .get("media_type")
485 .and_then(|v| v.as_str())
486 .unwrap_or("image/png")
487 .to_string(),
488 file_uri: url.to_string(),
489 });
490 }
491 }
492 ContentBlock::File { file_id } => {
493 parts.push(Part::FileData {
494 mime_type: "application/octet-stream".to_string(),
495 file_uri: file_id.clone(),
496 });
497 }
498 }
499 }
500
501 Content { role: "user".to_string(), parts }
502 }
503
504 async fn process_runner_event(
506 &mut self,
507 event: &Event,
508 turn_usage: &mut UsageReport,
509 custom_tool_ids: &mut Vec<String>,
510 ) {
511 if let Some(ref usage_meta) = event.llm_response.usage_metadata {
513 let report = UsageReport::from_usage_metadata(usage_meta);
514 turn_usage.accumulate(&report);
515 }
516
517 if event.llm_response.partial {
519 return;
520 }
521
522 if let Some(ref content) = event.llm_response.content {
524 for part in &content.parts {
525 match part {
526 Part::Text { text } => {
527 if text.is_empty() {
528 continue;
529 }
530 let output = RunnerOutput::TextContent { text: text.clone() };
531 let session_event = map_runner_output(output, self.seq.next());
532 self.emit_event(session_event).await;
533 }
534 Part::FunctionCall { name, args, id, .. } => {
535 let tool_use_id =
536 id.clone().unwrap_or_else(|| format!("tu_{}", uuid::Uuid::new_v4()));
537
538 let tool_kind = self.classify_tool(name);
540
541 let output = match tool_kind {
542 ToolKind::Custom => {
543 let ctu_id = format!("ctu_{}", uuid::Uuid::new_v4());
544 custom_tool_ids.push(ctu_id.clone());
545 RunnerOutput::CustomToolCall {
546 custom_tool_use_id: ctu_id,
547 name: name.clone(),
548 input: args.clone(),
549 }
550 }
551 ToolKind::Builtin => RunnerOutput::BuiltinToolCall {
552 tool_use_id,
553 name: name.clone(),
554 input: args.clone(),
555 },
556 ToolKind::Mcp => RunnerOutput::McpToolCall {
557 tool_use_id,
558 name: name.clone(),
559 input: args.clone(),
560 },
561 };
562
563 let session_event = map_runner_output(output.clone(), self.seq.next());
564 self.emit_event(session_event).await;
565
566 if requires_parking(&output)
568 && let Some(ctu_id) = custom_tool_use_id(&output)
569 {
570 let ctu_id_owned = ctu_id.to_string();
571 debug!(
572 session_id = %self.session_id,
573 custom_tool_use_id = %ctu_id_owned,
574 "parking for custom tool result"
575 );
576 match self.parking.park(&ctu_id_owned).await {
577 Ok(_result_blocks) => {
578 debug!(
579 session_id = %self.session_id,
580 custom_tool_use_id = %ctu_id_owned,
581 "custom tool result delivered"
582 );
583 }
584 Err(e) => {
585 warn!(
586 session_id = %self.session_id,
587 error = %e,
588 "custom tool park failed or timed out"
589 );
590 }
591 }
592 }
593 }
594 _ => {}
596 }
597 }
598 }
599 }
600
601 fn classify_tool(&self, name: &str) -> ToolKind {
603 const BUILTIN_TOOLS: &[&str] =
605 &["bash", "filesystem", "web_search", "web_fetch", "code_execution"];
606
607 if BUILTIN_TOOLS.contains(&name) {
608 ToolKind::Builtin
609 } else if name.starts_with("mcp_") || name.contains("::") {
610 ToolKind::Mcp
611 } else {
612 ToolKind::Custom
614 }
615 }
616
617 async fn emit_event(&mut self, event: SessionEvent) {
619 let run_state =
621 RunState { seq: self.seq.current(), pending_tool_ids: Vec::new(), status: self.status };
622 self.checkpoint.write().await.checkpoint(event.clone(), run_state);
623
624 let _ = self.event_tx.send(event);
626 }
627
628 async fn emit_idle(&mut self, stop_reason: Option<StopReason>, usage: Option<UsageReport>) {
630 self.status = SessionStatus::Idle;
631 let idle_event = SessionEvent::StatusIdle { seq: self.seq.next(), stop_reason, usage };
632 self.emit_event(idle_event).await;
633 }
634
635 fn check_interrupt(&self) -> bool {
639 self.cancel_token.is_cancelled()
640 }
641
642 async fn check_pause(&self) {
644 loop {
645 let is_paused = *self.pause_flag.lock().await;
646 if !is_paused {
647 break;
648 }
649 debug!(session_id = %self.session_id, "session loop paused, waiting for resume");
650 self.pause_notify.notified().await;
651 }
652 }
653}
654
655use crate::event_mapping::ToolKind;
659
660#[cfg(test)]
661mod tests {
662 use std::time::Duration;
663
664 use super::*;
665 use adk_core::{FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream};
666 use async_stream::stream;
667 use async_trait::async_trait;
668
669 struct TestLlm {
671 response_text: String,
672 }
673
674 impl TestLlm {
675 fn new(text: &str) -> Self {
676 Self { response_text: text.to_string() }
677 }
678 }
679
680 #[async_trait]
681 impl Llm for TestLlm {
682 fn name(&self) -> &str {
683 "test-llm"
684 }
685
686 async fn generate_content(
687 &self,
688 _request: LlmRequest,
689 _stream: bool,
690 ) -> adk_core::Result<LlmResponseStream> {
691 let text = self.response_text.clone();
692 let s = stream! {
693 yield Ok(LlmResponse {
694 content: Some(Content::new("model").with_text(&text)),
695 partial: false,
696 turn_complete: true,
697 finish_reason: Some(FinishReason::Stop),
698 ..Default::default()
699 });
700 };
701 Ok(Box::pin(s))
702 }
703 }
704
705 fn build_test_agent(llm: impl Llm + 'static) -> Arc<dyn Agent> {
707 let agent =
708 adk_agent::LlmAgentBuilder::new("test-agent").model(Arc::new(llm)).build().unwrap();
709 Arc::new(agent)
710 }
711
712 fn create_test_loop()
714 -> (mpsc::Sender<UserEvent>, broadcast::Receiver<SessionEvent>, CancellationToken, SessionLoop)
715 {
716 let (event_tx, event_rx) = mpsc::channel(64);
717 let (broadcast_tx, broadcast_rx) = broadcast::channel(256);
718 let cancel = CancellationToken::new();
719 let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
720 let agent = build_test_agent(TestLlm::new("Hello from the agent"));
721 let session_service: Arc<dyn SessionService> =
722 Arc::new(adk_session::InMemorySessionService::new());
723
724 let session_loop = SessionLoop::new(
725 "test_session".to_string(),
726 event_rx,
727 broadcast_tx,
728 parking,
729 cancel.clone(),
730 agent,
731 session_service,
732 );
733
734 (event_tx, broadcast_rx, cancel, session_loop)
735 }
736
737 #[tokio::test]
738 async fn test_basic_message_flow() {
739 let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
740
741 let handle = tokio::spawn(session_loop.run());
742
743 event_tx
745 .send(UserEvent::Message {
746 content: vec![ContentBlock::Text { text: "Hello".to_string() }],
747 })
748 .await
749 .unwrap();
750
751 let ev1 = broadcast_rx.recv().await.unwrap();
753 match ev1 {
754 SessionEvent::StatusRunning { seq } => assert_eq!(seq, 0),
755 other => panic!("expected StatusRunning, got: {other:?}"),
756 }
757
758 let mut got_message = false;
760 let mut got_idle = false;
761 for _ in 0..10 {
762 match tokio::time::timeout(Duration::from_secs(5), broadcast_rx.recv()).await {
763 Ok(Ok(SessionEvent::Message { content, .. })) => {
764 assert!(!content.is_empty());
765 got_message = true;
766 }
767 Ok(Ok(SessionEvent::StatusIdle { stop_reason, .. })) => {
768 assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
769 got_idle = true;
770 break;
771 }
772 Ok(Ok(SessionEvent::Error { message, .. })) => {
773 debug!("got error event: {message}");
775 }
776 Ok(Ok(other)) => {
777 debug!("got other event: {other:?}");
778 }
779 Ok(Err(_)) => break,
780 Err(_) => break,
781 }
782 }
783
784 assert!(got_idle, "expected StatusIdle event");
786
787 drop(event_tx);
789 let result = handle.await.unwrap();
790 assert!(result.is_ok());
791
792 let _ = got_message;
797 }
798
799 #[tokio::test]
800 async fn test_seq_monotonically_increases() {
801 let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
802
803 let handle = tokio::spawn(session_loop.run());
804
805 event_tx
807 .send(UserEvent::Message {
808 content: vec![ContentBlock::Text { text: "First".to_string() }],
809 })
810 .await
811 .unwrap();
812
813 let mut seqs = Vec::new();
815 for _ in 0..10 {
816 match tokio::time::timeout(Duration::from_secs(5), broadcast_rx.recv()).await {
817 Ok(Ok(ev)) => {
818 let seq = match &ev {
819 SessionEvent::StatusRunning { seq } => *seq,
820 SessionEvent::Message { seq, .. } => *seq,
821 SessionEvent::StatusIdle { seq, .. } => *seq,
822 SessionEvent::ToolUse { seq, .. } => *seq,
823 SessionEvent::CustomToolUse { seq, .. } => *seq,
824 SessionEvent::McpToolUse { seq, .. } => *seq,
825 SessionEvent::Error { seq, .. } => *seq,
826 };
827 seqs.push(seq);
828 if matches!(ev, SessionEvent::StatusIdle { .. }) {
829 break;
830 }
831 }
832 _ => break,
833 }
834 }
835
836 assert!(seqs.len() >= 2, "expected at least 2 events");
838 for window in seqs.windows(2) {
839 assert!(
840 window[1] > window[0],
841 "seq must be strictly increasing: {} should be > {}",
842 window[1],
843 window[0]
844 );
845 }
846
847 drop(event_tx);
848 handle.await.unwrap().unwrap();
849 }
850
851 #[tokio::test]
852 async fn test_interrupt_stops_loop() {
853 let (event_tx, mut broadcast_rx, cancel, session_loop) = create_test_loop();
854
855 let handle = tokio::spawn(session_loop.run());
856
857 tokio::time::sleep(Duration::from_millis(10)).await;
859
860 cancel.cancel();
862
863 let ev = broadcast_rx.recv().await.unwrap();
865 match ev {
866 SessionEvent::StatusIdle { stop_reason, .. } => {
867 assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
868 }
869 other => panic!("expected StatusIdle on interrupt, got: {other:?}"),
870 }
871
872 let result = handle.await.unwrap();
874 assert!(result.is_ok());
875
876 drop(event_tx);
877 }
878
879 #[tokio::test]
880 async fn test_user_interrupt_event_stops_loop() {
881 let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
882
883 let handle = tokio::spawn(session_loop.run());
884
885 event_tx.send(UserEvent::Interrupt {}).await.unwrap();
887
888 let ev = broadcast_rx.recv().await.unwrap();
890 match ev {
891 SessionEvent::StatusIdle { stop_reason, .. } => {
892 assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
893 }
894 other => panic!("expected StatusIdle, got: {other:?}"),
895 }
896
897 let result = handle.await.unwrap();
898 assert!(result.is_ok());
899
900 drop(event_tx);
901 }
902
903 #[tokio::test]
904 async fn test_pause_and_resume() {
905 let (event_tx, event_rx) = mpsc::channel(64);
906 let (broadcast_tx, mut broadcast_rx) = broadcast::channel(256);
907 let cancel = CancellationToken::new();
908 let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
909 let pause_flag = Arc::new(Mutex::new(false));
910 let pause_notify = Arc::new(Notify::new());
911 let agent = build_test_agent(TestLlm::new("resumed response"));
912 let session_service: Arc<dyn SessionService> =
913 Arc::new(adk_session::InMemorySessionService::new());
914
915 #[cfg(feature = "memory")]
916 let session_loop = SessionLoop::with_pause_controls(
917 "pause_test".to_string(),
918 event_rx,
919 broadcast_tx,
920 parking,
921 cancel.clone(),
922 Arc::clone(&pause_flag),
923 Arc::clone(&pause_notify),
924 Arc::new(RwLock::new(CheckpointManager::new("pause_test".to_string()))),
925 agent,
926 session_service,
927 None,
928 );
929 #[cfg(not(feature = "memory"))]
930 let session_loop = SessionLoop::with_pause_controls(
931 "pause_test".to_string(),
932 event_rx,
933 broadcast_tx,
934 parking,
935 cancel.clone(),
936 Arc::clone(&pause_flag),
937 Arc::clone(&pause_notify),
938 Arc::new(RwLock::new(CheckpointManager::new("pause_test".to_string()))),
939 agent,
940 session_service,
941 );
942
943 let handle = tokio::spawn(session_loop.run());
944
945 *pause_flag.lock().await = true;
947
948 event_tx
950 .send(UserEvent::Message {
951 content: vec![ContentBlock::Text { text: "While paused".to_string() }],
952 })
953 .await
954 .unwrap();
955
956 tokio::time::sleep(Duration::from_millis(50)).await;
958
959 assert!(broadcast_rx.try_recv().is_err());
961
962 *pause_flag.lock().await = false;
964 pause_notify.notify_one();
965
966 let ev1 = tokio::time::timeout(Duration::from_secs(2), broadcast_rx.recv())
968 .await
969 .expect("timed out waiting for event after resume")
970 .unwrap();
971
972 match ev1 {
973 SessionEvent::StatusRunning { .. } => {}
974 other => panic!("expected StatusRunning after resume, got: {other:?}"),
975 }
976
977 drop(event_tx);
979 handle.await.unwrap().unwrap();
980 }
981
982 #[tokio::test]
983 async fn test_channel_close_stops_loop() {
984 let (event_tx, event_rx) = mpsc::channel(64);
985 let (broadcast_tx, _broadcast_rx) = broadcast::channel(256);
986 let cancel = CancellationToken::new();
987 let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
988 let agent = build_test_agent(TestLlm::new("test"));
989 let session_service: Arc<dyn SessionService> =
990 Arc::new(adk_session::InMemorySessionService::new());
991
992 let session_loop = SessionLoop::new(
993 "close_test".to_string(),
994 event_rx,
995 broadcast_tx,
996 parking,
997 cancel,
998 agent,
999 session_service,
1000 );
1001
1002 let handle = tokio::spawn(session_loop.run());
1003
1004 drop(event_tx);
1006
1007 let result = handle.await.unwrap();
1009 assert!(result.is_ok());
1010 }
1011
1012 #[tokio::test]
1013 async fn test_custom_tool_result_delivery() {
1014 let (event_tx, event_rx) = mpsc::channel(64);
1015 let (broadcast_tx, _broadcast_rx) = broadcast::channel(256);
1016 let cancel = CancellationToken::new();
1017 let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
1018 let parking_clone = Arc::clone(&parking);
1019 let agent = build_test_agent(TestLlm::new("test"));
1020 let session_service: Arc<dyn SessionService> =
1021 Arc::new(adk_session::InMemorySessionService::new());
1022
1023 let session_loop = SessionLoop::new(
1024 "parking_test".to_string(),
1025 event_rx,
1026 broadcast_tx,
1027 parking_clone,
1028 cancel,
1029 agent,
1030 session_service,
1031 );
1032
1033 let handle = tokio::spawn(session_loop.run());
1034
1035 let parking_for_park = Arc::clone(&parking);
1037 let park_handle = tokio::spawn(async move { parking_for_park.park("ctu_test_001").await });
1038
1039 tokio::time::sleep(Duration::from_millis(10)).await;
1041
1042 event_tx
1044 .send(UserEvent::CustomToolResult {
1045 custom_tool_use_id: "ctu_test_001".to_string(),
1046 content: vec![ContentBlock::Text { text: "tool output".to_string() }],
1047 })
1048 .await
1049 .unwrap();
1050
1051 let result = tokio::time::timeout(Duration::from_secs(2), park_handle)
1053 .await
1054 .expect("park timed out")
1055 .unwrap()
1056 .unwrap();
1057
1058 assert_eq!(result.len(), 1);
1059 match &result[0] {
1060 ContentBlock::Text { text } => assert_eq!(text, "tool output"),
1061 _ => panic!("expected Text"),
1062 }
1063
1064 drop(event_tx);
1066 handle.await.unwrap().unwrap();
1067 }
1068
1069 #[tokio::test]
1070 async fn test_tool_classification() {
1071 let (event_tx, event_rx) = mpsc::channel(64);
1072 let (broadcast_tx, _) = broadcast::channel(256);
1073 let cancel = CancellationToken::new();
1074 let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
1075 let agent = build_test_agent(TestLlm::new("test"));
1076 let session_service: Arc<dyn SessionService> =
1077 Arc::new(adk_session::InMemorySessionService::new());
1078
1079 let session_loop = SessionLoop::new(
1080 "classify_test".to_string(),
1081 event_rx,
1082 broadcast_tx,
1083 parking,
1084 cancel,
1085 agent,
1086 session_service,
1087 );
1088
1089 assert!(matches!(session_loop.classify_tool("bash"), ToolKind::Builtin));
1091 assert!(matches!(session_loop.classify_tool("filesystem"), ToolKind::Builtin));
1092 assert!(matches!(session_loop.classify_tool("web_search"), ToolKind::Builtin));
1093 assert!(matches!(session_loop.classify_tool("web_fetch"), ToolKind::Builtin));
1094 assert!(matches!(session_loop.classify_tool("code_execution"), ToolKind::Builtin));
1095
1096 assert!(matches!(session_loop.classify_tool("mcp_file_read"), ToolKind::Mcp));
1098 assert!(matches!(session_loop.classify_tool("server::tool"), ToolKind::Mcp));
1099
1100 assert!(matches!(session_loop.classify_tool("get_weather"), ToolKind::Custom));
1102 assert!(matches!(session_loop.classify_tool("deploy"), ToolKind::Custom));
1103
1104 drop(event_tx);
1105 }
1106}