#![cfg(feature = "openai")]
use rstructor::{ApiErrorKind, Instructor, LLMClient, OpenAIClient, RStructorError};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
#[derive(Instructor, Serialize, Deserialize, Debug, PartialEq)]
#[llm(validate = "validate_movie")]
struct Movie {
title: String,
year: u16,
}
fn validate_movie(m: &Movie) -> rstructor::Result<()> {
if m.year < 1888 {
return Err(RStructorError::ValidationError(
"year predates cinema".into(),
));
}
Ok(())
}
fn chat_completion(content: &str) -> String {
json!({
"choices": [{
"message": { "role": "assistant", "content": content },
"finish_reason": "stop",
}]
})
.to_string()
}
fn client(server: &mockito::Server) -> OpenAIClient {
OpenAIClient::new("test-key")
.unwrap()
.base_url(server.url())
.model("gpt-4o-mini")
}
#[tokio::test]
async fn materialize_parses_a_real_response() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(chat_completion(r#"{"title":"Inception","year":2010}"#))
.expect(1)
.create_async()
.await;
let movie: Movie = client(&server)
.materialize("Describe Inception")
.await
.unwrap();
assert_eq!(
movie,
Movie {
title: "Inception".into(),
year: 2010
}
);
m.assert_async().await;
}
#[tokio::test]
async fn reask_loop_recovers_from_validation_failure() {
let mut server = mockito::Server::new_async().await;
let bad = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(chat_completion(r#"{"title":"Old","year":1700}"#))
.expect(1)
.create_async()
.await;
let good = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(chat_completion(r#"{"title":"Metropolis","year":1927}"#))
.expect(1)
.create_async()
.await;
let movie: Movie = client(&server).materialize("a film").await.unwrap();
assert_eq!(movie.year, 1927);
bad.assert_async().await;
good.assert_async().await;
}
#[tokio::test]
async fn retryable_status_is_retried() {
let mut server = mockito::Server::new_async().await;
let rate_limited = server
.mock("POST", "/chat/completions")
.with_status(429)
.with_header("retry-after", "0")
.with_body("{}")
.expect(1)
.create_async()
.await;
let ok = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(chat_completion(r#"{"title":"Dune","year":2021}"#))
.expect(1)
.create_async()
.await;
let movie: Movie = client(&server).materialize("a film").await.unwrap();
assert_eq!(movie.title, "Dune");
rate_limited.assert_async().await;
ok.assert_async().await;
}
#[tokio::test]
async fn auth_error_is_surfaced_and_not_retried() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("POST", "/chat/completions")
.with_status(401)
.with_body(r#"{"error":{"message":"invalid api key"}}"#)
.expect(1) .create_async()
.await;
let err = client(&server)
.materialize::<Movie>("a film")
.await
.unwrap_err();
assert!(
matches!(
err.api_error_kind(),
Some(ApiErrorKind::AuthenticationFailed)
),
"expected AuthenticationFailed, got {err:?}"
);
m.assert_async().await;
}
#[tokio::test]
async fn generate_with_metadata_parses_content_and_usage() {
let mut server = mockito::Server::new_async().await;
let body = json!({
"choices": [{
"message": { "role": "assistant", "content": "hello there" },
"finish_reason": "stop",
}],
"usage": { "prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8 },
"model": "gpt-4o-mini",
})
.to_string();
let captured: std::sync::Arc<std::sync::Mutex<Vec<Value>>> =
std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = captured.clone();
let m = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
if let Ok(b) = req.utf8_lossy_body()
&& let Ok(v) = serde_json::from_str::<Value>(&b)
{
sink.lock().unwrap().push(v);
}
true
})
.with_status(200)
.with_body(body)
.expect(1)
.create_async()
.await;
let result = client(&server).generate_with_metadata("hi").await.unwrap();
assert_eq!(result.text, "hello there");
let usage = result.usage.expect("usage should be parsed");
assert_eq!(usage.input_tokens, 3);
assert_eq!(usage.output_tokens, 5);
assert_eq!(usage.total_tokens(), 8);
m.assert_async().await;
let bodies = captured.lock().unwrap();
assert_eq!(bodies.len(), 1, "expected exactly one request");
assert!(
bodies[0].get("response_format").is_none(),
"response_format must be absent for plain generation, got {}",
bodies[0]
);
}
#[tokio::test]
async fn generate_returns_text_content() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(chat_completion("plain answer"))
.expect(1)
.create_async()
.await;
let text = client(&server).generate("hi").await.unwrap();
assert_eq!(text, "plain answer");
m.assert_async().await;
}
#[tokio::test]
async fn generate_empty_choices_is_unexpected_response() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(json!({ "choices": [] }).to_string())
.expect(1)
.create_async()
.await;
let err = client(&server).generate("hi").await.unwrap_err();
assert!(
matches!(
err.api_error_kind(),
Some(ApiErrorKind::UnexpectedResponse { .. })
),
"expected UnexpectedResponse, got {err:?}"
);
m.assert_async().await;
}
#[tokio::test]
async fn generate_null_content_is_unexpected_response() {
let mut server = mockito::Server::new_async().await;
let body = json!({
"choices": [{
"message": { "role": "assistant", "content": null },
"finish_reason": "stop",
}]
})
.to_string();
let m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(body)
.expect(1)
.create_async()
.await;
let err = client(&server).generate("hi").await.unwrap_err();
assert!(
matches!(
err.api_error_kind(),
Some(ApiErrorKind::UnexpectedResponse { .. })
),
"expected UnexpectedResponse, got {err:?}"
);
m.assert_async().await;
}
#[tokio::test]
async fn gpt5_sends_reasoning_effort_and_forces_temperature_one() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("POST", "/chat/completions")
.match_body(mockito::Matcher::PartialJson(json!({
"reasoning_effort": "medium",
"temperature": 1.0,
})))
.with_status(200)
.with_body(chat_completion("ok"))
.expect(1)
.create_async()
.await;
let text = OpenAIClient::new("test-key")
.unwrap()
.base_url(server.url())
.model("gpt-5")
.generate("hi")
.await
.unwrap();
assert_eq!(text, "ok");
m.assert_async().await;
}
#[tokio::test]
async fn non_gpt5_omits_reasoning_effort_and_passes_temperature_through() {
let mut server = mockito::Server::new_async().await;
let captured: std::sync::Arc<std::sync::Mutex<Vec<Value>>> =
std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = captured.clone();
let m = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
if let Ok(body) = req.utf8_lossy_body()
&& let Ok(v) = serde_json::from_str::<Value>(&body)
{
sink.lock().unwrap().push(v);
}
true
})
.with_status(200)
.with_body(chat_completion("ok"))
.expect(1)
.create_async()
.await;
let text = OpenAIClient::new("test-key")
.unwrap()
.base_url(server.url())
.model("gpt-4o-mini")
.temperature(0.2)
.generate("hi")
.await
.unwrap();
assert_eq!(text, "ok");
m.assert_async().await;
let bodies = captured.lock().unwrap();
assert_eq!(bodies.len(), 1, "expected exactly one request");
let body = &bodies[0];
assert!(
body.get("reasoning_effort").is_none(),
"reasoning_effort must be omitted for non-gpt-5, got {body}"
);
assert_eq!(
body["temperature"],
json!(0.2),
"configured temperature must pass through unchanged"
);
}
#[tokio::test]
async fn list_models_keeps_only_chat_models() {
let mut server = mockito::Server::new_async().await;
let body = json!({
"data": [
{ "id": "gpt-4o" },
{ "id": "o3" },
{ "id": "o4-mini" },
{ "id": "o1-pro" },
{ "id": "whisper-1" },
{ "id": "text-embedding-3-small" },
{ "id": "dall-e-3" },
]
})
.to_string();
let m = server
.mock("GET", "/models")
.with_status(200)
.with_body(body)
.expect(1)
.create_async()
.await;
let models = client(&server).list_models().await.unwrap();
let ids: Vec<&str> = models.iter().map(|m| m.id.as_str()).collect();
assert_eq!(ids, vec!["gpt-4o", "o3", "o4-mini", "o1-pro"]);
m.assert_async().await;
}
#[tokio::test]
async fn list_models_no_data_returns_empty() {
let mut server = mockito::Server::new_async().await;
let m = server
.mock("GET", "/models")
.with_status(200)
.with_body("{}")
.expect(1)
.create_async()
.await;
let models = client(&server).list_models().await.unwrap();
assert!(models.is_empty(), "expected empty list, got {models:?}");
m.assert_async().await;
}
#[tokio::test]
async fn usage_model_name_falls_back_to_client_model() {
let mut server = mockito::Server::new_async().await;
let body = json!({
"choices": [{
"message": { "role": "assistant", "content": "hi" },
"finish_reason": "stop",
}],
"usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 },
})
.to_string();
let m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(body)
.expect(1)
.create_async()
.await;
let result = OpenAIClient::new("test-key")
.unwrap()
.base_url(server.url())
.model("gpt-4o-mini")
.generate_with_metadata("hi")
.await
.unwrap();
let usage = result.usage.expect("usage should be parsed");
assert_eq!(usage.model, "gpt-4o-mini");
m.assert_async().await;
}
#[cfg(feature = "tools")]
fn tool_call_response(call_id: &str, name: &str, args: &str) -> String {
json!({
"choices": [{
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": { "name": name, "arguments": args },
}],
},
"finish_reason": "tool_calls",
}]
})
.to_string()
}
#[cfg(feature = "tools")]
fn recording_add_tool(
flag: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> rstructor::FnTool<
AddArgs,
impl Fn(AddArgs) -> std::future::Ready<rstructor::Result<Value>> + Clone,
> {
rstructor::FnTool::new("add", "Add two integers", move |args: AddArgs| {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
std::future::ready(Ok(json!({ "sum": args.a + args.b })))
})
}
#[cfg(feature = "tools")]
#[derive(Instructor, Serialize, Deserialize)]
struct AddArgs {
#[llm(description = "First addend")]
a: i64,
#[llm(description = "Second addend")]
b: i64,
}
#[cfg(feature = "tools")]
#[tokio::test]
async fn tool_loop_full_round_trip() {
use rstructor::{RequestExt, Toolbox};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let mut server = mockito::Server::new_async().await;
let captured: Arc<std::sync::Mutex<Vec<Value>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
let sink1 = captured.clone();
let first = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink1.lock().unwrap().push(v.clone());
!messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(tool_call_response("c1", "add", r#"{"a":2,"b":3}"#))
.expect(1)
.create_async()
.await;
let sink2 = captured.clone();
let second = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink2.lock().unwrap().push(v.clone());
messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(chat_completion("the sum is 5"))
.expect(1)
.create_async()
.await;
let invoked = Arc::new(AtomicBool::new(false));
let toolbox = Toolbox::new().with(recording_add_tool(invoked.clone()));
let answer = client(&server)
.with_tools(&toolbox)
.run("add 2 and 3")
.await
.unwrap();
assert_eq!(answer, "the sum is 5");
assert!(
invoked.load(Ordering::SeqCst),
"the real tool closure must have run"
);
first.assert_async().await;
second.assert_async().await;
let bodies = captured.lock().unwrap();
let second_body = bodies
.iter()
.find(|v| messages_contain_tool_role(v))
.expect("a request carrying the tool result must exist");
let messages = second_body["messages"].as_array().unwrap();
let tool_msg = messages
.iter()
.find(|m| m["role"] == json!("tool"))
.expect("a role:tool message must be present");
assert_eq!(tool_msg["tool_call_id"], json!("c1"));
let content = tool_msg["content"].as_str().unwrap();
assert!(
content.contains("\"sum\":5"),
"tool result content should carry the sum, got {content}"
);
}
#[cfg(feature = "tools")]
fn messages_contain_tool_role(body: &Value) -> bool {
body.get("messages")
.and_then(Value::as_array)
.map(|msgs| msgs.iter().any(|m| m.get("role") == Some(&json!("tool"))))
.unwrap_or(false)
}
#[cfg(feature = "tools")]
#[tokio::test]
async fn tool_loop_unknown_tool_continues() {
use rstructor::{RequestExt, Toolbox};
use std::sync::Arc;
let mut server = mockito::Server::new_async().await;
let captured: Arc<std::sync::Mutex<Vec<Value>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
let sink1 = captured.clone();
let first = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink1.lock().unwrap().push(v.clone());
!messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(tool_call_response("c1", "does_not_exist", "{}"))
.expect(1)
.create_async()
.await;
let sink2 = captured.clone();
let second = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink2.lock().unwrap().push(v.clone());
messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(chat_completion("recovered"))
.expect(1)
.create_async()
.await;
let invoked = Arc::new(std::sync::atomic::AtomicBool::new(false));
let toolbox = Toolbox::new().with(recording_add_tool(invoked.clone()));
let answer = client(&server)
.with_tools(&toolbox)
.run("call a missing tool")
.await
.unwrap();
assert_eq!(answer, "recovered");
assert!(
!invoked.load(std::sync::atomic::Ordering::SeqCst),
"the real add tool must NOT have run for an unknown tool"
);
first.assert_async().await;
second.assert_async().await;
let bodies = captured.lock().unwrap();
let second_body = bodies
.iter()
.find(|v| messages_contain_tool_role(v))
.expect("a request carrying the error result must exist");
let messages = second_body["messages"].as_array().unwrap();
let tool_msg = messages
.iter()
.find(|m| m["role"] == json!("tool"))
.expect("a role:tool message must be present");
let content = tool_msg["content"].as_str().unwrap();
assert!(
content.contains("unknown tool: does_not_exist"),
"error content should name the unknown tool, got {content}"
);
}
#[cfg(feature = "tools")]
#[tokio::test]
async fn tool_loop_tool_error_is_swallowed() {
use rstructor::{FnTool, RequestExt, Toolbox};
use std::sync::Arc;
let mut server = mockito::Server::new_async().await;
let captured: Arc<std::sync::Mutex<Vec<Value>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
let sink1 = captured.clone();
let first = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink1.lock().unwrap().push(v.clone());
!messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(tool_call_response("c1", "boom", r#"{"a":1,"b":1}"#))
.expect(1)
.create_async()
.await;
let sink2 = captured.clone();
let second = server
.mock("POST", "/chat/completions")
.match_request(move |req| {
let v: Value = serde_json::from_str(&req.utf8_lossy_body().unwrap()).unwrap();
sink2.lock().unwrap().push(v.clone());
messages_contain_tool_role(&v)
})
.with_status(200)
.with_body(chat_completion("handled"))
.expect(1)
.create_async()
.await;
let boom = FnTool::new("boom", "always fails", |_args: AddArgs| {
std::future::ready(Err(RStructorError::ValidationError(
"tool blew up".to_string(),
)))
});
let toolbox = Toolbox::new().with(boom);
let answer = client(&server)
.with_tools(&toolbox)
.run("trigger the failing tool")
.await
.unwrap();
assert_eq!(answer, "handled");
first.assert_async().await;
second.assert_async().await;
let bodies = captured.lock().unwrap();
let second_body = bodies
.iter()
.find(|v| messages_contain_tool_role(v))
.expect("a request carrying the error result must exist");
let messages = second_body["messages"].as_array().unwrap();
let tool_msg = messages
.iter()
.find(|m| m["role"] == json!("tool"))
.expect("a role:tool message must be present");
let content = tool_msg["content"].as_str().unwrap();
assert!(
content.contains("error"),
"swallowed tool error should appear in the content, got {content}"
);
assert!(
content.contains("tool blew up"),
"the tool's error message should be preserved, got {content}"
);
}
#[cfg(feature = "tools")]
#[tokio::test]
async fn tool_loop_exhaustion_errors() {
use rstructor::{RequestExt, Toolbox};
use std::sync::Arc;
let mut server = mockito::Server::new_async().await;
let always_tool = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_body(tool_call_response("c1", "add", r#"{"a":1,"b":1}"#))
.expect(2)
.create_async()
.await;
let invoked = Arc::new(std::sync::atomic::AtomicBool::new(false));
let toolbox = Toolbox::new().with(recording_add_tool(invoked.clone()));
let err = client(&server)
.with_tools(&toolbox)
.max_iterations(2)
.run("loop forever")
.await
.unwrap_err();
let msg = err.to_string();
assert!(
matches!(err, RStructorError::ValidationError(_)),
"expected ValidationError, got {err:?}"
);
assert!(
msg.contains("did not converge"),
"error should say it did not converge, got: {msg}"
);
assert!(
msg.contains('2'),
"error should mention the iteration budget (2), got: {msg}"
);
always_tool.assert_async().await;
}