Skip to main content

adk_managed/
testing.rs

1//! Testing utilities for the managed agent runtime.
2//!
3//! This module provides deterministic test doubles for the managed agent runtime
4//! pipeline. These are **not mocks** — they implement the real traits and exercise
5//! the full runtime pipeline (parking, checkpoints, replay, event mapping). Only
6//! the LLM provider API call is replaced with pre-scripted deterministic responses.
7//!
8//! # Architecture
9//!
10//! ```text
11//! ScriptedLlm (deterministic responses)
12//!   │
13//!   ▼
14//! Full runtime pipeline (SessionLoop, CheckpointManager, ToolParkingLot, etc.)
15//!   │
16//!   ▼
17//! SessionEvent stream (byte-identical assertions possible)
18//! ```
19//!
20//! # Usage
21//!
22//! ```rust,ignore
23//! use adk_managed::testing::{ScriptedLlm, ScriptedTurn, ScriptedToolCall};
24//! use serde_json::json;
25//!
26//! let turns = vec![
27//!     ScriptedTurn {
28//!         text: Some("Hello! How can I help you?".to_string()),
29//!         tool_calls: vec![],
30//!     },
31//!     ScriptedTurn {
32//!         text: None,
33//!         tool_calls: vec![ScriptedToolCall {
34//!             name: "web_search".to_string(),
35//!             input: json!({"query": "rust async"}),
36//!             id: Some("tc_001".to_string()),
37//!         }],
38//!     },
39//! ];
40//!
41//! let llm = ScriptedLlm::new("scripted-model", turns);
42//! // Use llm in place of any Arc<dyn Llm> in the runtime pipeline
43//! ```
44
45use adk_core::{
46    Llm, LlmRequest, LlmResponse, LlmResponseStream, Result as AdkResult, types::Content,
47};
48use async_stream::stream;
49use async_trait::async_trait;
50use serde::{Deserialize, Serialize};
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53/// A pre-scripted turn that the [`ScriptedLlm`] will return.
54///
55/// Each turn represents one complete LLM response. It can contain text content,
56/// tool calls, or both — mirroring real LLM behavior where a response may
57/// include reasoning text followed by tool invocations.
58///
59/// # Wire Format
60///
61/// Serializes to/from JSON for use in fixture files:
62///
63/// ```json
64/// {
65///   "text": "Let me search for that.",
66///   "tool_calls": [
67///     { "name": "web_search", "input": {"query": "rust"}, "id": "tc_001" }
68///   ]
69/// }
70/// ```
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScriptedTurn {
73    /// Text response content (if any).
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub text: Option<String>,
76    /// Tool calls to make (if any).
77    #[serde(default, skip_serializing_if = "Vec::is_empty")]
78    pub tool_calls: Vec<ScriptedToolCall>,
79}
80
81/// A scripted tool call within a [`ScriptedTurn`].
82///
83/// Represents a function call that the LLM "decides" to make.
84/// The `id` field maps to the tool_use_id / function call ID used
85/// for round-trip correlation with tool results.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ScriptedToolCall {
88    /// Name of the tool to call.
89    pub name: String,
90    /// Input arguments as JSON.
91    pub input: serde_json::Value,
92    /// Optional tool call ID. If not provided, a deterministic ID is generated.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub id: Option<String>,
95}
96
97/// A deterministic LLM double with pre-scripted responses.
98///
99/// `ScriptedLlm` implements the real [`Llm`] trait and exercises the full
100/// runtime pipeline. Only the provider API call is replaced — everything else
101/// (session loop, checkpoints, parking, event mapping) runs exactly as it would
102/// with a real provider.
103///
104/// This is explicitly **NOT a mock**. It:
105/// - Implements the full `Llm` trait contract
106/// - Returns complete `LlmResponse` objects with proper `Content` and `Part` types
107/// - Supports tool calls (function calls) in responses
108/// - Advances through turns deterministically (FIFO order)
109/// - Is thread-safe (`Send + Sync` via `AtomicUsize`)
110///
111/// # Panics
112///
113/// If more turns are requested than were scripted, the LLM returns an empty
114/// response with `turn_complete = true` rather than panicking.
115pub struct ScriptedLlm {
116    /// Model name identifier.
117    name: String,
118    /// Pre-scripted turns in FIFO order.
119    turns: Vec<ScriptedTurn>,
120    /// Current turn index (atomic for thread safety).
121    current_turn: AtomicUsize,
122}
123
124impl ScriptedLlm {
125    /// Create a new `ScriptedLlm` with the given name and pre-scripted turns.
126    ///
127    /// Turns are consumed in FIFO order — each call to `generate_content`
128    /// advances to the next turn.
129    pub fn new(name: impl Into<String>, turns: Vec<ScriptedTurn>) -> Self {
130        Self { name: name.into(), turns, current_turn: AtomicUsize::new(0) }
131    }
132
133    /// Returns the number of turns that have been consumed so far.
134    pub fn turns_consumed(&self) -> usize {
135        self.current_turn.load(Ordering::Relaxed)
136    }
137
138    /// Returns the total number of scripted turns.
139    pub fn total_turns(&self) -> usize {
140        self.turns.len()
141    }
142
143    /// Build an `LlmResponse` from a `ScriptedTurn`.
144    fn build_response(turn: &ScriptedTurn, turn_index: usize) -> LlmResponse {
145        use adk_core::FinishReason;
146        use adk_core::types::Part;
147
148        let mut parts = Vec::new();
149
150        // Add text part if present.
151        if let Some(text) = &turn.text {
152            parts.push(Part::Text { text: text.clone() });
153        }
154
155        // Add function call parts.
156        for (i, tool_call) in turn.tool_calls.iter().enumerate() {
157            let id =
158                tool_call.id.clone().unwrap_or_else(|| format!("scripted_tc_{turn_index}_{i}"));
159            parts.push(Part::FunctionCall {
160                name: tool_call.name.clone(),
161                args: tool_call.input.clone(),
162                id: Some(id),
163                thought_signature: None,
164            });
165        }
166
167        let content = if parts.is_empty() {
168            None
169        } else {
170            Some(Content { role: "model".to_string(), parts })
171        };
172
173        LlmResponse {
174            content,
175            usage_metadata: None,
176            finish_reason: Some(FinishReason::Stop),
177            citation_metadata: None,
178            partial: false,
179            turn_complete: true,
180            interrupted: false,
181            error_code: None,
182            error_message: None,
183            provider_metadata: None,
184            interaction_id: None,
185        }
186    }
187}
188
189impl std::fmt::Debug for ScriptedLlm {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("ScriptedLlm")
192            .field("name", &self.name)
193            .field("turns", &self.turns.len())
194            .field("current_turn", &self.current_turn.load(Ordering::Relaxed))
195            .finish()
196    }
197}
198
199#[async_trait]
200impl Llm for ScriptedLlm {
201    fn name(&self) -> &str {
202        &self.name
203    }
204
205    async fn generate_content(
206        &self,
207        _request: LlmRequest,
208        _stream: bool,
209    ) -> AdkResult<LlmResponseStream> {
210        let turn_index = self.current_turn.fetch_add(1, Ordering::Relaxed);
211
212        let response = if turn_index < self.turns.len() {
213            Self::build_response(&self.turns[turn_index], turn_index)
214        } else {
215            // Beyond scripted turns — return empty complete response.
216            LlmResponse {
217                content: Some(Content {
218                    role: "model".to_string(),
219                    parts: vec![adk_core::types::Part::Text {
220                        text: "[ScriptedLlm: no more scripted turns]".to_string(),
221                    }],
222                }),
223                usage_metadata: None,
224                finish_reason: Some(adk_core::FinishReason::Stop),
225                citation_metadata: None,
226                partial: false,
227                turn_complete: true,
228                interrupted: false,
229                error_code: None,
230                error_message: None,
231                provider_metadata: None,
232                interaction_id: None,
233            }
234        };
235
236        let response_stream = stream! {
237            yield Ok(response);
238        };
239
240        Ok(Box::pin(response_stream))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use futures::StreamExt;
248    use serde_json::json;
249
250    #[tokio::test]
251    async fn test_scripted_llm_returns_text() {
252        let turns =
253            vec![ScriptedTurn { text: Some("Hello, world!".to_string()), tool_calls: vec![] }];
254        let llm = ScriptedLlm::new("test-model", turns);
255
256        assert_eq!(llm.name(), "test-model");
257
258        let request = LlmRequest::new("test-model", vec![]);
259        let mut stream = llm.generate_content(request, false).await.unwrap();
260
261        let response = stream.next().await.unwrap().unwrap();
262        assert!(response.turn_complete);
263        assert!(!response.partial);
264
265        let content = response.content.unwrap();
266        assert_eq!(content.role, "model");
267        assert_eq!(content.parts.len(), 1);
268        match &content.parts[0] {
269            adk_core::types::Part::Text { text } => {
270                assert_eq!(text, "Hello, world!");
271            }
272            other => panic!("expected Text part, got: {other:?}"),
273        }
274    }
275
276    #[tokio::test]
277    async fn test_scripted_llm_returns_tool_calls() {
278        let turns = vec![ScriptedTurn {
279            text: None,
280            tool_calls: vec![ScriptedToolCall {
281                name: "web_search".to_string(),
282                input: json!({"query": "rust async"}),
283                id: Some("tc_001".to_string()),
284            }],
285        }];
286        let llm = ScriptedLlm::new("tool-model", turns);
287
288        let request = LlmRequest::new("tool-model", vec![]);
289        let mut stream = llm.generate_content(request, false).await.unwrap();
290
291        let response = stream.next().await.unwrap().unwrap();
292        let content = response.content.unwrap();
293        assert_eq!(content.parts.len(), 1);
294        match &content.parts[0] {
295            adk_core::types::Part::FunctionCall { name, args, id, .. } => {
296                assert_eq!(name, "web_search");
297                assert_eq!(args, &json!({"query": "rust async"}));
298                assert_eq!(id, &Some("tc_001".to_string()));
299            }
300            other => panic!("expected FunctionCall part, got: {other:?}"),
301        }
302    }
303
304    #[tokio::test]
305    async fn test_scripted_llm_advances_through_turns() {
306        let turns = vec![
307            ScriptedTurn { text: Some("First".to_string()), tool_calls: vec![] },
308            ScriptedTurn { text: Some("Second".to_string()), tool_calls: vec![] },
309            ScriptedTurn { text: Some("Third".to_string()), tool_calls: vec![] },
310        ];
311        let llm = ScriptedLlm::new("multi-turn", turns);
312
313        for (i, expected) in ["First", "Second", "Third"].iter().enumerate() {
314            let request = LlmRequest::new("multi-turn", vec![]);
315            let mut stream = llm.generate_content(request, false).await.unwrap();
316            let response = stream.next().await.unwrap().unwrap();
317            let content = response.content.unwrap();
318            match &content.parts[0] {
319                adk_core::types::Part::Text { text } => {
320                    assert_eq!(text, *expected);
321                }
322                other => panic!("turn {i}: expected Text, got: {other:?}"),
323            }
324        }
325
326        assert_eq!(llm.turns_consumed(), 3);
327    }
328
329    #[tokio::test]
330    async fn test_scripted_llm_handles_exhaustion() {
331        let turns = vec![ScriptedTurn { text: Some("Only one".to_string()), tool_calls: vec![] }];
332        let llm = ScriptedLlm::new("exhausted", turns);
333
334        // Consume the only turn.
335        let request = LlmRequest::new("exhausted", vec![]);
336        let mut stream = llm.generate_content(request, false).await.unwrap();
337        let _ = stream.next().await.unwrap().unwrap();
338
339        // Next call should return a fallback.
340        let request = LlmRequest::new("exhausted", vec![]);
341        let mut stream = llm.generate_content(request, false).await.unwrap();
342        let response = stream.next().await.unwrap().unwrap();
343        assert!(response.turn_complete);
344        let content = response.content.unwrap();
345        match &content.parts[0] {
346            adk_core::types::Part::Text { text } => {
347                assert!(text.contains("no more scripted turns"));
348            }
349            other => panic!("expected fallback Text, got: {other:?}"),
350        }
351    }
352
353    #[tokio::test]
354    async fn test_scripted_llm_mixed_text_and_tool_calls() {
355        let turns = vec![ScriptedTurn {
356            text: Some("Let me search for that.".to_string()),
357            tool_calls: vec![ScriptedToolCall {
358                name: "web_search".to_string(),
359                input: json!({"query": "ADK Rust"}),
360                id: Some("tc_mixed".to_string()),
361            }],
362        }];
363        let llm = ScriptedLlm::new("mixed", turns);
364
365        let request = LlmRequest::new("mixed", vec![]);
366        let mut stream = llm.generate_content(request, false).await.unwrap();
367        let response = stream.next().await.unwrap().unwrap();
368        let content = response.content.unwrap();
369
370        assert_eq!(content.parts.len(), 2);
371        assert!(matches!(&content.parts[0], adk_core::types::Part::Text { .. }));
372        assert!(matches!(&content.parts[1], adk_core::types::Part::FunctionCall { .. }));
373    }
374
375    #[tokio::test]
376    async fn test_scripted_turn_serialization_roundtrip() {
377        let turn = ScriptedTurn {
378            text: Some("Hello".to_string()),
379            tool_calls: vec![ScriptedToolCall {
380                name: "search".to_string(),
381                input: json!({"q": "test"}),
382                id: Some("id_1".to_string()),
383            }],
384        };
385
386        let json = serde_json::to_string(&turn).unwrap();
387        let deserialized: ScriptedTurn = serde_json::from_str(&json).unwrap();
388
389        assert_eq!(deserialized.text, turn.text);
390        assert_eq!(deserialized.tool_calls.len(), 1);
391        assert_eq!(deserialized.tool_calls[0].name, "search");
392        assert_eq!(deserialized.tool_calls[0].id, Some("id_1".to_string()));
393    }
394
395    #[tokio::test]
396    async fn test_auto_generated_tool_call_ids() {
397        let turns = vec![ScriptedTurn {
398            text: None,
399            tool_calls: vec![
400                ScriptedToolCall {
401                    name: "tool_a".to_string(),
402                    input: json!({}),
403                    id: None, // auto-generate
404                },
405                ScriptedToolCall {
406                    name: "tool_b".to_string(),
407                    input: json!({}),
408                    id: None, // auto-generate
409                },
410            ],
411        }];
412        let llm = ScriptedLlm::new("auto-id", turns);
413
414        let request = LlmRequest::new("auto-id", vec![]);
415        let mut stream = llm.generate_content(request, false).await.unwrap();
416        let response = stream.next().await.unwrap().unwrap();
417        let content = response.content.unwrap();
418
419        // Both should have deterministic IDs based on turn and index.
420        match &content.parts[0] {
421            adk_core::types::Part::FunctionCall { id, .. } => {
422                assert_eq!(id, &Some("scripted_tc_0_0".to_string()));
423            }
424            other => panic!("expected FunctionCall, got: {other:?}"),
425        }
426        match &content.parts[1] {
427            adk_core::types::Part::FunctionCall { id, .. } => {
428                assert_eq!(id, &Some("scripted_tc_0_1".to_string()));
429            }
430            other => panic!("expected FunctionCall, got: {other:?}"),
431        }
432    }
433}