Skip to main content

agentforge_core/
agent.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// The AgentForge native agent file schema (v1).
5/// Also the normalized representation after parsing any supported format.
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct AgentFile {
8    pub agentforge_schema_version: String,
9    pub name: String,
10    pub version: String,
11    pub model: ModelConfig,
12    pub system_prompt: String,
13    pub tools: Vec<ToolDefinition>,
14    pub output_schema: Option<serde_json::Value>,
15    pub constraints: Vec<String>,
16    pub eval_hints: Option<EvalHints>,
17    pub metadata: Option<HashMap<String, serde_json::Value>>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct ModelConfig {
22    pub provider: ModelProvider,
23    pub model_id: String,
24    pub temperature: Option<f64>,
25    pub max_tokens: Option<u32>,
26    pub top_p: Option<f64>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
30#[serde(rename_all = "snake_case")]
31pub enum ModelProvider {
32    Openai,
33    Anthropic,
34    Ollama,
35    Bedrock,
36    NvidiaNim,
37    Custom,
38}
39
40impl std::fmt::Display for ModelProvider {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            ModelProvider::Openai => write!(f, "openai"),
44            ModelProvider::Anthropic => write!(f, "anthropic"),
45            ModelProvider::Ollama => write!(f, "ollama"),
46            ModelProvider::Bedrock => write!(f, "bedrock"),
47            ModelProvider::NvidiaNim => write!(f, "nvidia_nim"),
48            ModelProvider::Custom => write!(f, "custom"),
49        }
50    }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct ToolDefinition {
55    pub name: String,
56    pub description: String,
57    pub parameters: serde_json::Value,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
61pub struct EvalHints {
62    pub domain: Option<String>,
63    pub typical_turns: Option<u32>,
64    pub critical_tools: Vec<String>,
65    pub pass_threshold: Option<f64>,
66    pub scenario_count: Option<u32>,
67}
68
69impl Default for EvalHints {
70    fn default() -> Self {
71        Self {
72            domain: None,
73            typical_turns: Some(3),
74            critical_tools: vec![],
75            pass_threshold: Some(0.85),
76            scenario_count: Some(100),
77        }
78    }
79}
80
81/// Parsed and versioned agent file stored in the database.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct AgentVersion {
84    pub id: uuid::Uuid,
85    pub name: String,
86    pub version: String,
87    pub sha: String,
88    pub file_content: AgentFile,
89    pub raw_content: String,
90    pub format: AgentFileFormat,
91    pub promoted: bool,
92    pub is_champion: bool,
93    pub changelog: Option<String>,
94    pub parent_sha: Option<String>,
95    pub created_at: chrono::DateTime<chrono::Utc>,
96    pub updated_at: chrono::DateTime<chrono::Utc>,
97}
98
99/// Supported agent file input formats.
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
101#[serde(rename_all = "snake_case")]
102pub enum AgentFileFormat {
103    NativeYaml,
104    OpenaiJson,
105    AnthropicJson,
106    LangchainYaml,
107    CrewaiYaml,
108    /// GitHub Copilot `.agent.md` format — YAML frontmatter + Markdown system prompt body.
109    CopilotAgentMd,
110}
111
112impl std::fmt::Display for AgentFileFormat {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            AgentFileFormat::NativeYaml => write!(f, "native_yaml"),
116            AgentFileFormat::OpenaiJson => write!(f, "openai_json"),
117            AgentFileFormat::AnthropicJson => write!(f, "anthropic_json"),
118            AgentFileFormat::LangchainYaml => write!(f, "langchain_yaml"),
119            AgentFileFormat::CrewaiYaml => write!(f, "crewai_yaml"),
120            AgentFileFormat::CopilotAgentMd => write!(f, "copilot_agent_md"),
121        }
122    }
123}
124
125impl std::str::FromStr for AgentFileFormat {
126    type Err = crate::AgentForgeError;
127
128    fn from_str(s: &str) -> Result<Self, Self::Err> {
129        match s {
130            "native_yaml" => Ok(AgentFileFormat::NativeYaml),
131            "openai_json" => Ok(AgentFileFormat::OpenaiJson),
132            "anthropic_json" => Ok(AgentFileFormat::AnthropicJson),
133            "langchain_yaml" => Ok(AgentFileFormat::LangchainYaml),
134            "crewai_yaml" => Ok(AgentFileFormat::CrewaiYaml),
135            "copilot_agent_md" => Ok(AgentFileFormat::CopilotAgentMd),
136            _ => Err(crate::AgentForgeError::InvalidFormat(s.to_string())),
137        }
138    }
139}
140
141/// Lint error surfaced during agent file validation.
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct LintError {
144    pub field: String,
145    pub message: String,
146    pub severity: LintSeverity,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
150#[serde(rename_all = "snake_case")]
151pub enum LintSeverity {
152    Error,
153    Warning,
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn model_provider_display() {
162        assert_eq!(ModelProvider::Openai.to_string(), "openai");
163        assert_eq!(ModelProvider::Anthropic.to_string(), "anthropic");
164    }
165
166    #[test]
167    fn agent_file_format_roundtrip() {
168        use std::str::FromStr;
169        assert_eq!(
170            AgentFileFormat::from_str("native_yaml").unwrap(),
171            AgentFileFormat::NativeYaml
172        );
173        assert_eq!(AgentFileFormat::NativeYaml.to_string(), "native_yaml");
174    }
175
176    #[test]
177    fn eval_hints_default() {
178        let hints = EvalHints::default();
179        assert_eq!(hints.pass_threshold, Some(0.85));
180        assert_eq!(hints.scenario_count, Some(100));
181    }
182
183    // ── 9 new tests ──────────────────────────────────────────────────────────
184
185    #[test]
186    fn model_provider_display_all_variants() {
187        assert_eq!(ModelProvider::Openai.to_string(), "openai");
188        assert_eq!(ModelProvider::Anthropic.to_string(), "anthropic");
189        assert_eq!(ModelProvider::Ollama.to_string(), "ollama");
190        assert_eq!(ModelProvider::Bedrock.to_string(), "bedrock");
191        assert_eq!(ModelProvider::NvidiaNim.to_string(), "nvidia_nim");
192        assert_eq!(ModelProvider::Custom.to_string(), "custom");
193    }
194
195    #[test]
196    fn model_provider_serde_roundtrip() {
197        let json = serde_json::to_string(&ModelProvider::NvidiaNim).unwrap();
198        assert_eq!(json, r#""nvidia_nim""#);
199        let back: ModelProvider = serde_json::from_str(&json).unwrap();
200        assert_eq!(back, ModelProvider::NvidiaNim);
201    }
202
203    #[test]
204    fn agent_file_format_display_all_variants() {
205        assert_eq!(AgentFileFormat::NativeYaml.to_string(), "native_yaml");
206        assert_eq!(AgentFileFormat::OpenaiJson.to_string(), "openai_json");
207        assert_eq!(AgentFileFormat::AnthropicJson.to_string(), "anthropic_json");
208        assert_eq!(AgentFileFormat::LangchainYaml.to_string(), "langchain_yaml");
209        assert_eq!(AgentFileFormat::CrewaiYaml.to_string(), "crewai_yaml");
210        assert_eq!(
211            AgentFileFormat::CopilotAgentMd.to_string(),
212            "copilot_agent_md"
213        );
214    }
215
216    #[test]
217    fn agent_file_format_from_str_all_variants() {
218        use std::str::FromStr;
219        let pairs = [
220            ("native_yaml", AgentFileFormat::NativeYaml),
221            ("openai_json", AgentFileFormat::OpenaiJson),
222            ("anthropic_json", AgentFileFormat::AnthropicJson),
223            ("langchain_yaml", AgentFileFormat::LangchainYaml),
224            ("crewai_yaml", AgentFileFormat::CrewaiYaml),
225            ("copilot_agent_md", AgentFileFormat::CopilotAgentMd),
226        ];
227        for (s, expected) in &pairs {
228            assert_eq!(AgentFileFormat::from_str(s).unwrap(), *expected);
229        }
230    }
231
232    #[test]
233    fn agent_file_format_from_str_unknown_returns_err() {
234        use std::str::FromStr;
235        assert!(AgentFileFormat::from_str("unknown_format").is_err());
236    }
237
238    #[test]
239    fn eval_hints_default_typical_turns() {
240        let hints = EvalHints::default();
241        assert_eq!(hints.typical_turns, Some(3));
242        assert!(hints.critical_tools.is_empty());
243        assert!(hints.domain.is_none());
244    }
245
246    #[test]
247    fn lint_severity_serde() {
248        let json = serde_json::to_string(&LintSeverity::Error).unwrap();
249        assert_eq!(json, r#""error""#);
250        let back: LintSeverity = serde_json::from_str(&json).unwrap();
251        assert_eq!(back, LintSeverity::Error);
252
253        let json2 = serde_json::to_string(&LintSeverity::Warning).unwrap();
254        assert_eq!(json2, r#""warning""#);
255    }
256
257    #[test]
258    fn tool_definition_stores_fields() {
259        let tool = ToolDefinition {
260            name: "get_order".to_string(),
261            description: "Fetch an order by ID".to_string(),
262            parameters: serde_json::json!({"type": "object", "properties": {"id": {"type": "string"}}}),
263        };
264        assert_eq!(tool.name, "get_order");
265        assert_eq!(tool.parameters["type"], "object");
266    }
267
268    #[test]
269    fn agent_version_parent_sha_is_optional() {
270        let v = AgentVersion {
271            id: uuid::Uuid::new_v4(),
272            name: "test".to_string(),
273            version: "1.0.0".to_string(),
274            sha: "abc123".to_string(),
275            file_content: AgentFile {
276                agentforge_schema_version: "1".to_string(),
277                name: "test".to_string(),
278                version: "1.0.0".to_string(),
279                model: ModelConfig {
280                    provider: ModelProvider::Openai,
281                    model_id: "gpt-4o".to_string(),
282                    temperature: None,
283                    max_tokens: None,
284                    top_p: None,
285                },
286                system_prompt: "You are helpful.".to_string(),
287                tools: vec![],
288                output_schema: None,
289                constraints: vec![],
290                eval_hints: None,
291                metadata: None,
292            },
293            raw_content: "{}".to_string(),
294            format: AgentFileFormat::NativeYaml,
295            promoted: false,
296            is_champion: false,
297            changelog: None,
298            parent_sha: Some("parent_sha_123".to_string()),
299            created_at: chrono::Utc::now(),
300            updated_at: chrono::Utc::now(),
301        };
302        assert_eq!(v.parent_sha.as_deref(), Some("parent_sha_123"));
303    }
304}