1use crate::dsl::evaluator::DslValue;
8use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11use std::time::Duration;
12use symbi_runtime::communication::policy_gate::CommunicationRequest;
13use symbi_runtime::types::{AgentId, MessageType, RequestId};
14
15pub async fn builtin_spawn_agent(
25 args: &[DslValue],
26 ctx: &ReasoningBuiltinContext,
27) -> Result<DslValue> {
28 let registry = ctx
29 .agent_registry
30 .as_ref()
31 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
32
33 let (name, system_prompt, tools, response_format) = parse_spawn_args(args)?;
34
35 let agent_id = registry
36 .spawn_agent(&name, &system_prompt, tools, response_format)
37 .await;
38
39 let mut result = HashMap::new();
40 result.insert(
41 "agent_id".to_string(),
42 DslValue::String(agent_id.to_string()),
43 );
44 result.insert("name".to_string(), DslValue::String(name));
45 Ok(DslValue::Map(result))
46}
47
48pub async fn builtin_ask(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
56 let registry = ctx
57 .agent_registry
58 .as_ref()
59 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
60
61 let provider = ctx
62 .provider
63 .as_ref()
64 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
65
66 let (agent_name, message) = parse_ask_args(args)?;
67
68 let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
70 let sender_id = ctx.sender_agent_id.unwrap_or_default();
71 let request_id = RequestId::new();
72
73 check_comm_policy(
74 ctx,
75 sender_id,
76 recipient_id,
77 MessageType::Request(request_id),
78 )?;
79 log_comm_message(
80 ctx,
81 sender_id,
82 recipient_id,
83 &message,
84 MessageType::Request(request_id),
85 Duration::from_secs(30),
86 )
87 .await;
88
89 let response = registry
90 .ask_agent(&agent_name, &message, provider.as_ref())
91 .await
92 .map_err(|e| ReplError::Execution(format!("ask({}) failed: {}", agent_name, e)))?;
93
94 log_comm_message(
95 ctx,
96 recipient_id,
97 sender_id,
98 &response,
99 MessageType::Response(request_id),
100 Duration::from_secs(30),
101 )
102 .await;
103
104 Ok(DslValue::String(response))
105}
106
107pub async fn builtin_send_to(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
115 let registry = ctx
116 .agent_registry
117 .as_ref()
118 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
119
120 let provider = ctx
121 .provider
122 .as_ref()
123 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
124
125 let (agent_name, message) = parse_ask_args(args)?;
126
127 let recipient_id = resolve_agent_id(&agent_name, ctx).await?;
129 let sender_id = ctx.sender_agent_id.unwrap_or_default();
130
131 check_comm_policy(
132 ctx,
133 sender_id,
134 recipient_id,
135 MessageType::Direct(recipient_id),
136 )?;
137 log_comm_message(
138 ctx,
139 sender_id,
140 recipient_id,
141 &message,
142 MessageType::Direct(recipient_id),
143 Duration::from_secs(30),
144 )
145 .await;
146
147 let registry = registry.clone();
149 let provider = provider.clone();
150 tokio::spawn(async move {
151 let _ = registry
152 .ask_agent(&agent_name, &message, provider.as_ref())
153 .await;
154 });
155
156 Ok(DslValue::Null)
157}
158
159pub async fn builtin_parallel(
166 args: &[DslValue],
167 ctx: &ReasoningBuiltinContext,
168) -> Result<DslValue> {
169 let registry = ctx
170 .agent_registry
171 .as_ref()
172 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
173
174 let provider = ctx
175 .provider
176 .as_ref()
177 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
178
179 let tasks = parse_parallel_args(args)?;
180
181 let sender_id = ctx.sender_agent_id.unwrap_or_default();
183 let mut checked_tasks = Vec::new();
184 for (agent_name, message) in &tasks {
185 let recipient_id = resolve_agent_id(agent_name, ctx).await?;
186 let request_id = RequestId::new();
187 check_comm_policy(
188 ctx,
189 sender_id,
190 recipient_id,
191 MessageType::Request(request_id),
192 )?;
193 checked_tasks.push((
194 agent_name.clone(),
195 message.clone(),
196 recipient_id,
197 request_id,
198 ));
199 }
200
201 let comm_bus = ctx.comm_bus.clone();
203 let mut handles = Vec::new();
204 for (agent_name, message, recipient_id, request_id) in checked_tasks {
205 log_comm_message(
206 ctx,
207 sender_id,
208 recipient_id,
209 &message,
210 MessageType::Request(request_id),
211 Duration::from_secs(30),
212 )
213 .await;
214
215 let registry = registry.clone();
216 let provider = provider.clone();
217 let bus = comm_bus.clone();
218 handles.push(tokio::spawn(async move {
219 let result = registry
220 .ask_agent(&agent_name, &message, provider.as_ref())
221 .await
222 .map_err(|e| format!("{}", e));
223
224 if let Ok(ref response) = result {
226 if let Some(ref bus) = bus {
227 let msg = bus.create_internal_message(
228 recipient_id,
229 sender_id,
230 bytes::Bytes::from(response.clone()),
231 MessageType::Response(request_id),
232 Duration::from_secs(30),
233 );
234 if let Err(e) = bus.send_message(msg).await {
235 tracing::warn!("Failed to log inter-agent response: {}", e);
236 }
237 }
238 }
239
240 result
241 }));
242 }
243
244 let mut results = Vec::new();
245 for handle in handles {
246 match handle.await {
247 Ok(Ok(response)) => results.push(DslValue::String(response)),
248 Ok(Err(e)) => {
249 let mut error_map = HashMap::new();
250 error_map.insert("error".to_string(), DslValue::String(e));
251 results.push(DslValue::Map(error_map));
252 }
253 Err(e) => {
254 let mut error_map = HashMap::new();
255 error_map.insert("error".to_string(), DslValue::String(e.to_string()));
256 results.push(DslValue::Map(error_map));
257 }
258 }
259 }
260
261 Ok(DslValue::List(results))
262}
263
264pub async fn builtin_race(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
271 let registry = ctx
272 .agent_registry
273 .as_ref()
274 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
275
276 let provider = ctx
277 .provider
278 .as_ref()
279 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
280
281 let tasks = parse_parallel_args(args)?;
282
283 if tasks.is_empty() {
284 return Err(ReplError::Execution(
285 "race requires at least one task".into(),
286 ));
287 }
288
289 let sender_id = ctx.sender_agent_id.unwrap_or_default();
291 let mut checked_tasks = Vec::new();
292 for (agent_name, message) in &tasks {
293 let recipient_id = resolve_agent_id(agent_name, ctx).await?;
294 let request_id = RequestId::new();
295 check_comm_policy(
296 ctx,
297 sender_id,
298 recipient_id,
299 MessageType::Request(request_id),
300 )?;
301 checked_tasks.push((
302 agent_name.clone(),
303 message.clone(),
304 recipient_id,
305 request_id,
306 ));
307 }
308
309 let comm_bus = ctx.comm_bus.clone();
311 let mut join_set = tokio::task::JoinSet::new();
312 for (agent_name, message, recipient_id, request_id) in checked_tasks {
313 log_comm_message(
314 ctx,
315 sender_id,
316 recipient_id,
317 &message,
318 MessageType::Request(request_id),
319 Duration::from_secs(30),
320 )
321 .await;
322
323 let registry = registry.clone();
324 let provider = provider.clone();
325 let bus = comm_bus.clone();
326 join_set.spawn(async move {
327 let result = registry
328 .ask_agent(&agent_name, &message, provider.as_ref())
329 .await
330 .map_err(|e| format!("{}", e));
331
332 if let Ok(ref response) = result {
334 if let Some(ref bus) = bus {
335 let msg = bus.create_internal_message(
336 recipient_id,
337 sender_id,
338 bytes::Bytes::from(response.clone()),
339 MessageType::Response(request_id),
340 Duration::from_secs(30),
341 );
342 if let Err(e) = bus.send_message(msg).await {
343 tracing::warn!("Failed to log inter-agent response: {}", e);
344 }
345 }
346 }
347
348 result
349 });
350 }
351
352 match join_set.join_next().await {
354 Some(Ok(Ok(response))) => {
355 join_set.abort_all();
356 Ok(DslValue::String(response))
357 }
358 Some(Ok(Err(e))) => {
359 join_set.abort_all();
360 Err(ReplError::Execution(format!(
361 "race: first completed with error: {}",
362 e
363 )))
364 }
365 Some(Err(e)) => {
366 join_set.abort_all();
367 Err(ReplError::Execution(format!("race: task panic: {}", e)))
368 }
369 None => Err(ReplError::Execution("race: no tasks to run".into())),
370 }
371}
372
373pub(crate) async fn resolve_agent_id(name: &str, ctx: &ReasoningBuiltinContext) -> Result<AgentId> {
377 let registry = ctx
378 .agent_registry
379 .as_ref()
380 .ok_or_else(|| ReplError::Execution("No agent registry configured".into()))?;
381
382 registry
383 .get_agent(name)
384 .await
385 .map(|agent| agent.agent_id)
386 .ok_or_else(|| ReplError::Execution(format!("Unknown agent: {}", name)))
387}
388
389pub(crate) fn check_comm_policy(
391 ctx: &ReasoningBuiltinContext,
392 sender: AgentId,
393 recipient: AgentId,
394 message_type: MessageType,
395) -> Result<()> {
396 if let Some(policy) = &ctx.comm_policy {
397 let request = CommunicationRequest {
398 sender,
399 recipient,
400 message_type,
401 topic: None,
402 };
403 policy
404 .evaluate(&request)
405 .map_err(|e| ReplError::Execution(format!("Inter-agent communication denied: {}", e)))
406 } else {
407 Ok(())
408 }
409}
410
411pub(crate) async fn log_comm_message(
413 ctx: &ReasoningBuiltinContext,
414 sender: AgentId,
415 recipient: AgentId,
416 payload: &str,
417 message_type: MessageType,
418 ttl: Duration,
419) {
420 if let Some(bus) = &ctx.comm_bus {
421 let msg = bus.create_internal_message(
422 sender,
423 recipient,
424 bytes::Bytes::from(payload.to_string()),
425 message_type,
426 ttl,
427 );
428 if let Err(e) = bus.send_message(msg).await {
429 tracing::warn!("Failed to log inter-agent message: {}", e);
430 }
431 }
432}
433
434fn parse_spawn_args(args: &[DslValue]) -> Result<(String, String, Vec<String>, Option<String>)> {
437 match args {
438 [DslValue::Map(map)] => {
439 let name = extract_string(map, "name")?;
440 let system = extract_string(map, "system")?;
441 let tools = map
442 .get("tools")
443 .and_then(|v| match v {
444 DslValue::List(items) => Some(
445 items
446 .iter()
447 .filter_map(|i| match i {
448 DslValue::String(s) => Some(s.clone()),
449 _ => None,
450 })
451 .collect(),
452 ),
453 _ => None,
454 })
455 .unwrap_or_default();
456 let response_format = map.get("response_format").and_then(|v| match v {
457 DslValue::String(s) => Some(s.clone()),
458 _ => None,
459 });
460 Ok((name, system, tools, response_format))
461 }
462 [DslValue::String(name), DslValue::String(system)] => {
463 Ok((name.clone(), system.clone(), Vec::new(), None))
464 }
465 [DslValue::String(name), DslValue::String(system), DslValue::List(tools)] => {
466 let tool_names = tools
467 .iter()
468 .filter_map(|t| match t {
469 DslValue::String(s) => Some(s.clone()),
470 _ => None,
471 })
472 .collect();
473 Ok((name.clone(), system.clone(), tool_names, None))
474 }
475 _ => Err(ReplError::Execution(
476 "spawn_agent requires (name: string, system: string, [tools?, response_format?])"
477 .into(),
478 )),
479 }
480}
481
482fn parse_ask_args(args: &[DslValue]) -> Result<(String, String)> {
483 match args {
484 [DslValue::String(agent), DslValue::String(message)] => {
485 Ok((agent.clone(), message.clone()))
486 }
487 [DslValue::Map(map)] => {
488 let agent = extract_string(map, "agent")?;
489 let message = extract_string(map, "message")?;
490 Ok((agent, message))
491 }
492 _ => Err(ReplError::Execution(
493 "requires (agent: string, message: string)".into(),
494 )),
495 }
496}
497
498fn parse_parallel_args(args: &[DslValue]) -> Result<Vec<(String, String)>> {
499 match args {
500 [DslValue::List(items)] => {
501 let mut tasks = Vec::new();
502 for item in items {
503 match item {
504 DslValue::Map(map) => {
505 let agent = extract_string(map, "agent")?;
506 let message = extract_string(map, "message")?;
507 tasks.push((agent, message));
508 }
509 _ => {
510 return Err(ReplError::Execution(
511 "parallel/race items must be maps with {agent, message}".into(),
512 ))
513 }
514 }
515 }
516 Ok(tasks)
517 }
518 _ => Err(ReplError::Execution(
519 "parallel/race requires a list of {agent, message} maps".into(),
520 )),
521 }
522}
523
524fn extract_string(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
525 map.get(key)
526 .and_then(|v| match v {
527 DslValue::String(s) => Some(s.clone()),
528 _ => None,
529 })
530 .ok_or_else(|| ReplError::Execution(format!("Missing required string argument '{}'", key)))
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_parse_spawn_args_named() {
539 let mut map = HashMap::new();
540 map.insert("name".into(), DslValue::String("researcher".into()));
541 map.insert("system".into(), DslValue::String("You research.".into()));
542 map.insert(
543 "tools".into(),
544 DslValue::List(vec![DslValue::String("search".into())]),
545 );
546
547 let (name, system, tools, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
548 assert_eq!(name, "researcher");
549 assert_eq!(system, "You research.");
550 assert_eq!(tools, vec!["search"]);
551 assert!(format.is_none());
552 }
553
554 #[test]
555 fn test_parse_spawn_args_positional() {
556 let args = vec![
557 DslValue::String("coder".into()),
558 DslValue::String("You code.".into()),
559 ];
560 let (name, system, tools, format) = parse_spawn_args(&args).unwrap();
561 assert_eq!(name, "coder");
562 assert_eq!(system, "You code.");
563 assert!(tools.is_empty());
564 assert!(format.is_none());
565 }
566
567 #[test]
568 fn test_parse_spawn_args_with_tools() {
569 let args = vec![
570 DslValue::String("worker".into()),
571 DslValue::String("You work.".into()),
572 DslValue::List(vec![
573 DslValue::String("read".into()),
574 DslValue::String("write".into()),
575 ]),
576 ];
577 let (name, system, tools, _) = parse_spawn_args(&args).unwrap();
578 assert_eq!(name, "worker");
579 assert_eq!(system, "You work.");
580 assert_eq!(tools, vec!["read", "write"]);
581 }
582
583 #[test]
584 fn test_parse_spawn_args_with_response_format() {
585 let mut map = HashMap::new();
586 map.insert("name".into(), DslValue::String("parser".into()));
587 map.insert("system".into(), DslValue::String("Parse data.".into()));
588 map.insert("response_format".into(), DslValue::String("json".into()));
589
590 let (_, _, _, format) = parse_spawn_args(&[DslValue::Map(map)]).unwrap();
591 assert_eq!(format, Some("json".into()));
592 }
593
594 #[test]
595 fn test_parse_ask_args_positional() {
596 let args = vec![
597 DslValue::String("agent1".into()),
598 DslValue::String("hello".into()),
599 ];
600 let (agent, msg) = parse_ask_args(&args).unwrap();
601 assert_eq!(agent, "agent1");
602 assert_eq!(msg, "hello");
603 }
604
605 #[test]
606 fn test_parse_ask_args_named() {
607 let mut map = HashMap::new();
608 map.insert("agent".into(), DslValue::String("bot".into()));
609 map.insert("message".into(), DslValue::String("hi".into()));
610 let (agent, msg) = parse_ask_args(&[DslValue::Map(map)]).unwrap();
611 assert_eq!(agent, "bot");
612 assert_eq!(msg, "hi");
613 }
614
615 #[test]
616 fn test_parse_parallel_args() {
617 let mut task1 = HashMap::new();
618 task1.insert("agent".into(), DslValue::String("a".into()));
619 task1.insert("message".into(), DslValue::String("m1".into()));
620
621 let mut task2 = HashMap::new();
622 task2.insert("agent".into(), DslValue::String("b".into()));
623 task2.insert("message".into(), DslValue::String("m2".into()));
624
625 let args = vec![DslValue::List(vec![
626 DslValue::Map(task1),
627 DslValue::Map(task2),
628 ])];
629 let tasks = parse_parallel_args(&args).unwrap();
630 assert_eq!(tasks.len(), 2);
631 assert_eq!(tasks[0], ("a".into(), "m1".into()));
632 assert_eq!(tasks[1], ("b".into(), "m2".into()));
633 }
634
635 #[test]
636 fn test_parse_spawn_args_missing_name() {
637 let map = HashMap::new();
638 assert!(parse_spawn_args(&[DslValue::Map(map)]).is_err());
639 }
640
641 #[test]
642 fn test_parse_ask_args_invalid() {
643 assert!(parse_ask_args(&[DslValue::Integer(42)]).is_err());
644 }
645
646 #[test]
647 fn test_parse_parallel_args_empty_list() {
648 let args = vec![DslValue::List(vec![])];
649 let tasks = parse_parallel_args(&args).unwrap();
650 assert!(tasks.is_empty());
651 }
652
653 #[test]
654 fn test_parse_parallel_args_invalid_item() {
655 let args = vec![DslValue::List(vec![DslValue::String("not a map".into())])];
656 assert!(parse_parallel_args(&args).is_err());
657 }
658}