Skip to main content

adk_managed/
session_loop.rs

1//! Supervised session loop for the managed agent runtime.
2//!
3//! The [`SessionLoop`] is the core execution engine. It runs as a
4//! `tokio::spawn`ed background task, dequeues [`UserEvent`]s from an
5//! mpsc channel, processes each turn, and broadcasts [`SessionEvent`]s
6//! to stream subscribers.
7//!
8//! # Architecture
9//!
10//! The loop composes:
11//! - [`SequenceCounter`] — assigns monotonically increasing `seq` to each event
12//! - [`ToolParkingLot`] — parks on `custom_tool_use` until client delivers a result
13//! - [`CheckpointManager`] — atomic checkpoint after each event
14//! - `tokio::broadcast` — fan-out to stream subscribers
15//! - [`Runner`] — drives the agent through the real LLM
16//! - [`SessionUsageTracker`] — tracks per-turn and cumulative token usage
17//!
18//! # Control Flow
19//!
20//! ```text
21//! Dequeue UserEvent → emit status.running → invoke Runner
22//!   → for each output event: classify, map, assign seq, checkpoint, broadcast
23//!   → if custom tool call: park, wait for result, resume
24//!   → track usage → emit status.idle → loop
25//! ```
26//!
27//! # Interrupt and Pause
28//!
29//! - **Interrupt**: A [`CancellationToken`] signals the loop to stop at the next
30//!   boundary. On interrupt, the loop emits `status.idle` and exits.
31//! - **Pause/Resume**: A pause flag + [`Notify`] allow the loop to park until
32//!   resumed.
33
34use std::sync::Arc;
35
36use futures::StreamExt;
37use tokio::sync::{Mutex, Notify, RwLock, broadcast, mpsc};
38use tokio_util::sync::CancellationToken;
39use tracing::{debug, info, warn};
40
41#[cfg(feature = "memory")]
42use adk_core::Memory;
43use adk_core::{Agent, Content, Event, Part};
44use adk_runner::Runner;
45use adk_session::service::SessionService;
46
47use crate::checkpoint::{CheckpointManager, RunState};
48use crate::event_mapping::{RunnerOutput, custom_tool_use_id, map_runner_output, requires_parking};
49use crate::parking::ToolParkingLot;
50use crate::sequence::SequenceCounter;
51use crate::types::{
52    ContentBlock, RuntimeError, SessionEvent, SessionStatus, StopReason, UserEvent,
53};
54use crate::usage::{SessionUsageTracker, UsageReport};
55
56/// Supervised session loop — one per active session.
57///
58/// Runs as a background `tokio::spawn`ed task. Receives user events via an
59/// mpsc channel, processes each turn through the real Runner, and broadcasts
60/// session events via a `tokio::broadcast` channel.
61///
62/// # Example
63///
64/// ```rust,ignore
65/// use std::sync::Arc;
66/// use std::time::Duration;
67/// use tokio::sync::{broadcast, mpsc, Mutex, Notify};
68/// use tokio_util::sync::CancellationToken;
69/// use adk_managed::session_loop::SessionLoop;
70/// use adk_managed::parking::ToolParkingLot;
71///
72/// let (event_tx, event_rx) = mpsc::channel(64);
73/// let (broadcast_tx, _) = broadcast::channel(256);
74/// let cancel = CancellationToken::new();
75/// let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(300)));
76///
77/// let loop_handle = SessionLoop::new(
78///     "session_001".to_string(),
79///     event_rx,
80///     broadcast_tx,
81///     parking,
82///     cancel.clone(),
83///     agent,
84///     session_service,
85/// );
86///
87/// let handle = tokio::spawn(loop_handle.run());
88/// // Send events via event_tx...
89/// ```
90pub struct SessionLoop {
91    /// Session identifier.
92    session_id: String,
93    /// Input channel for user events.
94    event_rx: mpsc::Receiver<UserEvent>,
95    /// Broadcast channel for session events (fan-out to subscribers).
96    event_tx: broadcast::Sender<SessionEvent>,
97    /// Monotonic sequence counter.
98    seq: SequenceCounter,
99    /// Custom tool parking lot.
100    parking: Arc<ToolParkingLot>,
101    /// Checkpoint manager for durable state (shared with ActiveSession for replay).
102    checkpoint: Arc<RwLock<CheckpointManager>>,
103    /// Cancellation token for interrupt handling.
104    cancel_token: CancellationToken,
105    /// Pause flag — when true, the loop parks until resumed.
106    pause_flag: Arc<Mutex<bool>>,
107    /// Notify used to wake the loop after resume.
108    pause_notify: Arc<Notify>,
109    /// Current session status.
110    status: SessionStatus,
111    /// The agent driving this session.
112    agent: Arc<dyn Agent>,
113    /// Session persistence backend (needed by the Runner).
114    session_service: Arc<dyn SessionService>,
115    /// Optional memory service for cross-session RAG injection.
116    #[cfg(feature = "memory")]
117    memory: Option<Arc<dyn Memory>>,
118    /// Accumulated usage tracking across all turns.
119    usage_tracker: SessionUsageTracker,
120}
121
122impl SessionLoop {
123    /// Create a new session loop.
124    ///
125    /// # Arguments
126    ///
127    /// * `session_id` - The session this loop operates on.
128    /// * `event_rx` - Receiver for incoming user events.
129    /// * `event_tx` - Broadcast sender for outgoing session events.
130    /// * `parking` - Shared parking lot for custom tool calls.
131    /// * `cancel_token` - Token to signal interrupt/shutdown.
132    /// * `agent` - The built agent to drive through the Runner.
133    /// * `session_service` - Session persistence for the Runner.
134    pub fn new(
135        session_id: String,
136        event_rx: mpsc::Receiver<UserEvent>,
137        event_tx: broadcast::Sender<SessionEvent>,
138        parking: Arc<ToolParkingLot>,
139        cancel_token: CancellationToken,
140        agent: Arc<dyn Agent>,
141        session_service: Arc<dyn SessionService>,
142    ) -> Self {
143        let checkpoint = Arc::new(RwLock::new(CheckpointManager::new(session_id.clone())));
144        Self {
145            session_id,
146            event_rx,
147            event_tx,
148            seq: SequenceCounter::default(),
149            parking,
150            checkpoint,
151            cancel_token,
152            pause_flag: Arc::new(Mutex::new(false)),
153            pause_notify: Arc::new(Notify::new()),
154            status: SessionStatus::Queued,
155            agent,
156            session_service,
157            #[cfg(feature = "memory")]
158            memory: None,
159            usage_tracker: SessionUsageTracker::new(),
160        }
161    }
162
163    /// Create a session loop with custom pause controls (for external pause/resume).
164    ///
165    /// This allows the runtime to share the pause flag, notify, and checkpoint
166    /// with the session handle so that `pause()`, `resume()`, and `stream_events()`
167    /// (replay) work correctly against the same state the loop writes to.
168    #[cfg(feature = "memory")]
169    #[allow(clippy::too_many_arguments)]
170    pub fn with_pause_controls(
171        session_id: String,
172        event_rx: mpsc::Receiver<UserEvent>,
173        event_tx: broadcast::Sender<SessionEvent>,
174        parking: Arc<ToolParkingLot>,
175        cancel_token: CancellationToken,
176        pause_flag: Arc<Mutex<bool>>,
177        pause_notify: Arc<Notify>,
178        checkpoint: Arc<RwLock<CheckpointManager>>,
179        agent: Arc<dyn Agent>,
180        session_service: Arc<dyn SessionService>,
181        memory: Option<Arc<dyn Memory>>,
182    ) -> Self {
183        Self {
184            session_id,
185            event_rx,
186            event_tx,
187            seq: SequenceCounter::default(),
188            parking,
189            checkpoint,
190            cancel_token,
191            pause_flag,
192            pause_notify,
193            status: SessionStatus::Queued,
194            agent,
195            session_service,
196            memory,
197            usage_tracker: SessionUsageTracker::new(),
198        }
199    }
200
201    /// Create a session loop with custom pause controls (for external pause/resume).
202    ///
203    /// See the `memory`-enabled variant for full documentation.
204    #[cfg(not(feature = "memory"))]
205    #[allow(clippy::too_many_arguments)]
206    pub fn with_pause_controls(
207        session_id: String,
208        event_rx: mpsc::Receiver<UserEvent>,
209        event_tx: broadcast::Sender<SessionEvent>,
210        parking: Arc<ToolParkingLot>,
211        cancel_token: CancellationToken,
212        pause_flag: Arc<Mutex<bool>>,
213        pause_notify: Arc<Notify>,
214        checkpoint: Arc<RwLock<CheckpointManager>>,
215        agent: Arc<dyn Agent>,
216        session_service: Arc<dyn SessionService>,
217    ) -> Self {
218        Self {
219            session_id,
220            event_rx,
221            event_tx,
222            seq: SequenceCounter::default(),
223            parking,
224            checkpoint,
225            cancel_token,
226            pause_flag,
227            pause_notify,
228            status: SessionStatus::Queued,
229            agent,
230            session_service,
231            usage_tracker: SessionUsageTracker::new(),
232        }
233    }
234
235    /// Get a clone of the pause flag for external control.
236    pub fn pause_flag(&self) -> Arc<Mutex<bool>> {
237        Arc::clone(&self.pause_flag)
238    }
239
240    /// Get a clone of the pause notify for external control.
241    pub fn pause_notify(&self) -> Arc<Notify> {
242        Arc::clone(&self.pause_notify)
243    }
244
245    /// Run the session loop (consumes self).
246    ///
247    /// This is the main loop body, designed to be `tokio::spawn`ed. It runs
248    /// until the input channel is closed or the cancellation token is triggered.
249    ///
250    /// # Returns
251    ///
252    /// Returns `Ok(())` on graceful shutdown, or `Err(RuntimeError)` if an
253    /// unrecoverable error occurs.
254    pub async fn run(mut self) -> Result<(), RuntimeError> {
255        info!(session_id = %self.session_id, "session loop started");
256
257        loop {
258            // Check for interrupt before waiting for the next event.
259            if self.cancel_token.is_cancelled() {
260                debug!(session_id = %self.session_id, "interrupt detected, shutting down");
261                self.emit_idle(Some(StopReason::EndTurn), None).await;
262                break;
263            }
264
265            // Check for pause.
266            self.check_pause().await;
267
268            // Wait for next event or cancellation.
269            let event = tokio::select! {
270                biased;
271                _ = self.cancel_token.cancelled() => {
272                    debug!(session_id = %self.session_id, "interrupted while waiting for event");
273                    self.emit_idle(Some(StopReason::EndTurn), None).await;
274                    break;
275                }
276                ev = self.event_rx.recv() => {
277                    match ev {
278                        Some(event) => event,
279                        None => {
280                            debug!(session_id = %self.session_id, "event channel closed, shutting down");
281                            break;
282                        }
283                    }
284                }
285            };
286
287            // Dispatch based on event type.
288            match event {
289                UserEvent::Message { content } => {
290                    self.process_turn(content).await?;
291                }
292                UserEvent::Interrupt {} => {
293                    debug!(session_id = %self.session_id, "user.interrupt received");
294                    self.emit_idle(Some(StopReason::EndTurn), None).await;
295                    break;
296                }
297                UserEvent::CustomToolResult { custom_tool_use_id, content } => {
298                    debug!(
299                        session_id = %self.session_id,
300                        tool_use_id = %custom_tool_use_id,
301                        "delivering custom tool result"
302                    );
303                    if let Err(e) = self.parking.deliver(&custom_tool_use_id, content).await {
304                        warn!(
305                            session_id = %self.session_id,
306                            error = %e,
307                            "failed to deliver custom tool result"
308                        );
309                    }
310                }
311                UserEvent::ToolConfirmation { tool_use_id, result, deny_message } => {
312                    debug!(
313                        session_id = %self.session_id,
314                        tool_use_id = %tool_use_id,
315                        result = ?result,
316                        "tool confirmation received, delivering to parking lot"
317                    );
318                    // Tool confirmation decisions are delivered via the parking lot.
319                    // The session loop parks on tool_use_id when a confirmation is
320                    // required (emitted as RequiresAction). The client sends back
321                    // Allow/Deny which we convert to a ContentBlock result.
322                    let content = match result {
323                        crate::types::ConfirmationResult::Allow => {
324                            vec![ContentBlock::Text {
325                                text: serde_json::json!({
326                                    "confirmation": "approved",
327                                    "tool_use_id": tool_use_id
328                                })
329                                .to_string(),
330                            }]
331                        }
332                        crate::types::ConfirmationResult::Deny => {
333                            let message = deny_message
334                                .unwrap_or_else(|| "Tool execution denied by user".to_string());
335                            vec![ContentBlock::Text {
336                                text: serde_json::json!({
337                                    "confirmation": "denied",
338                                    "tool_use_id": tool_use_id,
339                                    "reason": message
340                                })
341                                .to_string(),
342                            }]
343                        }
344                    };
345                    if let Err(e) = self.parking.deliver(&tool_use_id, content).await {
346                        warn!(
347                            session_id = %self.session_id,
348                            error = %e,
349                            "failed to deliver tool confirmation"
350                        );
351                    }
352                }
353                UserEvent::ToolResult { tool_use_id, .. } => {
354                    debug!(
355                        session_id = %self.session_id,
356                        tool_use_id = %tool_use_id,
357                        "tool result received (self-hosted only, not yet wired)"
358                    );
359                }
360                UserEvent::DefineOutcome { criteria } => {
361                    debug!(
362                        session_id = %self.session_id,
363                        criteria = %criteria,
364                        "outcome criteria defined"
365                    );
366                    // Stored for future use — outcome evaluation is a later task.
367                }
368            }
369        }
370
371        info!(session_id = %self.session_id, "session loop exited");
372        Ok(())
373    }
374
375    /// Process a single turn: emit status.running, invoke Runner, emit events, emit status.idle.
376    async fn process_turn(&mut self, content: Vec<ContentBlock>) -> Result<(), RuntimeError> {
377        // 1. Emit status.running
378        self.status = SessionStatus::Running;
379        let running_event = SessionEvent::StatusRunning { seq: self.seq.next() };
380        self.emit_event(running_event).await;
381
382        // 2. Check interrupt before processing.
383        if self.check_interrupt() {
384            self.emit_idle(Some(StopReason::EndTurn), None).await;
385            return Ok(());
386        }
387
388        // 3. Build user Content from ContentBlocks
389        let user_content = self.build_user_content(&content);
390
391        // 4. Build and invoke the Runner
392        let runner = self.build_runner()?;
393
394        let event_stream = runner
395            .run_str("managed_user", &self.session_id, user_content)
396            .await
397            .map_err(|e| RuntimeError::internal(format!("runner invocation failed: {e}")))?;
398
399        // 5. Consume event stream, mapping each event to SessionEvents
400        let mut turn_usage = UsageReport::default();
401        let mut custom_tool_ids = Vec::new();
402
403        futures::pin_mut!(event_stream);
404
405        while let Some(event_result) = event_stream.next().await {
406            // Check interrupt between events
407            if self.check_interrupt() {
408                self.emit_idle(Some(StopReason::EndTurn), None).await;
409                return Ok(());
410            }
411
412            match event_result {
413                Ok(event) => {
414                    self.process_runner_event(&event, &mut turn_usage, &mut custom_tool_ids).await;
415                }
416                Err(e) => {
417                    warn!(
418                        session_id = %self.session_id,
419                        error = %e,
420                        "runner event stream error"
421                    );
422                    let error_event = SessionEvent::Error {
423                        code: "runner_error".to_string(),
424                        message: e.to_string(),
425                        seq: self.seq.next(),
426                    };
427                    self.emit_event(error_event).await;
428                }
429            }
430        }
431
432        // 6. Track usage
433        // 6. Track usage
434        let turn_usage_report = if !turn_usage.is_empty() {
435            self.usage_tracker.record_turn(turn_usage.clone());
436            Some(turn_usage)
437        } else {
438            None
439        };
440
441        // 7. Determine stop reason
442        let stop_reason = if custom_tool_ids.is_empty() {
443            Some(StopReason::EndTurn)
444        } else {
445            Some(StopReason::RequiresAction { event_ids: custom_tool_ids })
446        };
447
448        // 8. Emit status.idle with usage from this turn
449        self.emit_idle(stop_reason, turn_usage_report).await;
450
451        Ok(())
452    }
453
454    /// Build a Runner instance for this turn.
455    fn build_runner(&self) -> Result<Runner, RuntimeError> {
456        #[allow(unused_mut)]
457        let mut builder = Runner::builder()
458            .app_name("managed")
459            .agent(Arc::clone(&self.agent))
460            .session_service(Arc::clone(&self.session_service))
461            .cancellation_token(self.cancel_token.clone());
462
463        #[cfg(feature = "memory")]
464        if let Some(ref memory) = self.memory {
465            builder = builder.memory_service(Arc::clone(memory));
466        }
467
468        builder.build().map_err(|e| RuntimeError::internal(format!("failed to build runner: {e}")))
469    }
470
471    /// Convert managed ContentBlocks into an adk-core Content for the Runner.
472    fn build_user_content(&self, blocks: &[ContentBlock]) -> Content {
473        let mut parts = Vec::new();
474        for block in blocks {
475            match block {
476                ContentBlock::Text { text } => {
477                    parts.push(Part::Text { text: text.clone() });
478                }
479                ContentBlock::Image { source } => {
480                    // Convert image block to inline data or file reference
481                    if let Some(url) = source.get("url").and_then(|v| v.as_str()) {
482                        parts.push(Part::FileData {
483                            mime_type: source
484                                .get("media_type")
485                                .and_then(|v| v.as_str())
486                                .unwrap_or("image/png")
487                                .to_string(),
488                            file_uri: url.to_string(),
489                        });
490                    }
491                }
492                ContentBlock::File { file_id } => {
493                    parts.push(Part::FileData {
494                        mime_type: "application/octet-stream".to_string(),
495                        file_uri: file_id.clone(),
496                    });
497                }
498            }
499        }
500
501        Content { role: "user".to_string(), parts }
502    }
503
504    /// Process a single Runner event, mapping it to SessionEvents and tracking usage.
505    async fn process_runner_event(
506        &mut self,
507        event: &Event,
508        turn_usage: &mut UsageReport,
509        custom_tool_ids: &mut Vec<String>,
510    ) {
511        // Extract usage metadata from the LLM response
512        if let Some(ref usage_meta) = event.llm_response.usage_metadata {
513            let report = UsageReport::from_usage_metadata(usage_meta);
514            turn_usage.accumulate(&report);
515        }
516
517        // Skip partial streaming chunks — we only emit complete events
518        if event.llm_response.partial {
519            return;
520        }
521
522        // Extract content from the LLM response
523        if let Some(ref content) = event.llm_response.content {
524            for part in &content.parts {
525                match part {
526                    Part::Text { text } => {
527                        if text.is_empty() {
528                            continue;
529                        }
530                        let output = RunnerOutput::TextContent { text: text.clone() };
531                        let session_event = map_runner_output(output, self.seq.next());
532                        self.emit_event(session_event).await;
533                    }
534                    Part::FunctionCall { name, args, id, .. } => {
535                        let tool_use_id =
536                            id.clone().unwrap_or_else(|| format!("tu_{}", uuid::Uuid::new_v4()));
537
538                        // Classify the tool call
539                        let tool_kind = self.classify_tool(name);
540
541                        let output = match tool_kind {
542                            ToolKind::Custom => {
543                                let ctu_id = format!("ctu_{}", uuid::Uuid::new_v4());
544                                custom_tool_ids.push(ctu_id.clone());
545                                RunnerOutput::CustomToolCall {
546                                    custom_tool_use_id: ctu_id,
547                                    name: name.clone(),
548                                    input: args.clone(),
549                                }
550                            }
551                            ToolKind::Builtin => RunnerOutput::BuiltinToolCall {
552                                tool_use_id,
553                                name: name.clone(),
554                                input: args.clone(),
555                            },
556                            ToolKind::Mcp => RunnerOutput::McpToolCall {
557                                tool_use_id,
558                                name: name.clone(),
559                                input: args.clone(),
560                            },
561                        };
562
563                        let session_event = map_runner_output(output.clone(), self.seq.next());
564                        self.emit_event(session_event).await;
565
566                        // If custom tool, park and wait for client result
567                        if requires_parking(&output)
568                            && let Some(ctu_id) = custom_tool_use_id(&output)
569                        {
570                            let ctu_id_owned = ctu_id.to_string();
571                            debug!(
572                                session_id = %self.session_id,
573                                custom_tool_use_id = %ctu_id_owned,
574                                "parking for custom tool result"
575                            );
576                            match self.parking.park(&ctu_id_owned).await {
577                                Ok(_result_blocks) => {
578                                    debug!(
579                                        session_id = %self.session_id,
580                                        custom_tool_use_id = %ctu_id_owned,
581                                        "custom tool result delivered"
582                                    );
583                                }
584                                Err(e) => {
585                                    warn!(
586                                        session_id = %self.session_id,
587                                        error = %e,
588                                        "custom tool park failed or timed out"
589                                    );
590                                }
591                            }
592                        }
593                    }
594                    // Skip FunctionResponse, Thinking, and other part types
595                    _ => {}
596                }
597            }
598        }
599    }
600
601    /// Classify a tool call by name to determine which RunnerOutput variant to use.
602    fn classify_tool(&self, name: &str) -> ToolKind {
603        // Known built-in tools execute server-side
604        const BUILTIN_TOOLS: &[&str] =
605            &["bash", "filesystem", "web_search", "web_fetch", "code_execution"];
606
607        if BUILTIN_TOOLS.contains(&name) {
608            ToolKind::Builtin
609        } else if name.starts_with("mcp_") || name.contains("::") {
610            ToolKind::Mcp
611        } else {
612            // All other tools are custom (client-executed)
613            ToolKind::Custom
614        }
615    }
616
617    /// Emit a session event: assign to checkpoint and broadcast.
618    async fn emit_event(&mut self, event: SessionEvent) {
619        // Checkpoint atomically via the shared manager.
620        let run_state =
621            RunState { seq: self.seq.current(), pending_tool_ids: Vec::new(), status: self.status };
622        self.checkpoint.write().await.checkpoint(event.clone(), run_state);
623
624        // Broadcast to subscribers (ignore if no receivers).
625        let _ = self.event_tx.send(event);
626    }
627
628    /// Emit a `status.idle` event and update internal status.
629    async fn emit_idle(&mut self, stop_reason: Option<StopReason>, usage: Option<UsageReport>) {
630        self.status = SessionStatus::Idle;
631        let idle_event = SessionEvent::StatusIdle { seq: self.seq.next(), stop_reason, usage };
632        self.emit_event(idle_event).await;
633    }
634
635    /// Check if the cancellation token has been triggered.
636    ///
637    /// Returns `true` if interrupted.
638    fn check_interrupt(&self) -> bool {
639        self.cancel_token.is_cancelled()
640    }
641
642    /// Check and handle pause state. If paused, blocks until resumed.
643    async fn check_pause(&self) {
644        loop {
645            let is_paused = *self.pause_flag.lock().await;
646            if !is_paused {
647                break;
648            }
649            debug!(session_id = %self.session_id, "session loop paused, waiting for resume");
650            self.pause_notify.notified().await;
651        }
652    }
653}
654
655/// Tool classification used internally by the session loop.
656///
657/// Re-exported from [`crate::event_mapping::ToolKind`] for internal use.
658use crate::event_mapping::ToolKind;
659
660#[cfg(test)]
661mod tests {
662    use std::time::Duration;
663
664    use super::*;
665    use adk_core::{FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream};
666    use async_stream::stream;
667    use async_trait::async_trait;
668
669    /// Mock LLM that returns a configurable response.
670    struct TestLlm {
671        response_text: String,
672    }
673
674    impl TestLlm {
675        fn new(text: &str) -> Self {
676            Self { response_text: text.to_string() }
677        }
678    }
679
680    #[async_trait]
681    impl Llm for TestLlm {
682        fn name(&self) -> &str {
683            "test-llm"
684        }
685
686        async fn generate_content(
687            &self,
688            _request: LlmRequest,
689            _stream: bool,
690        ) -> adk_core::Result<LlmResponseStream> {
691            let text = self.response_text.clone();
692            let s = stream! {
693                yield Ok(LlmResponse {
694                    content: Some(Content::new("model").with_text(&text)),
695                    partial: false,
696                    turn_complete: true,
697                    finish_reason: Some(FinishReason::Stop),
698                    ..Default::default()
699                });
700            };
701            Ok(Box::pin(s))
702        }
703    }
704
705    /// Build a test agent with the given LLM.
706    fn build_test_agent(llm: impl Llm + 'static) -> Arc<dyn Agent> {
707        let agent =
708            adk_agent::LlmAgentBuilder::new("test-agent").model(Arc::new(llm)).build().unwrap();
709        Arc::new(agent)
710    }
711
712    /// Helper to create a session loop with default test configuration.
713    fn create_test_loop()
714    -> (mpsc::Sender<UserEvent>, broadcast::Receiver<SessionEvent>, CancellationToken, SessionLoop)
715    {
716        let (event_tx, event_rx) = mpsc::channel(64);
717        let (broadcast_tx, broadcast_rx) = broadcast::channel(256);
718        let cancel = CancellationToken::new();
719        let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
720        let agent = build_test_agent(TestLlm::new("Hello from the agent"));
721        let session_service: Arc<dyn SessionService> =
722            Arc::new(adk_session::InMemorySessionService::new());
723
724        let session_loop = SessionLoop::new(
725            "test_session".to_string(),
726            event_rx,
727            broadcast_tx,
728            parking,
729            cancel.clone(),
730            agent,
731            session_service,
732        );
733
734        (event_tx, broadcast_rx, cancel, session_loop)
735    }
736
737    #[tokio::test]
738    async fn test_basic_message_flow() {
739        let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
740
741        let handle = tokio::spawn(session_loop.run());
742
743        // Send a message.
744        event_tx
745            .send(UserEvent::Message {
746                content: vec![ContentBlock::Text { text: "Hello".to_string() }],
747            })
748            .await
749            .unwrap();
750
751        // Expect: status.running, then agent response events, then status.idle
752        let ev1 = broadcast_rx.recv().await.unwrap();
753        match ev1 {
754            SessionEvent::StatusRunning { seq } => assert_eq!(seq, 0),
755            other => panic!("expected StatusRunning, got: {other:?}"),
756        }
757
758        // Collect remaining events until we get StatusIdle
759        let mut got_message = false;
760        let mut got_idle = false;
761        for _ in 0..10 {
762            match tokio::time::timeout(Duration::from_secs(5), broadcast_rx.recv()).await {
763                Ok(Ok(SessionEvent::Message { content, .. })) => {
764                    assert!(!content.is_empty());
765                    got_message = true;
766                }
767                Ok(Ok(SessionEvent::StatusIdle { stop_reason, .. })) => {
768                    assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
769                    got_idle = true;
770                    break;
771                }
772                Ok(Ok(SessionEvent::Error { message, .. })) => {
773                    // In test environments without a real model, errors are acceptable
774                    debug!("got error event: {message}");
775                }
776                Ok(Ok(other)) => {
777                    debug!("got other event: {other:?}");
778                }
779                Ok(Err(_)) => break,
780                Err(_) => break,
781            }
782        }
783
784        // We must at least get status.idle (the turn completes regardless)
785        assert!(got_idle, "expected StatusIdle event");
786
787        // Close the channel to stop the loop.
788        drop(event_tx);
789        let result = handle.await.unwrap();
790        assert!(result.is_ok());
791
792        // Note: got_message depends on whether the Runner successfully invoked
793        // the mock LLM. In unit tests, InMemorySessionService may not have the
794        // session pre-created so the Runner creates one — either way the flow
795        // should complete without panics.
796        let _ = got_message;
797    }
798
799    #[tokio::test]
800    async fn test_seq_monotonically_increases() {
801        let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
802
803        let handle = tokio::spawn(session_loop.run());
804
805        // Send a message
806        event_tx
807            .send(UserEvent::Message {
808                content: vec![ContentBlock::Text { text: "First".to_string() }],
809            })
810            .await
811            .unwrap();
812
813        // Collect events from the turn
814        let mut seqs = Vec::new();
815        for _ in 0..10 {
816            match tokio::time::timeout(Duration::from_secs(5), broadcast_rx.recv()).await {
817                Ok(Ok(ev)) => {
818                    let seq = match &ev {
819                        SessionEvent::StatusRunning { seq } => *seq,
820                        SessionEvent::Message { seq, .. } => *seq,
821                        SessionEvent::StatusIdle { seq, .. } => *seq,
822                        SessionEvent::ToolUse { seq, .. } => *seq,
823                        SessionEvent::CustomToolUse { seq, .. } => *seq,
824                        SessionEvent::McpToolUse { seq, .. } => *seq,
825                        SessionEvent::Error { seq, .. } => *seq,
826                    };
827                    seqs.push(seq);
828                    if matches!(ev, SessionEvent::StatusIdle { .. }) {
829                        break;
830                    }
831                }
832                _ => break,
833            }
834        }
835
836        // Verify strict monotonic increase.
837        assert!(seqs.len() >= 2, "expected at least 2 events");
838        for window in seqs.windows(2) {
839            assert!(
840                window[1] > window[0],
841                "seq must be strictly increasing: {} should be > {}",
842                window[1],
843                window[0]
844            );
845        }
846
847        drop(event_tx);
848        handle.await.unwrap().unwrap();
849    }
850
851    #[tokio::test]
852    async fn test_interrupt_stops_loop() {
853        let (event_tx, mut broadcast_rx, cancel, session_loop) = create_test_loop();
854
855        let handle = tokio::spawn(session_loop.run());
856
857        // Give the loop a moment to start waiting.
858        tokio::time::sleep(Duration::from_millis(10)).await;
859
860        // Trigger interrupt.
861        cancel.cancel();
862
863        // Should emit status.idle on interrupt.
864        let ev = broadcast_rx.recv().await.unwrap();
865        match ev {
866            SessionEvent::StatusIdle { stop_reason, .. } => {
867                assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
868            }
869            other => panic!("expected StatusIdle on interrupt, got: {other:?}"),
870        }
871
872        // The loop should exit cleanly.
873        let result = handle.await.unwrap();
874        assert!(result.is_ok());
875
876        drop(event_tx);
877    }
878
879    #[tokio::test]
880    async fn test_user_interrupt_event_stops_loop() {
881        let (event_tx, mut broadcast_rx, _cancel, session_loop) = create_test_loop();
882
883        let handle = tokio::spawn(session_loop.run());
884
885        // Send an interrupt event.
886        event_tx.send(UserEvent::Interrupt {}).await.unwrap();
887
888        // Should emit status.idle.
889        let ev = broadcast_rx.recv().await.unwrap();
890        match ev {
891            SessionEvent::StatusIdle { stop_reason, .. } => {
892                assert!(matches!(stop_reason, Some(StopReason::EndTurn)));
893            }
894            other => panic!("expected StatusIdle, got: {other:?}"),
895        }
896
897        let result = handle.await.unwrap();
898        assert!(result.is_ok());
899
900        drop(event_tx);
901    }
902
903    #[tokio::test]
904    async fn test_pause_and_resume() {
905        let (event_tx, event_rx) = mpsc::channel(64);
906        let (broadcast_tx, mut broadcast_rx) = broadcast::channel(256);
907        let cancel = CancellationToken::new();
908        let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
909        let pause_flag = Arc::new(Mutex::new(false));
910        let pause_notify = Arc::new(Notify::new());
911        let agent = build_test_agent(TestLlm::new("resumed response"));
912        let session_service: Arc<dyn SessionService> =
913            Arc::new(adk_session::InMemorySessionService::new());
914
915        #[cfg(feature = "memory")]
916        let session_loop = SessionLoop::with_pause_controls(
917            "pause_test".to_string(),
918            event_rx,
919            broadcast_tx,
920            parking,
921            cancel.clone(),
922            Arc::clone(&pause_flag),
923            Arc::clone(&pause_notify),
924            Arc::new(RwLock::new(CheckpointManager::new("pause_test".to_string()))),
925            agent,
926            session_service,
927            None,
928        );
929        #[cfg(not(feature = "memory"))]
930        let session_loop = SessionLoop::with_pause_controls(
931            "pause_test".to_string(),
932            event_rx,
933            broadcast_tx,
934            parking,
935            cancel.clone(),
936            Arc::clone(&pause_flag),
937            Arc::clone(&pause_notify),
938            Arc::new(RwLock::new(CheckpointManager::new("pause_test".to_string()))),
939            agent,
940            session_service,
941        );
942
943        let handle = tokio::spawn(session_loop.run());
944
945        // Pause the loop.
946        *pause_flag.lock().await = true;
947
948        // Send a message — should not be processed while paused.
949        event_tx
950            .send(UserEvent::Message {
951                content: vec![ContentBlock::Text { text: "While paused".to_string() }],
952            })
953            .await
954            .unwrap();
955
956        // Give the loop time to potentially process (it shouldn't).
957        tokio::time::sleep(Duration::from_millis(50)).await;
958
959        // Verify nothing was broadcast yet (try_recv should fail).
960        assert!(broadcast_rx.try_recv().is_err());
961
962        // Resume.
963        *pause_flag.lock().await = false;
964        pause_notify.notify_one();
965
966        // Now the message should be processed.
967        let ev1 = tokio::time::timeout(Duration::from_secs(2), broadcast_rx.recv())
968            .await
969            .expect("timed out waiting for event after resume")
970            .unwrap();
971
972        match ev1 {
973            SessionEvent::StatusRunning { .. } => {}
974            other => panic!("expected StatusRunning after resume, got: {other:?}"),
975        }
976
977        // Clean up.
978        drop(event_tx);
979        handle.await.unwrap().unwrap();
980    }
981
982    #[tokio::test]
983    async fn test_channel_close_stops_loop() {
984        let (event_tx, event_rx) = mpsc::channel(64);
985        let (broadcast_tx, _broadcast_rx) = broadcast::channel(256);
986        let cancel = CancellationToken::new();
987        let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
988        let agent = build_test_agent(TestLlm::new("test"));
989        let session_service: Arc<dyn SessionService> =
990            Arc::new(adk_session::InMemorySessionService::new());
991
992        let session_loop = SessionLoop::new(
993            "close_test".to_string(),
994            event_rx,
995            broadcast_tx,
996            parking,
997            cancel,
998            agent,
999            session_service,
1000        );
1001
1002        let handle = tokio::spawn(session_loop.run());
1003
1004        // Drop the sender — closes the channel.
1005        drop(event_tx);
1006
1007        // Loop should exit cleanly.
1008        let result = handle.await.unwrap();
1009        assert!(result.is_ok());
1010    }
1011
1012    #[tokio::test]
1013    async fn test_custom_tool_result_delivery() {
1014        let (event_tx, event_rx) = mpsc::channel(64);
1015        let (broadcast_tx, _broadcast_rx) = broadcast::channel(256);
1016        let cancel = CancellationToken::new();
1017        let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
1018        let parking_clone = Arc::clone(&parking);
1019        let agent = build_test_agent(TestLlm::new("test"));
1020        let session_service: Arc<dyn SessionService> =
1021            Arc::new(adk_session::InMemorySessionService::new());
1022
1023        let session_loop = SessionLoop::new(
1024            "parking_test".to_string(),
1025            event_rx,
1026            broadcast_tx,
1027            parking_clone,
1028            cancel,
1029            agent,
1030            session_service,
1031        );
1032
1033        let handle = tokio::spawn(session_loop.run());
1034
1035        // Park a tool call from another task.
1036        let parking_for_park = Arc::clone(&parking);
1037        let park_handle = tokio::spawn(async move { parking_for_park.park("ctu_test_001").await });
1038
1039        // Give the park a moment to register.
1040        tokio::time::sleep(Duration::from_millis(10)).await;
1041
1042        // Send custom tool result via the session loop.
1043        event_tx
1044            .send(UserEvent::CustomToolResult {
1045                custom_tool_use_id: "ctu_test_001".to_string(),
1046                content: vec![ContentBlock::Text { text: "tool output".to_string() }],
1047            })
1048            .await
1049            .unwrap();
1050
1051        // The parked task should receive the result.
1052        let result = tokio::time::timeout(Duration::from_secs(2), park_handle)
1053            .await
1054            .expect("park timed out")
1055            .unwrap()
1056            .unwrap();
1057
1058        assert_eq!(result.len(), 1);
1059        match &result[0] {
1060            ContentBlock::Text { text } => assert_eq!(text, "tool output"),
1061            _ => panic!("expected Text"),
1062        }
1063
1064        // Clean up.
1065        drop(event_tx);
1066        handle.await.unwrap().unwrap();
1067    }
1068
1069    #[tokio::test]
1070    async fn test_tool_classification() {
1071        let (event_tx, event_rx) = mpsc::channel(64);
1072        let (broadcast_tx, _) = broadcast::channel(256);
1073        let cancel = CancellationToken::new();
1074        let parking = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
1075        let agent = build_test_agent(TestLlm::new("test"));
1076        let session_service: Arc<dyn SessionService> =
1077            Arc::new(adk_session::InMemorySessionService::new());
1078
1079        let session_loop = SessionLoop::new(
1080            "classify_test".to_string(),
1081            event_rx,
1082            broadcast_tx,
1083            parking,
1084            cancel,
1085            agent,
1086            session_service,
1087        );
1088
1089        // Test builtin tools
1090        assert!(matches!(session_loop.classify_tool("bash"), ToolKind::Builtin));
1091        assert!(matches!(session_loop.classify_tool("filesystem"), ToolKind::Builtin));
1092        assert!(matches!(session_loop.classify_tool("web_search"), ToolKind::Builtin));
1093        assert!(matches!(session_loop.classify_tool("web_fetch"), ToolKind::Builtin));
1094        assert!(matches!(session_loop.classify_tool("code_execution"), ToolKind::Builtin));
1095
1096        // Test MCP tools
1097        assert!(matches!(session_loop.classify_tool("mcp_file_read"), ToolKind::Mcp));
1098        assert!(matches!(session_loop.classify_tool("server::tool"), ToolKind::Mcp));
1099
1100        // Test custom tools
1101        assert!(matches!(session_loop.classify_tool("get_weather"), ToolKind::Custom));
1102        assert!(matches!(session_loop.classify_tool("deploy"), ToolKind::Custom));
1103
1104        drop(event_tx);
1105    }
1106}