Skip to main content

repl_core/dsl/
reasoning_builtins.rs

1//! Core reasoning builtins for the DSL
2//!
3//! Provides async builtin functions that bridge the DSL with the
4//! reasoning loop infrastructure: `reason`, `llm_call`, `parse_json`,
5//! `delegate`, and `tool_call`.
6
7use crate::dsl::agent_composition::{check_comm_policy, log_comm_message};
8use crate::dsl::evaluator::DslValue;
9use crate::error::{ReplError, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use symbi_runtime::communication::policy_gate::CommunicationPolicyGate;
14use symbi_runtime::communication::CommunicationBus;
15use symbi_runtime::reasoning::agent_registry::AgentRegistry;
16use symbi_runtime::reasoning::inference::InferenceProvider;
17use symbi_runtime::reasoning::policy_bridge::ReasoningPolicyGate;
18use symbi_runtime::types::{AgentId, MessageType, RequestId};
19
20/// Shared state for async reasoning builtins.
21#[derive(Clone, Default)]
22pub struct ReasoningBuiltinContext {
23    /// Inference provider for LLM calls.
24    pub provider: Option<Arc<dyn InferenceProvider>>,
25    /// Agent registry for multi-agent composition.
26    pub agent_registry: Option<Arc<AgentRegistry>>,
27    /// The AgentId of the calling agent (for communication tracking).
28    pub sender_agent_id: Option<AgentId>,
29    /// Communication bus for message tracking and audit.
30    pub comm_bus: Option<Arc<dyn CommunicationBus + Send + Sync>>,
31    /// Communication policy gate for inter-agent authorization.
32    pub comm_policy: Option<Arc<CommunicationPolicyGate>>,
33    /// Policy gate for the reasoning loop — governs tool calls and
34    /// delegations inside `reason()`. If None, `DefaultPolicyGate::new()`
35    /// is used (production-default, non-permissive). Production callers
36    /// should install [`OpaPolicyGateBridge`] or another concrete gate
37    /// instead of relying on the default.
38    pub reasoning_policy_gate: Option<Arc<dyn ReasoningPolicyGate + Send + Sync>>,
39}
40
41/// Execute the `reason` builtin: runs a full reasoning loop.
42///
43/// Arguments (positional or named):
44/// - system: string — system prompt
45/// - user: string — user message
46/// - max_iterations: integer (optional, default 10)
47/// - max_tokens: integer (optional, default 100000)
48///
49/// Returns a map with keys: response, iterations, total_tokens, termination_reason.
50pub async fn builtin_reason(args: &[DslValue], ctx: &ReasoningBuiltinContext) -> Result<DslValue> {
51    let provider = ctx
52        .provider
53        .as_ref()
54        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
55
56    let (system, user, max_iterations, max_tokens) = parse_reason_args(args)?;
57
58    use symbi_runtime::reasoning::circuit_breaker::CircuitBreakerRegistry;
59    use symbi_runtime::reasoning::context_manager::DefaultContextManager;
60    use symbi_runtime::reasoning::conversation::{Conversation, ConversationMessage};
61    use symbi_runtime::reasoning::executor::DefaultActionExecutor;
62    use symbi_runtime::reasoning::loop_types::{BufferedJournal, LoopConfig};
63    use symbi_runtime::reasoning::policy_bridge::DefaultPolicyGate;
64    use symbi_runtime::reasoning::reasoning_loop::ReasoningLoopRunner;
65
66    // Prefer a caller-provided policy gate (e.g. OpaPolicyGateBridge wired
67    // from the runtime). Fall back to the non-permissive default rather
68    // than `DefaultPolicyGate::permissive()` so `reason()` no longer opts
69    // every DSL program into unrestricted tool calls regardless of how
70    // the runtime was configured.
71    let policy_gate: Arc<dyn ReasoningPolicyGate + Send + Sync> =
72        match ctx.reasoning_policy_gate.clone() {
73            Some(gate) => gate,
74            None => Arc::new(DefaultPolicyGate::new()),
75        };
76
77    let runner = ReasoningLoopRunner {
78        provider: Arc::clone(provider),
79        policy_gate,
80        executor: Arc::new(DefaultActionExecutor::default()),
81        context_manager: Arc::new(DefaultContextManager::default()),
82        circuit_breakers: Arc::new(CircuitBreakerRegistry::default()),
83        journal: Arc::new(BufferedJournal::new(1000)),
84        knowledge_bridge: None,
85    };
86
87    let mut conv = Conversation::with_system(&system);
88    conv.push(ConversationMessage::user(&user));
89
90    let config = LoopConfig {
91        max_iterations,
92        max_total_tokens: max_tokens,
93        ..Default::default()
94    };
95
96    let result = runner.run(AgentId::new(), conv, config).await;
97
98    let mut map = HashMap::new();
99    map.insert("response".to_string(), DslValue::String(result.output));
100    map.insert(
101        "iterations".to_string(),
102        DslValue::Integer(result.iterations as i64),
103    );
104    map.insert(
105        "total_tokens".to_string(),
106        DslValue::Integer(result.total_usage.total_tokens as i64),
107    );
108    map.insert(
109        "termination_reason".to_string(),
110        DslValue::String(format!("{:?}", result.termination_reason)),
111    );
112
113    Ok(DslValue::Map(map))
114}
115
116/// Execute the `llm_call` builtin: one-shot LLM call.
117///
118/// Arguments:
119/// - prompt: string — the prompt to send
120/// - model: string (optional) — model override
121/// - temperature: number (optional)
122/// - max_tokens: integer (optional)
123///
124/// Returns a string.
125pub async fn builtin_llm_call(
126    args: &[DslValue],
127    ctx: &ReasoningBuiltinContext,
128) -> Result<DslValue> {
129    let provider = ctx
130        .provider
131        .as_ref()
132        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
133
134    let prompt = match args.first() {
135        Some(DslValue::String(s)) => s.clone(),
136        Some(DslValue::Map(map)) => map
137            .get("prompt")
138            .and_then(|v| match v {
139                DslValue::String(s) => Some(s.clone()),
140                _ => None,
141            })
142            .ok_or_else(|| ReplError::Execution("llm_call requires 'prompt' argument".into()))?,
143        _ => {
144            return Err(ReplError::Execution(
145                "llm_call requires a string prompt".into(),
146            ))
147        }
148    };
149
150    use symbi_runtime::reasoning::conversation::{Conversation, ConversationMessage};
151    use symbi_runtime::reasoning::inference::InferenceOptions;
152
153    let mut conv = Conversation::new();
154    conv.push(ConversationMessage::user(&prompt));
155
156    let options = InferenceOptions::default();
157    let response = provider
158        .complete(&conv, &options)
159        .await
160        .map_err(|e| ReplError::Execution(format!("LLM call failed: {}", e)))?;
161
162    Ok(DslValue::String(response.content))
163}
164
165/// Execute the `parse_json` builtin: parse a string as JSON.
166///
167/// Arguments:
168/// - text: string — the JSON text to parse
169///
170/// Returns a DslValue (Map, List, String, Number, Boolean, or Null).
171pub fn builtin_parse_json(args: &[DslValue]) -> Result<DslValue> {
172    let text = match args.first() {
173        Some(DslValue::String(s)) => s,
174        _ => {
175            return Err(ReplError::Execution(
176                "parse_json requires a string argument".into(),
177            ))
178        }
179    };
180
181    let value: serde_json::Value = serde_json::from_str(text)
182        .map_err(|e| ReplError::Execution(format!("JSON parse error: {}", e)))?;
183
184    Ok(json_to_dsl_value(&value))
185}
186
187/// Execute the `tool_call` builtin: explicit tool invocation.
188///
189/// Arguments:
190/// - name: string — tool name
191/// - args: map — tool arguments
192///
193/// Returns the tool result as a string.
194pub async fn builtin_tool_call(
195    args: &[DslValue],
196    _ctx: &ReasoningBuiltinContext,
197) -> Result<DslValue> {
198    let (name, arguments) = match args {
199        [DslValue::String(name), DslValue::Map(args_map)] => {
200            let json_args: serde_json::Map<String, serde_json::Value> = args_map
201                .iter()
202                .map(|(k, v)| (k.clone(), v.to_json()))
203                .collect();
204            (
205                name.clone(),
206                serde_json::Value::Object(json_args).to_string(),
207            )
208        }
209        [DslValue::String(name), DslValue::String(args_str)] => (name.clone(), args_str.clone()),
210        [DslValue::String(name)] => (name.clone(), "{}".to_string()),
211        _ => {
212            return Err(ReplError::Execution(
213                "tool_call requires (name: string, args?: map|string)".into(),
214            ))
215        }
216    };
217
218    // In a full setup, this would go through ToolInvocationEnforcer.
219    // For now, return a structured result indicating the tool call was made.
220    let mut result = HashMap::new();
221    result.insert("tool".to_string(), DslValue::String(name));
222    result.insert("arguments".to_string(), DslValue::String(arguments));
223    result.insert(
224        "status".to_string(),
225        DslValue::String("executed".to_string()),
226    );
227
228    Ok(DslValue::Map(result))
229}
230
231/// Execute the `delegate` builtin: send a message to another agent.
232///
233/// Arguments:
234/// - agent: string — agent name
235/// - message: string — message to send
236/// - timeout: duration (optional)
237///
238/// Returns the agent's response as a string.
239pub async fn builtin_delegate(
240    args: &[DslValue],
241    ctx: &ReasoningBuiltinContext,
242) -> Result<DslValue> {
243    let (agent_name, message) = match args {
244        [DslValue::String(agent), DslValue::String(msg)] => (agent.clone(), msg.clone()),
245        [DslValue::Map(map)] => {
246            let agent = map
247                .get("agent")
248                .and_then(|v| match v {
249                    DslValue::String(s) => Some(s.clone()),
250                    _ => None,
251                })
252                .ok_or_else(|| ReplError::Execution("delegate requires 'agent' argument".into()))?;
253            let msg = map
254                .get("message")
255                .and_then(|v| match v {
256                    DslValue::String(s) => Some(s.clone()),
257                    _ => None,
258                })
259                .ok_or_else(|| {
260                    ReplError::Execution("delegate requires 'message' argument".into())
261                })?;
262            (agent, msg)
263        }
264        _ => {
265            return Err(ReplError::Execution(
266                "delegate requires (agent: string, message: string)".into(),
267            ))
268        }
269    };
270
271    // Communication bus wiring: resolve recipient (fallback for unregistered agents)
272    let recipient_id = if let Some(registry) = &ctx.agent_registry {
273        registry
274            .get_agent(&agent_name)
275            .await
276            .map(|a| a.agent_id)
277            .unwrap_or_default()
278    } else {
279        AgentId::new()
280    };
281    let sender_id = ctx.sender_agent_id.unwrap_or_default();
282    let request_id = RequestId::new();
283
284    check_comm_policy(
285        ctx,
286        sender_id,
287        recipient_id,
288        MessageType::Request(request_id),
289    )?;
290    log_comm_message(
291        ctx,
292        sender_id,
293        recipient_id,
294        &message,
295        MessageType::Request(request_id),
296        Duration::from_secs(30),
297    )
298    .await;
299
300    // Use inference provider to simulate delegation (each agent is a separate conversation)
301    let provider = ctx
302        .provider
303        .as_ref()
304        .ok_or_else(|| ReplError::Execution("No inference provider configured".into()))?;
305
306    use symbi_runtime::reasoning::conversation::{Conversation, ConversationMessage};
307    use symbi_runtime::reasoning::inference::InferenceOptions;
308
309    let mut conv = Conversation::with_system(format!(
310        "You are agent '{}'. Respond to the delegated task.",
311        agent_name
312    ));
313    conv.push(ConversationMessage::user(&message));
314
315    let response = provider
316        .complete(&conv, &InferenceOptions::default())
317        .await
318        .map_err(|e| {
319            ReplError::Execution(format!("Delegation to '{}' failed: {}", agent_name, e))
320        })?;
321
322    log_comm_message(
323        ctx,
324        recipient_id,
325        sender_id,
326        &response.content,
327        MessageType::Response(request_id),
328        Duration::from_secs(30),
329    )
330    .await;
331
332    Ok(DslValue::String(response.content))
333}
334
335// --- Helper functions ---
336
337fn parse_reason_args(args: &[DslValue]) -> Result<(String, String, u32, u32)> {
338    match args {
339        // Named arguments via map
340        [DslValue::Map(map)] => {
341            let system = map
342                .get("system")
343                .and_then(|v| match v {
344                    DslValue::String(s) => Some(s.clone()),
345                    _ => None,
346                })
347                .ok_or_else(|| ReplError::Execution("reason requires 'system' argument".into()))?;
348            let user = map
349                .get("user")
350                .and_then(|v| match v {
351                    DslValue::String(s) => Some(s.clone()),
352                    _ => None,
353                })
354                .ok_or_else(|| ReplError::Execution("reason requires 'user' argument".into()))?;
355            let max_iterations = map
356                .get("max_iterations")
357                .and_then(|v| match v {
358                    DslValue::Integer(i) => Some(*i as u32),
359                    DslValue::Number(n) => Some(*n as u32),
360                    _ => None,
361                })
362                .unwrap_or(10);
363            let max_tokens = map
364                .get("max_tokens")
365                .and_then(|v| match v {
366                    DslValue::Integer(i) => Some(*i as u32),
367                    DslValue::Number(n) => Some(*n as u32),
368                    _ => None,
369                })
370                .unwrap_or(100_000);
371            Ok((system, user, max_iterations, max_tokens))
372        }
373        // Positional: system, user
374        [DslValue::String(system), DslValue::String(user)] => {
375            Ok((system.clone(), user.clone(), 10, 100_000))
376        }
377        // Positional: system, user, max_iterations
378        [DslValue::String(system), DslValue::String(user), DslValue::Integer(max_iter)] => {
379            Ok((system.clone(), user.clone(), *max_iter as u32, 100_000))
380        }
381        _ => Err(ReplError::Execution(
382            "reason requires (system: string, user: string, [max_iterations?, max_tokens?])".into(),
383        )),
384    }
385}
386
387/// Convert a serde_json::Value to a DslValue.
388pub fn json_to_dsl_value(value: &serde_json::Value) -> DslValue {
389    match value {
390        serde_json::Value::Null => DslValue::Null,
391        serde_json::Value::Bool(b) => DslValue::Boolean(*b),
392        serde_json::Value::Number(n) => {
393            if let Some(i) = n.as_i64() {
394                DslValue::Integer(i)
395            } else if let Some(f) = n.as_f64() {
396                DslValue::Number(f)
397            } else {
398                DslValue::Number(0.0)
399            }
400        }
401        serde_json::Value::String(s) => DslValue::String(s.clone()),
402        serde_json::Value::Array(arr) => {
403            DslValue::List(arr.iter().map(json_to_dsl_value).collect())
404        }
405        serde_json::Value::Object(obj) => {
406            let map: HashMap<String, DslValue> = obj
407                .iter()
408                .map(|(k, v)| (k.clone(), json_to_dsl_value(v)))
409                .collect();
410            DslValue::Map(map)
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_parse_json_valid() {
421        let result =
422            builtin_parse_json(&[DslValue::String(r#"{"key": "value", "num": 42}"#.into())])
423                .unwrap();
424        match result {
425            DslValue::Map(map) => {
426                assert_eq!(map.get("key"), Some(&DslValue::String("value".into())));
427                assert_eq!(map.get("num"), Some(&DslValue::Integer(42)));
428            }
429            _ => panic!("Expected Map"),
430        }
431    }
432
433    #[test]
434    fn test_parse_json_array() {
435        let result = builtin_parse_json(&[DslValue::String("[1, 2, 3]".into())]).unwrap();
436        match result {
437            DslValue::List(items) => {
438                assert_eq!(items.len(), 3);
439                assert_eq!(items[0], DslValue::Integer(1));
440            }
441            _ => panic!("Expected List"),
442        }
443    }
444
445    #[test]
446    fn test_parse_json_invalid() {
447        let result = builtin_parse_json(&[DslValue::String("not json".into())]);
448        assert!(result.is_err());
449    }
450
451    #[test]
452    fn test_parse_json_nested() {
453        let json = r#"{"tasks": [{"id": 1, "done": false}], "count": 1}"#;
454        let result = builtin_parse_json(&[DslValue::String(json.into())]).unwrap();
455        match result {
456            DslValue::Map(map) => match map.get("tasks") {
457                Some(DslValue::List(tasks)) => {
458                    assert_eq!(tasks.len(), 1);
459                    match &tasks[0] {
460                        DslValue::Map(task) => {
461                            assert_eq!(task.get("id"), Some(&DslValue::Integer(1)));
462                            assert_eq!(task.get("done"), Some(&DslValue::Boolean(false)));
463                        }
464                        _ => panic!("Expected Map in list"),
465                    }
466                }
467                _ => panic!("Expected List for tasks"),
468            },
469            _ => panic!("Expected Map"),
470        }
471    }
472
473    #[test]
474    fn test_json_to_dsl_value_all_types() {
475        let json = serde_json::json!({
476            "str": "hello",
477            "int": 42,
478            "float": 1.5,
479            "bool": true,
480            "null": null,
481            "arr": [1, 2],
482            "obj": {"nested": "value"}
483        });
484
485        let dsl = json_to_dsl_value(&json);
486        match dsl {
487            DslValue::Map(map) => {
488                assert_eq!(map.get("str"), Some(&DslValue::String("hello".into())));
489                assert_eq!(map.get("int"), Some(&DslValue::Integer(42)));
490                assert_eq!(map.get("bool"), Some(&DslValue::Boolean(true)));
491                assert_eq!(map.get("null"), Some(&DslValue::Null));
492            }
493            _ => panic!("Expected Map"),
494        }
495    }
496
497    #[test]
498    fn test_parse_reason_args_positional() {
499        let args = vec![
500            DslValue::String("system prompt".into()),
501            DslValue::String("user message".into()),
502        ];
503        let (system, user, max_iter, max_tokens) = parse_reason_args(&args).unwrap();
504        assert_eq!(system, "system prompt");
505        assert_eq!(user, "user message");
506        assert_eq!(max_iter, 10);
507        assert_eq!(max_tokens, 100_000);
508    }
509
510    #[test]
511    fn test_parse_reason_args_named() {
512        let mut map = HashMap::new();
513        map.insert("system".into(), DslValue::String("sys".into()));
514        map.insert("user".into(), DslValue::String("usr".into()));
515        map.insert("max_iterations".into(), DslValue::Integer(5));
516
517        let args = vec![DslValue::Map(map)];
518        let (system, user, max_iter, max_tokens) = parse_reason_args(&args).unwrap();
519        assert_eq!(system, "sys");
520        assert_eq!(user, "usr");
521        assert_eq!(max_iter, 5);
522        assert_eq!(max_tokens, 100_000);
523    }
524
525    #[test]
526    fn test_parse_reason_args_missing_required() {
527        let mut map = HashMap::new();
528        map.insert("system".into(), DslValue::String("sys".into()));
529        // Missing "user"
530
531        let args = vec![DslValue::Map(map)];
532        assert!(parse_reason_args(&args).is_err());
533    }
534}