1use crate::dsl::evaluator::DslValue;
8use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11
12pub async fn builtin_spawn_agent(
22 args: &[DslValue],
23 ctx: &ReasoningBuiltinContext,
24) -> Result<DslValue> {
25 let registry = ctx
26 .agent_registry
27 .as_ref()
28 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
29
30 let (name, system_prompt, tools, response_format) = parse_spawn_args(args)?;
31
32 let agent_id = registry
33 .spawn_agent(&name, &system_prompt, tools, response_format)
34 .await;
35
36 let mut result = HashMap::new();
37 result.insert(
38 "agent_id".to_string(),
39 DslValue::String(agent_id.to_string()),
40 );
41 result.insert("name".to_string(), DslValue::String(name));
42 Ok(DslValue::Map(result))
43}
44
45pub async fn builtin_ask(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
53 let registry = ctx
54 .agent_registry
55 .as_ref()
56 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
57
58 let provider = ctx
59 .provider
60 .as_ref()
61 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
62
63 let (agent_name, message) = parse_ask_args(args)?;
64
65 let response = registry
66 .ask_agent(&agent_name, &message, provider.as_ref())
67 .await
68 .map_err(|e| ReplError::Execution(format!("ask({}) failed: {}", agent_name, e)))?;
69
70 Ok(DslValue::String(response))
71}
72
73pub async fn builtin_send_to(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
81 let registry = ctx
82 .agent_registry
83 .as_ref()
84 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
85
86 let provider = ctx
87 .provider
88 .as_ref()
89 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
90
91 let (agent_name, message) = parse_ask_args(args)?;
92
93 if !registry.has_agent(&agent_name).await {
94 return Err(ReplError::Execution(format!(
95 "Agent '{}' not found",
96 agent_name
97 )));
98 }
99
100 let registry = registry.clone();
102 let provider = provider.clone();
103 tokio::spawn(async move {
104 let _ = registry
105 .ask_agent(&agent_name, &message, provider.as_ref())
106 .await;
107 });
108
109 Ok(DslValue::Null)
110}
111
112pub async fn builtin_parallel(
119 args: &[DslValue],
120 ctx: &ReasoningBuiltinContext,
121) -> Result<DslValue> {
122 let registry = ctx
123 .agent_registry
124 .as_ref()
125 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
126
127 let provider = ctx
128 .provider
129 .as_ref()
130 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
131
132 let tasks = parse_parallel_args(args)?;
133
134 let mut handles = Vec::new();
135 for (agent_name, message) in tasks {
136 let registry = registry.clone();
137 let provider = provider.clone();
138 handles.push(tokio::spawn(async move {
139 registry
140 .ask_agent(&agent_name, &message, provider.as_ref())
141 .await
142 .map_err(|e| format!("{}", e))
143 }));
144 }
145
146 let mut results = Vec::new();
147 for handle in handles {
148 match handle.await {
149 Ok(Ok(response)) => results.push(DslValue::String(response)),
150 Ok(Err(e)) => {
151 let mut error_map = HashMap::new();
152 error_map.insert("error".to_string(), DslValue::String(e));
153 results.push(DslValue::Map(error_map));
154 }
155 Err(e) => {
156 let mut error_map = HashMap::new();
157 error_map.insert("error".to_string(), DslValue::String(e.to_string()));
158 results.push(DslValue::Map(error_map));
159 }
160 }
161 }
162
163 Ok(DslValue::List(results))
164}
165
166pub async fn builtin_race(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
173 let registry = ctx
174 .agent_registry
175 .as_ref()
176 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
177
178 let provider = ctx
179 .provider
180 .as_ref()
181 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
182
183 let tasks = parse_parallel_args(args)?;
184
185 if tasks.is_empty() {
186 return Err(ReplError::Execution(
187 "race requires at least one task".into(),
188 ));
189 }
190
191 let mut join_set = tokio::task::JoinSet::new();
192 for (agent_name, message) in tasks {
193 let registry = registry.clone();
194 let provider = provider.clone();
195 join_set.spawn(async move {
196 registry
197 .ask_agent(&agent_name, &message, provider.as_ref())
198 .await
199 .map_err(|e| format!("{}", e))
200 });
201 }
202
203 match join_set.join_next().await {
205 Some(Ok(Ok(response))) => {
206 join_set.abort_all();
207 Ok(DslValue::String(response))
208 }
209 Some(Ok(Err(e))) => {
210 join_set.abort_all();
211 Err(ReplError::Execution(format!(
212 "race: first completed with error: {}",
213 e
214 )))
215 }
216 Some(Err(e)) => {
217 join_set.abort_all();
218 Err(ReplError::Execution(format!("race: task panic: {}", e)))
219 }
220 None => Err(ReplError::Execution("race: no tasks to run".into())),
221 }
222}
223
224fn parse_spawn_args(args: &[DslValue]) -> Result<(String, String, Vec<String>, Option<String>)> {
227 match args {
228 [DslValue::Map(map)] => {
229 let name = extract_string(map, "name")?;
230 let system = extract_string(map, "system")?;
231 let tools = map
232 .get("tools")
233 .and_then(|v| match v {
234 DslValue::List(items) => Some(
235 items
236 .iter()
237 .filter_map(|i| match i {
238 DslValue::String(s) => Some(s.clone()),
239 _ => None,
240 })
241 .collect(),
242 ),
243 _ => None,
244 })
245 .unwrap_or_default();
246 let response_format = map.get("response_format").and_then(|v| match v {
247 DslValue::String(s) => Some(s.clone()),
248 _ => None,
249 });
250 Ok((name, system, tools, response_format))
251 }
252 [DslValue::String(name), DslValue::String(system)] => {
253 Ok((name.clone(), system.clone(), Vec::new(), None))
254 }
255 [DslValue::String(name), DslValue::String(system), DslValue::List(tools)] => {
256 let tool_names = tools
257 .iter()
258 .filter_map(|t| match t {
259 DslValue::String(s) => Some(s.clone()),
260 _ => None,
261 })
262 .collect();
263 Ok((name.clone(), system.clone(), tool_names, None))
264 }
265 _ => Err(ReplError::Execution(
266 "spawn_agent requires (name: string, system: string, [tools?, response_format?])"
267 .into(),
268 )),
269 }
270}
271
272fn parse_ask_args(args: &[DslValue]) -> Result<(String, String)> {
273 match args {
274 [DslValue::String(agent), DslValue::String(message)] => {
275 Ok((agent.clone(), message.clone()))
276 }
277 [DslValue::Map(map)] => {
278 let agent = extract_string(map, "agent")?;
279 let message = extract_string(map, "message")?;
280 Ok((agent, message))
281 }
282 _ => Err(ReplError::Execution(
283 "requires (agent: string, message: string)".into(),
284 )),
285 }
286}
287
288fn parse_parallel_args(args: &[DslValue]) -> Result<Vec<(String, String)>> {
289 match args {
290 [DslValue::List(items)] => {
291 let mut tasks = Vec::new();
292 for item in items {
293 match item {
294 DslValue::Map(map) => {
295 let agent = extract_string(map, "agent")?;
296 let message = extract_string(map, "message")?;
297 tasks.push((agent, message));
298 }
299 _ => {
300 return Err(ReplError::Execution(
301 "parallel/race items must be maps with {agent, message}".into(),
302 ))
303 }
304 }
305 }
306 Ok(tasks)
307 }
308 _ => Err(ReplError::Execution(
309 "parallel/race requires a list of {agent, message} maps".into(),
310 )),
311 }
312}
313
314fn extract_string(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
315 map.get(key)
316 .and_then(|v| match v {
317 DslValue::String(s) => Some(s.clone()),
318 _ => None,
319 })
320 .ok_or_else(|| ReplError::Execution(format!("Missing required string argument '{}'", key)))
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_parse_spawn_args_named() {
329 let mut map = HashMap::new();
330 map.insert("name".into(), DslValue::String("researcher".into()));
331 map.insert("system".into(), DslValue::String("You research.".into()));
332 map.insert(
333 "tools".into(),
334 DslValue::List(vec![DslValue::String("search".into())]),
335 );
336
337 let (name, system, tools, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
338 assert_eq!(name, "researcher");
339 assert_eq!(system, "You research.");
340 assert_eq!(tools, vec!["search"]);
341 assert!(format.is_none());
342 }
343
344 #[test]
345 fn test_parse_spawn_args_positional() {
346 let args = vec![
347 DslValue::String("coder".into()),
348 DslValue::String("You code.".into()),
349 ];
350 let (name, system, tools, format) = parse_spawn_args(&args).unwrap();
351 assert_eq!(name, "coder");
352 assert_eq!(system, "You code.");
353 assert!(tools.is_empty());
354 assert!(format.is_none());
355 }
356
357 #[test]
358 fn test_parse_spawn_args_with_tools() {
359 let args = vec![
360 DslValue::String("worker".into()),
361 DslValue::String("You work.".into()),
362 DslValue::List(vec![
363 DslValue::String("read".into()),
364 DslValue::String("write".into()),
365 ]),
366 ];
367 let (name, system, tools, _) = parse_spawn_args(&args).unwrap();
368 assert_eq!(name, "worker");
369 assert_eq!(system, "You work.");
370 assert_eq!(tools, vec!["read", "write"]);
371 }
372
373 #[test]
374 fn test_parse_spawn_args_with_response_format() {
375 let mut map = HashMap::new();
376 map.insert("name".into(), DslValue::String("parser".into()));
377 map.insert("system".into(), DslValue::String("Parse data.".into()));
378 map.insert("response_format".into(), DslValue::String("json".into()));
379
380 let (_, _, _, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
381 assert_eq!(format, Some("json".into()));
382 }
383
384 #[test]
385 fn test_parse_ask_args_positional() {
386 let args = vec![
387 DslValue::String("agent1".into()),
388 DslValue::String("hello".into()),
389 ];
390 let (agent, msg) = parse_ask_args(&args).unwrap();
391 assert_eq!(agent, "agent1");
392 assert_eq!(msg, "hello");
393 }
394
395 #[test]
396 fn test_parse_ask_args_named() {
397 let mut map = HashMap::new();
398 map.insert("agent".into(), DslValue::String("bot".into()));
399 map.insert("message".into(), DslValue::String("hi".into()));
400 let (agent, msg) = parse_ask_args(&[DslValue::Map(map)]).unwrap();
401 assert_eq!(agent, "bot");
402 assert_eq!(msg, "hi");
403 }
404
405 #[test]
406 fn test_parse_parallel_args() {
407 let mut task1 = HashMap::new();
408 task1.insert("agent".into(), DslValue::String("a".into()));
409 task1.insert("message".into(), DslValue::String("m1".into()));
410
411 let mut task2 = HashMap::new();
412 task2.insert("agent".into(), DslValue::String("b".into()));
413 task2.insert("message".into(), DslValue::String("m2".into()));
414
415 let args = vec![DslValue::List(vec![
416 DslValue::Map(task1),
417 DslValue::Map(task2),
418 ])];
419 let tasks = parse_parallel_args(&args).unwrap();
420 assert_eq!(tasks.len(), 2);
421 assert_eq!(tasks[0], ("a".into(), "m1".into()));
422 assert_eq!(tasks[1], ("b".into(), "m2".into()));
423 }
424
425 #[test]
426 fn test_parse_spawn_args_missing_name() {
427 let map = HashMap::new();
428 assert!(parse_spawn_args(&[DslValue::Map(map)]).is_err());
429 }
430
431 #[test]
432 fn test_parse_ask_args_invalid() {
433 assert!(parse_ask_args(&[DslValue::Integer(42)]).is_err());
434 }
435
436 #[test]
437 fn test_parse_parallel_args_empty_list() {
438 let args = vec![DslValue::List(vec![])];
439 let tasks = parse_parallel_args(&args).unwrap();
440 assert!(tasks.is_empty());
441 }
442
443 #[test]
444 fn test_parse_parallel_args_invalid_item() {
445 let args = vec![DslValue::List(vec![DslValue::String("not a map".into())])];
446 assert!(parse_parallel_args(&args).is_err());
447 }
448}