Skip to main content

adk_bench/
workload.rs

1//! Workload schema, loading, and validation.
2//!
3//! Defines the JSON workload format for reproducible benchmarks
4//! and provides built-in workload definitions.
5//!
6//! # Schema
7//!
8//! Workloads are JSON files conforming to the [`Workload`] schema. Each workload
9//! describes an agent scenario with instructions, tools, expected turns, and
10//! metadata annotations.
11//!
12//! # Built-in Workloads
13//!
14//! Three standard workloads are provided via [`builtin_workloads()`]:
15//! - **simple_tool_call** — single tool invocation measuring basic dispatch overhead
16//! - **multi_step_reasoning** — multi-turn reasoning chain with sequential tool use
17//! - **parallel_tool_invocation** — concurrent tool calls measuring parallel dispatch
18//!
19//! A fourth workload, **multi_agent_delegation**, is available via
20//! [`multi_agent_delegation_workload()`] and intended for use when the
21//! `experimental` runtime flag is enabled.
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use adk_bench::workload::{load_workload, builtin_workloads};
27//! use std::path::Path;
28//!
29//! // Load from file
30//! let workload = load_workload(Path::new("my_workload.json"))?;
31//!
32//! // Use built-in workloads
33//! let workloads = builtin_workloads();
34//! ```
35
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::Path;
39
40use crate::error::{BenchError, Result};
41
42/// A benchmark workload definition loaded from JSON.
43///
44/// Workloads define reproducible agent scenarios for benchmarking,
45/// including agent configuration, expected behavior, and metadata.
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47#[serde(rename_all = "camelCase")]
48pub struct Workload {
49    /// Unique workload name.
50    pub name: String,
51    /// Human-readable description.
52    pub description: String,
53    /// Agent configuration for this workload.
54    pub agent: AgentConfig,
55    /// LLM model identifier to use.
56    pub model: String,
57    /// Structured output schema (JSON Schema for response format).
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub output_schema: Option<serde_json::Value>,
60    /// Expected number of agent turns.
61    pub expected_turns: usize,
62    /// Optional metadata annotations (arbitrary key-value pairs).
63    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
64    pub metadata: HashMap<String, serde_json::Value>,
65    /// Schema version for forward compatibility.
66    #[serde(default = "default_schema_version")]
67    pub schema_version: u32,
68}
69
70fn default_schema_version() -> u32 {
71    1
72}
73
74/// Agent configuration within a workload.
75///
76/// Specifies the agent's instructions, available tools, and the
77/// initial user message to benchmark.
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80pub struct AgentConfig {
81    /// System instructions for the agent.
82    pub instructions: String,
83    /// Tool definitions available to the agent.
84    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
85    pub tools: HashMap<String, ToolDefinition>,
86    /// User message to send as the initial prompt.
87    pub user_message: String,
88}
89
90/// Tool definition within a workload.
91///
92/// Describes a simulated tool for benchmarking purposes, including
93/// its schema and optional fixed response for deterministic execution.
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
95#[serde(rename_all = "camelCase")]
96pub struct ToolDefinition {
97    /// Tool description for the LLM.
98    pub description: String,
99    /// JSON Schema for tool parameters.
100    pub parameters: serde_json::Value,
101    /// Simulated execution time in milliseconds (for benchmarking tool dispatch).
102    #[serde(default)]
103    pub simulated_latency_ms: u64,
104    /// Fixed response value returned by the simulated tool.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub fixed_response: Option<serde_json::Value>,
107}
108
109/// Loads and validates a workload from a JSON file path.
110///
111/// Reads the file, parses JSON, and validates all required fields are present
112/// and well-formed. Returns a descriptive [`BenchError::WorkloadValidation`]
113/// error on schema violations.
114///
115/// # Errors
116///
117/// - [`BenchError::WorkloadNotFound`] if the file does not exist
118/// - [`BenchError::WorkloadValidation`] if JSON parsing fails or required fields are invalid
119pub fn load_workload(path: &Path) -> Result<Workload> {
120    let path_str = path.display().to_string();
121
122    if !path.exists() {
123        return Err(BenchError::WorkloadNotFound { path: path_str });
124    }
125
126    let content = std::fs::read_to_string(path).map_err(|e| BenchError::WorkloadValidation {
127        field: "file".to_string(),
128        reason: format!("failed to read workload file '{path_str}': {e}"),
129    })?;
130
131    let workload: Workload =
132        serde_json::from_str(&content).map_err(|e| BenchError::WorkloadValidation {
133            field: parse_error_field(&e),
134            reason: format!("invalid workload JSON: {e}"),
135        })?;
136
137    validate_workload(&workload)?;
138
139    Ok(workload)
140}
141
142/// Returns built-in benchmark workloads.
143///
144/// Provides three standard workloads for common benchmarking scenarios:
145/// - `simple_tool_call` — single tool invocation
146/// - `multi_step_reasoning` — multi-turn reasoning chain
147/// - `parallel_tool_invocation` — concurrent tool calls
148///
149/// The multi-agent delegation workload is intentionally excluded here;
150/// use [`multi_agent_delegation_workload()`] when the `experimental`
151/// runtime flag is enabled.
152pub fn builtin_workloads() -> Vec<Workload> {
153    vec![
154        simple_tool_call_workload(),
155        multi_step_reasoning_workload(),
156        parallel_tool_invocation_workload(),
157    ]
158}
159
160/// Returns the multi-agent delegation workload.
161///
162/// This workload exercises multi-agent orchestration where a coordinator
163/// agent delegates subtasks to specialist agents. It is intended for use
164/// only when the `experimental` runtime configuration flag is enabled,
165/// as the multi-agent API may not be stable.
166pub fn multi_agent_delegation_workload() -> Workload {
167    let mut tools = HashMap::new();
168    tools.insert(
169        "delegate_to_researcher".to_string(),
170        ToolDefinition {
171            description: "Delegate a research subtask to the researcher agent".to_string(),
172            parameters: serde_json::json!({
173                "type": "object",
174                "properties": {
175                    "query": {
176                        "type": "string",
177                        "description": "The research query to investigate"
178                    },
179                    "depth": {
180                        "type": "string",
181                        "enum": ["shallow", "deep"],
182                        "description": "How thorough the research should be"
183                    }
184                },
185                "required": ["query"]
186            }),
187            simulated_latency_ms: 50,
188            fixed_response: Some(serde_json::json!({
189                "findings": "Research results on the topic",
190                "confidence": 0.85,
191                "sources": ["source_1", "source_2"]
192            })),
193        },
194    );
195    tools.insert(
196        "delegate_to_writer".to_string(),
197        ToolDefinition {
198            description: "Delegate a writing subtask to the writer agent".to_string(),
199            parameters: serde_json::json!({
200                "type": "object",
201                "properties": {
202                    "topic": {
203                        "type": "string",
204                        "description": "The topic to write about"
205                    },
206                    "style": {
207                        "type": "string",
208                        "enum": ["formal", "casual", "technical"],
209                        "description": "Writing style"
210                    },
211                    "max_words": {
212                        "type": "integer",
213                        "description": "Maximum word count"
214                    }
215                },
216                "required": ["topic", "style"]
217            }),
218            simulated_latency_ms: 75,
219            fixed_response: Some(serde_json::json!({
220                "content": "Generated content based on research findings",
221                "word_count": 250
222            })),
223        },
224    );
225
226    let mut metadata = HashMap::new();
227    metadata.insert("category".to_string(), serde_json::Value::String("multi-agent".to_string()));
228    metadata.insert("stability".to_string(), serde_json::Value::String("experimental".to_string()));
229
230    Workload {
231        name: "multi_agent_delegation".to_string(),
232        description: "Coordinator agent delegates research and writing subtasks to specialist agents, measuring multi-agent orchestration overhead".to_string(),
233        agent: AgentConfig {
234            instructions: "You are a project coordinator. Break down the user's request into research and writing subtasks. First delegate research to gather information, then delegate writing to produce the final output.".to_string(),
235            tools,
236            user_message: "Write a technical summary about the performance benefits of async runtimes in systems programming.".to_string(),
237        },
238        model: "gemini-2.5-flash".to_string(),
239        output_schema: Some(serde_json::json!({
240            "type": "object",
241            "properties": {
242                "summary": { "type": "string" },
243                "research_quality": { "type": "number" },
244                "delegations_made": { "type": "integer" }
245            },
246            "required": ["summary", "delegations_made"]
247        })),
248        expected_turns: 5,
249        metadata,
250        schema_version: 1,
251    }
252}
253
254fn simple_tool_call_workload() -> Workload {
255    let mut tools = HashMap::new();
256    tools.insert(
257        "get_weather".to_string(),
258        ToolDefinition {
259            description: "Get the current weather for a given city".to_string(),
260            parameters: serde_json::json!({
261                "type": "object",
262                "properties": {
263                    "city": {
264                        "type": "string",
265                        "description": "The city name to get weather for"
266                    },
267                    "units": {
268                        "type": "string",
269                        "enum": ["celsius", "fahrenheit"],
270                        "description": "Temperature units"
271                    }
272                },
273                "required": ["city"]
274            }),
275            simulated_latency_ms: 10,
276            fixed_response: Some(serde_json::json!({
277                "temperature": 22.5,
278                "condition": "sunny",
279                "humidity": 45
280            })),
281        },
282    );
283
284    Workload {
285        name: "simple_tool_call".to_string(),
286        description: "Single tool invocation measuring basic dispatch overhead. The agent receives a weather query and must call one tool to respond."
287            .to_string(),
288        agent: AgentConfig {
289            instructions: "You are a helpful weather assistant. When asked about weather, use the get_weather tool to retrieve current conditions.".to_string(),
290            tools,
291            user_message: "What is the weather in San Francisco?".to_string(),
292        },
293        model: "gemini-2.5-flash".to_string(),
294        output_schema: Some(serde_json::json!({
295            "type": "object",
296            "properties": {
297                "temperature": { "type": "number" },
298                "condition": { "type": "string" },
299                "city": { "type": "string" }
300            },
301            "required": ["temperature", "condition", "city"]
302        })),
303        expected_turns: 2,
304        metadata: HashMap::new(),
305        schema_version: 1,
306    }
307}
308
309fn multi_step_reasoning_workload() -> Workload {
310    let mut tools = HashMap::new();
311    tools.insert(
312        "search_database".to_string(),
313        ToolDefinition {
314            description: "Search a product database by query".to_string(),
315            parameters: serde_json::json!({
316                "type": "object",
317                "properties": {
318                    "query": {
319                        "type": "string",
320                        "description": "Search query"
321                    },
322                    "category": {
323                        "type": "string",
324                        "description": "Product category filter"
325                    },
326                    "max_results": {
327                        "type": "integer",
328                        "description": "Maximum number of results to return"
329                    }
330                },
331                "required": ["query"]
332            }),
333            simulated_latency_ms: 15,
334            fixed_response: Some(serde_json::json!({
335                "results": [
336                    {"id": "p1", "name": "Widget A", "price": 29.99, "rating": 4.5},
337                    {"id": "p2", "name": "Widget B", "price": 19.99, "rating": 4.2},
338                    {"id": "p3", "name": "Widget C", "price": 39.99, "rating": 4.8}
339                ],
340                "total_count": 3
341            })),
342        },
343    );
344    tools.insert(
345        "get_product_details".to_string(),
346        ToolDefinition {
347            description: "Get detailed information about a specific product".to_string(),
348            parameters: serde_json::json!({
349                "type": "object",
350                "properties": {
351                    "product_id": {
352                        "type": "string",
353                        "description": "The product identifier"
354                    }
355                },
356                "required": ["product_id"]
357            }),
358            simulated_latency_ms: 10,
359            fixed_response: Some(serde_json::json!({
360                "id": "p3",
361                "name": "Widget C",
362                "price": 39.99,
363                "rating": 4.8,
364                "reviews": 128,
365                "in_stock": true,
366                "description": "Premium widget with advanced features"
367            })),
368        },
369    );
370    tools.insert(
371        "calculate_shipping".to_string(),
372        ToolDefinition {
373            description: "Calculate shipping cost for a product to a destination".to_string(),
374            parameters: serde_json::json!({
375                "type": "object",
376                "properties": {
377                    "product_id": {
378                        "type": "string",
379                        "description": "The product identifier"
380                    },
381                    "destination": {
382                        "type": "string",
383                        "description": "Shipping destination (zip code or city)"
384                    }
385                },
386                "required": ["product_id", "destination"]
387            }),
388            simulated_latency_ms: 10,
389            fixed_response: Some(serde_json::json!({
390                "cost": 5.99,
391                "estimated_days": 3,
392                "carrier": "standard"
393            })),
394        },
395    );
396
397    Workload {
398        name: "multi_step_reasoning".to_string(),
399        description: "Multi-turn reasoning chain with sequential tool use. The agent must search products, get details on the best match, and calculate shipping — each step depends on previous results."
400            .to_string(),
401        agent: AgentConfig {
402            instructions: "You are a shopping assistant. Help the user find the best product by searching the database, getting details on the top-rated result, and calculating shipping to their location.".to_string(),
403            tools,
404            user_message: "Find me the best-rated widget and tell me the total cost including shipping to 94105.".to_string(),
405        },
406        model: "gemini-2.5-flash".to_string(),
407        output_schema: Some(serde_json::json!({
408            "type": "object",
409            "properties": {
410                "product_name": { "type": "string" },
411                "product_price": { "type": "number" },
412                "shipping_cost": { "type": "number" },
413                "total_cost": { "type": "number" },
414                "estimated_delivery_days": { "type": "integer" }
415            },
416            "required": ["product_name", "total_cost"]
417        })),
418        expected_turns: 4,
419        metadata: HashMap::new(),
420        schema_version: 1,
421    }
422}
423
424fn parallel_tool_invocation_workload() -> Workload {
425    let mut tools = HashMap::new();
426    tools.insert(
427        "fetch_stock_price".to_string(),
428        ToolDefinition {
429            description: "Fetch the current stock price for a ticker symbol".to_string(),
430            parameters: serde_json::json!({
431                "type": "object",
432                "properties": {
433                    "ticker": {
434                        "type": "string",
435                        "description": "Stock ticker symbol (e.g., AAPL, GOOGL)"
436                    }
437                },
438                "required": ["ticker"]
439            }),
440            simulated_latency_ms: 20,
441            fixed_response: Some(serde_json::json!({
442                "ticker": "AAPL",
443                "price": 178.50,
444                "change": 2.30,
445                "change_percent": 1.31
446            })),
447        },
448    );
449    tools.insert(
450        "fetch_company_news".to_string(),
451        ToolDefinition {
452            description: "Fetch recent news headlines for a company".to_string(),
453            parameters: serde_json::json!({
454                "type": "object",
455                "properties": {
456                    "ticker": {
457                        "type": "string",
458                        "description": "Stock ticker symbol"
459                    },
460                    "limit": {
461                        "type": "integer",
462                        "description": "Maximum number of headlines"
463                    }
464                },
465                "required": ["ticker"]
466            }),
467            simulated_latency_ms: 25,
468            fixed_response: Some(serde_json::json!({
469                "headlines": [
470                    "Company reports strong Q4 earnings",
471                    "New product launch announced for next quarter"
472                ]
473            })),
474        },
475    );
476    tools.insert(
477        "fetch_analyst_rating".to_string(),
478        ToolDefinition {
479            description: "Fetch analyst consensus rating for a stock".to_string(),
480            parameters: serde_json::json!({
481                "type": "object",
482                "properties": {
483                    "ticker": {
484                        "type": "string",
485                        "description": "Stock ticker symbol"
486                    }
487                },
488                "required": ["ticker"]
489            }),
490            simulated_latency_ms: 15,
491            fixed_response: Some(serde_json::json!({
492                "rating": "buy",
493                "target_price": 195.00,
494                "analyst_count": 32
495            })),
496        },
497    );
498
499    Workload {
500        name: "parallel_tool_invocation".to_string(),
501        description: "Concurrent tool calls measuring parallel dispatch efficiency. The agent must fetch stock price, news, and analyst rating simultaneously for a portfolio analysis."
502            .to_string(),
503        agent: AgentConfig {
504            instructions: "You are a financial analyst assistant. When asked about a stock, fetch the current price, recent news, and analyst rating in parallel to provide a comprehensive summary.".to_string(),
505            tools,
506            user_message: "Give me a complete analysis of AAPL including current price, recent news, and analyst consensus.".to_string(),
507        },
508        model: "gemini-2.5-flash".to_string(),
509        output_schema: Some(serde_json::json!({
510            "type": "object",
511            "properties": {
512                "ticker": { "type": "string" },
513                "current_price": { "type": "number" },
514                "analyst_rating": { "type": "string" },
515                "target_price": { "type": "number" },
516                "summary": { "type": "string" }
517            },
518            "required": ["ticker", "current_price", "analyst_rating"]
519        })),
520        expected_turns: 2,
521        metadata: HashMap::new(),
522        schema_version: 1,
523    }
524}
525
526/// Validates a workload's required fields and constraints.
527fn validate_workload(workload: &Workload) -> Result<()> {
528    if workload.name.is_empty() {
529        return Err(BenchError::WorkloadValidation {
530            field: "name".to_string(),
531            reason: "workload name must not be empty".to_string(),
532        });
533    }
534
535    if workload.description.is_empty() {
536        return Err(BenchError::WorkloadValidation {
537            field: "description".to_string(),
538            reason: "workload description must not be empty".to_string(),
539        });
540    }
541
542    if workload.model.is_empty() {
543        return Err(BenchError::WorkloadValidation {
544            field: "model".to_string(),
545            reason: "model identifier must not be empty".to_string(),
546        });
547    }
548
549    if workload.agent.instructions.is_empty() {
550        return Err(BenchError::WorkloadValidation {
551            field: "agent.instructions".to_string(),
552            reason: "agent instructions must not be empty".to_string(),
553        });
554    }
555
556    if workload.agent.user_message.is_empty() {
557        return Err(BenchError::WorkloadValidation {
558            field: "agent.userMessage".to_string(),
559            reason: "agent user message must not be empty".to_string(),
560        });
561    }
562
563    if workload.expected_turns == 0 {
564        return Err(BenchError::WorkloadValidation {
565            field: "expectedTurns".to_string(),
566            reason: "expected turns must be at least 1".to_string(),
567        });
568    }
569
570    if workload.schema_version == 0 {
571        return Err(BenchError::WorkloadValidation {
572            field: "schemaVersion".to_string(),
573            reason: "schema version must be at least 1".to_string(),
574        });
575    }
576
577    // Validate tool definitions
578    for (tool_name, tool_def) in &workload.agent.tools {
579        if tool_def.description.is_empty() {
580            return Err(BenchError::WorkloadValidation {
581                field: format!("agent.tools.{tool_name}.description"),
582                reason: "tool description must not be empty".to_string(),
583            });
584        }
585    }
586
587    Ok(())
588}
589
590/// Extracts the field name from a serde_json parse error when possible.
591fn parse_error_field(error: &serde_json::Error) -> String {
592    // serde_json errors include line/column but not always the field name.
593    // We provide the best context available.
594    let msg = error.to_string();
595    if msg.contains("missing field") {
596        // Extract field name from "missing field `fieldName`"
597        if let Some(start) = msg.find('`')
598            && let Some(end) = msg[start + 1..].find('`')
599        {
600            return msg[start + 1..start + 1 + end].to_string();
601        }
602    }
603    "root".to_string()
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    use std::io::Write;
610    use tempfile::NamedTempFile;
611
612    #[test]
613    fn test_builtin_workloads_count() {
614        let workloads = builtin_workloads();
615        assert_eq!(workloads.len(), 3);
616    }
617
618    #[test]
619    fn test_builtin_workload_names() {
620        let workloads = builtin_workloads();
621        let names: Vec<&str> = workloads.iter().map(|w| w.name.as_str()).collect();
622        assert!(names.contains(&"simple_tool_call"));
623        assert!(names.contains(&"multi_step_reasoning"));
624        assert!(names.contains(&"parallel_tool_invocation"));
625    }
626
627    #[test]
628    fn test_multi_agent_delegation_not_in_builtin() {
629        let workloads = builtin_workloads();
630        let names: Vec<&str> = workloads.iter().map(|w| w.name.as_str()).collect();
631        assert!(!names.contains(&"multi_agent_delegation"));
632    }
633
634    #[test]
635    fn test_multi_agent_delegation_workload() {
636        let workload = multi_agent_delegation_workload();
637        assert_eq!(workload.name, "multi_agent_delegation");
638        assert_eq!(workload.expected_turns, 5);
639        assert!(workload.agent.tools.contains_key("delegate_to_researcher"));
640        assert!(workload.agent.tools.contains_key("delegate_to_writer"));
641        assert!(workload.metadata.contains_key("stability"));
642    }
643
644    #[test]
645    fn test_workload_serialization_round_trip() {
646        let workloads = builtin_workloads();
647        for workload in &workloads {
648            let json = serde_json::to_string(workload).unwrap();
649            let deserialized: Workload = serde_json::from_str(&json).unwrap();
650            assert_eq!(workload, &deserialized);
651        }
652    }
653
654    #[test]
655    fn test_load_workload_not_found() {
656        let result = load_workload(Path::new("/nonexistent/path.json"));
657        assert!(result.is_err());
658        let err = result.unwrap_err();
659        assert!(matches!(err, BenchError::WorkloadNotFound { .. }));
660    }
661
662    #[test]
663    fn test_load_workload_invalid_json() {
664        let mut file = NamedTempFile::new().unwrap();
665        writeln!(file, "not valid json").unwrap();
666        let result = load_workload(file.path());
667        assert!(result.is_err());
668        let err = result.unwrap_err();
669        assert!(matches!(err, BenchError::WorkloadValidation { .. }));
670    }
671
672    #[test]
673    fn test_load_workload_missing_field() {
674        let mut file = NamedTempFile::new().unwrap();
675        writeln!(file, r#"{{"name": "test"}}"#).unwrap();
676        let result = load_workload(file.path());
677        assert!(result.is_err());
678        let err = result.unwrap_err();
679        assert!(matches!(err, BenchError::WorkloadValidation { .. }));
680    }
681
682    #[test]
683    fn test_load_workload_valid() {
684        let workload = simple_tool_call_workload();
685        let json = serde_json::to_string_pretty(&workload).unwrap();
686
687        let mut file = NamedTempFile::new().unwrap();
688        write!(file, "{json}").unwrap();
689
690        let loaded = load_workload(file.path()).unwrap();
691        assert_eq!(workload, loaded);
692    }
693
694    #[test]
695    fn test_validate_empty_name() {
696        let mut workload = simple_tool_call_workload();
697        workload.name = String::new();
698        let result = validate_workload(&workload);
699        assert!(result.is_err());
700    }
701
702    #[test]
703    fn test_validate_zero_expected_turns() {
704        let mut workload = simple_tool_call_workload();
705        workload.expected_turns = 0;
706        let result = validate_workload(&workload);
707        assert!(result.is_err());
708    }
709
710    #[test]
711    fn test_schema_version_defaults_to_1() {
712        let json = r#"{
713            "name": "test",
714            "description": "test workload",
715            "agent": {
716                "instructions": "do something",
717                "userMessage": "hello"
718            },
719            "model": "gemini-2.5-flash",
720            "expectedTurns": 2
721        }"#;
722        let workload: Workload = serde_json::from_str(json).unwrap();
723        assert_eq!(workload.schema_version, 1);
724    }
725
726    #[test]
727    fn test_metadata_preserved_in_round_trip() {
728        let mut workload = simple_tool_call_workload();
729        workload
730            .metadata
731            .insert("author".to_string(), serde_json::Value::String("test-user".to_string()));
732        workload.metadata.insert("version".to_string(), serde_json::json!(2));
733
734        let json = serde_json::to_string(&workload).unwrap();
735        let deserialized: Workload = serde_json::from_str(&json).unwrap();
736        assert_eq!(workload.metadata, deserialized.metadata);
737    }
738}