use super::*;
use insta::assert_snapshot;
use proptest::prelude::*;
#[test]
fn test_async_gemini_client_new() {
let client =
AsyncGeminiClient::new("AIza-test-key".to_string(), "gemini-2.0-flash".to_string());
assert!(format!("{:?}", client).contains("AsyncGeminiClient"));
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_api_key_storage(
api_key in "[a-zA-Z0-9-_]{10,100}",
model in "[a-zA-Z0-9-]{5,20}",
) {
let client = AsyncGeminiClient::new(
api_key.clone(),
model,
);
prop_assert_eq!(
client.api_key(),
&api_key,
"Client should store the exact API key provided"
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_model_selection_storage(
api_key in "[a-zA-Z0-9-_]{10,50}",
model in "[a-zA-Z0-9-]{5,50}",
) {
let client = AsyncGeminiClient::new(
api_key,
model.clone(),
);
prop_assert_eq!(
client.model(),
&model,
"Client should store the exact model name provided"
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_request_format_correctness(
api_key in "[a-zA-Z0-9-_]{10,50}",
model in "[a-zA-Z0-9-]{5,50}",
prompt in ".*",
) {
let client = AsyncGeminiClient::new(
api_key,
model,
);
let result = client.build_request_body(&prompt);
prop_assert!(result.is_ok(), "Request body should serialize successfully");
let body = result.unwrap();
let json: serde_json::Value = serde_json::from_str(&body)
.expect("Request body should be valid JSON");
prop_assert!(
json.get("stream").is_none(),
"Request should not include stream field (Gemini uses query param)"
);
let contents = json.get("contents").and_then(|v| v.as_array());
prop_assert!(contents.is_some(), "Request should have a contents array");
let contents = contents.unwrap();
prop_assert_eq!(contents.len(), 1, "Contents array should have exactly one element");
let content = &contents[0];
prop_assert_eq!(
content.get("role").and_then(|v| v.as_str()),
Some("user"),
"Content should have role 'user'"
);
let parts = content.get("parts").and_then(|v| v.as_array());
prop_assert!(parts.is_some(), "Content should have a parts array");
let parts = parts.unwrap();
prop_assert_eq!(parts.len(), 1, "Parts array should have exactly one element");
prop_assert_eq!(
parts[0].get("text").and_then(|v| v.as_str()),
Some(prompt.as_str()),
"Part text should match the prompt"
);
}
}
#[test]
fn snapshot_request_body_format() {
let client = AsyncGeminiClient::new("AIza-test123".to_string(), "gemini-2.0-flash".to_string());
let body = client
.build_request_body("suggest jq filters for: extract user names")
.expect("Request body should serialize successfully");
let json: serde_json::Value =
serde_json::from_str(&body).expect("Request body should be valid JSON");
let pretty_json =
serde_json::to_string_pretty(&json).expect("Should be able to pretty-print JSON");
assert_snapshot!(pretty_json);
}
#[test]
fn test_build_url_format() {
let client =
AsyncGeminiClient::new("AIza-test-key".to_string(), "gemini-2.0-flash".to_string());
let url = client.build_url();
assert!(url.starts_with("https://generativelanguage.googleapis.com/v1beta/models/"));
assert!(url.contains("gemini-2.0-flash:streamGenerateContent"));
assert!(url.contains("alt=sse"));
assert!(url.contains("key=AIza-test-key"));
}
#[tokio::test]
async fn test_cancellation_before_response() {
use std::sync::mpsc;
use tokio_util::sync::CancellationToken;
let client =
AsyncGeminiClient::new("AIza-test-key".to_string(), "gemini-2.0-flash".to_string());
let (tx, _rx) = mpsc::channel();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let result = client
.stream_with_cancel("test prompt", 1, cancel_token, tx)
.await;
assert!(
matches!(result, Err(AiError::Cancelled)),
"Should return Cancelled error when token is already cancelled"
);
}
#[test]
fn test_channel_disconnection() {
use crate::ai::ai_state::AiResponse;
use std::sync::mpsc;
let (tx, rx) = mpsc::channel();
drop(rx);
let result = tx.send(AiResponse::Chunk {
text: "test".to_string(),
request_id: 1,
});
assert!(result.is_err(), "Send should fail when receiver is dropped");
}