1use crate::dsl::evaluator::DslValue;
8use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11use std::time::Duration;
12use symbi_runtime::communication::policy_gate::CommunicationRequest;
13use symbi_runtime::types::{AgentId, MessageType, RequestId};
14
15pub async fn builtin_spawn_agent(
25 args: &[DslValue],
26 ctx: &ReasoningBuiltinContext,
27) -> Result<DslValue> {
28 let registry = ctx
29 .agent_registry
30 .as_ref()
31 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
32
33 let (name, system_prompt, tools, response_format) = parse_spawn_args(args)?;
34
35 let agent_id = registry
36 .spawn_agent(&name, &system_prompt, tools, response_format)
37 .await;
38
39 let mut result = HashMap::new();
40 result.insert(
41 "agent_id".to_string(),
42 DslValue::String(agent_id.to_string()),
43 );
44 result.insert("name".to_string(), DslValue::String(name));
45 Ok(DslValue::Map(result))
46}
47
48pub async fn builtin_ask(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
56 let registry = ctx
57 .agent_registry
58 .as_ref()
59 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
60
61 let provider = ctx
62 .provider
63 .as_ref()
64 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
65
66 let (agent_name, message) = parse_ask_args(args)?;
67
68 let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
70 let sender_id = ctx.sender_agent_id.unwrap_or_default();
71 let request_id = RequestId::new();
72
73 check_comm_policy(
74 ctx,
75 sender_id,
76 recipient_id,
77 MessageType::Request(request_id),
78 )?;
79 log_comm_message(
80 ctx,
81 sender_id,
82 recipient_id,
83 &message,
84 MessageType::Request(request_id),
85 Duration::from_secs(30),
86 )
87 .await;
88
89 let response = registry
90 .ask_agent(&agent_name, &message, provider.as_ref())
91 .await
92 .map_err(|e| ReplError::Execution(format!("ask({}) failed: {}", agent_name, e)))?;
93
94 log_comm_message(
95 ctx,
96 recipient_id,
97 sender_id,
98 &response,
99 MessageType::Response(request_id),
100 Duration::from_secs(30),
101 )
102 .await;
103
104 Ok(DslValue::String(response))
105}
106
107pub async fn builtin_send_to(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
115 let registry = ctx
116 .agent_registry
117 .as_ref()
118 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
119
120 let provider = ctx
121 .provider
122 .as_ref()
123 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
124
125 let (agent_name, message) = parse_ask_args(args)?;
126
127 let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
129 let sender_id = ctx.sender_agent_id.unwrap_or_default();
130
131 check_comm_policy(
132 ctx,
133 sender_id,
134 recipient_id,
135 MessageType::Direct(recipient_id),
136 )?;
137 log_comm_message(
138 ctx,
139 sender_id,
140 recipient_id,
141 &message,
142 MessageType::Direct(recipient_id),
143 Duration::from_secs(30),
144 )
145 .await;
146
147 let registry = registry.clone();
151 let provider = provider.clone();
152 tokio::spawn(async move {
153 match registry
154 .ask_agent(&agent_name, &message, provider.as_ref())
155 .await
156 {
157 Ok(_) => {
158 tracing::debug!(
159 agent = %agent_name,
160 sender = %sender_id,
161 "send_to: background ask_agent succeeded",
162 );
163 }
164 Err(e) => {
165 tracing::warn!(
166 agent = %agent_name,
167 sender = %sender_id,
168 error = %e,
169 "send_to: background ask_agent failed",
170 );
171 }
172 }
173 });
174
175 Ok(DslValue::Null)
176}
177
178pub async fn builtin_parallel(
185 args: &[DslValue],
186 ctx: &ReasoningBuiltinContext,
187) -> Result<DslValue> {
188 let registry = ctx
189 .agent_registry
190 .as_ref()
191 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
192
193 let provider = ctx
194 .provider
195 .as_ref()
196 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
197
198 let tasks = parse_parallel_args(args)?;
199
200 let sender_id = ctx.sender_agent_id.unwrap_or_default();
202 let mut checked_tasks = Vec::new();
203 for (agent_name, message) in &tasks {
204 let recipient_id = resolve_agent_id(agent_name, ctx).await?;
205 let request_id = RequestId::new();
206 check_comm_policy(
207 ctx,
208 sender_id,
209 recipient_id,
210 MessageType::Request(request_id),
211 )?;
212 checked_tasks.push((
213 agent_name.clone(),
214 message.clone(),
215 recipient_id,
216 request_id,
217 ));
218 }
219
220 let comm_bus = ctx.comm_bus.clone();
222 let mut handles = Vec::new();
223 for (agent_name, message, recipient_id, request_id) in checked_tasks {
224 log_comm_message(
225 ctx,
226 sender_id,
227 recipient_id,
228 &message,
229 MessageType::Request(request_id),
230 Duration::from_secs(30),
231 )
232 .await;
233
234 let registry = registry.clone();
235 let provider = provider.clone();
236 let bus = comm_bus.clone();
237 handles.push(tokio::spawn(async move {
238 let result = registry
239 .ask_agent(&agent_name, &message, provider.as_ref())
240 .await
241 .map_err(|e| format!("{}", e));
242
243 if let Ok(ref response) = result {
245 if let Some(ref bus) = bus {
246 let msg = bus.create_internal_message(
247 recipient_id,
248 sender_id,
249 bytes::Bytes::from(response.clone()),
250 MessageType::Response(request_id),
251 Duration::from_secs(30),
252 );
253 if let Err(e) = bus.send_message(msg).await {
254 tracing::warn!("Failed to log inter-agent response: {}", e);
255 }
256 }
257 }
258
259 result
260 }));
261 }
262
263 let mut results = Vec::new();
264 for handle in handles {
265 match handle.await {
266 Ok(Ok(response)) => results.push(DslValue::String(response)),
267 Ok(Err(e)) => {
268 let mut error_map = HashMap::new();
269 error_map.insert("error".to_string(), DslValue::String(e));
270 results.push(DslValue::Map(error_map));
271 }
272 Err(e) => {
273 let mut error_map = HashMap::new();
274 error_map.insert("error".to_string(), DslValue::String(e.to_string()));
275 results.push(DslValue::Map(error_map));
276 }
277 }
278 }
279
280 Ok(DslValue::List(results))
281}
282
283pub async fn builtin_race(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
290 let registry = ctx
291 .agent_registry
292 .as_ref()
293 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
294
295 let provider = ctx
296 .provider
297 .as_ref()
298 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
299
300 let tasks = parse_parallel_args(args)?;
301
302 if tasks.is_empty() {
303 return Err(ReplError::Execution(
304 "race requires at least one task".into(),
305 ));
306 }
307
308 let sender_id = ctx.sender_agent_id.unwrap_or_default();
310 let mut checked_tasks = Vec::new();
311 for (agent_name, message) in &tasks {
312 let recipient_id = resolve_agent_id(agent_name, ctx).await?;
313 let request_id = RequestId::new();
314 check_comm_policy(
315 ctx,
316 sender_id,
317 recipient_id,
318 MessageType::Request(request_id),
319 )?;
320 checked_tasks.push((
321 agent_name.clone(),
322 message.clone(),
323 recipient_id,
324 request_id,
325 ));
326 }
327
328 let comm_bus = ctx.comm_bus.clone();
330 let mut join_set = tokio::task::JoinSet::new();
331 for (agent_name, message, recipient_id, request_id) in checked_tasks {
332 log_comm_message(
333 ctx,
334 sender_id,
335 recipient_id,
336 &message,
337 MessageType::Request(request_id),
338 Duration::from_secs(30),
339 )
340 .await;
341
342 let registry = registry.clone();
343 let provider = provider.clone();
344 let bus = comm_bus.clone();
345 join_set.spawn(async move {
346 let result = registry
347 .ask_agent(&agent_name, &message, provider.as_ref())
348 .await
349 .map_err(|e| format!("{}", e));
350
351 if let Ok(ref response) = result {
353 if let Some(ref bus) = bus {
354 let msg = bus.create_internal_message(
355 recipient_id,
356 sender_id,
357 bytes::Bytes::from(response.clone()),
358 MessageType::Response(request_id),
359 Duration::from_secs(30),
360 );
361 if let Err(e) = bus.send_message(msg).await {
362 tracing::warn!("Failed to log inter-agent response: {}", e);
363 }
364 }
365 }
366
367 result
368 });
369 }
370
371 match join_set.join_next().await {
373 Some(Ok(Ok(response))) => {
374 join_set.abort_all();
375 Ok(DslValue::String(response))
376 }
377 Some(Ok(Err(e))) => {
378 join_set.abort_all();
379 Err(ReplError::Execution(format!(
380 "race: first completed with error: {}",
381 e
382 )))
383 }
384 Some(Err(e)) => {
385 join_set.abort_all();
386 Err(ReplError::Execution(format!("race: task panic: {}", e)))
387 }
388 None => Err(ReplError::Execution("race: no tasks to run".into())),
389 }
390}
391
392pub(crate) async fn resolve_agent_id(name: &str, ctx: &ReasoningBuiltinContext) -> Result<AgentId> {
396 let registry = ctx
397 .agent_registry
398 .as_ref()
399 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
400
401 registry
402 .get_agent(name)
403 .await
404 .map(|agent| agent.agent_id)
405 .ok_or_else(|| ReplError::Execution(format!("Unknown agent: {}", name)))
406}
407
408pub(crate) fn check_comm_policy(
410 ctx: &ReasoningBuiltinContext,
411 sender: AgentId,
412 recipient: AgentId,
413 message_type: MessageType,
414) -> Result<()> {
415 if let Some(policy) = &ctx.comm_policy {
416 let request = CommunicationRequest {
417 sender,
418 recipient,
419 message_type,
420 topic: None,
421 };
422 policy
423 .evaluate(&request)
424 .map_err(|e| ReplError::Execution(format!("Inter-agent communication denied: {}", e)))
425 } else {
426 Ok(())
427 }
428}
429
430pub(crate) async fn log_comm_message(
432 ctx: &ReasoningBuiltinContext,
433 sender: AgentId,
434 recipient: AgentId,
435 payload: &str,
436 message_type: MessageType,
437 ttl: Duration,
438) {
439 if let Some(bus) = &ctx.comm_bus {
440 let msg = bus.create_internal_message(
441 sender,
442 recipient,
443 bytes::Bytes::from(payload.to_string()),
444 message_type,
445 ttl,
446 );
447 if let Err(e) = bus.send_message(msg).await {
448 tracing::warn!("Failed to log inter-agent message: {}", e);
449 }
450 }
451}
452
453fn parse_spawn_args(args: &[DslValue]) -> Result<(String, String, Vec<String>, Option<String>)> {
456 match args {
457 [DslValue::Map(map)] => {
458 let name = extract_string(map, "name")?;
459 let system = extract_string(map, "system")?;
460 let tools = map
461 .get("tools")
462 .and_then(|v| match v {
463 DslValue::List(items) => Some(
464 items
465 .iter()
466 .filter_map(|i| match i {
467 DslValue::String(s) => Some(s.clone()),
468 _ => None,
469 })
470 .collect(),
471 ),
472 _ => None,
473 })
474 .unwrap_or_default();
475 let response_format = map.get("response_format").and_then(|v| match v {
476 DslValue::String(s) => Some(s.clone()),
477 _ => None,
478 });
479 Ok((name, system, tools, response_format))
480 }
481 [DslValue::String(name), DslValue::String(system)] => {
482 Ok((name.clone(), system.clone(), Vec::new(), None))
483 }
484 [DslValue::String(name), DslValue::String(system), DslValue::List(tools)] => {
485 let tool_names = tools
486 .iter()
487 .filter_map(|t| match t {
488 DslValue::String(s) => Some(s.clone()),
489 _ => None,
490 })
491 .collect();
492 Ok((name.clone(), system.clone(), tool_names, None))
493 }
494 _ => Err(ReplError::Execution(
495 "spawn_agent requires (name: string, system: string, [tools?, response_format?])"
496 .into(),
497 )),
498 }
499}
500
501fn parse_ask_args(args: &[DslValue]) -> Result<(String, String)> {
502 match args {
503 [DslValue::String(agent), DslValue::String(message)] => {
504 Ok((agent.clone(), message.clone()))
505 }
506 [DslValue::Map(map)] => {
507 let agent = extract_string(map, "agent")?;
508 let message = extract_string(map, "message")?;
509 Ok((agent, message))
510 }
511 _ => Err(ReplError::Execution(
512 "requires (agent: string, message: string)".into(),
513 )),
514 }
515}
516
517const DEFAULT_MAX_PARALLEL_TASKS: usize = 32;
525
526fn max_parallel_tasks() -> usize {
527 std::env::var("SYMBIONT_MAX_PARALLEL_TASKS")
528 .ok()
529 .and_then(|v| v.parse::<usize>().ok())
530 .filter(|n| *n > 0)
531 .unwrap_or(DEFAULT_MAX_PARALLEL_TASKS)
532}
533
534fn parse_parallel_args(args: &[DslValue]) -> Result<Vec<(String, String)>> {
535 let cap = max_parallel_tasks();
536 match args {
537 [DslValue::List(items)] => {
538 if items.len() > cap {
539 return Err(ReplError::Execution(format!(
540 "parallel/race: too many tasks ({} > {}); raise SYMBIONT_MAX_PARALLEL_TASKS \
541 if intentional",
542 items.len(),
543 cap
544 )));
545 }
546 let mut tasks = Vec::new();
547 for item in items {
548 match item {
549 DslValue::Map(map) => {
550 let agent = extract_string(map, "agent")?;
551 let message = extract_string(map, "message")?;
552 tasks.push((agent, message));
553 }
554 _ => {
555 return Err(ReplError::Execution(
556 "parallel/race items must be maps with {agent, message}".into(),
557 ))
558 }
559 }
560 }
561 Ok(tasks)
562 }
563 _ => Err(ReplError::Execution(
564 "parallel/race requires a list of {agent, message} maps".into(),
565 )),
566 }
567}
568
569fn extract_string(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
570 map.get(key)
571 .and_then(|v| match v {
572 DslValue::String(s) => Some(s.clone()),
573 _ => None,
574 })
575 .ok_or_else(|| ReplError::Execution(format!("Missing required string argument '{}'", key)))
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn test_parse_spawn_args_named() {
584 let mut map = HashMap::new();
585 map.insert("name".into(), DslValue::String("researcher".into()));
586 map.insert("system".into(), DslValue::String("You research.".into()));
587 map.insert(
588 "tools".into(),
589 DslValue::List(vec![DslValue::String("search".into())]),
590 );
591
592 let (name, system, tools, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
593 assert_eq!(name, "researcher");
594 assert_eq!(system, "You research.");
595 assert_eq!(tools, vec!["search"]);
596 assert!(format.is_none());
597 }
598
599 #[test]
600 fn test_parse_spawn_args_positional() {
601 let args = vec![
602 DslValue::String("coder".into()),
603 DslValue::String("You code.".into()),
604 ];
605 let (name, system, tools, format) = parse_spawn_args(&args).unwrap();
606 assert_eq!(name, "coder");
607 assert_eq!(system, "You code.");
608 assert!(tools.is_empty());
609 assert!(format.is_none());
610 }
611
612 #[test]
613 fn test_parse_spawn_args_with_tools() {
614 let args = vec![
615 DslValue::String("worker".into()),
616 DslValue::String("You work.".into()),
617 DslValue::List(vec![
618 DslValue::String("read".into()),
619 DslValue::String("write".into()),
620 ]),
621 ];
622 let (name, system, tools, _) = parse_spawn_args(&args).unwrap();
623 assert_eq!(name, "worker");
624 assert_eq!(system, "You work.");
625 assert_eq!(tools, vec!["read", "write"]);
626 }
627
628 #[test]
629 fn test_parse_spawn_args_with_response_format() {
630 let mut map = HashMap::new();
631 map.insert("name".into(), DslValue::String("parser".into()));
632 map.insert("system".into(), DslValue::String("Parse data.".into()));
633 map.insert("response_format".into(), DslValue::String("json".into()));
634
635 let (_, _, _, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
636 assert_eq!(format, Some("json".into()));
637 }
638
639 #[test]
640 fn test_parse_ask_args_positional() {
641 let args = vec![
642 DslValue::String("agent1".into()),
643 DslValue::String("hello".into()),
644 ];
645 let (agent, msg) = parse_ask_args(&args).unwrap();
646 assert_eq!(agent, "agent1");
647 assert_eq!(msg, "hello");
648 }
649
650 #[test]
651 fn test_parse_ask_args_named() {
652 let mut map = HashMap::new();
653 map.insert("agent".into(), DslValue::String("bot".into()));
654 map.insert("message".into(), DslValue::String("hi".into()));
655 let (agent, msg) = parse_ask_args(&[DslValue::Map(map)]).unwrap();
656 assert_eq!(agent, "bot");
657 assert_eq!(msg, "hi");
658 }
659
660 #[test]
661 fn test_parse_parallel_args() {
662 let mut task1 = HashMap::new();
663 task1.insert("agent".into(), DslValue::String("a".into()));
664 task1.insert("message".into(), DslValue::String("m1".into()));
665
666 let mut task2 = HashMap::new();
667 task2.insert("agent".into(), DslValue::String("b".into()));
668 task2.insert("message".into(), DslValue::String("m2".into()));
669
670 let args = vec![DslValue::List(vec![
671 DslValue::Map(task1),
672 DslValue::Map(task2),
673 ])];
674 let tasks = parse_parallel_args(&args).unwrap();
675 assert_eq!(tasks.len(), 2);
676 assert_eq!(tasks[0], ("a".into(), "m1".into()));
677 assert_eq!(tasks[1], ("b".into(), "m2".into()));
678 }
679
680 #[test]
681 fn test_parse_spawn_args_missing_name() {
682 let map = HashMap::new();
683 assert!(parse_spawn_args(&[DslValue::Map(map)]).is_err());
684 }
685
686 #[test]
687 fn test_parse_ask_args_invalid() {
688 assert!(parse_ask_args(&[DslValue::Integer(42)]).is_err());
689 }
690
691 #[test]
692 fn test_parse_parallel_args_empty_list() {
693 let args = vec![DslValue::List(vec![])];
694 let tasks = parse_parallel_args(&args).unwrap();
695 assert!(tasks.is_empty());
696 }
697
698 #[test]
699 fn test_parse_parallel_args_invalid_item() {
700 let args = vec![DslValue::List(vec![DslValue::String("not a map".into())])];
701 assert!(parse_parallel_args(&args).is_err());
702 }
703
704 fn env_test_lock() -> std::sync::MutexGuard<'static, ()> {
707 use std::sync::{Mutex, OnceLock};
708 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
709 LOCK.get_or_init(|| Mutex::new(()))
710 .lock()
711 .unwrap_or_else(|e| e.into_inner())
712 }
713
714 #[test]
715 fn test_parse_parallel_args_rejects_oversize_list() {
716 let _g = env_test_lock();
717 std::env::remove_var("SYMBIONT_MAX_PARALLEL_TASKS");
719 let mut items = Vec::new();
720 for i in 0..(DEFAULT_MAX_PARALLEL_TASKS + 1) {
721 let mut map = HashMap::new();
722 map.insert("agent".into(), DslValue::String(format!("a{}", i)));
723 map.insert("message".into(), DslValue::String("hi".into()));
724 items.push(DslValue::Map(map));
725 }
726 let args = vec![DslValue::List(items)];
727 let err = parse_parallel_args(&args).unwrap_err();
728 let msg = format!("{}", err);
729 assert!(
730 msg.contains("too many tasks"),
731 "expected fan-out cap error, got: {}",
732 msg
733 );
734 }
735
736 #[test]
737 fn test_parse_parallel_args_env_override_allows_larger_list() {
738 let _g = env_test_lock();
739 std::env::set_var("SYMBIONT_MAX_PARALLEL_TASKS", "64");
740 let mut items = Vec::new();
741 for i in 0..40 {
742 let mut map = HashMap::new();
743 map.insert("agent".into(), DslValue::String(format!("a{}", i)));
744 map.insert("message".into(), DslValue::String("hi".into()));
745 items.push(DslValue::Map(map));
746 }
747 let args = vec![DslValue::List(items)];
748 let res = parse_parallel_args(&args);
749 std::env::remove_var("SYMBIONT_MAX_PARALLEL_TASKS");
750 assert!(res.is_ok(), "override must widen the cap");
751 }
752}