1use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::Path;
39
40use crate::error::{BenchError, Result};
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47#[serde(rename_all = "camelCase")]
48pub struct Workload {
49 pub name: String,
51 pub description: String,
53 pub agent: AgentConfig,
55 pub model: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub output_schema: Option<serde_json::Value>,
60 pub expected_turns: usize,
62 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
64 pub metadata: HashMap<String, serde_json::Value>,
65 #[serde(default = "default_schema_version")]
67 pub schema_version: u32,
68}
69
70fn default_schema_version() -> u32 {
71 1
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80pub struct AgentConfig {
81 pub instructions: String,
83 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
85 pub tools: HashMap<String, ToolDefinition>,
86 pub user_message: String,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
95#[serde(rename_all = "camelCase")]
96pub struct ToolDefinition {
97 pub description: String,
99 pub parameters: serde_json::Value,
101 #[serde(default)]
103 pub simulated_latency_ms: u64,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub fixed_response: Option<serde_json::Value>,
107}
108
109pub fn load_workload(path: &Path) -> Result<Workload> {
120 let path_str = path.display().to_string();
121
122 if !path.exists() {
123 return Err(BenchError::WorkloadNotFound { path: path_str });
124 }
125
126 let content = std::fs::read_to_string(path).map_err(|e| BenchError::WorkloadValidation {
127 field: "file".to_string(),
128 reason: format!("failed to read workload file '{path_str}': {e}"),
129 })?;
130
131 let workload: Workload =
132 serde_json::from_str(&content).map_err(|e| BenchError::WorkloadValidation {
133 field: parse_error_field(&e),
134 reason: format!("invalid workload JSON: {e}"),
135 })?;
136
137 validate_workload(&workload)?;
138
139 Ok(workload)
140}
141
142pub fn builtin_workloads() -> Vec<Workload> {
153 vec![
154 simple_tool_call_workload(),
155 multi_step_reasoning_workload(),
156 parallel_tool_invocation_workload(),
157 ]
158}
159
160pub fn multi_agent_delegation_workload() -> Workload {
167 let mut tools = HashMap::new();
168 tools.insert(
169 "delegate_to_researcher".to_string(),
170 ToolDefinition {
171 description: "Delegate a research subtask to the researcher agent".to_string(),
172 parameters: serde_json::json!({
173 "type": "object",
174 "properties": {
175 "query": {
176 "type": "string",
177 "description": "The research query to investigate"
178 },
179 "depth": {
180 "type": "string",
181 "enum": ["shallow", "deep"],
182 "description": "How thorough the research should be"
183 }
184 },
185 "required": ["query"]
186 }),
187 simulated_latency_ms: 50,
188 fixed_response: Some(serde_json::json!({
189 "findings": "Research results on the topic",
190 "confidence": 0.85,
191 "sources": ["source_1", "source_2"]
192 })),
193 },
194 );
195 tools.insert(
196 "delegate_to_writer".to_string(),
197 ToolDefinition {
198 description: "Delegate a writing subtask to the writer agent".to_string(),
199 parameters: serde_json::json!({
200 "type": "object",
201 "properties": {
202 "topic": {
203 "type": "string",
204 "description": "The topic to write about"
205 },
206 "style": {
207 "type": "string",
208 "enum": ["formal", "casual", "technical"],
209 "description": "Writing style"
210 },
211 "max_words": {
212 "type": "integer",
213 "description": "Maximum word count"
214 }
215 },
216 "required": ["topic", "style"]
217 }),
218 simulated_latency_ms: 75,
219 fixed_response: Some(serde_json::json!({
220 "content": "Generated content based on research findings",
221 "word_count": 250
222 })),
223 },
224 );
225
226 let mut metadata = HashMap::new();
227 metadata.insert("category".to_string(), serde_json::Value::String("multi-agent".to_string()));
228 metadata.insert("stability".to_string(), serde_json::Value::String("experimental".to_string()));
229
230 Workload {
231 name: "multi_agent_delegation".to_string(),
232 description: "Coordinator agent delegates research and writing subtasks to specialist agents, measuring multi-agent orchestration overhead".to_string(),
233 agent: AgentConfig {
234 instructions: "You are a project coordinator. Break down the user's request into research and writing subtasks. First delegate research to gather information, then delegate writing to produce the final output.".to_string(),
235 tools,
236 user_message: "Write a technical summary about the performance benefits of async runtimes in systems programming.".to_string(),
237 },
238 model: "gemini-2.5-flash".to_string(),
239 output_schema: Some(serde_json::json!({
240 "type": "object",
241 "properties": {
242 "summary": { "type": "string" },
243 "research_quality": { "type": "number" },
244 "delegations_made": { "type": "integer" }
245 },
246 "required": ["summary", "delegations_made"]
247 })),
248 expected_turns: 5,
249 metadata,
250 schema_version: 1,
251 }
252}
253
254fn simple_tool_call_workload() -> Workload {
255 let mut tools = HashMap::new();
256 tools.insert(
257 "get_weather".to_string(),
258 ToolDefinition {
259 description: "Get the current weather for a given city".to_string(),
260 parameters: serde_json::json!({
261 "type": "object",
262 "properties": {
263 "city": {
264 "type": "string",
265 "description": "The city name to get weather for"
266 },
267 "units": {
268 "type": "string",
269 "enum": ["celsius", "fahrenheit"],
270 "description": "Temperature units"
271 }
272 },
273 "required": ["city"]
274 }),
275 simulated_latency_ms: 10,
276 fixed_response: Some(serde_json::json!({
277 "temperature": 22.5,
278 "condition": "sunny",
279 "humidity": 45
280 })),
281 },
282 );
283
284 Workload {
285 name: "simple_tool_call".to_string(),
286 description: "Single tool invocation measuring basic dispatch overhead. The agent receives a weather query and must call one tool to respond."
287 .to_string(),
288 agent: AgentConfig {
289 instructions: "You are a helpful weather assistant. When asked about weather, use the get_weather tool to retrieve current conditions.".to_string(),
290 tools,
291 user_message: "What is the weather in San Francisco?".to_string(),
292 },
293 model: "gemini-2.5-flash".to_string(),
294 output_schema: Some(serde_json::json!({
295 "type": "object",
296 "properties": {
297 "temperature": { "type": "number" },
298 "condition": { "type": "string" },
299 "city": { "type": "string" }
300 },
301 "required": ["temperature", "condition", "city"]
302 })),
303 expected_turns: 2,
304 metadata: HashMap::new(),
305 schema_version: 1,
306 }
307}
308
309fn multi_step_reasoning_workload() -> Workload {
310 let mut tools = HashMap::new();
311 tools.insert(
312 "search_database".to_string(),
313 ToolDefinition {
314 description: "Search a product database by query".to_string(),
315 parameters: serde_json::json!({
316 "type": "object",
317 "properties": {
318 "query": {
319 "type": "string",
320 "description": "Search query"
321 },
322 "category": {
323 "type": "string",
324 "description": "Product category filter"
325 },
326 "max_results": {
327 "type": "integer",
328 "description": "Maximum number of results to return"
329 }
330 },
331 "required": ["query"]
332 }),
333 simulated_latency_ms: 15,
334 fixed_response: Some(serde_json::json!({
335 "results": [
336 {"id": "p1", "name": "Widget A", "price": 29.99, "rating": 4.5},
337 {"id": "p2", "name": "Widget B", "price": 19.99, "rating": 4.2},
338 {"id": "p3", "name": "Widget C", "price": 39.99, "rating": 4.8}
339 ],
340 "total_count": 3
341 })),
342 },
343 );
344 tools.insert(
345 "get_product_details".to_string(),
346 ToolDefinition {
347 description: "Get detailed information about a specific product".to_string(),
348 parameters: serde_json::json!({
349 "type": "object",
350 "properties": {
351 "product_id": {
352 "type": "string",
353 "description": "The product identifier"
354 }
355 },
356 "required": ["product_id"]
357 }),
358 simulated_latency_ms: 10,
359 fixed_response: Some(serde_json::json!({
360 "id": "p3",
361 "name": "Widget C",
362 "price": 39.99,
363 "rating": 4.8,
364 "reviews": 128,
365 "in_stock": true,
366 "description": "Premium widget with advanced features"
367 })),
368 },
369 );
370 tools.insert(
371 "calculate_shipping".to_string(),
372 ToolDefinition {
373 description: "Calculate shipping cost for a product to a destination".to_string(),
374 parameters: serde_json::json!({
375 "type": "object",
376 "properties": {
377 "product_id": {
378 "type": "string",
379 "description": "The product identifier"
380 },
381 "destination": {
382 "type": "string",
383 "description": "Shipping destination (zip code or city)"
384 }
385 },
386 "required": ["product_id", "destination"]
387 }),
388 simulated_latency_ms: 10,
389 fixed_response: Some(serde_json::json!({
390 "cost": 5.99,
391 "estimated_days": 3,
392 "carrier": "standard"
393 })),
394 },
395 );
396
397 Workload {
398 name: "multi_step_reasoning".to_string(),
399 description: "Multi-turn reasoning chain with sequential tool use. The agent must search products, get details on the best match, and calculate shipping — each step depends on previous results."
400 .to_string(),
401 agent: AgentConfig {
402 instructions: "You are a shopping assistant. Help the user find the best product by searching the database, getting details on the top-rated result, and calculating shipping to their location.".to_string(),
403 tools,
404 user_message: "Find me the best-rated widget and tell me the total cost including shipping to 94105.".to_string(),
405 },
406 model: "gemini-2.5-flash".to_string(),
407 output_schema: Some(serde_json::json!({
408 "type": "object",
409 "properties": {
410 "product_name": { "type": "string" },
411 "product_price": { "type": "number" },
412 "shipping_cost": { "type": "number" },
413 "total_cost": { "type": "number" },
414 "estimated_delivery_days": { "type": "integer" }
415 },
416 "required": ["product_name", "total_cost"]
417 })),
418 expected_turns: 4,
419 metadata: HashMap::new(),
420 schema_version: 1,
421 }
422}
423
424fn parallel_tool_invocation_workload() -> Workload {
425 let mut tools = HashMap::new();
426 tools.insert(
427 "fetch_stock_price".to_string(),
428 ToolDefinition {
429 description: "Fetch the current stock price for a ticker symbol".to_string(),
430 parameters: serde_json::json!({
431 "type": "object",
432 "properties": {
433 "ticker": {
434 "type": "string",
435 "description": "Stock ticker symbol (e.g., AAPL, GOOGL)"
436 }
437 },
438 "required": ["ticker"]
439 }),
440 simulated_latency_ms: 20,
441 fixed_response: Some(serde_json::json!({
442 "ticker": "AAPL",
443 "price": 178.50,
444 "change": 2.30,
445 "change_percent": 1.31
446 })),
447 },
448 );
449 tools.insert(
450 "fetch_company_news".to_string(),
451 ToolDefinition {
452 description: "Fetch recent news headlines for a company".to_string(),
453 parameters: serde_json::json!({
454 "type": "object",
455 "properties": {
456 "ticker": {
457 "type": "string",
458 "description": "Stock ticker symbol"
459 },
460 "limit": {
461 "type": "integer",
462 "description": "Maximum number of headlines"
463 }
464 },
465 "required": ["ticker"]
466 }),
467 simulated_latency_ms: 25,
468 fixed_response: Some(serde_json::json!({
469 "headlines": [
470 "Company reports strong Q4 earnings",
471 "New product launch announced for next quarter"
472 ]
473 })),
474 },
475 );
476 tools.insert(
477 "fetch_analyst_rating".to_string(),
478 ToolDefinition {
479 description: "Fetch analyst consensus rating for a stock".to_string(),
480 parameters: serde_json::json!({
481 "type": "object",
482 "properties": {
483 "ticker": {
484 "type": "string",
485 "description": "Stock ticker symbol"
486 }
487 },
488 "required": ["ticker"]
489 }),
490 simulated_latency_ms: 15,
491 fixed_response: Some(serde_json::json!({
492 "rating": "buy",
493 "target_price": 195.00,
494 "analyst_count": 32
495 })),
496 },
497 );
498
499 Workload {
500 name: "parallel_tool_invocation".to_string(),
501 description: "Concurrent tool calls measuring parallel dispatch efficiency. The agent must fetch stock price, news, and analyst rating simultaneously for a portfolio analysis."
502 .to_string(),
503 agent: AgentConfig {
504 instructions: "You are a financial analyst assistant. When asked about a stock, fetch the current price, recent news, and analyst rating in parallel to provide a comprehensive summary.".to_string(),
505 tools,
506 user_message: "Give me a complete analysis of AAPL including current price, recent news, and analyst consensus.".to_string(),
507 },
508 model: "gemini-2.5-flash".to_string(),
509 output_schema: Some(serde_json::json!({
510 "type": "object",
511 "properties": {
512 "ticker": { "type": "string" },
513 "current_price": { "type": "number" },
514 "analyst_rating": { "type": "string" },
515 "target_price": { "type": "number" },
516 "summary": { "type": "string" }
517 },
518 "required": ["ticker", "current_price", "analyst_rating"]
519 })),
520 expected_turns: 2,
521 metadata: HashMap::new(),
522 schema_version: 1,
523 }
524}
525
526fn validate_workload(workload: &Workload) -> Result<()> {
528 if workload.name.is_empty() {
529 return Err(BenchError::WorkloadValidation {
530 field: "name".to_string(),
531 reason: "workload name must not be empty".to_string(),
532 });
533 }
534
535 if workload.description.is_empty() {
536 return Err(BenchError::WorkloadValidation {
537 field: "description".to_string(),
538 reason: "workload description must not be empty".to_string(),
539 });
540 }
541
542 if workload.model.is_empty() {
543 return Err(BenchError::WorkloadValidation {
544 field: "model".to_string(),
545 reason: "model identifier must not be empty".to_string(),
546 });
547 }
548
549 if workload.agent.instructions.is_empty() {
550 return Err(BenchError::WorkloadValidation {
551 field: "agent.instructions".to_string(),
552 reason: "agent instructions must not be empty".to_string(),
553 });
554 }
555
556 if workload.agent.user_message.is_empty() {
557 return Err(BenchError::WorkloadValidation {
558 field: "agent.userMessage".to_string(),
559 reason: "agent user message must not be empty".to_string(),
560 });
561 }
562
563 if workload.expected_turns == 0 {
564 return Err(BenchError::WorkloadValidation {
565 field: "expectedTurns".to_string(),
566 reason: "expected turns must be at least 1".to_string(),
567 });
568 }
569
570 if workload.schema_version == 0 {
571 return Err(BenchError::WorkloadValidation {
572 field: "schemaVersion".to_string(),
573 reason: "schema version must be at least 1".to_string(),
574 });
575 }
576
577 for (tool_name, tool_def) in &workload.agent.tools {
579 if tool_def.description.is_empty() {
580 return Err(BenchError::WorkloadValidation {
581 field: format!("agent.tools.{tool_name}.description"),
582 reason: "tool description must not be empty".to_string(),
583 });
584 }
585 }
586
587 Ok(())
588}
589
590fn parse_error_field(error: &serde_json::Error) -> String {
592 let msg = error.to_string();
595 if msg.contains("missing field") {
596 if let Some(start) = msg.find('`')
598 && let Some(end) = msg[start + 1..].find('`')
599 {
600 return msg[start + 1..start + 1 + end].to_string();
601 }
602 }
603 "root".to_string()
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use std::io::Write;
610 use tempfile::NamedTempFile;
611
612 #[test]
613 fn test_builtin_workloads_count() {
614 let workloads = builtin_workloads();
615 assert_eq!(workloads.len(), 3);
616 }
617
618 #[test]
619 fn test_builtin_workload_names() {
620 let workloads = builtin_workloads();
621 let names: Vec<&str> = workloads.iter().map(|w| w.name.as_str()).collect();
622 assert!(names.contains(&"simple_tool_call"));
623 assert!(names.contains(&"multi_step_reasoning"));
624 assert!(names.contains(&"parallel_tool_invocation"));
625 }
626
627 #[test]
628 fn test_multi_agent_delegation_not_in_builtin() {
629 let workloads = builtin_workloads();
630 let names: Vec<&str> = workloads.iter().map(|w| w.name.as_str()).collect();
631 assert!(!names.contains(&"multi_agent_delegation"));
632 }
633
634 #[test]
635 fn test_multi_agent_delegation_workload() {
636 let workload = multi_agent_delegation_workload();
637 assert_eq!(workload.name, "multi_agent_delegation");
638 assert_eq!(workload.expected_turns, 5);
639 assert!(workload.agent.tools.contains_key("delegate_to_researcher"));
640 assert!(workload.agent.tools.contains_key("delegate_to_writer"));
641 assert!(workload.metadata.contains_key("stability"));
642 }
643
644 #[test]
645 fn test_workload_serialization_round_trip() {
646 let workloads = builtin_workloads();
647 for workload in &workloads {
648 let json = serde_json::to_string(workload).unwrap();
649 let deserialized: Workload = serde_json::from_str(&json).unwrap();
650 assert_eq!(workload, &deserialized);
651 }
652 }
653
654 #[test]
655 fn test_load_workload_not_found() {
656 let result = load_workload(Path::new("/nonexistent/path.json"));
657 assert!(result.is_err());
658 let err = result.unwrap_err();
659 assert!(matches!(err, BenchError::WorkloadNotFound { .. }));
660 }
661
662 #[test]
663 fn test_load_workload_invalid_json() {
664 let mut file = NamedTempFile::new().unwrap();
665 writeln!(file, "not valid json").unwrap();
666 let result = load_workload(file.path());
667 assert!(result.is_err());
668 let err = result.unwrap_err();
669 assert!(matches!(err, BenchError::WorkloadValidation { .. }));
670 }
671
672 #[test]
673 fn test_load_workload_missing_field() {
674 let mut file = NamedTempFile::new().unwrap();
675 writeln!(file, r#"{{"name": "test"}}"#).unwrap();
676 let result = load_workload(file.path());
677 assert!(result.is_err());
678 let err = result.unwrap_err();
679 assert!(matches!(err, BenchError::WorkloadValidation { .. }));
680 }
681
682 #[test]
683 fn test_load_workload_valid() {
684 let workload = simple_tool_call_workload();
685 let json = serde_json::to_string_pretty(&workload).unwrap();
686
687 let mut file = NamedTempFile::new().unwrap();
688 write!(file, "{json}").unwrap();
689
690 let loaded = load_workload(file.path()).unwrap();
691 assert_eq!(workload, loaded);
692 }
693
694 #[test]
695 fn test_validate_empty_name() {
696 let mut workload = simple_tool_call_workload();
697 workload.name = String::new();
698 let result = validate_workload(&workload);
699 assert!(result.is_err());
700 }
701
702 #[test]
703 fn test_validate_zero_expected_turns() {
704 let mut workload = simple_tool_call_workload();
705 workload.expected_turns = 0;
706 let result = validate_workload(&workload);
707 assert!(result.is_err());
708 }
709
710 #[test]
711 fn test_schema_version_defaults_to_1() {
712 let json = r#"{
713 "name": "test",
714 "description": "test workload",
715 "agent": {
716 "instructions": "do something",
717 "userMessage": "hello"
718 },
719 "model": "gemini-2.5-flash",
720 "expectedTurns": 2
721 }"#;
722 let workload: Workload = serde_json::from_str(json).unwrap();
723 assert_eq!(workload.schema_version, 1);
724 }
725
726 #[test]
727 fn test_metadata_preserved_in_round_trip() {
728 let mut workload = simple_tool_call_workload();
729 workload
730 .metadata
731 .insert("author".to_string(), serde_json::Value::String("test-user".to_string()));
732 workload.metadata.insert("version".to_string(), serde_json::json!(2));
733
734 let json = serde_json::to_string(&workload).unwrap();
735 let deserialized: Workload = serde_json::from_str(&json).unwrap();
736 assert_eq!(workload.metadata, deserialized.metadata);
737 }
738}