1use std::fmt;
23
24use serde::{Deserialize, Serialize};
25
26use crate::generation::GenerationOptions;
27use crate::tool::ToolChoice;
28use crate::types::{ModelId, ProviderId};
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
48pub struct AgentName(String);
49
50impl AgentName {
51 pub fn new(name: impl Into<String>) -> Self {
53 Self(name.into())
54 }
55
56 pub fn as_str(&self) -> &str {
58 &self.0
59 }
60}
61
62impl fmt::Display for AgentName {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 self.0.fmt(f)
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
83#[serde(rename_all = "snake_case")]
84pub enum AgentRole {
85 Primary,
87 SubAgent,
89 Internal,
91}
92
93impl AgentRole {
94 pub fn is_primary(&self) -> bool {
96 matches!(self, Self::Primary)
97 }
98
99 pub fn is_sub_agent(&self) -> bool {
101 matches!(self, Self::SubAgent)
102 }
103
104 pub fn is_internal(&self) -> bool {
106 matches!(self, Self::Internal)
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
137#[serde(tag = "type", rename_all = "snake_case")]
138pub enum AgentModelRef {
139 Inherit,
141 ById {
143 model_id: ModelId,
144 provider_id: ProviderId,
145 },
146 ByAlias {
148 alias: String,
149 },
150}
151
152impl AgentModelRef {
153 pub fn by_id(model_id: ModelId, provider_id: ProviderId) -> Self {
155 Self::ById {
156 model_id,
157 provider_id,
158 }
159 }
160
161 pub fn by_alias(alias: impl Into<String>) -> Self {
163 Self::ByAlias {
164 alias: alias.into(),
165 }
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
201#[serde(tag = "type", rename_all = "snake_case")]
202pub enum ToolFilter {
203 #[default]
205 AllowAll,
206 AllowList { tools: Vec<String> },
208 DenyList { tools: Vec<String> },
210 None,
212}
213
214impl ToolFilter {
215 pub fn allow_list(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
217 Self::AllowList {
218 tools: tools.into_iter().map(Into::into).collect(),
219 }
220 }
221
222 pub fn deny_list(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
224 Self::DenyList {
225 tools: tools.into_iter().map(Into::into).collect(),
226 }
227 }
228
229 pub fn is_allowed(&self, tool_name: &str) -> bool {
231 match self {
232 Self::AllowAll => true,
233 Self::AllowList { tools } => tools.iter().any(|t| t == tool_name),
234 Self::DenyList { tools } => !tools.iter().any(|t| t == tool_name),
235 Self::None => false,
236 }
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct AgentDefinition {
281 pub name: AgentName,
283
284 pub role: AgentRole,
286
287 #[serde(default)]
293 pub description: String,
294
295 #[serde(default)]
300 pub system_prompt: Vec<String>,
301
302 #[serde(default, skip_serializing_if = "Option::is_none")]
304 pub model: Option<AgentModelRef>,
305
306 #[serde(default)]
308 pub tool_filter: ToolFilter,
309
310 #[serde(default, skip_serializing_if = "Option::is_none")]
315 pub tool_choice: Option<ToolChoice>,
316
317 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub generation: Option<GenerationOptions>,
323
324 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub max_steps: Option<u32>,
330
331 #[serde(default, skip_serializing_if = "Vec::is_empty")]
336 pub sub_agents: Vec<AgentName>,
337
338 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub output_schema: Option<serde_json::Value>,
344
345 #[serde(default, skip_serializing_if = "Option::is_none")]
350 pub provider_options: Option<serde_json::Value>,
351}
352
353impl AgentDefinition {
354 pub fn new(name: impl Into<String>, role: AgentRole) -> Self {
356 Self {
357 name: AgentName::new(name),
358 role,
359 description: String::new(),
360 system_prompt: Vec::new(),
361 model: None,
362 tool_filter: ToolFilter::default(),
363 tool_choice: None,
364 generation: None,
365 max_steps: None,
366 sub_agents: Vec::new(),
367 output_schema: None,
368 provider_options: None,
369 }
370 }
371
372 pub fn with_description(mut self, description: impl Into<String>) -> Self {
374 self.description = description.into();
375 self
376 }
377
378 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
380 self.system_prompt = vec![prompt.into()];
381 self
382 }
383
384 pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
386 self.system_prompt.push(prompt.into());
387 self
388 }
389
390 pub fn with_system_prompts(mut self, prompts: Vec<String>) -> Self {
392 self.system_prompt = prompts;
393 self
394 }
395
396 pub fn with_model(mut self, model: AgentModelRef) -> Self {
398 self.model = Some(model);
399 self
400 }
401
402 pub fn with_tool_filter(mut self, filter: ToolFilter) -> Self {
404 self.tool_filter = filter;
405 self
406 }
407
408 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
410 self.tool_choice = Some(choice);
411 self
412 }
413
414 pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
416 self.generation = Some(generation);
417 self
418 }
419
420 pub fn with_max_steps(mut self, steps: u32) -> Self {
422 self.max_steps = Some(steps);
423 self
424 }
425
426 pub fn with_sub_agents(mut self, agents: Vec<AgentName>) -> Self {
428 self.sub_agents = agents;
429 self
430 }
431
432 pub fn add_sub_agent(mut self, agent: impl Into<String>) -> Self {
434 self.sub_agents.push(AgentName::new(agent));
435 self
436 }
437
438 pub fn with_output_schema(mut self, schema: serde_json::Value) -> Self {
440 self.output_schema = Some(schema);
441 self
442 }
443
444 pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
446 self.provider_options = Some(options);
447 self
448 }
449
450 pub fn joined_system_prompt(&self) -> String {
454 self.system_prompt.join("\n\n")
455 }
456}
457
458#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
469 fn test_agent_name_new() {
470 let name = AgentName::new("explore");
471 assert_eq!(name.as_str(), "explore");
472 assert_eq!(name.to_string(), "explore");
473 }
474
475 #[test]
476 fn test_agent_name_serde_roundtrip() {
477 let name = AgentName::new("build");
478 let json = serde_json::to_string(&name).unwrap();
479 let restored: AgentName = serde_json::from_str(&json).unwrap();
480 assert_eq!(name, restored);
481 }
482
483 #[test]
486 fn test_agent_role_predicates() {
487 assert!(AgentRole::Primary.is_primary());
488 assert!(!AgentRole::Primary.is_sub_agent());
489 assert!(!AgentRole::Primary.is_internal());
490
491 assert!(!AgentRole::SubAgent.is_primary());
492 assert!(AgentRole::SubAgent.is_sub_agent());
493 assert!(!AgentRole::SubAgent.is_internal());
494
495 assert!(!AgentRole::Internal.is_primary());
496 assert!(!AgentRole::Internal.is_sub_agent());
497 assert!(AgentRole::Internal.is_internal());
498 }
499
500 #[test]
501 fn test_agent_role_serde() {
502 let json = serde_json::to_string(&AgentRole::SubAgent).unwrap();
503 assert_eq!(json, r#""sub_agent""#);
504 let restored: AgentRole = serde_json::from_str(&json).unwrap();
505 assert_eq!(restored, AgentRole::SubAgent);
506 }
507
508 #[test]
511 fn test_agent_model_ref_inherit() {
512 let r = AgentModelRef::Inherit;
513 let json = serde_json::to_string(&r).unwrap();
514 assert!(json.contains(r#""type":"inherit""#));
515 }
516
517 #[test]
518 fn test_agent_model_ref_by_id() {
519 let r = AgentModelRef::by_id(
520 ModelId::new("gpt-4o"),
521 ProviderId::new("openai"),
522 );
523 if let AgentModelRef::ById { model_id, provider_id } = &r {
524 assert_eq!(model_id.as_str(), "gpt-4o");
525 assert_eq!(provider_id.as_str(), "openai");
526 } else {
527 panic!("expected ById");
528 }
529 }
530
531 #[test]
532 fn test_agent_model_ref_by_alias() {
533 let r = AgentModelRef::by_alias("fast");
534 if let AgentModelRef::ByAlias { alias } = &r {
535 assert_eq!(alias, "fast");
536 } else {
537 panic!("expected ByAlias");
538 }
539 }
540
541 #[test]
542 fn test_agent_model_ref_serde_roundtrip() {
543 let refs = vec![
544 AgentModelRef::Inherit,
545 AgentModelRef::by_id(ModelId::new("claude-sonnet-4-20250514"), ProviderId::new("anthropic")),
546 AgentModelRef::by_alias("cheap"),
547 ];
548 for r in refs {
549 let json = serde_json::to_string(&r).unwrap();
550 let restored: AgentModelRef = serde_json::from_str(&json).unwrap();
551 assert_eq!(r, restored);
552 }
553 }
554
555 #[test]
558 fn test_tool_filter_allow_all() {
559 let f = ToolFilter::AllowAll;
560 assert!(f.is_allowed("anything"));
561 }
562
563 #[test]
564 fn test_tool_filter_allow_list() {
565 let f = ToolFilter::allow_list(["read_file", "grep"]);
566 assert!(f.is_allowed("read_file"));
567 assert!(f.is_allowed("grep"));
568 assert!(!f.is_allowed("bash"));
569 }
570
571 #[test]
572 fn test_tool_filter_deny_list() {
573 let f = ToolFilter::deny_list(["bash", "write_file"]);
574 assert!(f.is_allowed("read_file"));
575 assert!(!f.is_allowed("bash"));
576 assert!(!f.is_allowed("write_file"));
577 }
578
579 #[test]
580 fn test_tool_filter_none() {
581 let f = ToolFilter::None;
582 assert!(!f.is_allowed("anything"));
583 }
584
585 #[test]
586 fn test_tool_filter_default_is_allow_all() {
587 assert_eq!(ToolFilter::default(), ToolFilter::AllowAll);
588 }
589
590 #[test]
591 fn test_tool_filter_serde_roundtrip() {
592 let filters = vec![
593 ToolFilter::AllowAll,
594 ToolFilter::allow_list(["read_file"]),
595 ToolFilter::deny_list(["bash"]),
596 ToolFilter::None,
597 ];
598 for f in filters {
599 let json = serde_json::to_string(&f).unwrap();
600 let restored: ToolFilter = serde_json::from_str(&json).unwrap();
601 assert_eq!(f, restored);
602 }
603 }
604
605 #[test]
608 fn test_agent_definition_minimal() {
609 let agent = AgentDefinition::new("test", AgentRole::Primary);
610 assert_eq!(agent.name.as_str(), "test");
611 assert_eq!(agent.role, AgentRole::Primary);
612 assert!(agent.description.is_empty());
613 assert!(agent.system_prompt.is_empty());
614 assert!(agent.model.is_none());
615 assert_eq!(agent.tool_filter, ToolFilter::AllowAll);
616 assert!(agent.tool_choice.is_none());
617 assert!(agent.generation.is_none());
618 assert!(agent.max_steps.is_none());
619 assert!(agent.sub_agents.is_empty());
620 assert!(agent.output_schema.is_none());
621 assert!(agent.provider_options.is_none());
622 }
623
624 #[test]
625 fn test_agent_definition_builder() {
626 use crate::tool::ToolChoice;
627 use crate::generation::GenerationOptions;
628
629 let agent = AgentDefinition::new("explore", AgentRole::SubAgent)
630 .with_description("Search agent")
631 .with_system_prompt("You are a search specialist.")
632 .append_system_prompt("Be thorough.")
633 .with_model(AgentModelRef::by_alias("fast"))
634 .with_tool_filter(ToolFilter::allow_list(["read_file", "grep"]))
635 .with_tool_choice(ToolChoice::Auto)
636 .with_generation(GenerationOptions::new().with_temperature(0.3))
637 .with_max_steps(10)
638 .add_sub_agent("deep_search");
639
640 assert_eq!(agent.name.as_str(), "explore");
641 assert_eq!(agent.role, AgentRole::SubAgent);
642 assert_eq!(agent.description, "Search agent");
643 assert_eq!(agent.system_prompt.len(), 2);
644 assert_eq!(agent.joined_system_prompt(), "You are a search specialist.\n\nBe thorough.");
645 assert!(agent.model.is_some());
646 assert!(agent.tool_filter.is_allowed("read_file"));
647 assert!(!agent.tool_filter.is_allowed("bash"));
648 assert_eq!(agent.tool_choice, Some(ToolChoice::Auto));
649 assert_eq!(agent.generation.as_ref().unwrap().temperature, Some(0.3));
650 assert_eq!(agent.max_steps, Some(10));
651 assert_eq!(agent.sub_agents.len(), 1);
652 assert_eq!(agent.sub_agents[0].as_str(), "deep_search");
653 }
654
655 #[test]
656 fn test_agent_definition_serde_roundtrip() {
657 use crate::tool::ToolChoice;
658 use crate::generation::GenerationOptions;
659
660 let agent = AgentDefinition::new("build", AgentRole::Primary)
661 .with_description("Coding agent")
662 .with_system_prompt("Help with code.")
663 .with_model(AgentModelRef::by_id(
664 ModelId::new("gpt-4o"),
665 ProviderId::new("openai"),
666 ))
667 .with_tool_filter(ToolFilter::deny_list(["dangerous_tool"]))
668 .with_tool_choice(ToolChoice::Required)
669 .with_generation(GenerationOptions::new().with_temperature(0.5).with_max_tokens(4096))
670 .with_max_steps(50)
671 .add_sub_agent("explore")
672 .add_sub_agent("title")
673 .with_output_schema(serde_json::json!({"type": "object"}))
674 .with_provider_options(serde_json::json!({"service_tier": "default"}));
675
676 let json = serde_json::to_string_pretty(&agent).unwrap();
677 let restored: AgentDefinition = serde_json::from_str(&json).unwrap();
678
679 assert_eq!(agent.name, restored.name);
680 assert_eq!(agent.role, restored.role);
681 assert_eq!(agent.description, restored.description);
682 assert_eq!(agent.system_prompt, restored.system_prompt);
683 assert_eq!(agent.model, restored.model);
684 assert_eq!(agent.tool_filter, restored.tool_filter);
685 assert_eq!(agent.tool_choice, restored.tool_choice);
686 assert_eq!(agent.generation, restored.generation);
687 assert_eq!(agent.max_steps, restored.max_steps);
688 assert_eq!(agent.sub_agents, restored.sub_agents);
689 assert_eq!(agent.output_schema, restored.output_schema);
690 assert_eq!(agent.provider_options, restored.provider_options);
691 }
692
693 #[test]
694 fn test_agent_definition_joined_prompt_empty() {
695 let agent = AgentDefinition::new("empty", AgentRole::Internal);
696 assert_eq!(agent.joined_system_prompt(), "");
697 }
698
699 #[test]
700 fn test_agent_definition_joined_prompt_single() {
701 let agent = AgentDefinition::new("t", AgentRole::Internal)
702 .with_system_prompt("Hello");
703 assert_eq!(agent.joined_system_prompt(), "Hello");
704 }
705}