1use std::collections::{BTreeMap, HashMap, HashSet};
10
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Deserializer, Serialize};
13use serde_json::Value;
14
15use crate::contract::inference::{ContextWindowPolicy, ReasoningEffort};
16use crate::error::StateError;
17
18pub trait PluginConfigKey: 'static + Send + Sync {
34 const KEY: &'static str;
36
37 type Config: Default
39 + Clone
40 + Serialize
41 + DeserializeOwned
42 + schemars::JsonSchema
43 + Send
44 + Sync
45 + 'static;
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
60#[serde(deny_unknown_fields)]
61pub struct AgentSpec {
62 pub id: String,
64 pub model_id: String,
66 pub system_prompt: String,
68 #[serde(default = "default_max_rounds")]
70 pub max_rounds: usize,
71 #[serde(default = "default_max_continuation_retries")]
73 pub max_continuation_retries: usize,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub context_policy: Option<ContextWindowPolicy>,
77 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub reasoning_effort: Option<ReasoningEffort>,
81 #[serde(default)]
83 pub plugin_ids: Vec<String>,
84 #[serde(
88 default,
89 skip_serializing_if = "HashSet::is_empty",
90 alias = "active_plugins"
91 )]
92 pub active_hook_filter: HashSet<String>,
93 #[serde(default)]
95 pub allowed_tools: Option<Vec<String>>,
96 #[serde(default)]
98 pub excluded_tools: Option<Vec<String>>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub endpoint: Option<RemoteEndpoint>,
103 #[serde(default, skip_serializing_if = "Vec::is_empty")]
106 pub delegates: Vec<String>,
107 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
109 pub sections: HashMap<String, Value>,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub registry: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
118pub struct RemoteAuth {
119 #[serde(rename = "type")]
120 pub auth_type: String,
121 #[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
122 pub params: BTreeMap<String, Value>,
123}
124
125impl RemoteAuth {
126 #[must_use]
127 pub fn bearer(token: impl Into<String>) -> Self {
128 let mut params = BTreeMap::new();
129 params.insert("token".into(), Value::String(token.into()));
130 Self {
131 auth_type: "bearer".into(),
132 params,
133 }
134 }
135
136 #[must_use]
137 pub fn param_str(&self, key: &str) -> Option<&str> {
138 self.params.get(key).and_then(Value::as_str)
139 }
140}
141
142#[derive(Debug, Clone, Serialize, PartialEq, schemars::JsonSchema)]
144pub struct RemoteEndpoint {
145 #[serde(default = "default_remote_backend")]
146 pub backend: String,
147 pub base_url: String,
148 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub auth: Option<RemoteAuth>,
150 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub target: Option<String>,
153 #[serde(default = "default_timeout")]
154 pub timeout_ms: u64,
155 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
156 pub options: BTreeMap<String, Value>,
157}
158
159impl Default for RemoteEndpoint {
160 fn default() -> Self {
161 Self {
162 backend: default_remote_backend(),
163 base_url: String::new(),
164 auth: None,
165 target: None,
166 timeout_ms: default_timeout(),
167 options: BTreeMap::new(),
168 }
169 }
170}
171
172fn default_remote_backend() -> String {
173 "a2a".to_string()
174}
175
176fn default_timeout() -> u64 {
177 300_000
178}
179
180#[derive(Debug, Deserialize)]
181struct RawRemoteEndpoint {
182 #[serde(default)]
183 backend: Option<String>,
184 base_url: String,
185 #[serde(default)]
186 auth: Option<RemoteAuth>,
187 #[serde(default)]
188 target: Option<String>,
189 #[serde(default)]
190 timeout_ms: Option<u64>,
191 #[serde(default)]
192 options: BTreeMap<String, Value>,
193 #[serde(default)]
194 bearer_token: Option<String>,
195 #[serde(default)]
196 agent_id: Option<String>,
197 #[serde(default)]
198 poll_interval_ms: Option<u64>,
199}
200
201impl<'de> Deserialize<'de> for RemoteEndpoint {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: Deserializer<'de>,
205 {
206 let raw = RawRemoteEndpoint::deserialize(deserializer)?;
207 let has_legacy_fields =
208 raw.bearer_token.is_some() || raw.agent_id.is_some() || raw.poll_interval_ms.is_some();
209 let has_canonical_fields = raw.backend.is_some()
210 || raw.auth.is_some()
211 || raw.target.is_some()
212 || !raw.options.is_empty();
213
214 if has_legacy_fields && has_canonical_fields {
215 return Err(serde::de::Error::custom(
216 "cannot mix legacy A2A endpoint fields with canonical remote endpoint fields",
217 ));
218 }
219
220 if has_legacy_fields {
221 let mut options = BTreeMap::new();
222 if let Some(poll_interval_ms) = raw.poll_interval_ms {
223 options.insert("poll_interval_ms".into(), Value::from(poll_interval_ms));
224 }
225 return Ok(Self {
226 backend: default_remote_backend(),
227 base_url: raw.base_url,
228 auth: raw.bearer_token.map(RemoteAuth::bearer),
229 target: raw.agent_id,
230 timeout_ms: raw.timeout_ms.unwrap_or_else(default_timeout),
231 options,
232 });
233 }
234
235 let backend = raw.backend.unwrap_or_else(default_remote_backend);
236 if backend.trim().is_empty() {
237 return Err(serde::de::Error::custom(
238 "remote endpoint backend must not be empty",
239 ));
240 }
241
242 Ok(Self {
243 backend,
244 base_url: raw.base_url,
245 auth: raw.auth,
246 target: raw.target,
247 timeout_ms: raw.timeout_ms.unwrap_or_else(default_timeout),
248 options: raw.options,
249 })
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
259#[serde(deny_unknown_fields)]
260pub struct ModelBindingSpec {
261 pub id: String,
263 pub provider_id: String,
265 pub upstream_model: String,
267}
268
269#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
275pub struct ProviderSpec {
276 pub id: String,
278 pub adapter: String,
280 #[serde(default, skip_serializing_if = "Option::is_none")]
282 pub api_key: Option<String>,
283 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub base_url: Option<String>,
286 #[serde(default = "default_provider_timeout_secs")]
288 pub timeout_secs: u64,
289}
290
291fn default_provider_timeout_secs() -> u64 {
292 300
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
301#[serde(rename_all = "lowercase")]
302pub enum McpTransportKind {
303 Stdio,
305 Http,
307}
308
309#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
311pub struct McpRestartPolicy {
312 #[serde(default)]
314 pub enabled: bool,
315 #[serde(default, skip_serializing_if = "Option::is_none")]
317 pub max_attempts: Option<u32>,
318 #[serde(default = "default_mcp_restart_delay_ms")]
320 pub delay_ms: u64,
321 #[serde(default = "default_mcp_restart_backoff_multiplier")]
323 pub backoff_multiplier: f64,
324 #[serde(default = "default_mcp_restart_max_delay_ms")]
326 pub max_delay_ms: u64,
327}
328
329impl Default for McpRestartPolicy {
330 fn default() -> Self {
331 Self {
332 enabled: false,
333 max_attempts: None,
334 delay_ms: default_mcp_restart_delay_ms(),
335 backoff_multiplier: default_mcp_restart_backoff_multiplier(),
336 max_delay_ms: default_mcp_restart_max_delay_ms(),
337 }
338 }
339}
340
341const fn default_mcp_restart_delay_ms() -> u64 {
342 1000
343}
344
345const fn default_mcp_restart_backoff_multiplier() -> f64 {
346 2.0
347}
348
349const fn default_mcp_restart_max_delay_ms() -> u64 {
350 30_000
351}
352
353#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
355pub struct McpServerSpec {
356 pub id: String,
358 pub transport: McpTransportKind,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub command: Option<String>,
363 #[serde(default, skip_serializing_if = "Vec::is_empty")]
365 pub args: Vec<String>,
366 #[serde(default, skip_serializing_if = "Option::is_none")]
368 pub url: Option<String>,
369 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
371 pub config: serde_json::Map<String, Value>,
372 #[serde(default = "default_mcp_timeout_secs")]
374 pub timeout_secs: u64,
375 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
377 pub env: BTreeMap<String, String>,
378 #[serde(default)]
380 pub restart_policy: McpRestartPolicy,
381}
382
383fn default_mcp_timeout_secs() -> u64 {
384 30
385}
386
387impl Default for McpServerSpec {
388 fn default() -> Self {
389 Self {
390 id: String::new(),
391 transport: McpTransportKind::Stdio,
392 command: None,
393 args: Vec::new(),
394 url: None,
395 config: serde_json::Map::new(),
396 timeout_secs: default_mcp_timeout_secs(),
397 env: BTreeMap::new(),
398 restart_policy: McpRestartPolicy::default(),
399 }
400 }
401}
402
403impl Default for ProviderSpec {
404 fn default() -> Self {
405 Self {
406 id: String::new(),
407 adapter: String::new(),
408 api_key: None,
409 base_url: None,
410 timeout_secs: default_provider_timeout_secs(),
411 }
412 }
413}
414
415impl Default for AgentSpec {
416 fn default() -> Self {
417 Self {
418 id: String::new(),
419 model_id: String::new(),
420 system_prompt: String::new(),
421 max_rounds: default_max_rounds(),
422 max_continuation_retries: default_max_continuation_retries(),
423 context_policy: None,
424 reasoning_effort: None,
425 plugin_ids: Vec::new(),
426 active_hook_filter: HashSet::new(),
427 allowed_tools: None,
428 excluded_tools: None,
429 endpoint: None,
430 delegates: Vec::new(),
431 sections: HashMap::new(),
432 registry: None,
433 }
434 }
435}
436
437fn default_max_rounds() -> usize {
438 16
439}
440
441fn default_max_continuation_retries() -> usize {
442 2
443}
444
445impl AgentSpec {
446 pub fn new(id: impl Into<String>) -> Self {
463 Self {
464 id: id.into(),
465 ..Default::default()
466 }
467 }
468
469 pub fn config<K: PluginConfigKey>(&self) -> Result<K::Config, StateError> {
475 match self.sections.get(K::KEY) {
476 Some(value) => {
477 serde_json::from_value(value.clone()).map_err(|e| StateError::KeyDecode {
478 key: K::KEY.into(),
479 message: e.to_string(),
480 })
481 }
482 None => Ok(K::Config::default()),
483 }
484 }
485
486 pub fn set_config<K: PluginConfigKey>(&mut self, config: K::Config) -> Result<(), StateError> {
488 let value = serde_json::to_value(config).map_err(|e| StateError::KeyEncode {
489 key: K::KEY.into(),
490 message: e.to_string(),
491 })?;
492 self.sections.insert(K::KEY.to_string(), value);
493 Ok(())
494 }
495
496 #[must_use]
499 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
500 self.model_id = model_id.into();
501 self
502 }
503
504 #[must_use]
505 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
506 self.system_prompt = prompt.into();
507 self
508 }
509
510 #[must_use]
511 pub fn with_max_rounds(mut self, n: usize) -> Self {
512 self.max_rounds = n;
513 self
514 }
515
516 #[must_use]
517 pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
518 self.reasoning_effort = Some(effort);
519 self
520 }
521
522 #[must_use]
523 pub fn with_hook_filter(mut self, plugin_id: impl Into<String>) -> Self {
524 self.active_hook_filter.insert(plugin_id.into());
525 self
526 }
527
528 pub fn with_config<K: PluginConfigKey>(
530 mut self,
531 config: K::Config,
532 ) -> Result<Self, StateError> {
533 self.set_config::<K>(config)?;
534 Ok(self)
535 }
536
537 #[must_use]
538 pub fn with_delegate(mut self, agent_id: impl Into<String>) -> Self {
539 self.delegates.push(agent_id.into());
540 self
541 }
542
543 #[must_use]
544 pub fn with_endpoint(mut self, endpoint: RemoteEndpoint) -> Self {
545 self.endpoint = Some(endpoint);
546 self
547 }
548
549 #[must_use]
551 pub fn with_section(mut self, key: impl Into<String>, value: Value) -> Self {
552 self.sections.insert(key.into(), value);
553 self
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use serde_json::json;
561
562 #[test]
563 fn agent_spec_serde_roundtrip() {
564 let spec = AgentSpec {
565 id: "coder".into(),
566 model_id: "claude-opus".into(),
567 system_prompt: "You are a coding assistant.".into(),
568 max_rounds: 8,
569 plugin_ids: vec!["permission".into(), "logging".into()],
570 allowed_tools: Some(vec!["read_file".into(), "write_file".into()]),
571 excluded_tools: Some(vec!["delete_file".into()]),
572 sections: {
573 let mut m = HashMap::new();
574 m.insert("permission".into(), json!({"mode": "strict"}));
575 m
576 },
577 ..Default::default()
578 };
579
580 let json_str = serde_json::to_string(&spec).unwrap();
581 let parsed: AgentSpec = serde_json::from_str(&json_str).unwrap();
582
583 assert_eq!(parsed.id, "coder");
584 assert_eq!(parsed.model_id, "claude-opus");
585 assert_eq!(parsed.system_prompt, "You are a coding assistant.");
586 assert_eq!(parsed.max_rounds, 8);
587 assert_eq!(parsed.plugin_ids, vec!["permission", "logging"]);
588 assert_eq!(
589 parsed.allowed_tools,
590 Some(vec!["read_file".into(), "write_file".into()])
591 );
592 assert_eq!(parsed.excluded_tools, Some(vec!["delete_file".into()]));
593 assert_eq!(parsed.sections["permission"]["mode"], "strict");
594 }
595
596 #[test]
597 fn agent_spec_defaults() {
598 let json_str = r#"{"id":"min","model_id":"m","system_prompt":"sp"}"#;
599 let spec: AgentSpec = serde_json::from_str(json_str).unwrap();
600
601 assert_eq!(spec.model_id, "m");
602 assert_eq!(spec.max_rounds, 16);
603 assert_eq!(spec.max_continuation_retries, 2);
604 assert!(spec.context_policy.is_none());
605 assert!(spec.plugin_ids.is_empty());
606 assert!(spec.active_hook_filter.is_empty());
607 assert!(spec.allowed_tools.is_none());
608 assert!(spec.excluded_tools.is_none());
609 assert!(spec.sections.is_empty());
610 }
611
612 #[test]
613 fn model_binding_spec_uses_canonical_names() {
614 let canonical = ModelBindingSpec {
615 id: "default".into(),
616 provider_id: "openai".into(),
617 upstream_model: "gpt-4o-mini".into(),
618 };
619
620 let encoded = serde_json::to_value(&canonical).unwrap();
621 assert_eq!(encoded["provider_id"], "openai");
622 assert_eq!(encoded["upstream_model"], "gpt-4o-mini");
623 assert!(encoded.get("provider").is_none());
624 assert!(encoded.get("model").is_none());
625 }
626
627 #[test]
628 fn provider_model_legacy_fields_are_rejected() {
629 let agent =
630 serde_json::from_str::<AgentSpec>(r#"{"id":"min","model":"m","system_prompt":"sp"}"#);
631 assert!(agent.is_err());
632
633 let model = serde_json::from_value::<ModelBindingSpec>(json!({
634 "id": "default",
635 "provider": "openai",
636 "model": "gpt-4o-mini"
637 }));
638 assert!(model.is_err());
639 }
640
641 struct ModelNameKey;
644 impl PluginConfigKey for ModelNameKey {
645 const KEY: &'static str = "model_name";
646 type Config = ModelNameConfig;
647 }
648
649 #[derive(
650 Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema,
651 )]
652 struct ModelNameConfig {
653 pub name: String,
654 }
655
656 struct PermKey;
657 impl PluginConfigKey for PermKey {
658 const KEY: &'static str = "permission";
659 type Config = PermConfig;
660 }
661
662 #[derive(
663 Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema,
664 )]
665 struct PermConfig {
666 pub mode: String,
667 }
668
669 #[test]
670 fn typed_config_roundtrip() {
671 let spec = AgentSpec::new("test")
672 .with_config::<ModelNameKey>(ModelNameConfig {
673 name: "opus".into(),
674 })
675 .unwrap()
676 .with_config::<PermKey>(PermConfig {
677 mode: "strict".into(),
678 })
679 .unwrap();
680
681 let model: ModelNameConfig = spec.config::<ModelNameKey>().unwrap();
682 assert_eq!(model.name, "opus");
683
684 let perm: PermConfig = spec.config::<PermKey>().unwrap();
685 assert_eq!(perm.mode, "strict");
686 }
687
688 #[test]
689 fn missing_config_returns_default() {
690 let spec = AgentSpec::new("test");
691 let model: ModelNameConfig = spec.config::<ModelNameKey>().unwrap();
692 assert_eq!(model, ModelNameConfig::default());
693 }
694
695 #[test]
696 fn config_serializes_to_json() {
697 let spec = AgentSpec::new("coder")
698 .with_model_id("sonnet")
699 .with_config::<ModelNameKey>(ModelNameConfig {
700 name: "custom".into(),
701 })
702 .unwrap();
703
704 let json = serde_json::to_string(&spec).unwrap();
705 let parsed: AgentSpec = serde_json::from_str(&json).unwrap();
706
707 assert_eq!(parsed.id, "coder");
708 assert_eq!(parsed.model_id, "sonnet");
709
710 let model: ModelNameConfig = parsed.config::<ModelNameKey>().unwrap();
711 assert_eq!(model.name, "custom");
712 }
713
714 #[test]
715 fn multiple_configs_independent() {
716 let mut spec = AgentSpec::new("test");
717 spec.set_config::<ModelNameKey>(ModelNameConfig { name: "a".into() })
718 .unwrap();
719 spec.set_config::<PermKey>(PermConfig { mode: "b".into() })
720 .unwrap();
721
722 spec.set_config::<ModelNameKey>(ModelNameConfig {
724 name: "updated".into(),
725 })
726 .unwrap();
727
728 let model: ModelNameConfig = spec.config::<ModelNameKey>().unwrap();
729 assert_eq!(model.name, "updated");
730
731 let perm: PermConfig = spec.config::<PermKey>().unwrap();
732 assert_eq!(perm.mode, "b");
733 }
734
735 #[test]
736 fn with_section_raw_json_still_works() {
737 let spec =
738 AgentSpec::new("test").with_section("custom", serde_json::json!({"key": "value"}));
739 assert_eq!(spec.sections["custom"]["key"], "value");
740 }
741
742 #[test]
743 fn remote_endpoint_canonical_roundtrip_uses_single_shape() {
744 let mut options = BTreeMap::new();
745 options.insert("poll_interval_ms".into(), json!(1000));
746 let endpoint = RemoteEndpoint {
747 backend: "a2a".into(),
748 base_url: "https://remote.example.com/v1/a2a".into(),
749 auth: Some(RemoteAuth::bearer("tok_123")),
750 target: Some("worker".into()),
751 timeout_ms: 60_000,
752 options,
753 };
754
755 let encoded = serde_json::to_value(&endpoint).unwrap();
756 assert_eq!(encoded["backend"], "a2a");
757 assert_eq!(encoded["auth"]["type"], "bearer");
758 assert_eq!(encoded["auth"]["token"], "tok_123");
759 assert_eq!(encoded["target"], "worker");
760 assert_eq!(encoded["options"]["poll_interval_ms"], 1000);
761 assert!(encoded.get("bearer_token").is_none());
762 assert!(encoded.get("agent_id").is_none());
763 assert!(encoded.get("poll_interval_ms").is_none());
764
765 let parsed: RemoteEndpoint = serde_json::from_value(encoded).unwrap();
766 assert_eq!(parsed, endpoint);
767 }
768
769 #[test]
770 fn remote_endpoint_legacy_a2a_input_normalizes_to_canonical_shape() {
771 let endpoint: RemoteEndpoint = serde_json::from_value(json!({
772 "base_url": "https://remote.example.com/v1/a2a",
773 "bearer_token": "tok_legacy",
774 "agent_id": "worker",
775 "poll_interval_ms": 750,
776 "timeout_ms": 60_000
777 }))
778 .unwrap();
779
780 assert_eq!(endpoint.backend, "a2a");
781 assert_eq!(
782 endpoint
783 .auth
784 .as_ref()
785 .and_then(|auth| auth.param_str("token")),
786 Some("tok_legacy")
787 );
788 assert_eq!(endpoint.target.as_deref(), Some("worker"));
789 assert_eq!(endpoint.options.get("poll_interval_ms"), Some(&json!(750)));
790 assert_eq!(endpoint.timeout_ms, 60_000);
791 }
792
793 #[test]
794 fn remote_endpoint_rejects_mixed_legacy_and_canonical_fields() {
795 let err = serde_json::from_value::<RemoteEndpoint>(json!({
796 "backend": "a2a",
797 "base_url": "https://remote.example.com/v1/a2a",
798 "auth": { "type": "bearer", "token": "tok_new" },
799 "bearer_token": "tok_old"
800 }))
801 .unwrap_err();
802
803 assert!(
804 err.to_string()
805 .contains("cannot mix legacy A2A endpoint fields")
806 );
807 }
808
809 #[test]
810 fn builder() {
811 let spec = AgentSpec::new("reviewer")
812 .with_model_id("claude-opus")
813 .with_hook_filter("permission")
814 .with_config::<PermKey>(PermConfig {
815 mode: "strict".into(),
816 })
817 .unwrap();
818
819 assert_eq!(spec.id, "reviewer");
820 assert_eq!(spec.model_id, "claude-opus");
821 assert!(spec.active_hook_filter.contains("permission"));
822 }
823}