Skip to main content

repl_core/dsl/
agent_composition.rs

1//! Agent composition builtins for the DSL
2//!
3//! Provides async builtins for spawning agents, sending messages,
4//! and executing concurrent patterns: `spawn_agent`, `ask`, `send_to`,
5//! `parallel`, and `race`.
6
7use crate::dsl::evaluator::DslValue;
8use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11use std::time::Duration;
12use symbi_runtime::communication::policy_gate::CommunicationRequest;
13use symbi_runtime::types::{AgentId, MessageType, RequestId};
14
15/// Execute the `spawn_agent` builtin: register a new named agent.
16///
17/// Arguments (named via map or positional):
18/// - name: string — agent name
19/// - system: string — system prompt
20/// - tools: list of strings (optional)
21/// - response_format: string (optional)
22///
23/// Returns a map with `agent_id` and `name`.
24pub async fn builtin_spawn_agent(
25    args: &[DslValue],
26    ctx: &ReasoningBuiltinContext,
27) -> Result<DslValue> {
28    let registry = ctx
29        .agent_registry
30        .as_ref()
31        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
32
33    let (name, system_prompt, tools, response_format) = parse_spawn_args(args)?;
34
35    let agent_id = registry
36        .spawn_agent(&name, &system_prompt, tools, response_format)
37        .await;
38
39    let mut result = HashMap::new();
40    result.insert(
41        "agent_id".to_string(),
42        DslValue::String(agent_id.to_string()),
43    );
44    result.insert("name".to_string(), DslValue::String(name));
45    Ok(DslValue::Map(result))
46}
47
48/// Execute the `ask` builtin: send a message to a named agent and wait for response.
49///
50/// Arguments:
51/// - agent: string — agent name
52/// - message: string
53///
54/// Returns the agent's response as a string.
55pub async fn builtin_ask(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
56    let registry = ctx
57        .agent_registry
58        .as_ref()
59        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
60
61    let provider = ctx
62        .provider
63        .as_ref()
64        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
65
66    let (agent_name, message) = parse_ask_args(args)?;
67
68    // Communication bus wiring: policy check + message logging
69    let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
70    let sender_id = ctx.sender_agent_id.unwrap_or_default();
71    let request_id = RequestId::new();
72
73    check_comm_policy(
74        ctx,
75        sender_id,
76        recipient_id,
77        MessageType::Request(request_id),
78    )?;
79    log_comm_message(
80        ctx,
81        sender_id,
82        recipient_id,
83        &message,
84        MessageType::Request(request_id),
85        Duration::from_secs(30),
86    )
87    .await;
88
89    let response = registry
90        .ask_agent(&agent_name, &message, provider.as_ref())
91        .await
92        .map_err(|e| ReplError::Execution(format!("ask({}) failed: {}", agent_name, e)))?;
93
94    log_comm_message(
95        ctx,
96        recipient_id,
97        sender_id,
98        &response,
99        MessageType::Response(request_id),
100        Duration::from_secs(30),
101    )
102    .await;
103
104    Ok(DslValue::String(response))
105}
106
107/// Execute the `send_to` builtin: fire-and-forget message to a named agent.
108///
109/// Arguments:
110/// - agent: string — agent name
111/// - message: string
112///
113/// Returns null (fire-and-forget).
114pub async fn builtin_send_to(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
115    let registry = ctx
116        .agent_registry
117        .as_ref()
118        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
119
120    let provider = ctx
121        .provider
122        .as_ref()
123        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
124
125    let (agent_name, message) = parse_ask_args(args)?;
126
127    // Communication bus wiring: policy check + message logging
128    let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
129    let sender_id = ctx.sender_agent_id.unwrap_or_default();
130
131    check_comm_policy(
132        ctx,
133        sender_id,
134        recipient_id,
135        MessageType::Direct(recipient_id),
136    )?;
137    log_comm_message(
138        ctx,
139        sender_id,
140        recipient_id,
141        &message,
142        MessageType::Direct(recipient_id),
143        Duration::from_secs(30),
144    )
145    .await;
146
147    // Fire-and-forget: spawn a background task
148    let registry = registry.clone();
149    let provider = provider.clone();
150    tokio::spawn(async move {
151        let _ = registry
152            .ask_agent(&agent_name, &message, provider.as_ref())
153            .await;
154    });
155
156    Ok(DslValue::Null)
157}
158
159/// Execute the `parallel` builtin: run multiple agent calls concurrently.
160///
161/// Arguments:
162/// - tasks: list of maps, each with `{agent: string, message: string}`
163///
164/// Returns a list of results (strings or error maps).
165pub async fn builtin_parallel(
166    args: &[DslValue],
167    ctx: &ReasoningBuiltinContext,
168) -> Result<DslValue> {
169    let registry = ctx
170        .agent_registry
171        .as_ref()
172        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
173
174    let provider = ctx
175        .provider
176        .as_ref()
177        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
178
179    let tasks = parse_parallel_args(args)?;
180
181    // Pre-spawn policy checks: all must pass before any task is spawned
182    let sender_id = ctx.sender_agent_id.unwrap_or_default();
183    let mut checked_tasks = Vec::new();
184    for (agent_name, message) in &tasks {
185        let recipient_id = resolve_agent_id(agent_name, ctx).await?;
186        let request_id = RequestId::new();
187        check_comm_policy(
188            ctx,
189            sender_id,
190            recipient_id,
191            MessageType::Request(request_id),
192        )?;
193        checked_tasks.push((
194            agent_name.clone(),
195            message.clone(),
196            recipient_id,
197            request_id,
198        ));
199    }
200
201    // All checks passed — log outbound messages and spawn tasks
202    let comm_bus = ctx.comm_bus.clone();
203    let mut handles = Vec::new();
204    for (agent_name, message, recipient_id, request_id) in checked_tasks {
205        log_comm_message(
206            ctx,
207            sender_id,
208            recipient_id,
209            &message,
210            MessageType::Request(request_id),
211            Duration::from_secs(30),
212        )
213        .await;
214
215        let registry = registry.clone();
216        let provider = provider.clone();
217        let bus = comm_bus.clone();
218        handles.push(tokio::spawn(async move {
219            let result = registry
220                .ask_agent(&agent_name, &message, provider.as_ref())
221                .await
222                .map_err(|e| format!("{}", e));
223
224            // Log response via cloned bus
225            if let Ok(ref response) = result {
226                if let Some(ref bus) = bus {
227                    let msg = bus.create_internal_message(
228                        recipient_id,
229                        sender_id,
230                        bytes::Bytes::from(response.clone()),
231                        MessageType::Response(request_id),
232                        Duration::from_secs(30),
233                    );
234                    if let Err(e) = bus.send_message(msg).await {
235                        tracing::warn!("Failed to log inter-agent response: {}", e);
236                    }
237                }
238            }
239
240            result
241        }));
242    }
243
244    let mut results = Vec::new();
245    for handle in handles {
246        match handle.await {
247            Ok(Ok(response)) => results.push(DslValue::String(response)),
248            Ok(Err(e)) => {
249                let mut error_map = HashMap::new();
250                error_map.insert("error".to_string(), DslValue::String(e));
251                results.push(DslValue::Map(error_map));
252            }
253            Err(e) => {
254                let mut error_map = HashMap::new();
255                error_map.insert("error".to_string(), DslValue::String(e.to_string()));
256                results.push(DslValue::Map(error_map));
257            }
258        }
259    }
260
261    Ok(DslValue::List(results))
262}
263
264/// Execute the `race` builtin: run multiple agent calls, return first to complete.
265///
266/// Arguments:
267/// - tasks: list of maps, each with `{agent: string, message: string}`
268///
269/// Returns the first successful result as a string.
270pub async fn builtin_race(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
271    let registry = ctx
272        .agent_registry
273        .as_ref()
274        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
275
276    let provider = ctx
277        .provider
278        .as_ref()
279        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
280
281    let tasks = parse_parallel_args(args)?;
282
283    if tasks.is_empty() {
284        return Err(ReplError::Execution(
285            "race requires at least one task".into(),
286        ));
287    }
288
289    // Pre-spawn policy checks: all must pass before any task is spawned
290    let sender_id = ctx.sender_agent_id.unwrap_or_default();
291    let mut checked_tasks = Vec::new();
292    for (agent_name, message) in &tasks {
293        let recipient_id = resolve_agent_id(agent_name, ctx).await?;
294        let request_id = RequestId::new();
295        check_comm_policy(
296            ctx,
297            sender_id,
298            recipient_id,
299            MessageType::Request(request_id),
300        )?;
301        checked_tasks.push((
302            agent_name.clone(),
303            message.clone(),
304            recipient_id,
305            request_id,
306        ));
307    }
308
309    // All checks passed — log outbound messages and spawn tasks
310    let comm_bus = ctx.comm_bus.clone();
311    let mut join_set = tokio::task::JoinSet::new();
312    for (agent_name, message, recipient_id, request_id) in checked_tasks {
313        log_comm_message(
314            ctx,
315            sender_id,
316            recipient_id,
317            &message,
318            MessageType::Request(request_id),
319            Duration::from_secs(30),
320        )
321        .await;
322
323        let registry = registry.clone();
324        let provider = provider.clone();
325        let bus = comm_bus.clone();
326        join_set.spawn(async move {
327            let result = registry
328                .ask_agent(&agent_name, &message, provider.as_ref())
329                .await
330                .map_err(|e| format!("{}", e));
331
332            // Log response via cloned bus
333            if let Ok(ref response) = result {
334                if let Some(ref bus) = bus {
335                    let msg = bus.create_internal_message(
336                        recipient_id,
337                        sender_id,
338                        bytes::Bytes::from(response.clone()),
339                        MessageType::Response(request_id),
340                        Duration::from_secs(30),
341                    );
342                    if let Err(e) = bus.send_message(msg).await {
343                        tracing::warn!("Failed to log inter-agent response: {}", e);
344                    }
345                }
346            }
347
348            result
349        });
350    }
351
352    // Return the first completed result
353    match join_set.join_next().await {
354        Some(Ok(Ok(response))) => {
355            join_set.abort_all();
356            Ok(DslValue::String(response))
357        }
358        Some(Ok(Err(e))) => {
359            join_set.abort_all();
360            Err(ReplError::Execution(format!(
361                "race: first completed with error: {}",
362                e
363            )))
364        }
365        Some(Err(e)) => {
366            join_set.abort_all();
367            Err(ReplError::Execution(format!("race: task panic: {}", e)))
368        }
369        None => Err(ReplError::Execution("race: no tasks to run".into())),
370    }
371}
372
373// --- Communication helpers ---
374
375/// Resolve an agent name to its AgentId via the registry.
376pub(crate) async fn resolve_agent_id(name: &str, ctx: &ReasoningBuiltinContext) -> Result<AgentId> {
377    let registry = ctx
378        .agent_registry
379        .as_ref()
380        .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
381
382    registry
383        .get_agent(name)
384        .await
385        .map(|agent| agent.agent_id)
386        .ok_or_else(|| ReplError::Execution(format!("Unknown agent: {}", name)))
387}
388
389/// Check communication policy. Returns Ok(()) if allowed or if no policy gate is configured.
390pub(crate) fn check_comm_policy(
391    ctx: &ReasoningBuiltinContext,
392    sender: AgentId,
393    recipient: AgentId,
394    message_type: MessageType,
395) -> Result<()> {
396    if let Some(policy) = &ctx.comm_policy {
397        let request = CommunicationRequest {
398            sender,
399            recipient,
400            message_type,
401            topic: None,
402        };
403        policy
404            .evaluate(&request)
405            .map_err(|e| ReplError::Execution(format!("Inter-agent communication denied: {}", e)))
406    } else {
407        Ok(())
408    }
409}
410
411/// Log an outbound message via the CommunicationBus. Best-effort (errors logged, not propagated).
412pub(crate) async fn log_comm_message(
413    ctx: &ReasoningBuiltinContext,
414    sender: AgentId,
415    recipient: AgentId,
416    payload: &str,
417    message_type: MessageType,
418    ttl: Duration,
419) {
420    if let Some(bus) = &ctx.comm_bus {
421        let msg = bus.create_internal_message(
422            sender,
423            recipient,
424            bytes::Bytes::from(payload.to_string()),
425            message_type,
426            ttl,
427        );
428        if let Err(e) = bus.send_message(msg).await {
429            tracing::warn!("Failed to log inter-agent message: {}", e);
430        }
431    }
432}
433
434// --- Argument parsing helpers ---
435
436fn parse_spawn_args(args: &[DslValue]) -> Result<(String, String, Vec<String>, Option<String>)> {
437    match args {
438        [DslValue::Map(map)] => {
439            let name = extract_string(map, "name")?;
440            let system = extract_string(map, "system")?;
441            let tools = map
442                .get("tools")
443                .and_then(|v| match v {
444                    DslValue::List(items) => Some(
445                        items
446                            .iter()
447                            .filter_map(|i| match i {
448                                DslValue::String(s) => Some(s.clone()),
449                                _ => None,
450                            })
451                            .collect(),
452                    ),
453                    _ => None,
454                })
455                .unwrap_or_default();
456            let response_format = map.get("response_format").and_then(|v| match v {
457                DslValue::String(s) => Some(s.clone()),
458                _ => None,
459            });
460            Ok((name, system, tools, response_format))
461        }
462        [DslValue::String(name), DslValue::String(system)] => {
463            Ok((name.clone(), system.clone(), Vec::new(), None))
464        }
465        [DslValue::String(name), DslValue::String(system), DslValue::List(tools)] => {
466            let tool_names = tools
467                .iter()
468                .filter_map(|t| match t {
469                    DslValue::String(s) => Some(s.clone()),
470                    _ => None,
471                })
472                .collect();
473            Ok((name.clone(), system.clone(), tool_names, None))
474        }
475        _ => Err(ReplError::Execution(
476            "spawn_agent requires (name: string, system: string, [tools?, response_format?])"
477                .into(),
478        )),
479    }
480}
481
482fn parse_ask_args(args: &[DslValue]) -> Result<(String, String)> {
483    match args {
484        [DslValue::String(agent), DslValue::String(message)] => {
485            Ok((agent.clone(), message.clone()))
486        }
487        [DslValue::Map(map)] => {
488            let agent = extract_string(map, "agent")?;
489            let message = extract_string(map, "message")?;
490            Ok((agent, message))
491        }
492        _ => Err(ReplError::Execution(
493            "requires (agent: string, message: string)".into(),
494        )),
495    }
496}
497
498fn parse_parallel_args(args: &[DslValue]) -> Result<Vec<(String, String)>> {
499    match args {
500        [DslValue::List(items)] => {
501            let mut tasks = Vec::new();
502            for item in items {
503                match item {
504                    DslValue::Map(map) => {
505                        let agent = extract_string(map, "agent")?;
506                        let message = extract_string(map, "message")?;
507                        tasks.push((agent, message));
508                    }
509                    _ => {
510                        return Err(ReplError::Execution(
511                            "parallel/race items must be maps with {agent, message}".into(),
512                        ))
513                    }
514                }
515            }
516            Ok(tasks)
517        }
518        _ => Err(ReplError::Execution(
519            "parallel/race requires a list of {agent, message} maps".into(),
520        )),
521    }
522}
523
524fn extract_string(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
525    map.get(key)
526        .and_then(|v| match v {
527            DslValue::String(s) => Some(s.clone()),
528            _ => None,
529        })
530        .ok_or_else(|| ReplError::Execution(format!("Missing required string argument '{}'", key)))
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_parse_spawn_args_named() {
539        let mut map = HashMap::new();
540        map.insert("name".into(), DslValue::String("researcher".into()));
541        map.insert("system".into(), DslValue::String("You research.".into()));
542        map.insert(
543            "tools".into(),
544            DslValue::List(vec![DslValue::String("search".into())]),
545        );
546
547        let (name, system, tools, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
548        assert_eq!(name, "researcher");
549        assert_eq!(system, "You research.");
550        assert_eq!(tools, vec!["search"]);
551        assert!(format.is_none());
552    }
553
554    #[test]
555    fn test_parse_spawn_args_positional() {
556        let args = vec![
557            DslValue::String("coder".into()),
558            DslValue::String("You code.".into()),
559        ];
560        let (name, system, tools, format) = parse_spawn_args(&args).unwrap();
561        assert_eq!(name, "coder");
562        assert_eq!(system, "You code.");
563        assert!(tools.is_empty());
564        assert!(format.is_none());
565    }
566
567    #[test]
568    fn test_parse_spawn_args_with_tools() {
569        let args = vec![
570            DslValue::String("worker".into()),
571            DslValue::String("You work.".into()),
572            DslValue::List(vec![
573                DslValue::String("read".into()),
574                DslValue::String("write".into()),
575            ]),
576        ];
577        let (name, system, tools, _) = parse_spawn_args(&args).unwrap();
578        assert_eq!(name, "worker");
579        assert_eq!(system, "You work.");
580        assert_eq!(tools, vec!["read", "write"]);
581    }
582
583    #[test]
584    fn test_parse_spawn_args_with_response_format() {
585        let mut map = HashMap::new();
586        map.insert("name".into(), DslValue::String("parser".into()));
587        map.insert("system".into(), DslValue::String("Parse data.".into()));
588        map.insert("response_format".into(), DslValue::String("json".into()));
589
590        let (_, _, _, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
591        assert_eq!(format, Some("json".into()));
592    }
593
594    #[test]
595    fn test_parse_ask_args_positional() {
596        let args = vec![
597            DslValue::String("agent1".into()),
598            DslValue::String("hello".into()),
599        ];
600        let (agent, msg) = parse_ask_args(&args).unwrap();
601        assert_eq!(agent, "agent1");
602        assert_eq!(msg, "hello");
603    }
604
605    #[test]
606    fn test_parse_ask_args_named() {
607        let mut map = HashMap::new();
608        map.insert("agent".into(), DslValue::String("bot".into()));
609        map.insert("message".into(), DslValue::String("hi".into()));
610        let (agent, msg) = parse_ask_args(&[DslValue::Map(map)]).unwrap();
611        assert_eq!(agent, "bot");
612        assert_eq!(msg, "hi");
613    }
614
615    #[test]
616    fn test_parse_parallel_args() {
617        let mut task1 = HashMap::new();
618        task1.insert("agent".into(), DslValue::String("a".into()));
619        task1.insert("message".into(), DslValue::String("m1".into()));
620
621        let mut task2 = HashMap::new();
622        task2.insert("agent".into(), DslValue::String("b".into()));
623        task2.insert("message".into(), DslValue::String("m2".into()));
624
625        let args = vec![DslValue::List(vec![
626            DslValue::Map(task1),
627            DslValue::Map(task2),
628        ])];
629        let tasks = parse_parallel_args(&args).unwrap();
630        assert_eq!(tasks.len(), 2);
631        assert_eq!(tasks[0], ("a".into(), "m1".into()));
632        assert_eq!(tasks[1], ("b".into(), "m2".into()));
633    }
634
635    #[test]
636    fn test_parse_spawn_args_missing_name() {
637        let map = HashMap::new();
638        assert!(parse_spawn_args(&[DslValue::Map(map)]).is_err());
639    }
640
641    #[test]
642    fn test_parse_ask_args_invalid() {
643        assert!(parse_ask_args(&[DslValue::Integer(42)]).is_err());
644    }
645
646    #[test]
647    fn test_parse_parallel_args_empty_list() {
648        let args = vec![DslValue::List(vec![])];
649        let tasks = parse_parallel_args(&args).unwrap();
650        assert!(tasks.is_empty());
651    }
652
653    #[test]
654    fn test_parse_parallel_args_invalid_item() {
655        let args = vec![DslValue::List(vec![DslValue::String("not a map".into())])];
656        assert!(parse_parallel_args(&args).is_err());
657    }
658}