Skip to main content

repl_core/dsl/
agent_composition.rs

1//! Agent composition builtins for the DSL
2//!
3//! Provides async builtins for spawning agents, sending messages,
4//! and executing concurrent patterns: `spawn_agent`, `ask`, `send_to`,
5//! `parallel`, and `race`.
6
7use crate::dsl::evaluator::DslValue;
8use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11
12/// Execute the `spawn_agent` builtin: register a new named agent.
13///
14/// Arguments (named via map or positional):
15/// - name: string — agent name
16/// - system: string — system prompt
17/// - tools: list of strings (optional)
18/// - response_format: string (optional)
19///
20/// Returns a map with `agent_id` and `name`.
21pub async fn builtin_spawn_agent(
22    args: &[DslValue],
23    ctx: &ReasoningBuiltinContext,
24) -> Result<DslValue> {
25    let registry = ctx
26        .agent_registry
27        .as_ref()
28        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
29
30    let (name, system_prompt, tools, response_format) = parse_spawn_args(args)?;
31
32    let agent_id = registry
33        .spawn_agent(&name, &system_prompt, tools, response_format)
34        .await;
35
36    let mut result = HashMap::new();
37    result.insert(
38        "agent_id".to_string(),
39        DslValue::String(agent_id.to_string()),
40    );
41    result.insert("name".to_string(), DslValue::String(name));
42    Ok(DslValue::Map(result))
43}
44
45/// Execute the `ask` builtin: send a message to a named agent and wait for response.
46///
47/// Arguments:
48/// - agent: string — agent name
49/// - message: string
50///
51/// Returns the agent's response as a string.
52pub async fn builtin_ask(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
53    let registry = ctx
54        .agent_registry
55        .as_ref()
56        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
57
58    let provider = ctx
59        .provider
60        .as_ref()
61        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
62
63    let (agent_name, message) = parse_ask_args(args)?;
64
65    let response = registry
66        .ask_agent(&agent_name, &message, provider.as_ref())
67        .await
68        .map_err(|e| ReplError::Execution(format!("ask({}) failed: {}", agent_name, e)))?;
69
70    Ok(DslValue::String(response))
71}
72
73/// Execute the `send_to` builtin: fire-and-forget message to a named agent.
74///
75/// Arguments:
76/// - agent: string — agent name
77/// - message: string
78///
79/// Returns null (fire-and-forget).
80pub async fn builtin_send_to(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
81    let registry = ctx
82        .agent_registry
83        .as_ref()
84        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
85
86    let provider = ctx
87        .provider
88        .as_ref()
89        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
90
91    let (agent_name, message) = parse_ask_args(args)?;
92
93    if !registry.has_agent(&agent_name).await {
94        return Err(ReplError::Execution(format!(
95            "Agent '{}' not found",
96            agent_name
97        )));
98    }
99
100    // Fire-and-forget: spawn a background task
101    let registry = registry.clone();
102    let provider = provider.clone();
103    tokio::spawn(async move {
104        let _ = registry
105            .ask_agent(&agent_name, &message, provider.as_ref())
106            .await;
107    });
108
109    Ok(DslValue::Null)
110}
111
112/// Execute the `parallel` builtin: run multiple agent calls concurrently.
113///
114/// Arguments:
115/// - tasks: list of maps, each with `{agent: string, message: string}`
116///
117/// Returns a list of results (strings or error maps).
118pub async fn builtin_parallel(
119    args: &[DslValue],
120    ctx: &ReasoningBuiltinContext,
121) -> Result<DslValue> {
122    let registry = ctx
123        .agent_registry
124        .as_ref()
125        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
126
127    let provider = ctx
128        .provider
129        .as_ref()
130        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
131
132    let tasks = parse_parallel_args(args)?;
133
134    let mut handles = Vec::new();
135    for (agent_name, message) in tasks {
136        let registry = registry.clone();
137        let provider = provider.clone();
138        handles.push(tokio::spawn(async move {
139            registry
140                .ask_agent(&agent_name, &message, provider.as_ref())
141                .await
142                .map_err(|e| format!("{}", e))
143        }));
144    }
145
146    let mut results = Vec::new();
147    for handle in handles {
148        match handle.await {
149            Ok(Ok(response)) => results.push(DslValue::String(response)),
150            Ok(Err(e)) => {
151                let mut error_map = HashMap::new();
152                error_map.insert("error".to_string(), DslValue::String(e));
153                results.push(DslValue::Map(error_map));
154            }
155            Err(e) => {
156                let mut error_map = HashMap::new();
157                error_map.insert("error".to_string(), DslValue::String(e.to_string()));
158                results.push(DslValue::Map(error_map));
159            }
160        }
161    }
162
163    Ok(DslValue::List(results))
164}
165
166/// Execute the `race` builtin: run multiple agent calls, return first to complete.
167///
168/// Arguments:
169/// - tasks: list of maps, each with `{agent: string, message: string}`
170///
171/// Returns the first successful result as a string.
172pub async fn builtin_race(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
173    let registry = ctx
174        .agent_registry
175        .as_ref()
176        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
177
178    let provider = ctx
179        .provider
180        .as_ref()
181        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
182
183    let tasks = parse_parallel_args(args)?;
184
185    if tasks.is_empty() {
186        return Err(ReplError::Execution(
187            "race requires at least one task".into(),
188        ));
189    }
190
191    let mut join_set = tokio::task::JoinSet::new();
192    for (agent_name, message) in tasks {
193        let registry = registry.clone();
194        let provider = provider.clone();
195        join_set.spawn(async move {
196            registry
197                .ask_agent(&agent_name, &message, provider.as_ref())
198                .await
199                .map_err(|e| format!("{}", e))
200        });
201    }
202
203    // Return the first completed result
204    match join_set.join_next().await {
205        Some(Ok(Ok(response))) => {
206            join_set.abort_all();
207            Ok(DslValue::String(response))
208        }
209        Some(Ok(Err(e))) => {
210            join_set.abort_all();
211            Err(ReplError::Execution(format!(
212                "race: first completed with error: {}",
213                e
214            )))
215        }
216        Some(Err(e)) => {
217            join_set.abort_all();
218            Err(ReplError::Execution(format!("race: task panic: {}", e)))
219        }
220        None => Err(ReplError::Execution("race: no tasks to run".into())),
221    }
222}
223
224// --- Argument parsing helpers ---
225
226fn parse_spawn_args(args: &[DslValue]) -> Result<(String, String, Vec<String>, Option<String>)> {
227    match args {
228        [DslValue::Map(map)] => {
229            let name = extract_string(map, "name")?;
230            let system = extract_string(map, "system")?;
231            let tools = map
232                .get("tools")
233                .and_then(|v| match v {
234                    DslValue::List(items) => Some(
235                        items
236                            .iter()
237                            .filter_map(|i| match i {
238                                DslValue::String(s) => Some(s.clone()),
239                                _ => None,
240                            })
241                            .collect(),
242                    ),
243                    _ => None,
244                })
245                .unwrap_or_default();
246            let response_format = map.get("response_format").and_then(|v| match v {
247                DslValue::String(s) => Some(s.clone()),
248                _ => None,
249            });
250            Ok((name, system, tools, response_format))
251        }
252        [DslValue::String(name), DslValue::String(system)] => {
253            Ok((name.clone(), system.clone(), Vec::new(), None))
254        }
255        [DslValue::String(name), DslValue::String(system), DslValue::List(tools)] => {
256            let tool_names = tools
257                .iter()
258                .filter_map(|t| match t {
259                    DslValue::String(s) => Some(s.clone()),
260                    _ => None,
261                })
262                .collect();
263            Ok((name.clone(), system.clone(), tool_names, None))
264        }
265        _ => Err(ReplError::Execution(
266            "spawn_agent requires (name: string, system: string, [tools?, response_format?])"
267                .into(),
268        )),
269    }
270}
271
272fn parse_ask_args(args: &[DslValue]) -> Result<(String, String)> {
273    match args {
274        [DslValue::String(agent), DslValue::String(message)] => {
275            Ok((agent.clone(), message.clone()))
276        }
277        [DslValue::Map(map)] => {
278            let agent = extract_string(map, "agent")?;
279            let message = extract_string(map, "message")?;
280            Ok((agent, message))
281        }
282        _ => Err(ReplError::Execution(
283            "requires (agent: string, message: string)".into(),
284        )),
285    }
286}
287
288fn parse_parallel_args(args: &[DslValue]) -> Result<Vec<(String, String)>> {
289    match args {
290        [DslValue::List(items)] => {
291            let mut tasks = Vec::new();
292            for item in items {
293                match item {
294                    DslValue::Map(map) => {
295                        let agent = extract_string(map, "agent")?;
296                        let message = extract_string(map, "message")?;
297                        tasks.push((agent, message));
298                    }
299                    _ => {
300                        return Err(ReplError::Execution(
301                            "parallel/race items must be maps with {agent, message}".into(),
302                        ))
303                    }
304                }
305            }
306            Ok(tasks)
307        }
308        _ => Err(ReplError::Execution(
309            "parallel/race requires a list of {agent, message} maps".into(),
310        )),
311    }
312}
313
314fn extract_string(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
315    map.get(key)
316        .and_then(|v| match v {
317            DslValue::String(s) => Some(s.clone()),
318            _ => None,
319        })
320        .ok_or_else(|| ReplError::Execution(format!("Missing required string argument '{}'", key)))
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_parse_spawn_args_named() {
329        let mut map = HashMap::new();
330        map.insert("name".into(), DslValue::String("researcher".into()));
331        map.insert("system".into(), DslValue::String("You research.".into()));
332        map.insert(
333            "tools".into(),
334            DslValue::List(vec![DslValue::String("search".into())]),
335        );
336
337        let (name, system, tools, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
338        assert_eq!(name, "researcher");
339        assert_eq!(system, "You research.");
340        assert_eq!(tools, vec!["search"]);
341        assert!(format.is_none());
342    }
343
344    #[test]
345    fn test_parse_spawn_args_positional() {
346        let args = vec![
347            DslValue::String("coder".into()),
348            DslValue::String("You code.".into()),
349        ];
350        let (name, system, tools, format) = parse_spawn_args(&args).unwrap();
351        assert_eq!(name, "coder");
352        assert_eq!(system, "You code.");
353        assert!(tools.is_empty());
354        assert!(format.is_none());
355    }
356
357    #[test]
358    fn test_parse_spawn_args_with_tools() {
359        let args = vec![
360            DslValue::String("worker".into()),
361            DslValue::String("You work.".into()),
362            DslValue::List(vec![
363                DslValue::String("read".into()),
364                DslValue::String("write".into()),
365            ]),
366        ];
367        let (name, system, tools, _) = parse_spawn_args(&args).unwrap();
368        assert_eq!(name, "worker");
369        assert_eq!(system, "You work.");
370        assert_eq!(tools, vec!["read", "write"]);
371    }
372
373    #[test]
374    fn test_parse_spawn_args_with_response_format() {
375        let mut map = HashMap::new();
376        map.insert("name".into(), DslValue::String("parser".into()));
377        map.insert("system".into(), DslValue::String("Parse data.".into()));
378        map.insert("response_format".into(), DslValue::String("json".into()));
379
380        let (_, _, _, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
381        assert_eq!(format, Some("json".into()));
382    }
383
384    #[test]
385    fn test_parse_ask_args_positional() {
386        let args = vec![
387            DslValue::String("agent1".into()),
388            DslValue::String("hello".into()),
389        ];
390        let (agent, msg) = parse_ask_args(&args).unwrap();
391        assert_eq!(agent, "agent1");
392        assert_eq!(msg, "hello");
393    }
394
395    #[test]
396    fn test_parse_ask_args_named() {
397        let mut map = HashMap::new();
398        map.insert("agent".into(), DslValue::String("bot".into()));
399        map.insert("message".into(), DslValue::String("hi".into()));
400        let (agent, msg) = parse_ask_args(&[DslValue::Map(map)]).unwrap();
401        assert_eq!(agent, "bot");
402        assert_eq!(msg, "hi");
403    }
404
405    #[test]
406    fn test_parse_parallel_args() {
407        let mut task1 = HashMap::new();
408        task1.insert("agent".into(), DslValue::String("a".into()));
409        task1.insert("message".into(), DslValue::String("m1".into()));
410
411        let mut task2 = HashMap::new();
412        task2.insert("agent".into(), DslValue::String("b".into()));
413        task2.insert("message".into(), DslValue::String("m2".into()));
414
415        let args = vec![DslValue::List(vec![
416            DslValue::Map(task1),
417            DslValue::Map(task2),
418        ])];
419        let tasks = parse_parallel_args(&args).unwrap();
420        assert_eq!(tasks.len(), 2);
421        assert_eq!(tasks[0], ("a".into(), "m1".into()));
422        assert_eq!(tasks[1], ("b".into(), "m2".into()));
423    }
424
425    #[test]
426    fn test_parse_spawn_args_missing_name() {
427        let map = HashMap::new();
428        assert!(parse_spawn_args(&[DslValue::Map(map)]).is_err());
429    }
430
431    #[test]
432    fn test_parse_ask_args_invalid() {
433        assert!(parse_ask_args(&[DslValue::Integer(42)]).is_err());
434    }
435
436    #[test]
437    fn test_parse_parallel_args_empty_list() {
438        let args = vec![DslValue::List(vec![])];
439        let tasks = parse_parallel_args(&args).unwrap();
440        assert!(tasks.is_empty());
441    }
442
443    #[test]
444    fn test_parse_parallel_args_invalid_item() {
445        let args = vec![DslValue::List(vec![DslValue::String("not a map".into())])];
446        assert!(parse_parallel_args(&args).is_err());
447    }
448}