1use crate::dsl::evaluator::DslValue;
7use crate::dsl::reasoning_builtins::ReasoningBuiltinContext;
8use crate::error::{ReplError, Result};
9use std::collections::HashMap;
10use std::sync::Arc;
11use symbi_runtime::reasoning::conversation::{Conversation, ConversationMessage};
12use symbi_runtime::reasoning::inference::InferenceOptions;
13
14pub async fn builtin_chain(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
24 let provider = ctx
25 .provider
26 .as_ref()
27 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
28
29 let steps = match args.first() {
30 Some(DslValue::List(steps)) => steps.clone(),
31 Some(DslValue::Map(map)) => map
32 .get("steps")
33 .and_then(|v| match v {
34 DslValue::List(l) => Some(l.clone()),
35 _ => None,
36 })
37 .ok_or_else(|| ReplError::Execution("chain requires 'steps' as a list".into()))?,
38 _ => {
39 return Err(ReplError::Execution(
40 "chain requires a list of steps".into(),
41 ))
42 }
43 };
44
45 if steps.is_empty() {
46 return Err(ReplError::Execution(
47 "chain requires at least one step".into(),
48 ));
49 }
50
51 let mut current_output = String::new();
52 let mut results = Vec::new();
53
54 for (i, step) in steps.iter().enumerate() {
55 let (system, user_template) = match step {
56 DslValue::String(prompt) => (
57 "You are a helpful assistant. Process the input and respond.".to_string(),
58 prompt.clone(),
59 ),
60 DslValue::Map(map) => {
61 let system = map
62 .get("system")
63 .and_then(|v| match v {
64 DslValue::String(s) => Some(s.clone()),
65 _ => None,
66 })
67 .unwrap_or_else(|| "You are a helpful assistant.".to_string());
68 let template = map
69 .get("prompt")
70 .and_then(|v| match v {
71 DslValue::String(s) => Some(s.clone()),
72 _ => None,
73 })
74 .unwrap_or_else(|| "Process the following input:".to_string());
75 (system, template)
76 }
77 _ => {
78 return Err(ReplError::Execution(format!(
79 "chain step {} must be a string or map",
80 i
81 )))
82 }
83 };
84
85 let user_msg = if current_output.is_empty() {
86 user_template
87 } else {
88 format!("{}\n\nPrevious output:\n{}", user_template, current_output)
89 };
90
91 let mut conv = Conversation::with_system(&system);
92 conv.push(ConversationMessage::user(&user_msg));
93
94 let response = provider
95 .complete(&conv, &InferenceOptions::default())
96 .await
97 .map_err(|e| ReplError::Execution(format!("Chain step {} failed: {}", i, e)))?;
98
99 current_output = response.content.clone();
100 results.push(DslValue::String(response.content));
101 }
102
103 let mut result_map = HashMap::new();
105 result_map.insert("output".to_string(), DslValue::String(current_output));
106 result_map.insert("steps".to_string(), DslValue::List(results));
107
108 Ok(DslValue::Map(result_map))
109}
110
111pub async fn builtin_debate(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
125 let provider = ctx
126 .provider
127 .as_ref()
128 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
129
130 let params = match args.first() {
131 Some(DslValue::Map(map)) => map.clone(),
132 _ => {
133 return Err(ReplError::Execution(
134 "debate requires named arguments as a map".into(),
135 ))
136 }
137 };
138
139 let writer_prompt = get_string_param(¶ms, "writer_prompt")?;
140 let critic_prompt = get_string_param(¶ms, "critic_prompt")?;
141 let topic = get_string_param(¶ms, "topic")?;
142 let rounds = params
143 .get("rounds")
144 .and_then(|v| match v {
145 DslValue::Integer(i) => Some(*i as u32),
146 DslValue::Number(n) => Some(*n as u32),
147 _ => None,
148 })
149 .unwrap_or(3);
150
151 let mut history = Vec::new();
152 let mut current_content = topic.clone();
153
154 for round in 0..rounds {
155 let mut writer_conv = Conversation::with_system(&writer_prompt);
157 if round == 0 {
158 writer_conv.push(ConversationMessage::user(format!(
159 "Topic: {}",
160 current_content
161 )));
162 } else {
163 writer_conv.push(ConversationMessage::user(format!(
164 "Revise your response based on this critique:\n\n{}\n\nOriginal topic: {}",
165 current_content, topic
166 )));
167 }
168
169 let writer_response = provider
170 .complete(&writer_conv, &InferenceOptions::default())
171 .await
172 .map_err(|e| {
173 ReplError::Execution(format!("Debate writer round {} failed: {}", round, e))
174 })?;
175
176 let mut round_map = HashMap::new();
177 round_map.insert("round".to_string(), DslValue::Integer(round as i64 + 1));
178 round_map.insert(
179 "writer".to_string(),
180 DslValue::String(writer_response.content.clone()),
181 );
182
183 let mut critic_conv = Conversation::with_system(&critic_prompt);
185 critic_conv.push(ConversationMessage::user(format!(
186 "Evaluate the following response:\n\n{}",
187 writer_response.content
188 )));
189
190 let critic_response = provider
191 .complete(&critic_conv, &InferenceOptions::default())
192 .await
193 .map_err(|e| {
194 ReplError::Execution(format!("Debate critic round {} failed: {}", round, e))
195 })?;
196
197 round_map.insert(
198 "critic".to_string(),
199 DslValue::String(critic_response.content.clone()),
200 );
201 history.push(DslValue::Map(round_map));
202
203 current_content = critic_response.content;
204 }
205
206 let mut final_conv = Conversation::with_system(&writer_prompt);
208 final_conv.push(ConversationMessage::user(format!(
209 "Provide your final, refined response incorporating all critiques.\n\nLatest critique: {}\n\nOriginal topic: {}",
210 current_content, topic
211 )));
212
213 let final_response = provider
214 .complete(&final_conv, &InferenceOptions::default())
215 .await
216 .map_err(|e| ReplError::Execution(format!("Debate final response failed: {}", e)))?;
217
218 let mut result = HashMap::new();
219 result.insert(
220 "final_answer".to_string(),
221 DslValue::String(final_response.content),
222 );
223 result.insert(
224 "rounds_completed".to_string(),
225 DslValue::Integer(rounds as i64),
226 );
227 result.insert("history".to_string(), DslValue::List(history));
228
229 Ok(DslValue::Map(result))
230}
231
232pub async fn builtin_map_reduce(
241 args: &[DslValue],
242 ctx: &ReasoningBuiltinContext,
243) -> Result<DslValue> {
244 let provider = ctx
245 .provider
246 .as_ref()
247 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
248
249 let params = match args.first() {
250 Some(DslValue::Map(map)) => map.clone(),
251 _ => {
252 return Err(ReplError::Execution(
253 "map_reduce requires named arguments as a map".into(),
254 ))
255 }
256 };
257
258 let inputs = match params.get("inputs") {
259 Some(DslValue::List(items)) => items.clone(),
260 _ => {
261 return Err(ReplError::Execution(
262 "map_reduce requires 'inputs' as a list".into(),
263 ))
264 }
265 };
266 let mapper_prompt = get_string_param(¶ms, "mapper")?;
267 let reducer_prompt = get_string_param(¶ms, "reducer")?;
268
269 let mut map_futures = Vec::new();
271 for input in &inputs {
272 let input_str = match input {
273 DslValue::String(s) => s.clone(),
274 other => format!("{:?}", other),
275 };
276 let provider = Arc::clone(provider);
277 let mapper_prompt = mapper_prompt.clone();
278
279 map_futures.push(async move {
280 let mut conv = Conversation::with_system(&mapper_prompt);
281 conv.push(ConversationMessage::user(&input_str));
282 provider
283 .complete(&conv, &InferenceOptions::default())
284 .await
285 .map(|r| r.content)
286 .map_err(|e| ReplError::Execution(format!("Map failed: {}", e)))
287 });
288 }
289
290 let mapped_results: Vec<String> = futures::future::try_join_all(map_futures).await?;
291
292 let combined = mapped_results
294 .iter()
295 .enumerate()
296 .map(|(i, r)| format!("Result {}: {}", i + 1, r))
297 .collect::<Vec<_>>()
298 .join("\n\n");
299
300 let mut reduce_conv = Conversation::with_system(&reducer_prompt);
301 reduce_conv.push(ConversationMessage::user(format!(
302 "Aggregate the following results:\n\n{}",
303 combined
304 )));
305
306 let reduce_response = provider
307 .complete(&reduce_conv, &InferenceOptions::default())
308 .await
309 .map_err(|e| ReplError::Execution(format!("Reduce failed: {}", e)))?;
310
311 let mut result = HashMap::new();
312 result.insert(
313 "result".to_string(),
314 DslValue::String(reduce_response.content),
315 );
316 result.insert(
317 "mapped_results".to_string(),
318 DslValue::List(mapped_results.into_iter().map(DslValue::String).collect()),
319 );
320
321 Ok(DslValue::Map(result))
322}
323
324pub async fn builtin_director(
333 args: &[DslValue],
334 ctx: &ReasoningBuiltinContext,
335) -> Result<DslValue> {
336 let provider = ctx
337 .provider
338 .as_ref()
339 .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
340
341 let params = match args.first() {
342 Some(DslValue::Map(map)) => map.clone(),
343 _ => {
344 return Err(ReplError::Execution(
345 "director requires named arguments as a map".into(),
346 ))
347 }
348 };
349
350 let orchestrator_prompt = get_string_param(¶ms, "orchestrator_prompt")?;
351 let task = get_string_param(¶ms, "task")?;
352
353 let workers = match params.get("workers") {
354 Some(DslValue::List(items)) => items.clone(),
355 _ => {
356 return Err(ReplError::Execution(
357 "director requires 'workers' as a list".into(),
358 ))
359 }
360 };
361
362 let worker_defs: Vec<(String, String)> = workers
364 .iter()
365 .map(|w| match w {
366 DslValue::Map(map) => {
367 let name = map
368 .get("name")
369 .and_then(|v| match v {
370 DslValue::String(s) => Some(s.clone()),
371 _ => None,
372 })
373 .unwrap_or_else(|| "worker".to_string());
374 let system = map
375 .get("system")
376 .and_then(|v| match v {
377 DslValue::String(s) => Some(s.clone()),
378 _ => None,
379 })
380 .unwrap_or_else(|| "You are a helpful assistant.".to_string());
381 Ok((name, system))
382 }
383 _ => Err(ReplError::Execution(
384 "Each worker must be a map with 'name' and 'system'".into(),
385 )),
386 })
387 .collect::<Result<Vec<_>>>()?;
388
389 let worker_names: Vec<String> = worker_defs.iter().map(|(n, _)| n.clone()).collect();
391 let mut plan_conv = Conversation::with_system(&orchestrator_prompt);
392 plan_conv.push(ConversationMessage::user(format!(
393 "Task: {}\n\nAvailable workers: {}\n\nCreate a plan assigning subtasks to each worker. Respond with a JSON object like: {{\"assignments\": [{{\"worker\": \"name\", \"subtask\": \"description\"}}]}}",
394 task,
395 worker_names.join(", ")
396 )));
397
398 let plan_options = InferenceOptions {
399 response_format: symbi_runtime::reasoning::inference::ResponseFormat::JsonObject,
400 ..Default::default()
401 };
402
403 let plan_response = provider
404 .complete(&plan_conv, &plan_options)
405 .await
406 .map_err(|e| ReplError::Execution(format!("Director planning failed: {}", e)))?;
407
408 let plan_text = plan_response.content.clone();
409
410 let assignments = parse_assignments(&plan_text, &worker_defs);
412
413 let mut worker_results = Vec::new();
415 for (worker_name, worker_system, subtask) in &assignments {
416 let mut worker_conv = Conversation::with_system(worker_system);
417 worker_conv.push(ConversationMessage::user(subtask));
418
419 let response = provider
420 .complete(&worker_conv, &InferenceOptions::default())
421 .await
422 .map_err(|e| ReplError::Execution(format!("Worker '{}' failed: {}", worker_name, e)))?;
423
424 let mut r = HashMap::new();
425 r.insert("worker".to_string(), DslValue::String(worker_name.clone()));
426 r.insert("subtask".to_string(), DslValue::String(subtask.clone()));
427 r.insert("result".to_string(), DslValue::String(response.content));
428 worker_results.push(DslValue::Map(r));
429 }
430
431 let results_summary = worker_results
433 .iter()
434 .map(|r| match r {
435 DslValue::Map(m) => {
436 let worker = m
437 .get("worker")
438 .and_then(|v| match v {
439 DslValue::String(s) => Some(s.as_str()),
440 _ => None,
441 })
442 .unwrap_or("unknown");
443 let result = m
444 .get("result")
445 .and_then(|v| match v {
446 DslValue::String(s) => Some(s.as_str()),
447 _ => None,
448 })
449 .unwrap_or("");
450 format!("Worker '{}': {}", worker, result)
451 }
452 _ => String::new(),
453 })
454 .collect::<Vec<_>>()
455 .join("\n\n");
456
457 let mut synth_conv = Conversation::with_system(&orchestrator_prompt);
458 synth_conv.push(ConversationMessage::user(format!(
459 "Synthesize the following worker results into a final answer:\n\n{}\n\nOriginal task: {}",
460 results_summary, task
461 )));
462
463 let synth_response = provider
464 .complete(&synth_conv, &InferenceOptions::default())
465 .await
466 .map_err(|e| ReplError::Execution(format!("Director synthesis failed: {}", e)))?;
467
468 let mut result = HashMap::new();
469 result.insert(
470 "result".to_string(),
471 DslValue::String(synth_response.content),
472 );
473 result.insert("plan".to_string(), DslValue::String(plan_text));
474 result.insert("worker_results".to_string(), DslValue::List(worker_results));
475
476 Ok(DslValue::Map(result))
477}
478
479fn get_string_param(map: &HashMap<String, DslValue>, key: &str) -> Result<String> {
482 map.get(key)
483 .and_then(|v| match v {
484 DslValue::String(s) => Some(s.clone()),
485 _ => None,
486 })
487 .ok_or_else(|| ReplError::Execution(format!("Missing required parameter '{}'", key)))
488}
489
490fn parse_assignments(
491 plan_text: &str,
492 worker_defs: &[(String, String)],
493) -> Vec<(String, String, String)> {
494 if let Ok(plan_json) = serde_json::from_str::<serde_json::Value>(plan_text) {
496 if let Some(assignments) = plan_json["assignments"].as_array() {
497 return assignments
498 .iter()
499 .filter_map(|a| {
500 let worker = a["worker"].as_str()?;
501 let subtask = a["subtask"].as_str()?;
502 let system = worker_defs
503 .iter()
504 .find(|(n, _)| n == worker)
505 .map(|(_, s)| s.clone())
506 .unwrap_or_else(|| "You are a helpful assistant.".to_string());
507 Some((worker.to_string(), system, subtask.to_string()))
508 })
509 .collect();
510 }
511 }
512
513 worker_defs
515 .iter()
516 .map(|(name, system)| {
517 (
518 name.clone(),
519 system.clone(),
520 format!("Complete this task: {}", plan_text),
521 )
522 })
523 .collect()
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn test_parse_assignments_valid_json() {
532 let plan = r#"{"assignments": [{"worker": "researcher", "subtask": "Find data"}, {"worker": "writer", "subtask": "Write report"}]}"#;
533 let workers = vec![
534 ("researcher".to_string(), "Research system".to_string()),
535 ("writer".to_string(), "Writer system".to_string()),
536 ];
537
538 let assignments = parse_assignments(plan, &workers);
539 assert_eq!(assignments.len(), 2);
540 assert_eq!(assignments[0].0, "researcher");
541 assert_eq!(assignments[0].2, "Find data");
542 assert_eq!(assignments[1].0, "writer");
543 assert_eq!(assignments[1].2, "Write report");
544 }
545
546 #[test]
547 fn test_parse_assignments_fallback() {
548 let plan = "This is not JSON";
549 let workers = vec![
550 ("a".to_string(), "System A".to_string()),
551 ("b".to_string(), "System B".to_string()),
552 ];
553
554 let assignments = parse_assignments(plan, &workers);
555 assert_eq!(assignments.len(), 2);
556 assert!(assignments[0].2.contains("This is not JSON"));
557 }
558
559 #[test]
560 fn test_get_string_param() {
561 let mut map = HashMap::new();
562 map.insert("key".into(), DslValue::String("value".into()));
563
564 assert_eq!(get_string_param(&map, "key").unwrap(), "value");
565 assert!(get_string_param(&map, "missing").is_err());
566 }
567
568 #[test]
569 fn test_get_string_param_wrong_type() {
570 let mut map = HashMap::new();
571 map.insert("key".into(), DslValue::Integer(42));
572
573 assert!(get_string_param(&map, "key").is_err());
574 }
575}