use open_ai_rust::{ChatMessage, Client, OpenAiError, OpenAiModel, PayLoadBuilder};
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn client(server: &MockServer) -> Client {
Client::builder()
.api_key("test")
.base_url(server.uri())
.build_unchecked()
}
fn client_no_retry(server: &MockServer) -> Client {
Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(0)
.build_unchecked()
}
fn chat_payload() -> open_ai_rust::ChatPayLoad {
PayLoadBuilder::new(OpenAiModel::GPT4oMini)
.messages(vec![ChatMessage::user("hi")])
.build()
}
async fn expect_api_status(status: u16, body: serde_json::Value) -> OpenAiError {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(status).set_body_json(body))
.mount(&server)
.await;
let c = client_no_retry(&server);
c.chat().create(chat_payload()).await.unwrap_err()
}
#[tokio::test]
async fn error_401_unauthorized_surfaces_as_api_error() {
let err = expect_api_status(
401,
json!({ "error": { "message": "Invalid API key", "type": "invalid_request_error", "code": "invalid_api_key" } }),
)
.await;
match err {
OpenAiError::Api {
status,
code,
type_,
message,
..
} => {
assert_eq!(status, 401);
assert_eq!(code.as_deref(), Some("invalid_api_key"));
assert_eq!(type_.as_deref(), Some("invalid_request_error"));
assert!(message.contains("Invalid API key"));
}
e => panic!("expected Api error, got {e:?}"),
}
}
#[tokio::test]
async fn error_403_permission_denied() {
let err = expect_api_status(
403,
json!({ "error": { "message": "Country not supported", "type": "permission_error" } }),
)
.await;
match err {
OpenAiError::Api { status, .. } => assert_eq!(status, 403),
_ => panic!("expected Api error"),
}
}
#[tokio::test]
async fn error_404_not_found_does_not_retry() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(404).set_body_json(json!({
"error": { "message": "model not found", "code": "model_not_found" }
})))
.expect(1) .mount(&server)
.await;
let c = Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(3)
.build_unchecked();
let err = c.chat().create(chat_payload()).await.unwrap_err();
if let OpenAiError::Api { status, code, .. } = err {
assert_eq!(status, 404);
assert_eq!(code.as_deref(), Some("model_not_found"));
} else {
panic!("expected Api error");
}
}
#[tokio::test]
async fn error_429_rate_limit_retries_until_success() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(429).set_body_json(json!({
"error": { "message": "rate limited", "type": "rate_limit_exceeded" }
})))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "x", "object": "chat.completion", "created": 1, "model": "m",
"choices": [{ "finish_reason": "stop", "index": 0, "message": { "role": "assistant", "content": "ok" } }],
"usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 },
})))
.mount(&server)
.await;
let c = Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(5)
.build_unchecked();
let resp = c.chat().create(chat_payload()).await.unwrap();
assert_eq!(resp.get_last_msg_text().as_deref(), Some("ok"));
}
#[tokio::test]
async fn error_429_exhausts_retries_then_surfaces() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(429).set_body_json(json!({
"error": { "message": "still rate limited" }
})))
.expect(1)
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
match err {
OpenAiError::Api { status: 429, .. } => {}
e => panic!("expected 429 Api error, got {e:?}"),
}
}
#[tokio::test]
async fn error_500_internal_retries() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_string("boom"))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "x", "object": "chat.completion", "created": 1, "model": "m",
"choices": [{ "finish_reason": "stop", "index": 0, "message": { "role": "assistant", "content": "ok" } }],
"usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 },
})))
.mount(&server)
.await;
let c = Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(5)
.build_unchecked();
let resp = c.chat().create(chat_payload()).await.unwrap();
assert_eq!(resp.get_last_msg_text().as_deref(), Some("ok"));
}
#[tokio::test]
async fn error_502_503_504_treated_as_retryable() {
for status in [502u16, 503, 504] {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(status).set_body_string("upstream"))
.up_to_n_times(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "x", "object": "chat.completion", "created": 1, "model": "m",
"choices": [{ "finish_reason": "stop", "index": 0, "message": { "role": "assistant", "content": "ok" } }],
"usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 },
})))
.mount(&server)
.await;
let c = Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(2)
.build_unchecked();
let resp = c.chat().create(chat_payload()).await.unwrap();
assert_eq!(
resp.get_last_msg_text().as_deref(),
Some("ok"),
"status {status}"
);
}
}
#[tokio::test]
async fn error_400_4xx_does_not_retry() {
for status in [400u16, 401, 403, 404, 422] {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(status).set_body_json(json!({
"error": { "message": "client error" }
})))
.expect(1) .mount(&server)
.await;
let c = Client::builder()
.api_key("test")
.base_url(server.uri())
.max_retries(5)
.build_unchecked();
let err = c.chat().create(chat_payload()).await.unwrap_err();
if let OpenAiError::Api { status: s, .. } = err {
assert_eq!(s, status);
} else {
panic!("expected Api error for status {status}");
}
}
}
#[tokio::test]
async fn malformed_json_in_success_body_yields_decode_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_string("{ this is not json"),
)
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
match err {
OpenAiError::Decode(_) => {}
e => panic!("expected Decode error, got {e:?}"),
}
}
#[tokio::test]
async fn error_body_without_envelope_is_surfaced_with_raw_text() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(500)
.insert_header("content-type", "text/html")
.set_body_string("<html>Internal Server Error</html>"),
)
.expect(1)
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
if let OpenAiError::Api {
status,
message,
code,
..
} = err
{
assert_eq!(status, 500);
assert!(message.contains("Internal Server Error"));
assert!(code.is_none(), "no code parsed from non-envelope body");
} else {
panic!("expected Api error");
}
}
#[tokio::test]
async fn empty_success_body_yields_decode_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
match err {
OpenAiError::Decode(_) => {}
e => panic!("expected Decode, got {e:?}"),
}
}
#[tokio::test]
async fn truncated_response_missing_required_field_yields_decode_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "x",
"object": "chat.completion",
})))
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
match err {
OpenAiError::Decode(_) => {}
e => panic!("expected Decode for truncated, got {e:?}"),
}
}
#[tokio::test]
async fn error_envelope_with_only_message_field_still_parses() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_json(json!({
"error": { "message": "only message, no code or type" }
})))
.expect(1)
.mount(&server)
.await;
let c = client_no_retry(&server);
let err = c.chat().create(chat_payload()).await.unwrap_err();
if let OpenAiError::Api {
status,
message,
code,
type_,
..
} = err
{
assert_eq!(status, 500);
assert_eq!(message, "only message, no code or type");
assert!(code.is_none());
assert!(type_.is_none());
} else {
panic!("expected Api error");
}
}
#[tokio::test]
async fn streaming_malformed_json_in_event_yields_stream_error() {
use futures_util::StreamExt;
let body = "data: { not valid json\n\ndata: [DONE]\n\n".to_string();
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(body.into_bytes(), "text/event-stream"),
)
.mount(&server)
.await;
let c = client(&server);
let mut s = c.chat().create_stream(chat_payload()).await.unwrap();
let first = s.next().await.unwrap();
match first {
Err(OpenAiError::Stream(msg)) => {
assert!(msg.contains("failed to decode"), "got: {msg}");
}
other => panic!("expected Stream error, got {other:?}"),
}
}
#[tokio::test]
async fn streaming_premature_eof_terminates_without_done() {
use futures_util::StreamExt;
let body = json!({
"id": "x", "object": "chat.completion.chunk",
"created": 1, "model": "m",
"choices": [{ "index": 0, "delta": { "content": "hi" }, "finish_reason": null }]
});
let raw = format!("data: {}\n\n", body);
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(raw.into_bytes(), "text/event-stream"),
)
.mount(&server)
.await;
let c = client(&server);
let mut s = c.chat().create_stream(chat_payload()).await.unwrap();
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.delta_text(), "hi");
assert!(
s.next().await.is_none(),
"expected stream to end without DONE"
);
}
#[test]
fn from_env_without_api_key_returns_config_error() {
let prev = std::env::var("OPENAI_API_KEY").ok();
std::env::remove_var("OPENAI_API_KEY");
let err = Client::from_env().unwrap_err();
match err {
OpenAiError::Config(msg) => assert!(msg.contains("OPENAI_API_KEY"), "got: {msg}"),
e => panic!("expected Config, got {e:?}"),
}
if let Some(v) = prev {
std::env::set_var("OPENAI_API_KEY", v);
}
}
#[test]
fn builder_build_without_api_key_returns_config_error() {
let err = Client::builder().build().unwrap_err();
match err {
OpenAiError::Config(_) => {}
e => panic!("expected Config, got {e:?}"),
}
}
#[test]
fn azure_requires_deployment_name_for_request_url() {
let c = Client::azure(
"k",
"https://r.openai.azure.com",
"dep",
"2024-10-01-preview",
);
let _ = c.base_url();
}
#[test]
fn error_converts_into_string() {
let e = OpenAiError::config("bad");
let s: String = e.into();
assert!(s.contains("bad"), "got: {s}");
}