use crate::get_option_value;
use crate::support::data::{IMAGE_URL_JPG_DUCK, get_b64_duck};
use crate::support::{
Check, Result, StreamExtract, assert_contains, contains_checks, extract_stream_end, get_big_content,
seed_chat_req_simple, seed_chat_req_tool_simple, validate_checks,
};
use genai::adapter::AdapterKind;
use genai::chat::{
CacheControl, ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, ContentPart, ImageSource, JsonSpec, Tool,
ToolResponse,
};
use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn};
use genai::{Client, ClientConfig, ModelIden};
use serde_json::{Value, json};
use std::sync::Arc;
use value_ext::JsonValueExt;
pub async fn common_test_chat_simple_ok(model: &str, checks: Option<Check>) -> Result<()> {
validate_checks(checks.clone(), Check::REASONING | Check::REASONING_USAGE)?;
let client = Client::default();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat(model, chat_req, None).await?;
let content = chat_res.content_text_as_str().ok_or("Should have content")?;
assert!(!content.trim().is_empty(), "Content should not be empty");
let usage = &chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
assert!(
total_tokens == prompt_tokens + completion_tokens,
"total_tokens should be equal to prompt_token + comletion_token"
);
if contains_checks(checks.clone(), Check::REASONING_USAGE) {
let reasoning_tokens = usage
.completion_tokens_details
.as_ref()
.and_then(|v| v.reasoning_tokens)
.ok_or("should have reasoning_tokens")?;
assert!(reasoning_tokens > 0, "reasoning_usage should be > 0");
}
if contains_checks(checks, Check::REASONING) {
let reasoning_content = chat_res
.reasoning_content
.as_deref()
.ok_or("SHOULD have extracted some reasoning_content")?;
assert!(!reasoning_content.is_empty(), "reasoning_content should not be empty");
assert!(
reasoning_content.len() > content.len(),
"Reasoning content should be > than the content"
);
}
Ok(())
}
pub async fn common_test_chat_multi_system_ok(model: &str) -> Result<()> {
let client = Client::default();
let chat_req = ChatRequest::new(vec![
ChatMessage::system("Be very concise"),
ChatMessage::system("Explain with bullet points"),
ChatMessage::user("Why is the sky blue?"),
])
.with_system("And end with 'Thank you'");
let chat_res = client.exec_chat(model, chat_req, None).await?;
assert!(
!get_option_value!(chat_res.content).is_empty(),
"Content should not be empty"
);
let usage = chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
assert!(
total_tokens == prompt_tokens + completion_tokens,
"total_tokens should be equal to prompt_tokens + completion_tokens"
);
Ok(())
}
pub async fn common_test_chat_json_mode_ok(model: &str, checks: Option<Check>) -> Result<()> {
validate_checks(checks.clone(), Check::USAGE)?;
let client = Client::default();
let chat_req = ChatRequest::new(vec![
ChatMessage::system(
r#"Turn the user content into the most probable JSON content.
Reply in a JSON format."#,
),
ChatMessage::user(
r#"
| Model | Maker
| gpt-4o | OpenAI
| gpt-4o-mini | OpenAI
| llama-3.1-70B | Meta
"#,
),
]);
let chat_options = ChatOptions::default().with_response_format(ChatResponseFormat::JsonMode);
let chat_res = client.exec_chat(model, chat_req, Some(&chat_options)).await?;
if contains_checks(checks, Check::USAGE) {
let usage = &chat_res.usage;
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
}
let content = chat_res.content_text_into_string().ok_or("SHOULD HAVE CONTENT")?;
let json: serde_json::Value = serde_json::from_str(&content).map_err(|err| format!("Was not valid JSON: {err}"))?;
let pretty_json = serde_json::to_string_pretty(&json).map_err(|err| format!("Was not valid JSON: {err}"))?;
Ok(())
}
pub async fn common_test_chat_json_structured_ok(model: &str, checks: Option<Check>) -> Result<()> {
validate_checks(checks.clone(), Check::USAGE)?;
let client = Client::default();
let chat_req = ChatRequest::new(vec![
ChatMessage::system(
r#"Turn the user content into the most probable JSON content.
Reply in a JSON format."#,
),
ChatMessage::user(
r#"
| Model | Maker
| gpt-4o | OpenAI
| gpt-4o-mini | OpenAI
| llama-3.1-70B | Meta
"#,
),
]);
let json_schema = json!({
"type": "object",
"properties": {
"all_models": {
"type": "array",
"items": {
"type": "object",
"properties": {
"maker": { "type": "string" },
"model_name": { "type": "string" }
},
"required": ["maker", "model_name"]
}
}
},
"required": ["all_models"]
});
let chat_options = ChatOptions::default().with_response_format(JsonSpec::new("some-schema", json_schema));
let chat_res = client.exec_chat(model, chat_req, Some(&chat_options)).await?;
if contains_checks(checks, Check::USAGE) {
let usage = &chat_res.usage;
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
}
let content = chat_res.content_text_into_string().ok_or("SHOULD HAVE CONTENT")?;
let json_response: serde_json::Value =
serde_json::from_str(&content).map_err(|err| format!("Was not valid JSON: {err}"))?;
let models: Vec<Value> = json_response.x_get("all_models")?;
assert_eq!(3, models.len(), "Number of models");
let first_maker: String = models.first().ok_or("No models")?.x_get("maker")?;
assert_eq!("OpenAI", first_maker, "First maker");
Ok(())
}
pub async fn common_test_chat_temperature_ok(model: &str) -> Result<()> {
let client = Client::default();
let chat_req = seed_chat_req_simple();
let chat_options = ChatOptions::default().with_temperature(0.);
let chat_res = client.exec_chat(model, chat_req, Some(&chat_options)).await?;
assert!(
!chat_res.content_text_as_str().unwrap_or("").is_empty(),
"Content should not be empty"
);
Ok(())
}
pub async fn common_test_chat_stop_sequences_ok(model: &str) -> Result<()> {
let client = Client::default();
let chat_req = ChatRequest::from_user("What is the capital of England?");
let chat_options = ChatOptions::default().with_stop_sequences(vec!["London".to_string()]);
let chat_res = client.exec_chat(model, chat_req, Some(&chat_options)).await?;
let ai_content_lower = chat_res
.content_text_as_str()
.ok_or("Should have a AI response")?
.to_lowercase();
assert!(
!ai_content_lower.contains("london"),
"Content should not contain 'London'"
);
Ok(())
}
pub async fn common_test_chat_reasoning_normalize_ok(model: &str) -> Result<()> {
let client = Client::builder()
.with_chat_options(ChatOptions::default().with_normalize_reasoning_content(true))
.build();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat(model, chat_req, None).await?;
chat_res.content_text_as_str();
let content = chat_res.content_text_as_str().ok_or("Should have content")?;
assert!(!content.trim().is_empty(), "Content should not be empty");
let reasoning_content = chat_res.reasoning_content.as_deref().ok_or("Should have reasoning_content")?;
assert!(
reasoning_content.len() > content.len(),
"reasoning_content should be > than the content"
);
let usage = chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
assert!(
total_tokens == prompt_tokens + completion_tokens,
"total_tokens should be equal to prompt_token + completion_tokens"
);
Ok(())
}
pub async fn common_test_chat_cache_implicit_simple_ok(model: &str) -> Result<()> {
let client = Client::default();
let big_content = get_big_content()?;
let chat_req = ChatRequest::new(vec![
ChatMessage::user(big_content),
ChatMessage::user("Give a very short summary of what each of those files are about."),
]);
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let chat_res = client.exec_chat(model, chat_req, None).await?;
let content = chat_res.content_text_as_str().ok_or("Should have content")?;
assert!(!content.trim().is_empty(), "Content should not be empty");
let usage = &chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
let prompt_tokens_details = usage
.prompt_tokens_details
.as_ref()
.ok_or("Should have prompt_tokens_details")?;
let cached_tokens = get_option_value!(prompt_tokens_details.cached_tokens);
assert!(cached_tokens > 0, " cached_tokens should be greater than 0");
assert!(total_tokens > 0, "total_tokens should be > 0");
Ok(())
}
pub async fn common_test_chat_cache_explicit_user_ok(model: &str) -> Result<()> {
let client = Client::default();
let big_content = get_big_content()?;
let chat_req = ChatRequest::new(vec![
ChatMessage::system("Give a very short summary of what each of those files are about"),
ChatMessage::user(big_content).with_options(CacheControl::Ephemeral),
]);
let chat_res = client.exec_chat(model, chat_req, None).await?;
let content = chat_res.content_text_as_str().ok_or("Should have content")?;
assert!(!content.trim().is_empty(), "Content should not be empty");
let usage = &chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
let prompt_tokens_details = usage
.prompt_tokens_details
.as_ref()
.ok_or("Should have prompt_tokens_details")?;
let cache_creation_tokens = get_option_value!(prompt_tokens_details.cache_creation_tokens);
let cached_tokens = get_option_value!(prompt_tokens_details.cached_tokens);
assert!(
cache_creation_tokens > 0 || cached_tokens > 0,
"one of cache_creation_tokens or cached_tokens should be greater than 0"
);
assert!(total_tokens > 0, "total_tokens should be > 0");
Ok(())
}
pub async fn common_test_chat_cache_explicit_system_ok(model: &str) -> Result<()> {
let client = Client::default();
let big_content = get_big_content()?;
let chat_req = ChatRequest::new(vec![
ChatMessage::system("You are a senior developer which has the following code base:"),
ChatMessage::system(big_content).with_options(CacheControl::Ephemeral),
ChatMessage::user("can you give a summary of each file (very concise)"),
]);
let chat_res = client.exec_chat(model, chat_req, None).await?;
let content = chat_res.content_text_as_str().ok_or("Should have content")?;
assert!(!content.trim().is_empty(), "Content should not be empty");
let usage = &chat_res.usage;
let prompt_tokens = get_option_value!(usage.prompt_tokens);
let completion_tokens = get_option_value!(usage.completion_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
let prompt_tokens_details = usage
.prompt_tokens_details
.as_ref()
.ok_or("Should have prompt_tokens_details")?;
let cache_creation_tokens = get_option_value!(prompt_tokens_details.cache_creation_tokens);
let cached_tokens = get_option_value!(prompt_tokens_details.cached_tokens);
assert!(
cache_creation_tokens > 0 || cached_tokens > 0,
"one of cache_creation_tokens or cached_tokens should be greater than 0"
);
assert!(total_tokens > 0, "total_tokens should be > 0");
Ok(())
}
pub async fn common_test_chat_stream_simple_ok(model: &str, checks: Option<Check>) -> Result<()> {
validate_checks(checks.clone(), Check::REASONING)?;
let client = Client::default();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?;
let StreamExtract {
stream_end,
content,
reasoning_content,
} = extract_stream_end(chat_res.stream).await?;
let content = content.ok_or("extract_stream_end SHOULD have extracted some content")?;
assert!(!content.is_empty(), "Content streamed should not be empty");
assert!(
stream_end.captured_usage.is_none(),
"StreamEnd should not have any meta_usage"
);
assert!(
stream_end.captured_content.is_none(),
"StreamEnd should not have any captured_content"
);
if contains_checks(checks, Check::REASONING) {
let reasoning_content =
reasoning_content.ok_or("extract_stream_end SHOULD have extracted some reasoning_content")?;
assert!(!reasoning_content.is_empty(), "reasoning_content should not be empty");
assert!(
reasoning_content.len() > content.len(),
"Reasoning content should be > than the content"
);
}
Ok(())
}
pub async fn common_test_chat_stream_capture_content_ok(model: &str) -> Result<()> {
let client = Client::builder()
.with_chat_options(ChatOptions::default().with_capture_content(true))
.build();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?;
let StreamExtract {
stream_end,
content,
reasoning_content,
} = extract_stream_end(chat_res.stream).await?;
assert!(
stream_end.captured_usage.is_none(),
"StreamEnd should not have any meta_usage"
);
let captured_content = get_option_value!(stream_end.captured_content);
assert!(!captured_content.is_empty(), "captured_content.length should be > 0");
assert!(
stream_end.captured_reasoning_content.is_none(),
"The captured_reasoning_content should be None"
);
Ok(())
}
pub async fn common_test_chat_stream_capture_all_ok(model: &str, checks: Option<Check>) -> Result<()> {
validate_checks(checks.clone(), Check::REASONING)?;
let client = Client::builder()
.with_chat_options(
ChatOptions::default()
.with_capture_usage(true)
.with_capture_content(true)
.with_capture_reasoning_content(true),
)
.build();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?;
let StreamExtract {
stream_end,
content,
reasoning_content,
} = extract_stream_end(chat_res.stream).await?;
let meta_usage = get_option_value!(stream_end.captured_usage);
assert!(
get_option_value!(meta_usage.prompt_tokens) > 0,
"prompt_token should be > 0"
);
assert!(
get_option_value!(meta_usage.completion_tokens) > 0,
"completion_tokens should be > 0"
);
assert!(
get_option_value!(meta_usage.total_tokens) > 0,
"total_tokens should be > 0"
);
let captured_content = get_option_value!(stream_end.captured_content);
let captured_content = captured_content.text_as_str().ok_or("Captured content should have a text")?;
assert!(!captured_content.is_empty(), "captured_content.length should be > 0");
if contains_checks(checks, Check::REASONING) {
let reasoning_content = stream_end
.captured_reasoning_content
.ok_or("captured_reasoning_content SHOULD have extracted some reasoning_content")?;
assert!(!reasoning_content.is_empty(), "reasoning_content should not be empty");
assert!(
reasoning_content.len() > captured_content.len(),
"Reasoning content should be > than the content"
);
}
Ok(())
}
pub async fn common_test_chat_image_url_ok(model: &str) -> Result<()> {
let client = Client::default();
let mut chat_req = ChatRequest::default().with_system("Answer in one sentence");
chat_req = chat_req.append_message(ChatMessage::user(vec![
ContentPart::from_text("What is in this picture?"),
ContentPart::from_image_url("image/jpeg", IMAGE_URL_JPG_DUCK),
]));
let chat_res = client.exec_chat(model, chat_req, None).await?;
let res = chat_res.content_text_as_str().ok_or("Should have text result")?;
assert_contains(res, "duck");
Ok(())
}
pub async fn common_test_chat_image_b64_ok(model: &str) -> Result<()> {
let client = Client::default();
let mut chat_req = ChatRequest::default().with_system("Answer in one sentence");
chat_req = chat_req.append_message(ChatMessage::user(vec![
ContentPart::from_text("What is in this picture?"),
ContentPart::from_image_base64("image/jpeg", get_b64_duck()?),
]));
let chat_res = client.exec_chat(model, chat_req, None).await?;
let res = chat_res.content_text_as_str().ok_or("Should have text result")?;
assert_contains(res, "duck");
Ok(())
}
pub async fn common_test_tool_simple_ok(model: &str, complete_check: bool) -> Result<()> {
let client = Client::default();
let chat_req = seed_chat_req_tool_simple();
let chat_res = client.exec_chat(model, chat_req, None).await?;
let mut tool_calls = chat_res.tool_calls().ok_or("Should have tool calls")?;
let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?;
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris");
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France");
if complete_check {
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C");
}
Ok(())
}
pub async fn common_test_tool_full_flow_ok(model: &str, complete_check: bool) -> Result<()> {
let client = Client::default();
let mut chat_req = seed_chat_req_tool_simple();
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
let tool_calls = chat_res.into_tool_calls().ok_or("Should have tool calls in chat_res")?;
let first_tool_call = tool_calls.first().ok_or("Should have at least one tool call")?;
let first_tool_call_id = &first_tool_call.call_id;
let tool_response = ToolResponse::new(first_tool_call_id, r#"{"weather": "Sunny", "temperature": "32C"}"#);
let chat_req = chat_req.append_message(tool_calls).append_message(tool_response);
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
let content = chat_res
.content_text_as_str()
.ok_or("Last response should be message")?
.to_lowercase();
assert!(content.contains("paris"), "Should contain 'Paris'");
assert!(content.contains("32"), "Should contain '32'");
if complete_check {
assert!(content.contains("sunny"), "Should contain 'sunny'");
}
Ok(())
}
pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> {
let auth_resolver = AuthResolver::from_resolver_fn(move |model_iden: ModelIden| Ok(Some(auth_data)));
let client = Client::builder().with_auth_resolver(auth_resolver).build();
let chat_req = seed_chat_req_simple();
let chat_res = client.exec_chat(model, chat_req, None).await?;
assert!(
!get_option_value!(chat_res.content).is_empty(),
"Content should not be empty"
);
let usage = chat_res.usage;
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
Ok(())
}
pub async fn common_test_list_models(adapter_kind: AdapterKind, contains: &str) -> Result<()> {
let client = Client::default();
let models = client.all_model_names(adapter_kind).await?;
assert_contains(&models, contains);
Ok(())
}