use std::collections::HashSet;
use onwards::{ChainStep, NextAction, StepDescriptor, StepKind, StepState};
use serde_json::{Value, json};
pub(crate) struct ParsedRequest {
pub model: String,
pub initial_messages: Vec<Value>,
pub tools: Option<Value>,
pub stream: bool,
}
pub(crate) fn parse_parent_request(body: &str) -> Result<ParsedRequest, String> {
let v: Value = serde_json::from_str(body).map_err(|e| format!("parent body parse: {e}"))?;
let model = v
.get("model")
.and_then(|x| x.as_str())
.ok_or_else(|| "parent body missing 'model'".to_string())?
.to_string();
let initial_messages = if let Some(input) = v.get("input") {
match input {
Value::String(s) => vec![json!({"role": "user", "content": s})],
Value::Array(items) => translate_input_items(items)?,
_ => return Err("'input' must be string or array".into()),
}
} else if let Some(messages) = v.get("messages").and_then(|m| m.as_array()) {
messages.clone()
} else {
return Err("parent body missing 'input' or 'messages'".into());
};
let tools = v.get("tools").map(normalize_tools);
let stream = v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false);
Ok(ParsedRequest {
model,
initial_messages,
tools,
stream,
})
}
fn translate_input_items(items: &[Value]) -> Result<Vec<Value>, String> {
let mut out: Vec<Value> = Vec::new();
for (idx, item) in items.iter().enumerate() {
let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("message");
match item_type {
"message" => {
let obj = item
.as_object()
.ok_or_else(|| format!("input[{idx}]: 'message' item must be a JSON object"))?;
if obj.get("role").and_then(|r| r.as_str()).is_none() {
return Err(format!("input[{idx}]: 'message' item missing 'role'"));
}
let mut translated = serde_json::Map::new();
for (k, v) in obj {
if k != "type" {
translated.insert(k.clone(), v.clone());
}
}
out.push(Value::Object(translated));
}
"function_call" => {
let obj = item
.as_object()
.ok_or_else(|| format!("input[{idx}]: 'function_call' item must be a JSON object"))?;
let call_id = obj
.get("call_id")
.and_then(|x| x.as_str())
.ok_or_else(|| format!("input[{idx}]: 'function_call' missing 'call_id'"))?
.to_string();
let name = obj
.get("name")
.and_then(|x| x.as_str())
.ok_or_else(|| format!("input[{idx}]: 'function_call' missing 'name'"))?
.to_string();
let arguments_str = match obj.get("arguments") {
Some(Value::String(s)) => s.clone(),
Some(Value::Null) | None => "{}".to_string(),
Some(other) => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
};
let new_tool_call = json!({
"id": call_id,
"type": "function",
"function": {"name": name, "arguments": arguments_str},
});
if let Some(last) = out.last_mut()
&& last.get("role").and_then(|r| r.as_str()) == Some("assistant")
&& let Some(obj) = last.as_object_mut()
{
let tool_calls = obj.entry("tool_calls").or_insert_with(|| json!([]));
if let Value::Array(arr) = tool_calls {
arr.push(new_tool_call);
continue;
}
}
out.push(json!({
"role": "assistant",
"content": Value::Null,
"tool_calls": [new_tool_call],
}));
}
"function_call_output" => {
let obj = item
.as_object()
.ok_or_else(|| format!("input[{idx}]: 'function_call_output' item must be a JSON object"))?;
let call_id = obj
.get("call_id")
.and_then(|x| x.as_str())
.ok_or_else(|| format!("input[{idx}]: 'function_call_output' missing 'call_id'"))?;
let content_str = match obj.get("output") {
Some(Value::String(s)) => s.clone(),
Some(Value::Null) | None => String::new(),
Some(other) => serde_json::to_string(other).unwrap_or_default(),
};
out.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_str,
}));
}
"reasoning" => {
tracing::debug!(idx, "dropping 'reasoning' input item during /v1/responses translation");
}
other => {
tracing::warn!(idx, item_type = %other, "unknown Open Responses input item type; dropping");
}
}
}
Ok(out)
}
fn normalize_tools(tools: &Value) -> Value {
let Value::Array(items) = tools else {
return tools.clone();
};
let normalized: Vec<Value> = items
.iter()
.map(|item| {
if item.get("function").is_some() {
return item.clone();
}
let item_type = item.get("type").and_then(|t| t.as_str()).unwrap_or("function");
if item_type != "function" {
return item.clone();
}
let mut function_obj = serde_json::Map::new();
if let Some(obj) = item.as_object() {
for (k, v) in obj {
if k != "type" {
function_obj.insert(k.clone(), v.clone());
}
}
}
json!({"type": "function", "function": function_obj})
})
.collect();
Value::Array(normalized)
}
pub(crate) fn build_messages_from_chain(initial: &[Value], chain: &[ChainStep]) -> Vec<Value> {
let mut messages: Vec<Value> = initial.to_vec();
let mut i = 0;
while i < chain.len() {
let step = &chain[i];
if !matches!(step.state, StepState::Completed) {
i += 1;
continue;
}
match step.kind {
StepKind::ModelCall => {
if let Some(payload) = &step.response_payload
&& let Some(message) = extract_assistant_message(payload)
{
messages.push(message);
}
i += 1;
}
StepKind::ToolCall => {
let call_id = step
.response_payload
.as_ref()
.map(|_p| "unknown".to_string())
.unwrap_or_else(|| format!("step_{}", step.sequence));
let content = step
.response_payload
.as_ref()
.map(|p| serde_json::to_string(p).unwrap_or_default())
.unwrap_or_default();
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content,
}));
i += 1;
}
}
}
messages
}
fn extract_assistant_message(model_response: &Value) -> Option<Value> {
model_response
.get("choices")
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|choice| choice.get("message"))
.cloned()
}
pub(crate) fn extract_tool_calls(model_response: &Value) -> Vec<StepDescriptor> {
let Some(message) = extract_assistant_message(model_response) else {
return vec![];
};
let Some(tool_calls) = message.get("tool_calls").and_then(|x| x.as_array()) else {
return vec![];
};
tool_calls
.iter()
.filter_map(|call| {
let function = call.get("function")?;
let name = function.get("name")?.as_str()?.to_string();
let raw_args = function.get("arguments");
let args: Value = match raw_args {
Some(Value::String(s)) => {
serde_json::from_str(s).unwrap_or(json!({}))
}
Some(other) => other.clone(),
None => json!({}),
};
let call_id = call.get("id").and_then(|x| x.as_str()).unwrap_or("call_unknown").to_string();
Some(StepDescriptor {
kind: StepKind::ToolCall,
request_payload: json!({
"name": name,
"args": args,
"call_id": call_id,
}),
})
})
.collect()
}
pub(crate) fn prepare_initial_model_call(parsed: &ParsedRequest) -> StepDescriptor {
let mut payload = json!({
"model": parsed.model,
"messages": parsed.initial_messages,
"stream": parsed.stream,
});
if let Some(tools) = &parsed.tools {
payload["tools"] = tools.clone();
}
StepDescriptor {
kind: StepKind::ModelCall,
request_payload: payload,
}
}
pub(crate) fn prepare_followup_model_call(parsed: &ParsedRequest, chain: &[ChainStep]) -> StepDescriptor {
let messages = build_messages_from_chain(&parsed.initial_messages, chain);
let mut payload = json!({
"model": parsed.model,
"messages": messages,
"stream": parsed.stream,
});
if let Some(tools) = &parsed.tools {
payload["tools"] = tools.clone();
}
StepDescriptor {
kind: StepKind::ModelCall,
request_payload: payload,
}
}
pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep], resolved_tool_names: &HashSet<String>) -> NextAction {
if chain.is_empty() {
return NextAction::AppendSteps(vec![prepare_initial_model_call(parsed)]);
}
let last = match chain
.iter()
.rev()
.find(|s| matches!(s.state, StepState::Completed | StepState::Failed))
{
Some(s) => s,
None => {
return NextAction::Fail(json!({
"type": "step_abandoned",
"message": "a step was in flight when this worker took over; the previous worker exited before completing it",
}));
}
};
if matches!(last.state, StepState::Failed) {
return NextAction::Fail(last.error.clone().unwrap_or_else(|| json!({"type": "step_failed"})));
}
match last.kind {
StepKind::ModelCall => {
let response = last.response_payload.as_ref().cloned().unwrap_or_else(|| json!({}));
let tool_calls = extract_tool_calls(&response);
if tool_calls.is_empty() {
return NextAction::Complete(response);
}
let all_registered = tool_calls
.iter()
.all(|step| tool_call_name(&step.request_payload).is_some_and(|name| resolved_tool_names.contains(name)));
if all_registered {
NextAction::AppendSteps(tool_calls)
} else {
NextAction::Complete(response)
}
}
StepKind::ToolCall => {
NextAction::AppendSteps(vec![prepare_followup_model_call(parsed, chain)])
}
}
}
fn tool_call_name(payload: &Value) -> Option<&str> {
payload.get("name").and_then(|n| n.as_str())
}
#[cfg(test)]
mod tests {
use super::*;
fn step(id: &str, seq: i64, kind: StepKind, state: StepState, response: Option<Value>) -> ChainStep {
ChainStep {
id: id.into(),
kind,
state,
sequence: seq,
prev_step_id: None,
parent_step_id: None,
response_payload: response,
error: None,
}
}
#[test]
fn parses_string_input() {
let body = r#"{"model":"gpt-4o","input":"hi"}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.model, "gpt-4o");
assert_eq!(p.initial_messages, vec![json!({"role":"user","content":"hi"})]);
}
#[test]
fn parses_messages_form() {
let body = r#"{"model":"x","messages":[{"role":"user","content":"hello"}]}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 1);
}
#[test]
fn translates_message_input_items() {
let body = r#"{
"model": "m",
"input": [
{"type":"message","role":"system","content":"Be terse."},
{"type":"message","role":"user","content":"Find one fact."}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 2);
assert_eq!(p.initial_messages[0], json!({"role":"system","content":"Be terse."}));
assert_eq!(p.initial_messages[1], json!({"role":"user","content":"Find one fact."}));
for msg in &p.initial_messages {
assert!(msg.get("type").is_none(), "translated message must drop 'type' field");
}
}
#[test]
fn translates_function_call_to_assistant_tool_calls_message() {
let body = r#"{
"model": "m",
"input": [
{"type":"message","role":"user","content":"go"},
{"type":"function_call","call_id":"call_a","name":"search","arguments":"{\"query\":\"x\"}"}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 2);
let assistant = &p.initial_messages[1];
assert_eq!(assistant["role"], "assistant");
assert_eq!(assistant["content"], Value::Null);
let tool_calls = assistant["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_a");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "search");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"query\":\"x\"}");
}
#[test]
fn translates_function_call_output_to_tool_message() {
let body = r#"{
"model": "m",
"input": [
{"type":"function_call","call_id":"call_a","name":"f","arguments":"{}"},
{"type":"function_call_output","call_id":"call_a","output":"{\"results\":[]}"}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 2);
let tool_msg = &p.initial_messages[1];
assert_eq!(tool_msg["role"], "tool");
assert_eq!(tool_msg["tool_call_id"], "call_a");
assert_eq!(tool_msg["content"], "{\"results\":[]}");
}
#[test]
fn collapses_consecutive_function_calls_into_one_assistant_message() {
let body = r#"{
"model": "m",
"input": [
{"type":"function_call","call_id":"call_a","name":"a","arguments":"{}"},
{"type":"function_call","call_id":"call_b","name":"b","arguments":"{}"}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 1);
let tool_calls = p.initial_messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0]["id"], "call_a");
assert_eq!(tool_calls[1]["id"], "call_b");
}
#[test]
fn folds_function_calls_into_preceding_assistant_message() {
let body = r#"{
"model": "m",
"input": [
{"type":"message","role":"user","content":"hi"},
{"type":"message","role":"assistant","content":"thinking..."},
{"type":"function_call","call_id":"call_a","name":"f","arguments":"{}"}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 2);
let assistant = &p.initial_messages[1];
assert_eq!(assistant["role"], "assistant");
assert_eq!(assistant["content"], "thinking...");
let tool_calls = assistant["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_a");
}
#[test]
fn full_multi_turn_tool_conversation_translates_correctly() {
let body = r#"{
"model": "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8",
"service_tier": "flex",
"max_output_tokens": 64,
"input": [
{"type":"message","role":"system","content":"Be terse."},
{"type":"message","role":"user","content":"Find one fact."},
{"type":"message","role":"system","content":"ctx"},
{"type":"function_call","call_id":"call_a","name":"search","arguments":"{\"query\":\"x\"}"},
{"type":"function_call_output","call_id":"call_a","output":"{\"results\":[]}"}
],
"tools":[{"type":"function","function":{"name":"search","description":"s","parameters":{"type":"object"}}}]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 5);
for (idx, msg) in p.initial_messages.iter().enumerate() {
let role = msg.get("role").and_then(|r| r.as_str());
assert!(role.is_some(), "messages[{idx}] must have a role; got {msg}");
}
assert_eq!(p.initial_messages[0]["role"], "system");
assert_eq!(p.initial_messages[1]["role"], "user");
assert_eq!(p.initial_messages[2]["role"], "system");
assert_eq!(p.initial_messages[3]["role"], "assistant");
assert_eq!(p.initial_messages[3]["tool_calls"][0]["function"]["name"], "search");
assert_eq!(p.initial_messages[4]["role"], "tool");
assert_eq!(p.initial_messages[4]["tool_call_id"], "call_a");
}
#[test]
fn drops_reasoning_input_items() {
let body = r#"{
"model": "m",
"input": [
{"type":"message","role":"user","content":"hi"},
{"type":"reasoning","summary":["thought"]},
{"type":"message","role":"assistant","content":"hello"}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 2);
assert_eq!(p.initial_messages[0]["role"], "user");
assert_eq!(p.initial_messages[1]["role"], "assistant");
}
#[test]
fn non_object_input_items_return_clear_errors() {
let err = match parse_parent_request(r#"{"model":"m","input":["bare string"]}"#) {
Ok(_) => panic!("expected Err for bare string in input"),
Err(e) => e,
};
assert!(err.contains("must be a JSON object"), "got: {err}");
assert!(err.contains("'message'"), "got: {err}");
}
#[test]
fn missing_role_on_message_item_returns_error() {
let body = r#"{
"model": "m",
"input": [{"type":"message","content":"hi"}]
}"#;
let err = match parse_parent_request(body) {
Ok(_) => panic!("expected Err, got Ok"),
Err(e) => e,
};
assert!(err.contains("missing 'role'"), "got: {err}");
}
#[test]
fn normalizes_spec_flat_tools_into_wrapped_form() {
let body = r#"{
"model": "m",
"input": "hi",
"tools": [
{"type":"function","name":"search","description":"s","parameters":{"type":"object"}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let tools = p.tools.unwrap();
let arr = tools.as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["type"], "function");
let function = &arr[0]["function"];
assert_eq!(function["name"], "search");
assert_eq!(function["description"], "s");
assert_eq!(function["parameters"], json!({"type": "object"}));
assert!(arr[0].get("name").is_none(), "wrapped tool must not have top-level 'name'");
}
#[test]
fn passes_already_wrapped_tools_through_unchanged() {
let body = r#"{
"model": "m",
"input": "hi",
"tools": [
{"type":"function","function":{"name":"search","description":"s","parameters":{"type":"object"}}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let tools = p.tools.unwrap();
let arr = tools.as_array().unwrap();
assert_eq!(arr[0]["function"]["name"], "search");
assert!(arr[0].get("name").is_none());
}
#[test]
fn null_arguments_on_function_call_become_empty_object() {
let body = r#"{
"model": "m",
"input": [
{"type":"function_call","call_id":"c","name":"f","arguments":null}
]
}"#;
let p = parse_parent_request(body).unwrap();
let tool_calls = p.initial_messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls[0]["function"]["arguments"], "{}");
}
#[test]
fn null_output_on_function_call_output_becomes_empty_string() {
let body = r#"{
"model": "m",
"input": [
{"type":"function_call","call_id":"c","name":"f","arguments":"{}"},
{"type":"function_call_output","call_id":"c","output":null}
]
}"#;
let p = parse_parent_request(body).unwrap();
let tool_msg = &p.initial_messages[1];
assert_eq!(tool_msg["role"], "tool");
assert_eq!(tool_msg["content"], "");
}
#[test]
fn normalize_tools_preserves_unknown_fields() {
let body = r#"{
"model": "m",
"input": "hi",
"tools": [
{"type":"function","name":"f","description":"d","parameters":{"type":"object"},"strict":true,"x_vendor":{"hint":"v"}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let arr = p.tools.unwrap();
let function = &arr[0]["function"];
assert_eq!(function["name"], "f");
assert_eq!(function["description"], "d");
assert_eq!(function["strict"], true);
assert_eq!(function["x_vendor"]["hint"], "v");
assert!(function.get("type").is_none());
assert_eq!(arr[0]["type"], "function");
}
#[test]
fn empty_input_array_translates_to_empty_messages() {
let body = r#"{"model":"m","input":[]}"#;
let p = parse_parent_request(body).unwrap();
assert!(p.initial_messages.is_empty());
}
#[test]
fn raw_object_arguments_are_json_serialized() {
let body = r#"{
"model": "m",
"input": [
{"type":"function_call","call_id":"c","name":"f","arguments":{"query":"x"}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let tool_calls = p.initial_messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"query\":\"x\"}");
}
#[test]
fn message_translation_preserves_chat_completions_fields() {
let body = r#"{
"model": "m",
"input": [
{"type":"message","role":"user","content":"hi","name":"alice"},
{"role":"tool","tool_call_id":"call_a","content":"{\"ok\":1}"},
{"role":"assistant","content":null,"tool_calls":[{"id":"call_a","type":"function","function":{"name":"f","arguments":"{}"}}]}
]
}"#;
let p = parse_parent_request(body).unwrap();
assert_eq!(p.initial_messages.len(), 3);
assert_eq!(p.initial_messages[0]["role"], "user");
assert_eq!(p.initial_messages[0]["name"], "alice");
assert!(p.initial_messages[0].get("type").is_none());
assert_eq!(p.initial_messages[1]["role"], "tool");
assert_eq!(p.initial_messages[1]["tool_call_id"], "call_a");
assert_eq!(p.initial_messages[2]["role"], "assistant");
let tool_calls = p.initial_messages[2]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_a");
}
#[test]
fn normalize_tools_passes_through_non_function_tool_types() {
let body = r#"{
"model": "m",
"input": "hi",
"tools": [
{"type":"web_search"},
{"type":"file_search","vector_store_ids":["vs_1"]},
{"type":"function","name":"f","parameters":{"type":"object"}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let arr = p.tools.unwrap();
let arr = arr.as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr[0], json!({"type": "web_search"}));
assert_eq!(arr[1], json!({"type": "file_search", "vector_store_ids": ["vs_1"]}));
assert_eq!(arr[2]["type"], "function");
assert_eq!(arr[2]["function"]["name"], "f");
assert!(arr[2].get("name").is_none());
}
#[test]
fn normalizes_mixed_wrapped_and_spec_flat_tools_in_one_array() {
let body = r#"{
"model": "m",
"input": "hi",
"tools": [
{"type":"function","function":{"name":"wrapped","description":"w"}},
{"type":"function","name":"flat","description":"f","parameters":{"type":"object"}}
]
}"#;
let p = parse_parent_request(body).unwrap();
let arr = p.tools.unwrap();
let arr = arr.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["function"]["name"], "wrapped");
assert_eq!(arr[1]["function"]["name"], "flat");
assert!(arr[1].get("name").is_none());
assert!(arr[1].get("parameters").is_none());
}
fn names(items: &[&str]) -> HashSet<String> {
items.iter().map(|s| s.to_string()).collect()
}
#[test]
fn empty_chain_emits_initial_model_call() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![json!({"role":"user","content":"hi"})],
tools: None,
stream: false,
};
match decide_next_action(&parsed, &[], &HashSet::new()) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 1);
assert!(matches!(steps[0].kind, StepKind::ModelCall));
assert_eq!(steps[0].request_payload["model"], "m");
}
_ => panic!("expected AppendSteps"),
}
}
#[test]
fn model_call_with_registered_tool_calls_emits_fan_out() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "a", "arguments": "{\"x\":1}"}},
{"id": "call_2", "type": "function", "function": {"name": "b", "arguments": "{}"}},
]
}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response))];
match decide_next_action(&parsed, &chain, &names(&["a", "b"])) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 2);
assert_eq!(steps[0].request_payload["name"], "a");
assert_eq!(steps[0].request_payload["args"]["x"], 1);
assert_eq!(steps[0].request_payload["call_id"], "call_1");
assert_eq!(steps[1].request_payload["name"], "b");
}
_ => panic!("expected AppendSteps"),
}
}
#[test]
fn model_call_with_unregistered_tool_completes_for_client_dispatch() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "read_pages", "arguments": "{\"id\":1}"}},
]
}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain, &HashSet::new()) {
NextAction::Complete(v) => assert_eq!(v, response),
other => panic!("expected Complete for unregistered tool, got {other:?}"),
}
}
#[test]
fn model_call_with_mixed_registered_and_unregistered_completes() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "weather", "arguments": "{}"}},
{"id": "call_2", "type": "function", "function": {"name": "client_only", "arguments": "{}"}},
]
}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain, &names(&["weather"])) {
NextAction::Complete(v) => assert_eq!(v, response),
other => panic!("expected Complete for mixed tool_calls, got {other:?}"),
}
}
#[test]
fn model_call_without_tool_calls_completes() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {"role": "assistant", "content": "the answer is 42"}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain, &HashSet::new()) {
NextAction::Complete(v) => assert_eq!(v, response),
_ => panic!("expected Complete"),
}
}
#[test]
fn after_tool_call_emits_followup_model_call() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![json!({"role":"user","content":"hi"})],
tools: None,
stream: false,
};
let model_response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "a", "arguments": "{}"}}]
}
}]
});
let chain = vec![
step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(model_response)),
step("s2", 2, StepKind::ToolCall, StepState::Completed, Some(json!({"result": 1}))),
];
match decide_next_action(&parsed, &chain, &names(&["a"])) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 1);
assert!(matches!(steps[0].kind, StepKind::ModelCall));
let messages = steps[0].request_payload["messages"].as_array().unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0]["role"], "user");
assert_eq!(messages[1]["role"], "assistant");
assert_eq!(messages[2]["role"], "tool");
}
_ => panic!("expected AppendSteps"),
}
}
#[test]
fn failed_step_propagates_as_fail() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let mut s = step("s1", 1, StepKind::ModelCall, StepState::Failed, None);
s.error = Some(json!({"type": "upstream_500"}));
match decide_next_action(&parsed, &[s], &HashSet::new()) {
NextAction::Fail(v) => assert_eq!(v, json!({"type": "upstream_500"})),
_ => panic!("expected Fail"),
}
}
}