Skip to main content

ai_agents_runtime/spec/
provider.rs

1//! Provider configuration types for YAML specification
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use ai_agents_tools::{ToolAliases, TrustLevel};
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct ProvidersConfig {
10    #[serde(default)]
11    pub builtin: BuiltinProviderConfig,
12
13    #[serde(default, skip_serializing_if = "Option::is_none")]
14    pub yaml: Option<YamlProviderConfig>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BuiltinProviderConfig {
19    #[serde(default = "default_true")]
20    pub enabled: bool,
21
22    #[serde(default)]
23    pub excluded_tools: Vec<String>,
24}
25
26fn default_true() -> bool {
27    true
28}
29
30impl Default for BuiltinProviderConfig {
31    fn default() -> Self {
32        Self {
33            enabled: true,
34            excluded_tools: Vec::new(),
35        }
36    }
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
40pub struct YamlProviderConfig {
41    #[serde(default = "default_true")]
42    pub enabled: bool,
43
44    #[serde(default)]
45    pub tools: Vec<YamlToolConfig>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct YamlToolConfig {
50    pub id: String,
51
52    pub name: String,
53
54    pub description: String,
55
56    #[serde(default)]
57    pub input_schema: serde_json::Value,
58
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub aliases: Option<ToolAliases>,
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct ProviderSecurityConfig {
65    #[serde(flatten)]
66    pub providers: HashMap<String, ProviderPolicyConfig>,
67}
68
69/// Policy configuration for a provider
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ProviderPolicyConfig {
72    #[serde(default = "default_true")]
73    pub enabled: bool,
74
75    #[serde(default)]
76    pub trust_level: TrustLevel,
77
78    #[serde(default)]
79    pub tools: HashMap<String, ToolPolicyConfig>,
80
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub timeout_ms: Option<u64>,
83}
84
85impl Default for ProviderPolicyConfig {
86    fn default() -> Self {
87        Self {
88            enabled: true,
89            trust_level: TrustLevel::default(),
90            tools: HashMap::new(),
91            timeout_ms: None,
92        }
93    }
94}
95
96/// Policy configuration for a tool within a provider
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ToolPolicyConfig {
99    #[serde(default = "default_true")]
100    pub enabled: bool,
101
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub timeout_ms: Option<u64>,
104
105    #[serde(default)]
106    pub require_approval: bool,
107}
108
109impl Default for ToolPolicyConfig {
110    fn default() -> Self {
111        Self {
112            enabled: true,
113            timeout_ms: None,
114            require_approval: false,
115        }
116    }
117}
118
119/// Global tool aliases configuration
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
121pub struct ToolAliasesConfig {
122    #[serde(flatten)]
123    pub tools: HashMap<String, ToolAliases>,
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_providers_config_default() {
132        let config = ProvidersConfig::default();
133        assert!(config.builtin.enabled);
134        assert!(config.yaml.is_none());
135    }
136
137    #[test]
138    fn test_builtin_provider_config_default() {
139        let config = BuiltinProviderConfig::default();
140        assert!(config.enabled);
141        assert!(config.excluded_tools.is_empty());
142    }
143
144    #[test]
145    fn test_providers_config_yaml() {
146        let yaml = r#"
147builtin:
148  enabled: true
149  excluded_tools:
150    - http
151yaml:
152  enabled: true
153  tools:
154    - id: custom_search
155      name: Custom Search
156      description: Search custom API
157"#;
158
159        let config: ProvidersConfig = serde_yaml::from_str(yaml).unwrap();
160        assert!(config.builtin.enabled);
161        assert_eq!(config.builtin.excluded_tools.len(), 1);
162        assert!(config.yaml.is_some());
163        let yaml_config = config.yaml.unwrap();
164        assert!(yaml_config.enabled);
165        assert_eq!(yaml_config.tools.len(), 1);
166    }
167
168    #[test]
169    fn test_provider_security_config_yaml() {
170        let yaml = r#"
171yaml:
172  trust_level: high
173  tools:
174    run_script:
175      require_approval: true
176      timeout_ms: 30000
177"#;
178
179        let config: ProviderSecurityConfig = serde_yaml::from_str(yaml).unwrap();
180        let yaml_policy = config.providers.get("yaml").unwrap();
181        assert_eq!(yaml_policy.trust_level, TrustLevel::High);
182        assert!(yaml_policy.tools.contains_key("run_script"));
183        assert!(
184            yaml_policy
185                .tools
186                .get("run_script")
187                .unwrap()
188                .require_approval
189        );
190    }
191
192    #[test]
193    fn test_tool_aliases_config_yaml() {
194        let yaml = r#"
195http:
196  names:
197    ko: 웹요청
198    ja: ウェブリクエスト
199  descriptions:
200    ko: HTTP 요청을 보냅니다
201    ja: HTTPリクエストを送信
202calculator:
203  names:
204    ko: 계산기
205"#;
206
207        let config: ToolAliasesConfig = serde_yaml::from_str(yaml).unwrap();
208        assert!(config.tools.contains_key("http"));
209        assert!(config.tools.contains_key("calculator"));
210
211        let http_aliases = config.tools.get("http").unwrap();
212        assert_eq!(http_aliases.get_name("ko"), Some("웹요청"));
213    }
214
215    #[test]
216    fn test_yaml_tool_config_with_aliases() {
217        let yaml = r#"
218id: search
219name: Web Search
220description: Search the web
221input_schema:
222  type: object
223  properties:
224    query:
225      type: string
226aliases:
227  names:
228    ko: 웹검색
229    ja: ウェブ検索
230  descriptions:
231    ko: 웹에서 검색합니다
232"#;
233
234        let config: YamlToolConfig = serde_yaml::from_str(yaml).unwrap();
235        assert_eq!(config.id, "search");
236        assert!(config.aliases.is_some());
237        let aliases = config.aliases.unwrap();
238        assert_eq!(aliases.get_name("ko"), Some("웹검색"));
239    }
240}