#[cfg(test)]
mod tests {
use crate::client::RetryConfig;
use crate::utils::retry::execute_with_retry_builder;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_retry_after_delta_seconds_delays_next_attempt() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(429).insert_header("retry-after", "1"))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50, max_backoff_ms: 5000,
retry_on_status_codes: vec![429],
total_timeout: Duration::from_secs(10),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let start = Instant::now();
let result = execute_with_retry_builder(&config, "retry_after_test", || {
client.get(mock_server.uri())
})
.await;
let elapsed = start.elapsed();
assert!(result.is_ok(), "Should succeed after retry: {:?}", result);
assert!(
elapsed >= Duration::from_millis(600),
"Retry-After header should cause a delay of ~1s, got: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_retry_after_zero_seconds_does_not_hang() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(429).insert_header("retry-after", "0"))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50,
max_backoff_ms: 500,
retry_on_status_codes: vec![429],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let start = Instant::now();
let result = execute_with_retry_builder(&config, "retry_after_zero", || {
client.get(mock_server.uri())
})
.await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(
elapsed < Duration::from_secs(2),
"Retry-After: 0 should not delay significantly, got: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_respects_retry_after_http_date() {
let mock_server = MockServer::start().await;
let future_date = std::time::SystemTime::now() + Duration::from_secs(2);
let http_date_str = httpdate::fmt_http_date(future_date);
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(429).insert_header("retry-after", &*http_date_str))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50,
max_backoff_ms: 10000,
retry_on_status_codes: vec![429],
total_timeout: Duration::from_secs(15),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let start = Instant::now();
let result = execute_with_retry_builder(&config, "retry_after_http_date", || {
client.get(mock_server.uri())
})
.await;
let elapsed = start.elapsed();
assert!(result.is_ok(), "Should succeed after retry: {:?}", result);
assert!(
elapsed >= Duration::from_millis(400),
"Retry-After HTTP-date should cause notable delay vs 50ms base backoff, got: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_retry_after_capped_at_max_backoff() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(429).insert_header("retry-after", "99999"))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50,
max_backoff_ms: 500, retry_on_status_codes: vec![429],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let start = Instant::now();
let result = execute_with_retry_builder(&config, "retry_after_capped", || {
client.get(mock_server.uri())
})
.await;
let elapsed = start.elapsed();
assert!(result.is_ok(), "Should succeed after retry: {:?}", result);
assert!(
elapsed < Duration::from_secs(2),
"Retry-After 99999 should be capped to max_backoff_ms (500ms), got: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_non_retryable_400_not_retried() {
let mock_server = MockServer::start().await;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
Mock::given(matchers::method("GET"))
.respond_with(
ResponseTemplate::new(400)
.set_body_string(r#"{"error": {"message": "Bad request", "code": 400}}"#),
)
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 50,
max_backoff_ms: 500,
retry_on_status_codes: vec![429, 500, 502, 503, 504],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let url = mock_server.uri();
let result = execute_with_retry_builder(&config, "no_retry_400", || {
call_count_clone.fetch_add(1, Ordering::SeqCst);
client.get(&url)
})
.await;
assert!(
result.is_ok(),
"execute_with_retry_builder returns Ok for non-retryable status"
);
assert_eq!(
result.unwrap().status().as_u16(),
400,
"Response must be returned as-is without retry"
);
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"400 must not trigger any retries — expected exactly 1 HTTP call"
);
}
#[tokio::test]
async fn test_non_retryable_404_not_retried() {
let mock_server = MockServer::start().await;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(404))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 50,
max_backoff_ms: 500,
retry_on_status_codes: vec![429, 500, 502, 503, 504],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let url = mock_server.uri();
let _ = execute_with_retry_builder(&config, "no_retry_404", || {
call_count_clone.fetch_add(1, Ordering::SeqCst);
client.get(&url)
})
.await;
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"404 must not trigger retries"
);
}
#[tokio::test]
async fn test_retry_budget_exhausted_returns_last_response() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(
ResponseTemplate::new(503).set_body_string(
r#"{"error": {"message": "Service unavailable", "code": 503}}"#,
),
)
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50,
max_backoff_ms: 100,
retry_on_status_codes: vec![503],
total_timeout: Duration::from_secs(10),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let result = execute_with_retry_builder(&config, "budget_exhausted", || {
client.get(mock_server.uri())
})
.await;
match result {
Ok(r) => assert_eq!(r.status().as_u16(), 503),
Err(crate::error::Error::TimeoutError(_)) => {
}
Err(other) => panic!("Unexpected error after retry budget exhausted: {:?}", other),
}
}
#[tokio::test]
async fn test_retry_exhausts_exactly_max_retries_attempts() {
let mock_server = MockServer::start().await;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(500))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 50,
max_backoff_ms: 200,
retry_on_status_codes: vec![500],
total_timeout: Duration::from_secs(10),
max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let url = mock_server.uri();
let _ = execute_with_retry_builder(&config, "count_retries", || {
call_count_clone.fetch_add(1, Ordering::SeqCst);
client.get(&url)
})
.await;
assert_eq!(
call_count.load(Ordering::SeqCst),
4,
"Should make exactly 1 initial + 3 retry attempts"
);
}
#[test]
fn test_sse_done_signal_recognized() {
let line = "data: [DONE]";
let data_part = line.trim_start_matches("data:").trim();
assert_eq!(data_part, "[DONE]");
}
#[test]
fn test_sse_data_prefix_stripped_correctly() {
let line = "data: {\"id\":\"test\",\"choices\":[]}";
let data_part = line.trim_start_matches("data:").trim();
assert!(
data_part.starts_with('{'),
"data_part must be raw JSON after strip"
);
}
#[test]
fn test_sse_comment_line_identified() {
let line = ": keep-alive";
assert!(line.starts_with(':'), "SSE comment lines start with ':'");
}
#[test]
fn test_sse_empty_line_skipped() {
let line = " ";
assert!(
line.trim().is_empty(),
"Whitespace-only lines must be treated as empty"
);
}
#[test]
fn test_streaming_chunk_deserializes_delta_content() {
use crate::types::chat::{ChatCompletionChunk, MessageContent};
let sse_data = r#"{
"id": "chatcmpl-abc",
"object": "chat.completion.chunk",
"created": 1700000000,
"model": "openai/gpt-4",
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": "Hello"},
"finish_reason": null
}]
}"#;
let chunk: ChatCompletionChunk = serde_json::from_str(sse_data).unwrap();
assert_eq!(chunk.id, "chatcmpl-abc");
assert_eq!(chunk.choices.len(), 1);
match chunk.choices[0].delta.content.as_ref() {
Some(MessageContent::Text(s)) => assert_eq!(s, "Hello"),
other => panic!("Expected Text content, got: {:?}", other),
}
assert!(chunk.choices[0].finish_reason.is_none());
}
#[test]
fn test_streaming_chunk_deserializes_finish_reason() {
use crate::types::chat::ChatCompletionChunk;
let sse_data = r#"{
"id": "chatcmpl-abc",
"object": "chat.completion.chunk",
"created": 1700000000,
"model": "openai/gpt-4",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}"#;
let chunk: ChatCompletionChunk = serde_json::from_str(sse_data).unwrap();
assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop"));
assert!(chunk.choices[0].delta.content.is_none());
}
#[test]
fn test_malformed_sse_data_fails_gracefully() {
use crate::types::chat::ChatCompletionChunk;
let malformed = r#"{"id": "chatcmpl-abc", "choices": [{ /* truncated */"#;
let result = serde_json::from_str::<ChatCompletionChunk>(malformed);
assert!(
result.is_err(),
"Malformed JSON must fail deserialization, not silently succeed"
);
}
#[test]
fn test_max_line_length_constant_value() {
let max_line = 64 * 1024usize;
assert_eq!(max_line, 65536);
assert!(1024 * 1024 > max_line);
}
#[test]
fn test_max_total_chunks_constant_value() {
let max_chunks = 10_000usize;
assert!(
max_chunks > 1000,
"Max chunks must be large enough for realistic completions"
);
assert!(
max_chunks < 1_000_000,
"Max chunks must be bounded to prevent memory exhaustion"
);
}
#[tokio::test]
async fn test_streaming_via_wiremock_two_chunks_and_done() {
use crate::api::chat::ChatApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
use crate::types::chat::{ChatCompletionRequest, Message};
use futures::StreamExt;
let mock_server = MockServer::start().await;
let sse_body = concat!(
"data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1700000000,",
"\"model\":\"openai/gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},",
"\"finish_reason\":null}]}\n\n",
"data: {\"id\":\"c2\",\"object\":\"chat.completion.chunk\",\"created\":1700000001,",
"\"model\":\"openai/gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},",
"\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n"
);
Mock::given(matchers::method("POST"))
.and(matchers::path("/api/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_string(sse_body),
)
.expect(1)
.mount(&mock_server)
.await;
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse(&format!("{}/api/v1/", mock_server.uri())).unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = ChatApi::new(client, &config).unwrap();
let request = ChatCompletionRequest {
model: "openai/gpt-4".to_string(),
messages: vec![Message::text(crate::types::chat::ChatRole::User, "hi")],
stream: None,
response_format: None,
tools: None,
tool_choice: None,
provider: None,
models: None,
transforms: None,
route: None,
user: None,
max_tokens: None,
temperature: None,
top_p: None,
top_k: None,
frequency_penalty: None,
presence_penalty: None,
repetition_penalty: None,
min_p: None,
top_a: None,
seed: None,
stop: None,
logit_bias: None,
logprobs: None,
top_logprobs: None,
prediction: None,
parallel_tool_calls: None,
verbosity: None,
debug: None,
plugins: None,
reasoning: None,
};
let mut stream = api.chat_completion_stream(request);
let mut chunks = Vec::new();
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => chunks.push(chunk),
Err(e) => panic!("Stream error: {:?}", e),
}
}
use crate::types::chat::MessageContent;
assert_eq!(chunks.len(), 2, "Should have received exactly 2 chunks");
match chunks[0].choices[0].delta.content.as_ref() {
Some(MessageContent::Text(s)) => assert_eq!(s, "Hello"),
other => panic!("Expected Text content in chunk 0, got: {:?}", other),
}
match chunks[1].choices[0].delta.content.as_ref() {
Some(MessageContent::Text(s)) => assert_eq!(s, " world"),
other => panic!("Expected Text content in chunk 1, got: {:?}", other),
}
assert_eq!(chunks[1].choices[0].finish_reason.as_deref(), Some("stop"));
}
#[tokio::test]
async fn test_streaming_validation_error_before_network_call() {
use crate::api::chat::ChatApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
use crate::types::chat::{ChatCompletionRequest, Message};
use futures::StreamExt;
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse("https://openrouter.ai/api/v1/").unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = ChatApi::new(client, &config).unwrap();
let request = ChatCompletionRequest {
model: "".to_string(), messages: vec![Message::text(crate::types::chat::ChatRole::User, "hi")],
stream: None,
response_format: None,
tools: None,
tool_choice: None,
provider: None,
models: None,
transforms: None,
route: None,
user: None,
max_tokens: None,
temperature: None,
top_p: None,
top_k: None,
frequency_penalty: None,
presence_penalty: None,
repetition_penalty: None,
min_p: None,
top_a: None,
seed: None,
stop: None,
logit_bias: None,
logprobs: None,
top_logprobs: None,
prediction: None,
parallel_tool_calls: None,
verbosity: None,
debug: None,
plugins: None,
reasoning: None,
};
let mut stream = api.chat_completion_stream(request);
let first = stream.next().await;
assert!(
first.is_some(),
"Stream must yield at least one item (the validation error)"
);
assert!(
first.unwrap().is_err(),
"First item from stream with invalid model must be an error"
);
assert!(
stream.next().await.is_none(),
"Stream must be exhausted after validation error"
);
}
#[tokio::test]
async fn test_embeddings_wiremock_happy_path() {
use crate::api::embeddings::EmbeddingsApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
use crate::types::embeddings::{EmbeddingInput, EmbeddingRequest};
let mock_server = MockServer::start().await;
let body = serde_json::json!({
"object": "list",
"data": [
{"embedding": [0.1, 0.2, 0.3], "index": 0, "object": "embedding"}
],
"model": "openai/text-embedding-3-small",
"usage": {"prompt_tokens": 3, "total_tokens": 3}
});
Mock::given(matchers::method("POST"))
.and(matchers::path("/api/v1/embeddings"))
.and(matchers::header_exists("authorization"))
.respond_with(ResponseTemplate::new(200).set_body_json(&body))
.expect(1)
.mount(&mock_server)
.await;
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse(&format!("{}/api/v1/", mock_server.uri())).unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = EmbeddingsApi::new(client, &config).unwrap();
let request = EmbeddingRequest {
model: "openai/text-embedding-3-small".to_string(),
input: EmbeddingInput::Single("hello world".to_string()),
encoding_format: None,
provider: None,
};
let response = api.create(request).await.unwrap();
assert_eq!(response.data.len(), 1);
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(response.data[0].index, 0);
}
#[tokio::test]
async fn test_embeddings_wiremock_batch_reversed_indices() {
use crate::api::embeddings::EmbeddingsApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
let mock_server = MockServer::start().await;
let body = serde_json::json!({
"object": "list",
"data": [
{"embedding": [0.3, 0.4], "index": 1, "object": "embedding"},
{"embedding": [0.1, 0.2], "index": 0, "object": "embedding"}
],
"model": "openai/text-embedding-3-small",
"usage": {"prompt_tokens": 6, "total_tokens": 6}
});
Mock::given(matchers::method("POST"))
.and(matchers::path("/api/v1/embeddings"))
.respond_with(ResponseTemplate::new(200).set_body_json(&body))
.expect(1)
.mount(&mock_server)
.await;
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse(&format!("{}/api/v1/", mock_server.uri())).unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = EmbeddingsApi::new(client, &config).unwrap();
let result = api
.embed_batch(
"openai/text-embedding-3-small",
vec!["first".to_string(), "second".to_string()],
)
.await
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![0.1, 0.2], "Index 0 embedding must be first");
assert_eq!(
result[1],
vec![0.3, 0.4],
"Index 1 embedding must be second"
);
}
#[tokio::test]
async fn test_embeddings_wiremock_validation_rejects_empty_batch() {
use crate::api::embeddings::EmbeddingsApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
use crate::types::embeddings::{EmbeddingInput, EmbeddingRequest};
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse("https://openrouter.ai/api/v1/").unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = EmbeddingsApi::new(client, &config).unwrap();
let request = EmbeddingRequest {
model: "openai/text-embedding-3-small".to_string(),
input: EmbeddingInput::Batch(vec![]),
encoding_format: None,
provider: None,
};
let result = api.create(request).await;
assert!(result.is_err());
match result.unwrap_err() {
crate::error::Error::ValidationError(msg) => {
assert!(msg.contains("empty"), "Error must mention empty: {msg}");
}
other => panic!("Expected ValidationError, got: {:?}", other),
}
}
#[tokio::test]
async fn test_embeddings_wiremock_validation_rejects_whitespace_item_in_batch() {
use crate::api::embeddings::EmbeddingsApi;
use crate::client::{ClientConfig, RetryConfig, SecureApiKey};
use crate::types::embeddings::{EmbeddingInput, EmbeddingRequest};
let config = ClientConfig {
api_key: Some(SecureApiKey::new("sk-test123456789012345678901234567890").unwrap()),
base_url: url::Url::parse("https://openrouter.ai/api/v1/").unwrap(),
timeout: std::time::Duration::from_secs(10),
http_referer: None,
site_title: None,
user_id: None,
retry_config: RetryConfig::default(),
max_response_bytes: 10 * 1024 * 1024,
};
let client = reqwest::Client::new();
let api = EmbeddingsApi::new(client, &config).unwrap();
let request = EmbeddingRequest {
model: "openai/text-embedding-3-small".to_string(),
input: EmbeddingInput::Batch(vec!["valid".to_string(), " ".to_string()]),
encoding_format: None,
provider: None,
};
let result = api.create(request).await;
assert!(result.is_err());
match result.unwrap_err() {
crate::error::Error::ValidationError(_) => {}
other => panic!(
"Expected ValidationError for whitespace-only batch item, got: {:?}",
other
),
}
}
}