1use adk_core::{
46 Llm, LlmRequest, LlmResponse, LlmResponseStream, Result as AdkResult, types::Content,
47};
48use async_stream::stream;
49use async_trait::async_trait;
50use serde::{Deserialize, Serialize};
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScriptedTurn {
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub text: Option<String>,
76 #[serde(default, skip_serializing_if = "Vec::is_empty")]
78 pub tool_calls: Vec<ScriptedToolCall>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ScriptedToolCall {
88 pub name: String,
90 pub input: serde_json::Value,
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub id: Option<String>,
95}
96
97pub struct ScriptedLlm {
116 name: String,
118 turns: Vec<ScriptedTurn>,
120 current_turn: AtomicUsize,
122}
123
124impl ScriptedLlm {
125 pub fn new(name: impl Into<String>, turns: Vec<ScriptedTurn>) -> Self {
130 Self { name: name.into(), turns, current_turn: AtomicUsize::new(0) }
131 }
132
133 pub fn turns_consumed(&self) -> usize {
135 self.current_turn.load(Ordering::Relaxed)
136 }
137
138 pub fn total_turns(&self) -> usize {
140 self.turns.len()
141 }
142
143 fn build_response(turn: &ScriptedTurn, turn_index: usize) -> LlmResponse {
145 use adk_core::FinishReason;
146 use adk_core::types::Part;
147
148 let mut parts = Vec::new();
149
150 if let Some(text) = &turn.text {
152 parts.push(Part::Text { text: text.clone() });
153 }
154
155 for (i, tool_call) in turn.tool_calls.iter().enumerate() {
157 let id =
158 tool_call.id.clone().unwrap_or_else(|| format!("scripted_tc_{turn_index}_{i}"));
159 parts.push(Part::FunctionCall {
160 name: tool_call.name.clone(),
161 args: tool_call.input.clone(),
162 id: Some(id),
163 thought_signature: None,
164 });
165 }
166
167 let content = if parts.is_empty() {
168 None
169 } else {
170 Some(Content { role: "model".to_string(), parts })
171 };
172
173 LlmResponse {
174 content,
175 usage_metadata: None,
176 finish_reason: Some(FinishReason::Stop),
177 citation_metadata: None,
178 partial: false,
179 turn_complete: true,
180 interrupted: false,
181 error_code: None,
182 error_message: None,
183 provider_metadata: None,
184 interaction_id: None,
185 }
186 }
187}
188
189impl std::fmt::Debug for ScriptedLlm {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("ScriptedLlm")
192 .field("name", &self.name)
193 .field("turns", &self.turns.len())
194 .field("current_turn", &self.current_turn.load(Ordering::Relaxed))
195 .finish()
196 }
197}
198
199#[async_trait]
200impl Llm for ScriptedLlm {
201 fn name(&self) -> &str {
202 &self.name
203 }
204
205 async fn generate_content(
206 &self,
207 _request: LlmRequest,
208 _stream: bool,
209 ) -> AdkResult<LlmResponseStream> {
210 let turn_index = self.current_turn.fetch_add(1, Ordering::Relaxed);
211
212 let response = if turn_index < self.turns.len() {
213 Self::build_response(&self.turns[turn_index], turn_index)
214 } else {
215 LlmResponse {
217 content: Some(Content {
218 role: "model".to_string(),
219 parts: vec![adk_core::types::Part::Text {
220 text: "[ScriptedLlm: no more scripted turns]".to_string(),
221 }],
222 }),
223 usage_metadata: None,
224 finish_reason: Some(adk_core::FinishReason::Stop),
225 citation_metadata: None,
226 partial: false,
227 turn_complete: true,
228 interrupted: false,
229 error_code: None,
230 error_message: None,
231 provider_metadata: None,
232 interaction_id: None,
233 }
234 };
235
236 let response_stream = stream! {
237 yield Ok(response);
238 };
239
240 Ok(Box::pin(response_stream))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use futures::StreamExt;
248 use serde_json::json;
249
250 #[tokio::test]
251 async fn test_scripted_llm_returns_text() {
252 let turns =
253 vec![ScriptedTurn { text: Some("Hello, world!".to_string()), tool_calls: vec![] }];
254 let llm = ScriptedLlm::new("test-model", turns);
255
256 assert_eq!(llm.name(), "test-model");
257
258 let request = LlmRequest::new("test-model", vec![]);
259 let mut stream = llm.generate_content(request, false).await.unwrap();
260
261 let response = stream.next().await.unwrap().unwrap();
262 assert!(response.turn_complete);
263 assert!(!response.partial);
264
265 let content = response.content.unwrap();
266 assert_eq!(content.role, "model");
267 assert_eq!(content.parts.len(), 1);
268 match &content.parts[0] {
269 adk_core::types::Part::Text { text } => {
270 assert_eq!(text, "Hello, world!");
271 }
272 other => panic!("expected Text part, got: {other:?}"),
273 }
274 }
275
276 #[tokio::test]
277 async fn test_scripted_llm_returns_tool_calls() {
278 let turns = vec![ScriptedTurn {
279 text: None,
280 tool_calls: vec![ScriptedToolCall {
281 name: "web_search".to_string(),
282 input: json!({"query": "rust async"}),
283 id: Some("tc_001".to_string()),
284 }],
285 }];
286 let llm = ScriptedLlm::new("tool-model", turns);
287
288 let request = LlmRequest::new("tool-model", vec![]);
289 let mut stream = llm.generate_content(request, false).await.unwrap();
290
291 let response = stream.next().await.unwrap().unwrap();
292 let content = response.content.unwrap();
293 assert_eq!(content.parts.len(), 1);
294 match &content.parts[0] {
295 adk_core::types::Part::FunctionCall { name, args, id, .. } => {
296 assert_eq!(name, "web_search");
297 assert_eq!(args, &json!({"query": "rust async"}));
298 assert_eq!(id, &Some("tc_001".to_string()));
299 }
300 other => panic!("expected FunctionCall part, got: {other:?}"),
301 }
302 }
303
304 #[tokio::test]
305 async fn test_scripted_llm_advances_through_turns() {
306 let turns = vec![
307 ScriptedTurn { text: Some("First".to_string()), tool_calls: vec![] },
308 ScriptedTurn { text: Some("Second".to_string()), tool_calls: vec![] },
309 ScriptedTurn { text: Some("Third".to_string()), tool_calls: vec![] },
310 ];
311 let llm = ScriptedLlm::new("multi-turn", turns);
312
313 for (i, expected) in ["First", "Second", "Third"].iter().enumerate() {
314 let request = LlmRequest::new("multi-turn", vec![]);
315 let mut stream = llm.generate_content(request, false).await.unwrap();
316 let response = stream.next().await.unwrap().unwrap();
317 let content = response.content.unwrap();
318 match &content.parts[0] {
319 adk_core::types::Part::Text { text } => {
320 assert_eq!(text, *expected);
321 }
322 other => panic!("turn {i}: expected Text, got: {other:?}"),
323 }
324 }
325
326 assert_eq!(llm.turns_consumed(), 3);
327 }
328
329 #[tokio::test]
330 async fn test_scripted_llm_handles_exhaustion() {
331 let turns = vec![ScriptedTurn { text: Some("Only one".to_string()), tool_calls: vec![] }];
332 let llm = ScriptedLlm::new("exhausted", turns);
333
334 let request = LlmRequest::new("exhausted", vec![]);
336 let mut stream = llm.generate_content(request, false).await.unwrap();
337 let _ = stream.next().await.unwrap().unwrap();
338
339 let request = LlmRequest::new("exhausted", vec![]);
341 let mut stream = llm.generate_content(request, false).await.unwrap();
342 let response = stream.next().await.unwrap().unwrap();
343 assert!(response.turn_complete);
344 let content = response.content.unwrap();
345 match &content.parts[0] {
346 adk_core::types::Part::Text { text } => {
347 assert!(text.contains("no more scripted turns"));
348 }
349 other => panic!("expected fallback Text, got: {other:?}"),
350 }
351 }
352
353 #[tokio::test]
354 async fn test_scripted_llm_mixed_text_and_tool_calls() {
355 let turns = vec![ScriptedTurn {
356 text: Some("Let me search for that.".to_string()),
357 tool_calls: vec![ScriptedToolCall {
358 name: "web_search".to_string(),
359 input: json!({"query": "ADK Rust"}),
360 id: Some("tc_mixed".to_string()),
361 }],
362 }];
363 let llm = ScriptedLlm::new("mixed", turns);
364
365 let request = LlmRequest::new("mixed", vec![]);
366 let mut stream = llm.generate_content(request, false).await.unwrap();
367 let response = stream.next().await.unwrap().unwrap();
368 let content = response.content.unwrap();
369
370 assert_eq!(content.parts.len(), 2);
371 assert!(matches!(&content.parts[0], adk_core::types::Part::Text { .. }));
372 assert!(matches!(&content.parts[1], adk_core::types::Part::FunctionCall { .. }));
373 }
374
375 #[tokio::test]
376 async fn test_scripted_turn_serialization_roundtrip() {
377 let turn = ScriptedTurn {
378 text: Some("Hello".to_string()),
379 tool_calls: vec![ScriptedToolCall {
380 name: "search".to_string(),
381 input: json!({"q": "test"}),
382 id: Some("id_1".to_string()),
383 }],
384 };
385
386 let json = serde_json::to_string(&turn).unwrap();
387 let deserialized: ScriptedTurn = serde_json::from_str(&json).unwrap();
388
389 assert_eq!(deserialized.text, turn.text);
390 assert_eq!(deserialized.tool_calls.len(), 1);
391 assert_eq!(deserialized.tool_calls[0].name, "search");
392 assert_eq!(deserialized.tool_calls[0].id, Some("id_1".to_string()));
393 }
394
395 #[tokio::test]
396 async fn test_auto_generated_tool_call_ids() {
397 let turns = vec![ScriptedTurn {
398 text: None,
399 tool_calls: vec![
400 ScriptedToolCall {
401 name: "tool_a".to_string(),
402 input: json!({}),
403 id: None, },
405 ScriptedToolCall {
406 name: "tool_b".to_string(),
407 input: json!({}),
408 id: None, },
410 ],
411 }];
412 let llm = ScriptedLlm::new("auto-id", turns);
413
414 let request = LlmRequest::new("auto-id", vec![]);
415 let mut stream = llm.generate_content(request, false).await.unwrap();
416 let response = stream.next().await.unwrap().unwrap();
417 let content = response.content.unwrap();
418
419 match &content.parts[0] {
421 adk_core::types::Part::FunctionCall { id, .. } => {
422 assert_eq!(id, &Some("scripted_tc_0_0".to_string()));
423 }
424 other => panic!("expected FunctionCall, got: {other:?}"),
425 }
426 match &content.parts[1] {
427 adk_core::types::Part::FunctionCall { id, .. } => {
428 assert_eq!(id, &Some("scripted_tc_0_1".to_string()));
429 }
430 other => panic!("expected FunctionCall, got: {other:?}"),
431 }
432 }
433}