use adk_core::{
Llm, LlmRequest, LlmResponse, LlmResponseStream, Result as AdkResult, types::Content,
};
use async_stream::stream;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptedTurn {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ScriptedToolCall>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptedToolCall {
pub name: String,
pub input: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
}
pub struct ScriptedLlm {
name: String,
turns: Vec<ScriptedTurn>,
current_turn: AtomicUsize,
}
impl ScriptedLlm {
pub fn new(name: impl Into<String>, turns: Vec<ScriptedTurn>) -> Self {
Self { name: name.into(), turns, current_turn: AtomicUsize::new(0) }
}
pub fn turns_consumed(&self) -> usize {
self.current_turn.load(Ordering::Relaxed)
}
pub fn total_turns(&self) -> usize {
self.turns.len()
}
fn build_response(turn: &ScriptedTurn, turn_index: usize) -> LlmResponse {
use adk_core::FinishReason;
use adk_core::types::Part;
let mut parts = Vec::new();
if let Some(text) = &turn.text {
parts.push(Part::Text { text: text.clone() });
}
for (i, tool_call) in turn.tool_calls.iter().enumerate() {
let id =
tool_call.id.clone().unwrap_or_else(|| format!("scripted_tc_{turn_index}_{i}"));
parts.push(Part::FunctionCall {
name: tool_call.name.clone(),
args: tool_call.input.clone(),
id: Some(id),
thought_signature: None,
});
}
let content = if parts.is_empty() {
None
} else {
Some(Content { role: "model".to_string(), parts })
};
LlmResponse {
content,
usage_metadata: None,
finish_reason: Some(FinishReason::Stop),
citation_metadata: None,
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: None,
interaction_id: None,
}
}
}
impl std::fmt::Debug for ScriptedLlm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScriptedLlm")
.field("name", &self.name)
.field("turns", &self.turns.len())
.field("current_turn", &self.current_turn.load(Ordering::Relaxed))
.finish()
}
}
#[async_trait]
impl Llm for ScriptedLlm {
fn name(&self) -> &str {
&self.name
}
async fn generate_content(
&self,
_request: LlmRequest,
_stream: bool,
) -> AdkResult<LlmResponseStream> {
let turn_index = self.current_turn.fetch_add(1, Ordering::Relaxed);
let response = if turn_index < self.turns.len() {
Self::build_response(&self.turns[turn_index], turn_index)
} else {
LlmResponse {
content: Some(Content {
role: "model".to_string(),
parts: vec![adk_core::types::Part::Text {
text: "[ScriptedLlm: no more scripted turns]".to_string(),
}],
}),
usage_metadata: None,
finish_reason: Some(adk_core::FinishReason::Stop),
citation_metadata: None,
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: None,
interaction_id: None,
}
};
let response_stream = stream! {
yield Ok(response);
};
Ok(Box::pin(response_stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use serde_json::json;
#[tokio::test]
async fn test_scripted_llm_returns_text() {
let turns =
vec![ScriptedTurn { text: Some("Hello, world!".to_string()), tool_calls: vec![] }];
let llm = ScriptedLlm::new("test-model", turns);
assert_eq!(llm.name(), "test-model");
let request = LlmRequest::new("test-model", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
assert!(response.turn_complete);
assert!(!response.partial);
let content = response.content.unwrap();
assert_eq!(content.role, "model");
assert_eq!(content.parts.len(), 1);
match &content.parts[0] {
adk_core::types::Part::Text { text } => {
assert_eq!(text, "Hello, world!");
}
other => panic!("expected Text part, got: {other:?}"),
}
}
#[tokio::test]
async fn test_scripted_llm_returns_tool_calls() {
let turns = vec![ScriptedTurn {
text: None,
tool_calls: vec![ScriptedToolCall {
name: "web_search".to_string(),
input: json!({"query": "rust async"}),
id: Some("tc_001".to_string()),
}],
}];
let llm = ScriptedLlm::new("tool-model", turns);
let request = LlmRequest::new("tool-model", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
let content = response.content.unwrap();
assert_eq!(content.parts.len(), 1);
match &content.parts[0] {
adk_core::types::Part::FunctionCall { name, args, id, .. } => {
assert_eq!(name, "web_search");
assert_eq!(args, &json!({"query": "rust async"}));
assert_eq!(id, &Some("tc_001".to_string()));
}
other => panic!("expected FunctionCall part, got: {other:?}"),
}
}
#[tokio::test]
async fn test_scripted_llm_advances_through_turns() {
let turns = vec![
ScriptedTurn { text: Some("First".to_string()), tool_calls: vec![] },
ScriptedTurn { text: Some("Second".to_string()), tool_calls: vec![] },
ScriptedTurn { text: Some("Third".to_string()), tool_calls: vec![] },
];
let llm = ScriptedLlm::new("multi-turn", turns);
for (i, expected) in ["First", "Second", "Third"].iter().enumerate() {
let request = LlmRequest::new("multi-turn", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
let content = response.content.unwrap();
match &content.parts[0] {
adk_core::types::Part::Text { text } => {
assert_eq!(text, *expected);
}
other => panic!("turn {i}: expected Text, got: {other:?}"),
}
}
assert_eq!(llm.turns_consumed(), 3);
}
#[tokio::test]
async fn test_scripted_llm_handles_exhaustion() {
let turns = vec![ScriptedTurn { text: Some("Only one".to_string()), tool_calls: vec![] }];
let llm = ScriptedLlm::new("exhausted", turns);
let request = LlmRequest::new("exhausted", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let _ = stream.next().await.unwrap().unwrap();
let request = LlmRequest::new("exhausted", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
assert!(response.turn_complete);
let content = response.content.unwrap();
match &content.parts[0] {
adk_core::types::Part::Text { text } => {
assert!(text.contains("no more scripted turns"));
}
other => panic!("expected fallback Text, got: {other:?}"),
}
}
#[tokio::test]
async fn test_scripted_llm_mixed_text_and_tool_calls() {
let turns = vec![ScriptedTurn {
text: Some("Let me search for that.".to_string()),
tool_calls: vec![ScriptedToolCall {
name: "web_search".to_string(),
input: json!({"query": "ADK Rust"}),
id: Some("tc_mixed".to_string()),
}],
}];
let llm = ScriptedLlm::new("mixed", turns);
let request = LlmRequest::new("mixed", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
let content = response.content.unwrap();
assert_eq!(content.parts.len(), 2);
assert!(matches!(&content.parts[0], adk_core::types::Part::Text { .. }));
assert!(matches!(&content.parts[1], adk_core::types::Part::FunctionCall { .. }));
}
#[tokio::test]
async fn test_scripted_turn_serialization_roundtrip() {
let turn = ScriptedTurn {
text: Some("Hello".to_string()),
tool_calls: vec![ScriptedToolCall {
name: "search".to_string(),
input: json!({"q": "test"}),
id: Some("id_1".to_string()),
}],
};
let json = serde_json::to_string(&turn).unwrap();
let deserialized: ScriptedTurn = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.text, turn.text);
assert_eq!(deserialized.tool_calls.len(), 1);
assert_eq!(deserialized.tool_calls[0].name, "search");
assert_eq!(deserialized.tool_calls[0].id, Some("id_1".to_string()));
}
#[tokio::test]
async fn test_auto_generated_tool_call_ids() {
let turns = vec![ScriptedTurn {
text: None,
tool_calls: vec![
ScriptedToolCall {
name: "tool_a".to_string(),
input: json!({}),
id: None, },
ScriptedToolCall {
name: "tool_b".to_string(),
input: json!({}),
id: None, },
],
}];
let llm = ScriptedLlm::new("auto-id", turns);
let request = LlmRequest::new("auto-id", vec![]);
let mut stream = llm.generate_content(request, false).await.unwrap();
let response = stream.next().await.unwrap().unwrap();
let content = response.content.unwrap();
match &content.parts[0] {
adk_core::types::Part::FunctionCall { id, .. } => {
assert_eq!(id, &Some("scripted_tc_0_0".to_string()));
}
other => panic!("expected FunctionCall, got: {other:?}"),
}
match &content.parts[1] {
adk_core::types::Part::FunctionCall { id, .. } => {
assert_eq!(id, &Some("scripted_tc_0_1".to_string()));
}
other => panic!("expected FunctionCall, got: {other:?}"),
}
}
}