use super::*;
use insta::assert_snapshot;
use proptest::prelude::*;
#[test]
fn test_async_openai_client_new() {
let client =
AsyncOpenAiClient::new("sk-proj-test".to_string(), "gpt-4o-mini".to_string(), None);
assert!(format!("{:?}", client).contains("AsyncOpenAiClient"));
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_api_key_storage(
api_key in "[a-zA-Z0-9-_]{10,100}",
model in "[a-zA-Z0-9-]{5,20}",
) {
let client = AsyncOpenAiClient::new(
api_key.clone(),
model,
None,
);
let debug_output = format!("{:?}", client);
prop_assert!(
debug_output.contains(&api_key),
"Client should store the exact API key provided: {}",
api_key
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_model_selection_storage(
api_key in "[a-zA-Z0-9-_]{10,50}",
model in "[a-zA-Z0-9-]{5,50}",
) {
let client = AsyncOpenAiClient::new(
api_key,
model.clone(),
None,
);
let debug_output = format!("{:?}", client);
prop_assert!(
debug_output.contains(&model),
"Client should store the exact model name provided: {}",
model
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_request_format_correctness(
api_key in "[a-zA-Z0-9-_]{10,50}",
model in "[a-zA-Z0-9-]{5,50}",
prompt in ".*",
) {
let client = AsyncOpenAiClient::new(
api_key,
model.clone(),
None,
);
let result = client.build_request_body(&prompt);
prop_assert!(result.is_ok(), "Request body should serialize successfully");
let body = result.unwrap();
let json: serde_json::Value = serde_json::from_str(&body)
.expect("Request body should be valid JSON");
prop_assert_eq!(
json.get("model").and_then(|v| v.as_str()),
Some(model.as_str()),
"Request should include the correct model"
);
prop_assert_eq!(
json.get("stream").and_then(|v| v.as_bool()),
Some(true),
"Request should have stream set to true"
);
prop_assert!(
json.get("max_tokens").is_none(),
"Request should not include max_tokens (using OpenAI default)"
);
let messages = json.get("messages").and_then(|v| v.as_array());
prop_assert!(messages.is_some(), "Request should have a messages array");
let messages = messages.unwrap();
prop_assert_eq!(messages.len(), 1, "Messages array should have exactly one message");
let message = &messages[0];
prop_assert_eq!(
message.get("role").and_then(|v| v.as_str()),
Some("user"),
"Message should have role 'user'"
);
prop_assert_eq!(
message.get("content").and_then(|v| v.as_str()),
Some(prompt.as_str()),
"Message content should match the prompt"
);
}
}
#[test]
fn snapshot_request_body_format() {
let client = AsyncOpenAiClient::new(
"sk-proj-test123".to_string(),
"gpt-4o-mini".to_string(),
None,
);
let body = client
.build_request_body("suggest jq filters for: extract user names")
.expect("Request body should serialize successfully");
let json: serde_json::Value =
serde_json::from_str(&body).expect("Request body should be valid JSON");
let pretty_json =
serde_json::to_string_pretty(&json).expect("Should be able to pretty-print JSON");
assert_snapshot!(pretty_json);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_authorization_header_format(
api_key in "[a-zA-Z0-9-_]{10,100}",
model in "[a-zA-Z0-9-]{5,20}",
) {
let client = AsyncOpenAiClient::new(
api_key.clone(),
model,
None,
);
let debug_output = format!("{:?}", client);
prop_assert!(
debug_output.contains(&api_key),
"Client should store the API key for use in Authorization header: {}",
api_key
);
let expected_header = format!("Bearer {}", api_key);
prop_assert!(
expected_header.starts_with("Bearer "),
"Authorization header should start with 'Bearer '"
);
prop_assert!(
expected_header.contains(&api_key),
"Authorization header should contain the API key"
);
}
}
#[test]
fn test_streaming_chunk_delivery_structure() {
use crate::ai::ai_state::AiResponse;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel();
let request_id = 42u64;
let test_text = "test chunk";
let result = tx.send(AiResponse::Chunk {
text: test_text.to_string(),
request_id,
});
assert!(result.is_ok(), "Should be able to send chunk");
let received = rx.recv().expect("Should receive chunk");
match received {
AiResponse::Chunk {
text,
request_id: rid,
} => {
assert_eq!(text, test_text, "Chunk text should match");
assert_eq!(rid, request_id, "Request ID should match");
}
_ => panic!("Expected Chunk variant"),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_http_error_propagation(
code in 400u16..600u16,
message in ".*",
) {
let error = AiError::Api {
provider: "OpenAI".to_string(),
code,
message: message.clone(),
};
let error_string = format!("{}", error);
prop_assert!(
error_string.contains("OpenAI"),
"Error should contain provider name 'OpenAI'"
);
prop_assert!(
error_string.contains(&code.to_string()),
"Error should contain status code: {}",
code
);
prop_assert!(
error_string.contains(&message),
"Error should contain error message: {}",
message
);
}
}
#[tokio::test]
async fn test_cancellation_before_response() {
use std::sync::mpsc;
use tokio_util::sync::CancellationToken;
let client = AsyncOpenAiClient::new("sk-test-key".to_string(), "gpt-4o-mini".to_string(), None);
let (tx, _rx) = mpsc::channel();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let result = client
.stream_with_cancel("test prompt", 1, cancel_token, tx)
.await;
assert!(
matches!(result, Err(AiError::Cancelled)),
"Should return Cancelled error when token is already cancelled"
);
}
#[tokio::test]
async fn test_cancellation_check_structure() {
use tokio_util::sync::CancellationToken;
let cancel_token = CancellationToken::new();
assert!(!cancel_token.is_cancelled());
cancel_token.cancel();
assert!(cancel_token.is_cancelled());
tokio::select! {
_ = cancel_token.cancelled() => {
}
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
panic!("Cancellation should complete immediately");
}
}
}
#[test]
fn test_stream_completion_structure() {
let result: Result<(), AiError> = Ok(());
assert!(result.is_ok(), "Stream completion should return Ok(())");
}
#[test]
fn test_empty_response_handling() {
use crate::ai::ai_state::AiResponse;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel::<AiResponse>();
drop(tx);
assert!(
rx.recv().is_err(),
"Should receive no chunks from empty stream"
);
}
#[test]
fn test_channel_disconnection() {
use crate::ai::ai_state::AiResponse;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel();
drop(rx);
let result = tx.send(AiResponse::Chunk {
text: "test".to_string(),
request_id: 1,
});
assert!(result.is_err(), "Send should fail when receiver is dropped");
if result.is_err() {
let graceful_stop: Result<(), AiError> = Ok(());
assert!(
graceful_stop.is_ok(),
"Should stop gracefully on disconnection"
);
}
}
#[test]
fn test_default_openai_url() {
let client = AsyncOpenAiClient::new("sk-test".to_string(), "gpt-4o-mini".to_string(), None);
let debug_output = format!("{:?}", client);
assert!(
debug_output.contains("https://api.openai.com/v1/chat/completions"),
"Should use default OpenAI URL when base_url is None"
);
}
#[test]
fn test_custom_base_url_without_trailing_slash() {
let client = AsyncOpenAiClient::new(
"test-key".to_string(),
"model".to_string(),
Some("http://localhost:11434/v1".to_string()),
);
let debug_output = format!("{:?}", client);
assert!(
debug_output.contains("http://localhost:11434/v1/chat/completions"),
"Should append /chat/completions to base_url"
);
}
#[test]
fn test_custom_base_url_with_trailing_slash() {
let client = AsyncOpenAiClient::new(
"test-key".to_string(),
"model".to_string(),
Some("http://localhost:11434/v1/".to_string()),
);
let debug_output = format!("{:?}", client);
assert!(
debug_output.contains("http://localhost:11434/v1/chat/completions"),
"Should handle trailing slash correctly"
);
}
#[test]
fn test_custom_base_url_with_endpoint() {
let client = AsyncOpenAiClient::new(
"test-key".to_string(),
"model".to_string(),
Some("http://localhost:11434/v1/chat/completions".to_string()),
);
let debug_output = format!("{:?}", client);
assert!(
debug_output.contains("http://localhost:11434/v1/chat/completions"),
"Should not duplicate /chat/completions when already present"
);
}
#[test]
fn test_is_custom_endpoint_default() {
let client = AsyncOpenAiClient::new("sk-test".to_string(), "gpt-4o-mini".to_string(), None);
assert!(
!client.is_custom_endpoint(),
"Default OpenAI URL should not be considered custom"
);
}
#[test]
fn test_is_custom_endpoint_openai_url() {
let client = AsyncOpenAiClient::new(
"sk-test".to_string(),
"gpt-4o-mini".to_string(),
Some("https://api.openai.com/v1".to_string()),
);
assert!(
!client.is_custom_endpoint(),
"Explicit OpenAI URL should not be considered custom"
);
}
#[test]
fn test_is_custom_endpoint_ollama() {
let client = AsyncOpenAiClient::new(
"".to_string(),
"llama3".to_string(),
Some("http://localhost:11434/v1".to_string()),
);
assert!(
client.is_custom_endpoint(),
"Ollama URL should be considered custom"
);
}
#[test]
fn test_is_custom_endpoint_groq() {
let client = AsyncOpenAiClient::new(
"test-key".to_string(),
"llama-3.3-70b".to_string(),
Some("https://api.groq.com/openai/v1".to_string()),
);
assert!(
client.is_custom_endpoint(),
"Groq URL should be considered custom"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_url_construction_formats(
base_url in "(http|https): ) {
let client = AsyncOpenAiClient::new(
"test-key".to_string(),
"test-model".to_string(),
Some(base_url.clone()),
);
let debug_output = format!("{:?}", client);
prop_assert!(
debug_output.contains("/chat/completions"),
"URL should always end with /chat/completions"
);
let has_single_endpoint = debug_output.matches("/chat/completions").count() == 1;
prop_assert!(
has_single_endpoint,
"URL should not duplicate /chat/completions"
);
}
}