openai_agents_rust/agent/
runner.rs

1use crate::agent::traits::{Agent, AgentContext};
2use crate::error::AgentError;
3use crate::model::Model;
4use crate::results::RunResult;
5use crate::tools::registry::ToolRegistry;
6use crate::utils::env::var_bool;
7use serde_json::json;
8
9/// Tool-use behavior modes analogous to the Python SDK behavior.
10#[derive(Clone)]
11pub enum ToolUseBehavior {
12    RunLlmAgain,
13    StopOnFirstTool,
14    StopAtTools(Vec<String>),
15    Custom(BehaviorFn),
16}
17
18/// Result of deciding whether tool outputs are final.
19#[derive(Clone, Debug, Default)]
20pub struct ToolsToFinalOutputResult {
21    pub is_final_output: bool,
22    pub final_output: Option<String>,
23}
24
25/// Callback to decide final output based on tool results.
26pub type BehaviorFn = std::sync::Arc<
27    dyn Fn(&AgentContext, &[(String, String)]) -> ToolsToFinalOutputResult + Send + Sync,
28>;
29
30/// Minimal runner scaffold to orchestrate a single agent call.
31pub struct Runner;
32
33impl Runner {
34    pub async fn run<A: Agent>(agent: &A, ctx: &AgentContext) -> Result<(), AgentError> {
35        // For now, just run the agent.
36        agent.run(ctx).await
37    }
38
39    /// Execute tools and collect outputs; may short-circuit based on behavior.
40    pub async fn run_tools_collect(
41        registry: &ToolRegistry,
42        ctx: &AgentContext,
43        input: &str,
44        behavior: &ToolUseBehavior,
45    ) -> Result<(Vec<(String, String)>, Option<String>), AgentError> {
46        let mut results: Vec<(String, String)> = Vec::new();
47        for tool in registry.all() {
48            if tool.is_enabled(ctx).await {
49                let name = tool.name().to_string();
50                let out = tool.call(input).await?;
51                // If StopAtTools, check if this tool is among the stop set.
52                match behavior {
53                    ToolUseBehavior::StopOnFirstTool => {
54                        return Ok((vec![(name, out.clone())], Some(out)));
55                    }
56                    ToolUseBehavior::StopAtTools(stop_list) => {
57                        if stop_list.iter().any(|n| n == &name) {
58                            return Ok((vec![(name, out.clone())], Some(out)));
59                        }
60                    }
61                    _ => {}
62                }
63                results.push((name, out));
64            }
65        }
66        Ok((results, None))
67    }
68
69    /// Minimal agent loop: try tools first (optional), then call the model with instructions.
70    pub async fn run_agent_with_model<M: Model + ?Sized>(
71        model: &M,
72        ctx: &AgentContext,
73        instructions: Option<&str>,
74        input: &str,
75        behavior: ToolUseBehavior,
76    ) -> Result<RunResult, AgentError> {
77        // First attempt: run tools if any are enabled.
78        let (tool_results, early_final) =
79            Self::run_tools_collect(&ctx.tools, ctx, input, &behavior).await?;
80        if let Some(out) = early_final {
81            if !tool_results.is_empty() {
82                tracing::info!(
83                    target: "runner",
84                    tool_count = tool_results.len(),
85                    tools = %serde_json::json!(tool_results),
86                    "early stop with tool outputs"
87                );
88            }
89            return Ok(RunResult {
90                id: None,
91                text: Some(out),
92                tool_outputs: tool_results,
93            });
94        }
95
96        // If custom behavior, allow it to decide final output.
97        if let ToolUseBehavior::Custom(decider) = &behavior {
98            let res = decider(ctx, &tool_results);
99            if res.is_final_output {
100                return Ok(RunResult {
101                    id: None,
102                    text: res.final_output,
103                    tool_outputs: vec![],
104                });
105            }
106        }
107
108        // No tool output or no tools available; call the model directly.
109        let combined_input = if tool_results.is_empty() {
110            input.to_string()
111        } else {
112            let mut agg = String::from(input);
113            for (name, out) in &tool_results {
114                agg.push_str("\n\nTool ");
115                agg.push_str(name);
116                agg.push_str(" output:\n");
117                agg.push_str(out);
118            }
119            agg
120        };
121
122        let max_turns = 3;
123        // Build initial chat messages (system + user)
124        let mut messages: Vec<serde_json::Value> = Vec::new();
125        if let Some(sys) = instructions {
126            messages.push(json!({"role": "system", "content": sys}));
127        }
128        messages.push(json!({"role": "user", "content": input}));
129        // Collect OpenAI tool specs for enabled tools
130        let mut tool_specs: Vec<serde_json::Value> = Vec::new();
131        for t in ctx.tools.all() {
132            if t.openai_tool_spec().is_some() && t.is_enabled(ctx).await {
133                if let Some(spec) = t.openai_tool_spec() {
134                    tool_specs.push(spec);
135                }
136            }
137        }
138        // Compatibility: allow disabling passing tools to the LLM via env flag
139        let disable_tools_in_llm = var_bool("VLLM_DISABLE_TOOLS_IN_LLM", false);
140        let mut previous_response_id: Option<String> = None;
141        let disable_tools_next_turn = false;
142        let mut collected_tool_outputs: Vec<(String, String)> = tool_results.clone();
143        for _turn in 0..max_turns {
144            let resp = model
145                .get_response(
146                    instructions,
147                    &combined_input,
148                    None,
149                    Some(&messages),
150                    if tool_specs.is_empty() || disable_tools_in_llm || disable_tools_next_turn {
151                        None
152                    } else {
153                        Some(&tool_specs)
154                    },
155                    None,
156                    None,
157                    None,
158                    false,
159                    previous_response_id.as_deref(),
160                    None,
161                )
162                .await?;
163
164            if let Some(rid) = &resp.id {
165                previous_response_id = Some(rid.clone());
166            }
167
168            if resp.tool_calls.is_empty() {
169                if let Some(text) = &resp.text {
170                    messages.push(json!({"role": "assistant", "content": text}));
171                }
172                return Ok(RunResult {
173                    id: resp.id,
174                    text: resp.text,
175                    tool_outputs: collected_tool_outputs,
176                });
177            }
178
179            // Add assistant message for proper round-trip.
180            let all_have_ids = resp
181                .tool_calls
182                .iter()
183                .all(|tc| tc.id.is_some() || tc.call_id.is_some());
184            if all_have_ids {
185                // Use tool_calls schema; set content to null per Harmony compatibility
186                messages.push(json!({
187                    "role": "assistant",
188                    "content": serde_json::Value::Null,
189                    "tool_calls": resp.tool_calls.iter().map(|tc| json!({
190                        "id": tc.id.clone().or(tc.call_id.clone()),
191                        "type": "function",
192                        "function": {"name": tc.name, "arguments": tc.arguments},
193                        "call_id": tc.call_id,
194                    })).collect::<Vec<_>>()
195                }));
196            } else {
197                // Legacy function_call schema supports only one function call per message.
198                if let Some(tc0) = resp.tool_calls.first() {
199                    messages.push(json!({
200                        "role": "assistant",
201                        "content": serde_json::Value::Null,
202                        "function_call": {"name": tc0.name, "arguments": tc0.arguments},
203                    }));
204                }
205            }
206
207            // Execute requested tool calls if available.
208            let mut executed_any_tool = false;
209            let mut missing_tools: Vec<String> = Vec::new();
210            let mut _new_tool_outputs: Vec<(String, String)> = Vec::new();
211            for tc in resp.tool_calls {
212                if let Some(tool) = ctx.tools.get_by_name(&tc.name) {
213                    if tool.is_enabled(ctx).await {
214                        let out = tool
215                            .call_with_context(ctx, tc.id.as_deref(), &tc.arguments)
216                            .await?;
217                        // Append a proper tool message for the next model turn.
218                        if let Some(link_id) = tc.call_id.clone().or(tc.id.clone()) {
219                            messages.push(json!({
220                                "role": "tool",
221                                "tool_call_id": link_id,
222                                "content": out
223                            }));
224                        } else {
225                            // Legacy function message
226                            messages.push(json!({
227                                "role": "function",
228                                "name": tc.name,
229                                "content": out
230                            }));
231                        }
232                        _new_tool_outputs.push((tc.name.clone(), out.clone()));
233                        executed_any_tool = true;
234                    } else {
235                        missing_tools.push(tc.name.clone());
236                    }
237                } else {
238                    missing_tools.push(tc.name.clone());
239                }
240            }
241            if !_new_tool_outputs.is_empty() {
242                collected_tool_outputs.extend(_new_tool_outputs);
243                tracing::info!(
244                    target: "runner",
245                    total_tools = collected_tool_outputs.len(),
246                    last_batch = %serde_json::json!(collected_tool_outputs),
247                    "collected tool outputs"
248                );
249            }
250
251            if !missing_tools.is_empty() {
252                return Err(AgentError::Other(format!(
253                    "model requested unknown or disabled tools: {}",
254                    missing_tools.join(", ")
255                )));
256            }
257            if !executed_any_tool {
258                return Err(AgentError::Other(
259                    "model returned tool_calls but none could be executed".into(),
260                ));
261            }
262            // combined_input remains the same; messages carry the tool outputs.
263        }
264
265        // Final model call to produce an answer after tool outputs were added to messages.
266        let resp = model
267            .get_response(
268                instructions,
269                &combined_input,
270                None,
271                Some(&messages),
272                if tool_specs.is_empty() || disable_tools_in_llm {
273                    None
274                } else {
275                    Some(&tool_specs)
276                },
277                None,
278                None,
279                None,
280                false,
281                previous_response_id.as_deref(),
282                None,
283            )
284            .await?;
285        let res = RunResult {
286            id: resp.id,
287            text: resp.text,
288            tool_outputs: collected_tool_outputs,
289        };
290        if !res.tool_outputs.is_empty() {
291            tracing::info!(
292                target: "runner",
293                tool_count = res.tool_outputs.len(),
294                tools = %serde_json::json!(res.tool_outputs),
295                "final result with tool outputs"
296            );
297        }
298        Ok(res)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::agent::traits::AgentContext;
306    use crate::client::OpenAiClient;
307
308    use async_trait::async_trait;
309    use std::sync::Arc;
310
311    struct EchoTool;
312    #[async_trait]
313    impl crate::tools::traits::Tool for EchoTool {
314        fn name(&self) -> &str {
315            "echo"
316        }
317        async fn call(&self, input: &str) -> Result<String, crate::error::AgentError> {
318            Ok(input.to_string())
319        }
320    }
321
322    use crate::model::openai_chat::OpenAiChat;
323
324    #[tokio::test]
325    async fn runner_returns_tool_outputs_on_stop_first() {
326        // Load from .env if present and use env overrides; then fallback for test defaults
327        let _ = dotenvy::dotenv();
328        let mut cfg = crate::config::load_from_env();
329        if cfg.base_url.is_empty() {
330            cfg.base_url = "http://localhost".into();
331        }
332        if cfg.model.is_empty() {
333            cfg.model = "openai/gpt-oss-120b".into();
334        }
335        // Avoid auth in unit test
336        cfg.api_key = String::new();
337        let client = Arc::new(OpenAiClient::new(cfg.clone()));
338        let plugins = Arc::new(crate::plugin::loader::PluginRegistry::new());
339        let mut reg = crate::tools::registry::ToolRegistry::new();
340        reg.register(EchoTool);
341        let ctx = AgentContext {
342            config: Arc::new(cfg.clone()),
343            client,
344            plugins,
345            tools: Arc::new(reg),
346        };
347        let model = OpenAiChat::new(cfg).without_auth();
348        let res = Runner::run_agent_with_model(
349            &model,
350            &ctx,
351            None,
352            "hi",
353            ToolUseBehavior::StopOnFirstTool,
354        )
355        .await
356        .unwrap();
357        assert_eq!(res.text.as_deref(), Some("hi"));
358        assert_eq!(res.tool_outputs.len(), 1);
359        assert_eq!(res.tool_outputs[0].0, "echo");
360        assert_eq!(res.tool_outputs[0].1, "hi");
361    }
362}