Skip to main content

repl_core/dsl/
pattern_builtins.rs

1//! Multi-agent pattern builtins for the DSL
2//!
3//! Provides convenience builtins for common multi-agent patterns:
4//! `chain`, `debate`, `map_reduce`, and `director`.
5
6use crate::dsl::evaluator::DslValue;
7use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
8use crate::error::{ReplError, Result};
9use std::collections::HashMap;
10use std::sync::Arc;
11use symbi_runtime::reasoning::conversation::{Conversation, ConversationMessage};
12use symbi_runtime::reasoning::inference::InferenceOptions;
13
14/// Execute the `chain` builtin: sequential execution where each step's
15/// output feeds into the next step's input.
16///
17/// Arguments:
18/// - steps: list of maps, each with keys: system, prompt_template (optional)
19///   OR
20/// - steps: list of strings (prompts, executed sequentially)
21///
22/// Returns the final step's output as a string.
23pub async fn builtin_chain(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
24    let provider = ctx
25        .provider
26        .as_ref()
27        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
28
29    let steps = match args.first() {
30        Some(DslValue::List(steps)) => steps.clone(),
31        Some(DslValue::Map(map)) => map
32            .get("steps")
33            .and_then(|v| match v {
34                DslValue::List(l) => Some(l.clone()),
35                _ => None,
36            })
37            .ok_or_else(|| ReplError::Execution("chain requires 'steps' as a list".into()))?,
38        _ => {
39            return Err(ReplError::Execution(
40                "chain requires a list of steps".into(),
41            ))
42        }
43    };
44
45    if steps.is_empty() {
46        return Err(ReplError::Execution(
47            "chain requires at least one step".into(),
48        ));
49    }
50
51    let mut current_output = String::new();
52    let mut results = Vec::new();
53
54    for (i, step) in steps.iter().enumerate() {
55        let (system, user_template) = match step {
56            DslValue::String(prompt) => (
57                "You are a helpful assistant. Process the input and respond.".to_string(),
58                prompt.clone(),
59            ),
60            DslValue::Map(map) => {
61                let system = map
62                    .get("system")
63                    .and_then(|v| match v {
64                        DslValue::String(s) => Some(s.clone()),
65                        _ => None,
66                    })
67                    .unwrap_or_else(|| "You are a helpful assistant.".to_string());
68                let template = map
69                    .get("prompt")
70                    .and_then(|v| match v {
71                        DslValue::String(s) => Some(s.clone()),
72                        _ => None,
73                    })
74                    .unwrap_or_else(|| "Process the following input:".to_string());
75                (system, template)
76            }
77            _ => {
78                return Err(ReplError::Execution(format!(
79                    "chain step {} must be a string or map",
80                    i
81                )))
82            }
83        };
84
85        let user_msg = if current_output.is_empty() {
86            user_template
87        } else {
88            format!("{}\n\nPrevious output:\n{}", user_template, current_output)
89        };
90
91        let mut conv = Conversation::with_system(&system);
92        conv.push(ConversationMessage::user(&user_msg));
93
94        let response = provider
95            .complete(&conv, &InferenceOptions::default())
96            .await
97            .map_err(|e| ReplError::Execution(format!("Chain step {} failed: {}", i, e)))?;
98
99        current_output = response.content.clone();
100        results.push(DslValue::String(response.content));
101    }
102
103    // Return a map with the final output and all intermediate results
104    let mut result_map = HashMap::new();
105    result_map.insert("output".to_string(), DslValue::String(current_output));
106    result_map.insert("steps".to_string(), DslValue::List(results));
107
108    Ok(DslValue::Map(result_map))
109}
110
111/// Execute the `debate` builtin: alternating two-agent critique with
112/// optional model routing and adaptive convergence.
113///
114/// Arguments (named via map):
115/// - writer_prompt: string — system prompt for the writer
116/// - critic_prompt: string — system prompt for the critic
117/// - topic: string — initial topic/content to debate
118/// - rounds: integer (optional, default 3)
119/// - writer_model: string (optional) — model for writer
120/// - critic_model: string (optional) — model for critic
121/// - convergence: string (optional) — "fixed" or "adaptive"
122///
123/// Returns a map with keys: final_answer, rounds_completed, history.
124pub async fn builtin_debate(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
125    let provider = ctx
126        .provider
127        .as_ref()
128        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
129
130    let params = match args.first() {
131        Some(DslValue::Map(map)) => map.clone(),
132        _ => {
133            return Err(ReplError::Execution(
134                "debate requires named arguments as a map".into(),
135            ))
136        }
137    };
138
139    let writer_prompt = get_string_param(&params, "writer_prompt")?;
140    let critic_prompt = get_string_param(&params, "critic_prompt")?;
141    let topic = get_string_param(&params, "topic")?;
142    let rounds = params
143        .get("rounds")
144        .and_then(|v| match v {
145            DslValue::Integer(i) => Some(*i as u32),
146            DslValue::Number(n) => Some(*n as u32),
147            _ => None,
148        })
149        .unwrap_or(3);
150
151    let mut history = Vec::new();
152    let mut current_content = topic.clone();
153
154    for round in 0..rounds {
155        // Writer phase
156        let mut writer_conv = Conversation::with_system(&writer_prompt);
157        if round == 0 {
158            writer_conv.push(ConversationMessage::user(format!(
159                "Topic: {}",
160                current_content
161            )));
162        } else {
163            writer_conv.push(ConversationMessage::user(format!(
164                "Revise your response based on this critique:\n\n{}\n\nOriginal topic: {}",
165                current_content, topic
166            )));
167        }
168
169        let writer_response = provider
170            .complete(&writer_conv, &InferenceOptions::default())
171            .await
172            .map_err(|e| {
173                ReplError::Execution(format!("Debate writer round {} failed: {}", round, e))
174            })?;
175
176        let mut round_map = HashMap::new();
177        round_map.insert("round".to_string(), DslValue::Integer(round as i64 + 1));
178        round_map.insert(
179            "writer".to_string(),
180            DslValue::String(writer_response.content.clone()),
181        );
182
183        // Critic phase
184        let mut critic_conv = Conversation::with_system(&critic_prompt);
185        critic_conv.push(ConversationMessage::user(format!(
186            "Evaluate the following response:\n\n{}",
187            writer_response.content
188        )));
189
190        let critic_response = provider
191            .complete(&critic_conv, &InferenceOptions::default())
192            .await
193            .map_err(|e| {
194                ReplError::Execution(format!("Debate critic round {} failed: {}", round, e))
195            })?;
196
197        round_map.insert(
198            "critic".to_string(),
199            DslValue::String(critic_response.content.clone()),
200        );
201        history.push(DslValue::Map(round_map));
202
203        current_content = critic_response.content;
204    }
205
206    // Final writer response incorporating all critique
207    let mut final_conv = Conversation::with_system(&writer_prompt);
208    final_conv.push(ConversationMessage::user(format!(
209        "Provide your final, refined response incorporating all critiques.\n\nLatest critique: {}\n\nOriginal topic: {}",
210        current_content, topic
211    )));
212
213    let final_response = provider
214        .complete(&final_conv, &InferenceOptions::default())
215        .await
216        .map_err(|e| ReplError::Execution(format!("Debate final response failed: {}", e)))?;
217
218    let mut result = HashMap::new();
219    result.insert(
220        "final_answer".to_string(),
221        DslValue::String(final_response.content),
222    );
223    result.insert(
224        "rounds_completed".to_string(),
225        DslValue::Integer(rounds as i64),
226    );
227    result.insert("history".to_string(), DslValue::List(history));
228
229    Ok(DslValue::Map(result))
230}
231
232/// Execute the `map_reduce` builtin: parallel fan-out + aggregate.
233///
234/// Arguments (named via map):
235/// - inputs: list — items to process
236/// - mapper: string — system prompt for the mapper
237/// - reducer: string — system prompt for the reducer
238///
239/// Returns a map with keys: result, mapped_results.
240pub async fn builtin_map_reduce(
241    args: &[DslValue],
242    ctx: &ReasoningBuiltinContext,
243) -> Result<DslValue> {
244    let provider = ctx
245        .provider
246        .as_ref()
247        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
248
249    let params = match args.first() {
250        Some(DslValue::Map(map)) => map.clone(),
251        _ => {
252            return Err(ReplError::Execution(
253                "map_reduce requires named arguments as a map".into(),
254            ))
255        }
256    };
257
258    let inputs = match params.get("inputs") {
259        Some(DslValue::List(items)) => items.clone(),
260        _ => {
261            return Err(ReplError::Execution(
262                "map_reduce requires 'inputs' as a list".into(),
263            ))
264        }
265    };
266    let mapper_prompt = get_string_param(&params, "mapper")?;
267    let reducer_prompt = get_string_param(&params, "reducer")?;
268
269    // Map phase: process each input concurrently
270    let mut map_futures = Vec::new();
271    for input in &inputs {
272        let input_str = match input {
273            DslValue::String(s) => s.clone(),
274            other => format!("{:?}", other),
275        };
276        let provider = Arc::clone(provider);
277        let mapper_prompt = mapper_prompt.clone();
278
279        map_futures.push(async move {
280            let mut conv = Conversation::with_system(&mapper_prompt);
281            conv.push(ConversationMessage::user(&input_str));
282            provider
283                .complete(&conv, &InferenceOptions::default())
284                .await
285                .map(|r| r.content)
286                .map_err(|e| ReplError::Execution(format!("Map failed: {}", e)))
287        });
288    }
289
290    let mapped_results: Vec<String> = futures::future::try_join_all(map_futures).await?;
291
292    // Reduce phase: aggregate all mapped results
293    let combined = mapped_results
294        .iter()
295        .enumerate()
296        .map(|(i, r)| format!("Result {}: {}", i + 1, r))
297        .collect::<Vec<_>>()
298        .join("\n\n");
299
300    let mut reduce_conv = Conversation::with_system(&reducer_prompt);
301    reduce_conv.push(ConversationMessage::user(format!(
302        "Aggregate the following results:\n\n{}",
303        combined
304    )));
305
306    let reduce_response = provider
307        .complete(&reduce_conv, &InferenceOptions::default())
308        .await
309        .map_err(|e| ReplError::Execution(format!("Reduce failed: {}", e)))?;
310
311    let mut result = HashMap::new();
312    result.insert(
313        "result".to_string(),
314        DslValue::String(reduce_response.content),
315    );
316    result.insert(
317        "mapped_results".to_string(),
318        DslValue::List(mapped_results.into_iter().map(DslValue::String).collect()),
319    );
320
321    Ok(DslValue::Map(result))
322}
323
324/// Execute the `director` builtin: decompose + delegate + synthesize.
325///
326/// Arguments (named via map):
327/// - orchestrator_prompt: string — system prompt for the director
328/// - workers: list of maps with {name, system} — worker agent definitions
329/// - task: string — the task to accomplish
330///
331/// Returns a map with keys: result, plan, worker_results.
332pub async fn builtin_director(
333    args: &[DslValue],
334    ctx: &ReasoningBuiltinContext,
335) -> Result<DslValue> {
336    let provider = ctx
337        .provider
338        .as_ref()
339        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
340
341    let params = match args.first() {
342        Some(DslValue::Map(map)) => map.clone(),
343        _ => {
344            return Err(ReplError::Execution(
345                "director requires named arguments as a map".into(),
346            ))
347        }
348    };
349
350    let orchestrator_prompt = get_string_param(&params, "orchestrator_prompt")?;
351    let task = get_string_param(&params, "task")?;
352
353    let workers = match params.get("workers") {
354        Some(DslValue::List(items)) => items.clone(),
355        _ => {
356            return Err(ReplError::Execution(
357                "director requires 'workers' as a list".into(),
358            ))
359        }
360    };
361
362    // Parse worker definitions
363    let worker_defs: Vec<(String, String)> = workers
364        .iter()
365        .map(|w| match w {
366            DslValue::Map(map) => {
367                let name = map
368                    .get("name")
369                    .and_then(|v| match v {
370                        DslValue::String(s) => Some(s.clone()),
371                        _ => None,
372                    })
373                    .unwrap_or_else(|| "worker".to_string());
374                let system = map
375                    .get("system")
376                    .and_then(|v| match v {
377                        DslValue::String(s) => Some(s.clone()),
378                        _ => None,
379                    })
380                    .unwrap_or_else(|| "You are a helpful assistant.".to_string());
381                Ok((name, system))
382            }
383            _ => Err(ReplError::Execution(
384                "Each worker must be a map with 'name' and 'system'".into(),
385            )),
386        })
387        .collect::<Result<Vec<_>>>()?;
388
389    // Step 1: Director creates a plan
390    let worker_names: Vec<String> = worker_defs.iter().map(|(n, _)| n.clone()).collect();
391    let mut plan_conv = Conversation::with_system(&orchestrator_prompt);
392    plan_conv.push(ConversationMessage::user(format!(
393        "Task: {}\n\nAvailable workers: {}\n\nCreate a plan assigning subtasks to each worker. Respond with a JSON object like: {{\"assignments\": [{{\"worker\": \"name\", \"subtask\": \"description\"}}]}}",
394        task,
395        worker_names.join(", ")
396    )));
397
398    let plan_options = InferenceOptions {
399        response_format: symbi_runtime::reasoning::inference::ResponseFormat::JsonObject,
400        ..Default::default()
401    };
402
403    let plan_response = provider
404        .complete(&plan_conv, &plan_options)
405        .await
406        .map_err(|e| ReplError::Execution(format!("Director planning failed: {}", e)))?;
407
408    let plan_text = plan_response.content.clone();
409
410    // Parse assignments
411    let assignments = parse_assignments(&plan_text, &worker_defs);
412
413    // Step 2: Execute worker subtasks
414    let mut worker_results = Vec::new();
415    for (worker_name, worker_system, subtask) in &assignments {
416        let mut worker_conv = Conversation::with_system(worker_system);
417        worker_conv.push(ConversationMessage::user(subtask));
418
419        let response = provider
420            .complete(&worker_conv, &InferenceOptions::default())
421            .await
422            .map_err(|e| ReplError::Execution(format!("Worker '{}' failed: {}", worker_name, e)))?;
423
424        let mut r = HashMap::new();
425        r.insert("worker".to_string(), DslValue::String(worker_name.clone()));
426        r.insert("subtask".to_string(), DslValue::String(subtask.clone()));
427        r.insert("result".to_string(), DslValue::String(response.content));
428        worker_results.push(DslValue::Map(r));
429    }
430
431    // Step 3: Director synthesizes results
432    let results_summary = worker_results
433        .iter()
434        .map(|r| match r {
435            DslValue::Map(m) => {
436                let worker = m
437                    .get("worker")
438                    .and_then(|v| match v {
439                        DslValue::String(s) => Some(s.as_str()),
440                        _ => None,
441                    })
442                    .unwrap_or("unknown");
443                let result = m
444                    .get("result")
445                    .and_then(|v| match v {
446                        DslValue::String(s) => Some(s.as_str()),
447                        _ => None,
448                    })
449                    .unwrap_or("");
450                format!("Worker '{}': {}", worker, result)
451            }
452            _ => String::new(),
453        })
454        .collect::<Vec<_>>()
455        .join("\n\n");
456
457    let mut synth_conv = Conversation::with_system(&orchestrator_prompt);
458    synth_conv.push(ConversationMessage::user(format!(
459        "Synthesize the following worker results into a final answer:\n\n{}\n\nOriginal task: {}",
460        results_summary, task
461    )));
462
463    let synth_response = provider
464        .complete(&synth_conv, &InferenceOptions::default())
465        .await
466        .map_err(|e| ReplError::Execution(format!("Director synthesis failed: {}", e)))?;
467
468    let mut result = HashMap::new();
469    result.insert(
470        "result".to_string(),
471        DslValue::String(synth_response.content),
472    );
473    result.insert("plan".to_string(), DslValue::String(plan_text));
474    result.insert("worker_results".to_string(), DslValue::List(worker_results));
475
476    Ok(DslValue::Map(result))
477}
478
479// --- Helpers ---
480
481fn get_string_param(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
482    map.get(key)
483        .and_then(|v| match v {
484            DslValue::String(s) => Some(s.clone()),
485            _ => None,
486        })
487        .ok_or_else(|| ReplError::Execution(format!("Missing required parameter '{}'", key)))
488}
489
490fn parse_assignments(
491    plan_text: &str,
492    worker_defs: &[(String, String)],
493) -> Vec<(String, String, String)> {
494    // Try to parse as JSON
495    if let Ok(plan_json) = serde_json::from_str::<serde_json::Value>(plan_text) {
496        if let Some(assignments) = plan_json["assignments"].as_array() {
497            return assignments
498                .iter()
499                .filter_map(|a| {
500                    let worker = a["worker"].as_str()?;
501                    let subtask = a["subtask"].as_str()?;
502                    let system = worker_defs
503                        .iter()
504                        .find(|(n, _)| n == worker)
505                        .map(|(_, s)| s.clone())
506                        .unwrap_or_else(|| "You are a helpful assistant.".to_string());
507                    Some((worker.to_string(), system, subtask.to_string()))
508                })
509                .collect();
510        }
511    }
512
513    // Fallback: assign the entire task to each worker
514    worker_defs
515        .iter()
516        .map(|(name, system)| {
517            (
518                name.clone(),
519                system.clone(),
520                format!("Complete this task: {}", plan_text),
521            )
522        })
523        .collect()
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn test_parse_assignments_valid_json() {
532        let plan = r#"{"assignments": [{"worker": "researcher", "subtask": "Find data"}, {"worker": "writer", "subtask": "Write report"}]}"#;
533        let workers = vec![
534            ("researcher".to_string(), "Research system".to_string()),
535            ("writer".to_string(), "Writer system".to_string()),
536        ];
537
538        let assignments = parse_assignments(plan, &workers);
539        assert_eq!(assignments.len(), 2);
540        assert_eq!(assignments[0].0, "researcher");
541        assert_eq!(assignments[0].2, "Find data");
542        assert_eq!(assignments[1].0, "writer");
543        assert_eq!(assignments[1].2, "Write report");
544    }
545
546    #[test]
547    fn test_parse_assignments_fallback() {
548        let plan = "This is not JSON";
549        let workers = vec![
550            ("a".to_string(), "System A".to_string()),
551            ("b".to_string(), "System B".to_string()),
552        ];
553
554        let assignments = parse_assignments(plan, &workers);
555        assert_eq!(assignments.len(), 2);
556        assert!(assignments[0].2.contains("This is not JSON"));
557    }
558
559    #[test]
560    fn test_get_string_param() {
561        let mut map = HashMap::new();
562        map.insert("key".into(), DslValue::String("value".into()));
563
564        assert_eq!(get_string_param(&map, "key").unwrap(), "value");
565        assert!(get_string_param(&map, "missing").is_err());
566    }
567
568    #[test]
569    fn test_get_string_param_wrong_type() {
570        let mut map = HashMap::new();
571        map.insert("key".into(), DslValue::Integer(42));
572
573        assert!(get_string_param(&map, "key").is_err());
574    }
575}