1#![allow(dead_code)]
3
4use std::collections::HashMap;
5use std::path::Path;
6use std::sync::Arc;
7
8use super::agent_memory::AgentMemoryScope;
9
10#[derive(Debug, Clone)]
12pub enum AgentMcpServerSpec {
13 Reference(String),
15 Inline {
17 name: String,
18 config: serde_json::Value,
19 },
20}
21
22#[derive(Clone)]
24pub struct AgentDefinition {
25 pub agent_type: String,
26 pub when_to_use: String,
27 pub tools: Vec<String>,
28 pub disallowed_tools: Vec<String>,
29 pub source: String,
30 pub base_dir: String,
31 pub get_system_prompt: Arc<dyn Fn() -> String + Send + Sync>,
32 pub model: Option<String>,
33 pub max_turns: Option<usize>,
34 pub permission_mode: Option<String>,
35 pub effort: Option<String>,
36 pub color: Option<String>,
37 pub mcp_servers: Vec<AgentMcpServerSpec>,
38 pub hooks: Option<serde_json::Value>,
39 pub skills: Vec<String>,
40 pub background: bool,
41 pub initial_prompt: Option<String>,
42 pub memory: Option<AgentMemoryScope>,
43 pub isolation: Option<String>,
44 pub required_mcp_servers: Vec<String>,
45 pub omit_claude_md: bool,
46 pub critical_system_reminder_experimental: Option<String>,
47}
48
49impl std::fmt::Debug for AgentDefinition {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("AgentDefinition")
52 .field("agent_type", &self.agent_type)
53 .field("when_to_use", &self.when_to_use)
54 .field("tools", &self.tools)
55 .field("disallowed_tools", &self.disallowed_tools)
56 .field("source", &self.source)
57 .field("base_dir", &self.base_dir)
58 .field("model", &self.model)
59 .field("max_turns", &self.max_turns)
60 .field("permission_mode", &self.permission_mode)
61 .field("effort", &self.effort)
62 .field("color", &self.color)
63 .field("mcp_servers", &self.mcp_servers)
64 .field("skills", &self.skills)
65 .field("background", &self.background)
66 .field("initial_prompt", &self.initial_prompt)
67 .field("memory", &self.memory)
68 .field("isolation", &self.isolation)
69 .field("required_mcp_servers", &self.required_mcp_servers)
70 .field("omit_claude_md", &self.omit_claude_md)
71 .field(
72 "critical_system_reminder_experimental",
73 &self.critical_system_reminder_experimental,
74 )
75 .finish_non_exhaustive()
76 }
77}
78
79impl AgentDefinition {
80 pub fn system_prompt(&self) -> String {
81 (self.get_system_prompt)()
82 }
83
84 pub fn is_built_in(&self) -> bool {
85 self.source == "built-in"
86 }
87}
88
89pub struct AgentDefinitionsResult {
91 pub active_agents: Vec<AgentDefinition>,
92 pub all_agents: Vec<AgentDefinition>,
93 pub failed_files: Vec<(String, String)>,
94 pub allowed_agent_types: Option<Vec<String>>,
95}
96
97pub fn get_active_agents_from_list(all_agents: &[AgentDefinition]) -> Vec<AgentDefinition> {
100 let priority = [
102 "built-in",
103 "plugin",
104 "userSettings",
105 "projectSettings",
106 "flagSettings",
107 "policySettings",
108 ];
109
110 let mut agent_map: HashMap<String, (usize, AgentDefinition)> = HashMap::new();
111
112 for agent in all_agents {
113 let priority_idx = priority
114 .iter()
115 .position(|&p| p == agent.source)
116 .unwrap_or(0);
117 let entry = agent_map.entry(agent.agent_type.clone());
118 entry
119 .and_modify(|(existing_priority, existing_agent)| {
120 if priority_idx > *existing_priority {
121 *existing_priority = priority_idx;
122 *existing_agent = agent.clone();
123 }
124 })
125 .or_insert((priority_idx, agent.clone()));
126 }
127
128 agent_map.into_values().map(|(_, agent)| agent).collect()
129}
130
131pub fn has_required_mcp_servers(agent: &AgentDefinition, available_servers: &[&str]) -> bool {
133 if agent.required_mcp_servers.is_empty() {
134 return true;
135 }
136 agent.required_mcp_servers.iter().all(|pattern| {
137 available_servers
138 .iter()
139 .any(|server| server.to_lowercase().contains(&pattern.to_lowercase()))
140 })
141}
142
143pub fn filter_agents_by_mcp_requirements<'a>(
145 agents: impl IntoIterator<Item = &'a AgentDefinition>,
146 available_servers: &[&str],
147) -> Vec<&'a AgentDefinition> {
148 agents
149 .into_iter()
150 .filter(|agent| has_required_mcp_servers(agent, available_servers))
151 .collect()
152}
153
154pub fn parse_agent_from_json(
156 name: &str,
157 definition: &serde_json::Value,
158 source: &str,
159) -> Option<AgentDefinition> {
160 let when_to_use = definition.get("description")?.as_str()?.to_string();
161 let prompt = definition.get("prompt")?.as_str()?.to_string();
162
163 let tools = definition
164 .get("tools")
165 .and_then(|t| t.as_array())
166 .map(|arr| {
167 arr.iter()
168 .filter_map(|v| v.as_str())
169 .map(|s| s.to_string())
170 .collect()
171 });
172
173 let disallowed_tools = definition
174 .get("disallowedTools")
175 .and_then(|t| t.as_array())
176 .map(|arr| {
177 arr.iter()
178 .filter_map(|v| v.as_str())
179 .map(|s| s.to_string())
180 .collect()
181 })
182 .unwrap_or_default();
183
184 let model = definition.get("model").and_then(|m| m.as_str()).map(|m| {
185 let trimmed = m.trim();
186 if trimmed.to_lowercase() == "inherit" {
187 "inherit".to_string()
188 } else {
189 trimmed.to_string()
190 }
191 });
192
193 let max_turns = definition
194 .get("maxTurns")
195 .and_then(|v| v.as_u64())
196 .map(|v| v as usize);
197
198 let permission_mode = definition
199 .get("permissionMode")
200 .and_then(|v| v.as_str())
201 .map(|s| s.to_string());
202
203 let effort = definition.get("effort").map(|v| v.to_string());
204
205 let background = definition
206 .get("background")
207 .and_then(|v| v.as_bool())
208 .unwrap_or(false);
209
210 let memory = definition
211 .get("memory")
212 .and_then(|v| v.as_str())
213 .and_then(AgentMemoryScope::from_str);
214
215 let isolation = definition
216 .get("isolation")
217 .and_then(|v| v.as_str())
218 .map(|s| s.to_string());
219
220 let initial_prompt = definition
221 .get("initialPrompt")
222 .and_then(|v| v.as_str())
223 .map(|s| s.to_string());
224
225 let skills = definition
226 .get("skills")
227 .and_then(|v| v.as_array())
228 .map(|arr| {
229 arr.iter()
230 .filter_map(|v| v.as_str())
231 .map(|s| s.to_string())
232 .collect()
233 })
234 .unwrap_or_default();
235
236 let memory_prompt = if memory.is_some() {
238 Some(super::agent_memory::load_agent_memory_prompt(
239 name,
240 memory.unwrap(),
241 ))
242 } else {
243 None
244 };
245
246 let system_prompt = prompt.clone();
247 let get_system_prompt: Arc<dyn Fn() -> String + Send + Sync> = Arc::new(move || {
248 if let Some(ref mp) = memory_prompt {
249 format!("{}\n\n{}", system_prompt, mp)
250 } else {
251 system_prompt.clone()
252 }
253 });
254
255 let mcp_servers = definition
257 .get("mcpServers")
258 .and_then(|v| v.as_array())
259 .map(|arr| {
260 arr.iter()
261 .filter_map(|item| {
262 if let Some(s) = item.as_str() {
263 Some(AgentMcpServerSpec::Reference(s.to_string()))
264 } else if let Some(obj) = item.as_object() {
265 if let Some(name) = obj.keys().next() {
266 Some(AgentMcpServerSpec::Inline {
267 name: name.clone(),
268 config: obj[name].clone(),
269 })
270 } else {
271 None
272 }
273 } else {
274 None
275 }
276 })
277 .collect()
278 })
279 .unwrap_or_default();
280
281 Some(AgentDefinition {
282 agent_type: name.to_string(),
283 when_to_use,
284 tools: tools.unwrap_or_default(),
285 disallowed_tools,
286 source: source.to_string(),
287 base_dir: source.to_string(),
288 get_system_prompt,
289 model,
290 max_turns,
291 permission_mode,
292 effort,
293 color: None,
294 mcp_servers,
295 hooks: definition.get("hooks").cloned(),
296 skills,
297 background,
298 initial_prompt,
299 memory,
300 isolation,
301 required_mcp_servers: vec![],
302 omit_claude_md: false,
303 critical_system_reminder_experimental: None,
304 })
305}
306
307pub fn parse_agents_from_json(
309 agents_json: &serde_json::Value,
310 source: &str,
311) -> Vec<AgentDefinition> {
312 if let Some(obj) = agents_json.as_object() {
313 obj.iter()
314 .filter_map(|(name, def)| parse_agent_from_json(name, def, source))
315 .collect()
316 } else {
317 vec![]
318 }
319}
320
321pub fn parse_agent_tools_from_frontmatter(value: &serde_json::Value) -> Option<Vec<String>> {
323 if let Some(arr) = value.as_array() {
324 Some(
325 arr.iter()
326 .filter_map(|v| v.as_str())
327 .map(|s| s.to_string())
328 .collect(),
329 )
330 } else if let Some(s) = value.as_str() {
331 if s.is_empty() {
332 return None;
333 }
334 Some(
335 s.split(',')
336 .map(|s| s.trim().to_string())
337 .filter(|s| !s.is_empty())
338 .collect(),
339 )
340 } else {
341 None
342 }
343}
344
345pub fn parse_slash_command_tools_from_frontmatter(value: &serde_json::Value) -> Vec<String> {
347 parse_agent_tools_from_frontmatter(value).unwrap_or_default()
348}
349
350pub fn load_agents_dir(cwd: &Path) -> AgentDefinitionsResult {
353 let agents_dir = cwd.join(".claude").join("agents");
354
355 if !agents_dir.exists() {
356 let built_ins = super::built_in_agents::get_built_in_agents();
357 return AgentDefinitionsResult {
358 active_agents: get_active_agents_from_list(&built_ins),
359 all_agents: built_ins,
360 failed_files: vec![],
361 allowed_agent_types: None,
362 };
363 }
364
365 let mut all_agents = super::built_in_agents::get_built_in_agents();
366 let mut failed_files: Vec<(String, String)> = Vec::new();
367
368 if let Ok(entries) = std::fs::read_dir(&agents_dir) {
370 for entry in entries.flatten() {
371 let path = entry.path();
372 if path.extension().and_then(|e| e.to_str()) != Some("md") {
373 continue;
374 }
375
376 match parse_agent_from_markdown(&path) {
377 Some(agent) => all_agents.push(agent),
378 None => {
379 failed_files.push((
380 path.display().to_string(),
381 "Failed to parse agent definition".to_string(),
382 ));
383 }
384 }
385 }
386 }
387
388 let active_agents = get_active_agents_from_list(&all_agents);
389
390 AgentDefinitionsResult {
391 active_agents,
392 all_agents,
393 failed_files,
394 allowed_agent_types: None,
395 }
396}
397
398fn parse_agent_from_markdown(path: &Path) -> Option<AgentDefinition> {
401 let content = std::fs::read_to_string(path).ok()?;
402
403 let (frontmatter, body) = parse_markdown_frontmatter(&content)?;
405
406 let agent_type = frontmatter.get("name")?.as_str()?.to_string();
407 let when_to_use = frontmatter
408 .get("description")?
409 .as_str()?
410 .replace("\\n", "\n");
411
412 let model = frontmatter.get("model").and_then(|v| {
414 v.as_str().map(|m| {
415 let trimmed = m.trim();
416 if trimmed.to_lowercase() == "inherit" {
417 "inherit".to_string()
418 } else {
419 trimmed.to_string()
420 }
421 })
422 });
423
424 let background = frontmatter
425 .get("background")
426 .and_then(|v| v.as_bool())
427 .unwrap_or(false);
428
429 let memory = frontmatter
430 .get("memory")
431 .and_then(|v| v.as_str())
432 .and_then(AgentMemoryScope::from_str);
433
434 let isolation = frontmatter
435 .get("isolation")
436 .and_then(|v| v.as_str())
437 .map(|s| s.to_string());
438
439 let max_turns = frontmatter
440 .get("maxTurns")
441 .and_then(|v| v.as_u64())
442 .map(|v| v as usize);
443
444 let permission_mode = frontmatter
445 .get("permissionMode")
446 .and_then(|v| v.as_str())
447 .map(|s| s.to_string());
448
449 let effort = frontmatter.get("effort").map(|v| v.to_string());
450
451 let initial_prompt = frontmatter
452 .get("initialPrompt")
453 .and_then(|v| v.as_str())
454 .map(|s| s.to_string());
455
456 let color = frontmatter
457 .get("color")
458 .and_then(|v| v.as_str())
459 .map(|s| s.to_string());
460
461 let tools = frontmatter
462 .get("tools")
463 .and_then(parse_agent_tools_from_frontmatter)
464 .unwrap_or_default();
465
466 let disallowed_tools = frontmatter
467 .get("disallowedTools")
468 .and_then(parse_agent_tools_from_frontmatter)
469 .unwrap_or_default();
470
471 let skills = parse_slash_command_tools_from_frontmatter(
472 frontmatter
473 .get("skills")
474 .unwrap_or(&serde_json::Value::Null),
475 );
476
477 let system_prompt = body.trim().to_string();
478
479 let memory_prompt_val =
481 memory.map(|m| super::agent_memory::load_agent_memory_prompt(&agent_type, m));
482
483 let get_system_prompt: Arc<dyn Fn() -> String + Send + Sync> = {
484 let prompt = system_prompt.clone();
485 let memory_prompt = memory_prompt_val.clone();
486 Arc::new(move || {
487 if let Some(ref mp) = memory_prompt {
488 format!("{}\n\n{}", prompt, mp)
489 } else {
490 prompt.clone()
491 }
492 })
493 };
494
495 let mcp_servers = frontmatter
497 .get("mcpServers")
498 .and_then(|v| v.as_array())
499 .map(|arr| {
500 arr.iter()
501 .filter_map(|item| {
502 if let Some(s) = item.as_str() {
503 Some(AgentMcpServerSpec::Reference(s.to_string()))
504 } else {
505 None
506 }
507 })
508 .collect()
509 })
510 .unwrap_or_default();
511
512 let filename = path
513 .file_stem()
514 .and_then(|s| s.to_str())
515 .unwrap_or("")
516 .to_string();
517
518 Some(AgentDefinition {
519 agent_type,
520 when_to_use,
521 tools,
522 disallowed_tools,
523 source: "userSettings".to_string(),
524 base_dir: "agents".to_string(),
525 get_system_prompt,
526 model,
527 max_turns,
528 permission_mode,
529 effort,
530 color,
531 mcp_servers,
532 hooks: frontmatter.get("hooks").cloned(),
533 skills,
534 background,
535 initial_prompt,
536 memory,
537 isolation,
538 required_mcp_servers: vec![],
539 omit_claude_md: false,
540 critical_system_reminder_experimental: None,
541 })
542}
543
544fn parse_markdown_frontmatter(content: &str) -> Option<(serde_json::Value, String)> {
547 let content = content.trim();
548 if !content.starts_with("---") {
549 return None;
550 }
551
552 let rest = &content[3..];
553 let end = rest.find("---")?;
554 let yaml_str = &rest[..end].trim();
555
556 let mut map: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
558
559 for line in yaml_str.lines() {
560 let line = line.trim();
561 if line.is_empty() || line.starts_with('#') {
562 continue;
563 }
564
565 if let Some(pos) = line.find(':') {
566 let key = line[..pos].trim().to_string();
567 let value = line[pos + 1..].trim();
568
569 if value.is_empty() {
570 continue;
571 }
572
573 let json_value = if value.starts_with('[') {
574 serde_json::from_str(value)
576 .ok()
577 .unwrap_or(serde_json::Value::String(value.to_string()))
578 } else if value.starts_with('{') {
579 serde_json::from_str(value)
581 .ok()
582 .unwrap_or(serde_json::Value::String(value.to_string()))
583 } else if let Ok(b) = value.parse::<bool>() {
584 serde_json::Value::Bool(b)
585 } else if let Ok(n) = value.parse::<u64>() {
586 serde_json::Value::Number(serde_json::Number::from(n))
587 } else {
588 let trimmed = value.trim_matches(|c: char| c == '"' || c == '\'');
590 serde_json::Value::String(trimmed.to_string())
591 };
592
593 map.insert(key, json_value);
594 }
595 }
596
597 let body = content[3 + end + 3..].trim().to_string();
598 Some((serde_json::Value::Object(map), body))
599}
600
601pub async fn initialize_agent_memory_snapshots(agents: &mut [AgentDefinition]) {
603 for agent in agents.iter_mut() {
604 if let Some(scope) = agent.memory {
605 match super::agent_memory_snapshot::check_agent_memory_snapshot(
606 &agent.agent_type,
607 scope,
608 )
609 .await
610 {
611 super::agent_memory_snapshot::SnapshotAction::Initialize {
612 ref snapshot_timestamp,
613 } => {
614 log::debug!(
615 "Initializing {} memory from project snapshot",
616 agent.agent_type
617 );
618 let _ = super::agent_memory_snapshot::initialize_from_snapshot(
619 &agent.agent_type,
620 scope,
621 snapshot_timestamp,
622 )
623 .await;
624 }
625 super::agent_memory_snapshot::SnapshotAction::PromptUpdate {
626 ref snapshot_timestamp,
627 } => {
628 log::debug!("Newer snapshot available for {} memory", agent.agent_type);
629 let _ = snapshot_timestamp.clone();
631 }
632 _ => {}
633 }
634 }
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 fn make_agent(agent_type: &str, source: &str) -> AgentDefinition {
643 AgentDefinition {
644 agent_type: agent_type.to_string(),
645 when_to_use: "test".to_string(),
646 tools: vec!["*".to_string()],
647 disallowed_tools: vec![],
648 source: source.to_string(),
649 base_dir: source.to_string(),
650 get_system_prompt: Arc::new(|| String::new()),
651 model: None,
652 max_turns: None,
653 permission_mode: None,
654 effort: None,
655 color: None,
656 mcp_servers: vec![],
657 hooks: None,
658 skills: vec![],
659 background: false,
660 initial_prompt: None,
661 memory: None,
662 isolation: None,
663 required_mcp_servers: vec![],
664 omit_claude_md: false,
665 critical_system_reminder_experimental: None,
666 }
667 }
668
669 #[test]
670 fn test_get_active_agents_priority() {
671 let agents = vec![
672 make_agent("test", "built-in"),
673 make_agent("test", "userSettings"),
674 ];
675 let active = get_active_agents_from_list(&agents);
676 assert_eq!(active.len(), 1);
677 assert_eq!(active[0].source, "userSettings");
678 }
679
680 #[test]
681 fn test_parse_markdown_frontmatter() {
682 let content = r#"---
683name: test-agent
684description: A test agent
685tools: [Bash, Read]
686---
687
688System prompt content"#;
689 let (fm, body) = parse_markdown_frontmatter(content).unwrap();
690 assert_eq!(fm["name"].as_str().unwrap(), "test-agent");
691 assert_eq!(body, "System prompt content");
692 }
693
694 #[test]
695 fn test_has_required_mcp_servers() {
696 let agent = make_agent("test", "built-in");
697 assert!(has_required_mcp_servers(&agent, &[]));
698
699 let agent_with_req = AgentDefinition {
700 required_mcp_servers: vec!["slack".to_string()],
701 ..agent
702 };
703 assert!(has_required_mcp_servers(&agent_with_req, &["slack-api"]));
704 assert!(!has_required_mcp_servers(&agent_with_req, &["other"]));
705 }
706}