use std::collections::HashMap;
use std::time::{Duration, Instant};
use herolib_ai::{AiClient, Message, Model, Provider, ProviderConfig};
#[derive(Debug)]
struct ModelTestResult {
model: Model,
provider: Provider,
provider_model_id: String,
success: bool,
response_time_ms: u64,
error: Option<String>,
response_preview: Option<String>,
}
#[allow(dead_code)]
#[derive(Debug)]
#[allow(dead_code)]
struct ProviderModelsResult {
provider: Provider,
available: bool,
models: Vec<String>,
error: Option<String>,
}
fn fetch_openrouter_models(api_key: &str) -> Result<Vec<String>, String> {
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(30)))
.build(),
);
let response = agent
.get("https://openrouter.ai/api/v1/models")
.header("Authorization", &format!("Bearer {}", api_key))
.call()
.map_err(|e| format!("HTTP error: {}", e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| format!("Read error: {}", e))?;
let json: serde_json::Value =
serde_json::from_str(&body).map_err(|e| format!("JSON parse error: {}", e))?;
let models = json["data"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
Ok(models)
}
fn fetch_groq_models(api_key: &str) -> Result<Vec<String>, String> {
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(30)))
.build(),
);
let response = agent
.get("https://api.groq.com/openai/v1/models")
.header("Authorization", &format!("Bearer {}", api_key))
.call()
.map_err(|e| format!("HTTP error: {}", e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| format!("Read error: {}", e))?;
let json: serde_json::Value =
serde_json::from_str(&body).map_err(|e| format!("JSON parse error: {}", e))?;
let models = json["data"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
Ok(models)
}
fn fetch_sambanova_models(api_key: &str) -> Result<Vec<String>, String> {
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(30)))
.build(),
);
let response = agent
.get("https://api.sambanova.ai/v1/models")
.header("Authorization", &format!("Bearer {}", api_key))
.call()
.map_err(|e| format!("HTTP error: {}", e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| format!("Read error: {}", e))?;
let json: serde_json::Value =
serde_json::from_str(&body).map_err(|e| format!("JSON parse error: {}", e))?;
let models = json["data"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
Ok(models)
}
fn test_model_on_provider(
provider: Provider,
api_key: &str,
model: Model,
model_id: &str,
) -> ModelTestResult {
let client =
AiClient::new().with_provider(ProviderConfig::new(provider, api_key).with_timeout(60));
let messages = vec![
Message::system("You are a helpful assistant. Answer in one short sentence."),
Message::user("Who are you? Just say your name/model."),
];
let start = Instant::now();
let result = client.chat_with_options(model, messages, Some(0.1), Some(50));
let elapsed = start.elapsed().as_millis() as u64;
match result {
Ok(response) => {
let content = response.content().unwrap_or("").to_string();
let preview = if content.len() > 100 {
format!("{}...", &content[..100])
} else {
content
};
ModelTestResult {
model,
provider,
provider_model_id: model_id.to_string(),
success: true,
response_time_ms: elapsed,
error: None,
response_preview: Some(preview),
}
}
Err(e) => ModelTestResult {
model,
provider,
provider_model_id: model_id.to_string(),
success: false,
response_time_ms: elapsed,
error: Some(e.to_string()),
response_preview: None,
},
}
}
fn print_header(title: &str) {
println!("\n{}", "=".repeat(70));
println!("{}", title);
println!("{}", "=".repeat(70));
}
fn print_subheader(title: &str) {
println!("\n{}", "-".repeat(50));
println!("{}", title);
println!("{}", "-".repeat(50));
}
fn main() {
println!("herolib-ai Model Test Utility");
println!("Testing model availability across all providers\n");
let mut api_keys: HashMap<Provider, String> = HashMap::new();
if let Ok(key) = std::env::var("GROQ_API_KEY") {
if !key.is_empty() {
api_keys.insert(Provider::Groq, key);
}
}
if let Ok(key) = std::env::var("OPENROUTER_API_KEY") {
if !key.is_empty() {
api_keys.insert(Provider::OpenRouter, key);
}
}
if let Ok(key) = std::env::var("SAMBANOVA_API_KEY") {
if !key.is_empty() {
api_keys.insert(Provider::SambaNova, key);
}
}
if api_keys.is_empty() {
eprintln!("ERROR: No API keys configured!");
eprintln!("Set at least one of:");
eprintln!(" - GROQ_API_KEY");
eprintln!(" - OPENROUTER_API_KEY");
eprintln!(" - SAMBANOVA_API_KEY");
std::process::exit(1);
}
println!("Configured providers:");
for provider in api_keys.keys() {
println!(" - {}", provider);
}
print_header("Phase 1: Querying Provider Model Lists");
let mut provider_models: HashMap<Provider, ProviderModelsResult> = HashMap::new();
let mut model_list_available = false;
for (provider, api_key) in &api_keys {
print!("Querying {}... ", provider);
let result = match provider {
Provider::Groq => fetch_groq_models(api_key),
Provider::OpenRouter => fetch_openrouter_models(api_key),
Provider::SambaNova => fetch_sambanova_models(api_key),
};
match result {
Ok(models) => {
println!("OK ({} models)", models.len());
model_list_available = true;
provider_models.insert(
*provider,
ProviderModelsResult {
provider: *provider,
available: true,
models,
error: None,
},
);
}
Err(e) => {
println!("FAILED: {}", e);
provider_models.insert(
*provider,
ProviderModelsResult {
provider: *provider,
available: false,
models: vec![],
error: Some(e),
},
);
}
}
}
if model_list_available {
print_header("Phase 2: Validating Model Mappings");
let mut mapping_errors: Vec<String> = Vec::new();
for model in Model::all() {
let info = model.info();
println!("\n{} ({}):", model.name(), info.description);
for mapping in &info.providers {
if let Some(provider_result) = provider_models.get(&mapping.provider) {
if provider_result.available {
let found = provider_result.models.iter().any(|m| m == mapping.model_id);
if found {
println!(" {} {} -> OK", mapping.provider, mapping.model_id);
} else {
println!(" {} {} -> NOT FOUND", mapping.provider, mapping.model_id);
mapping_errors.push(format!(
"{}: {} not found on {}",
model.name(),
mapping.model_id,
mapping.provider
));
}
} else {
println!(
" {} {} -> SKIPPED (provider unavailable)",
mapping.provider, mapping.model_id
);
}
} else {
println!(
" {} {} -> SKIPPED (no API key)",
mapping.provider, mapping.model_id
);
}
}
}
if !mapping_errors.is_empty() {
print_subheader("Mapping Errors");
for error in &mapping_errors {
println!(" ERROR: {}", error);
}
}
} else {
println!("\nSkipping model validation (no provider model lists available)");
}
print_header("Phase 3: Testing Models with 'whoami' Query");
let mut test_results: Vec<ModelTestResult> = Vec::new();
let mut tested_count = 0;
let mut success_count = 0;
for model in Model::all() {
let info = model.info();
print_subheader(&format!("Testing: {}", model.name()));
for mapping in &info.providers {
if let Some(api_key) = api_keys.get(&mapping.provider) {
print!(" {} ({})... ", mapping.provider, mapping.model_id);
let result =
test_model_on_provider(mapping.provider, api_key, *model, mapping.model_id);
tested_count += 1;
if result.success {
success_count += 1;
println!("OK ({}ms)", result.response_time_ms);
if let Some(preview) = &result.response_preview {
println!(" Response: {}", preview);
}
} else {
println!("FAILED");
if let Some(error) = &result.error {
println!(" Error: {}", error);
}
}
test_results.push(result);
}
}
}
print_header("Final Report");
println!("\nProvider Status:");
for (provider, result) in &provider_models {
let status = if result.available { "OK" } else { "FAILED" };
println!(
" {}: {} ({} models)",
provider,
status,
result.models.len()
);
}
println!("\nTest Summary:");
println!(" Total tests: {}", tested_count);
println!(" Successful: {}", success_count);
println!(" Failed: {}", tested_count - success_count);
println!(
" Success rate: {:.1}%",
if tested_count > 0 {
(success_count as f64 / tested_count as f64) * 100.0
} else {
0.0
}
);
println!("\nDetailed Results by Model:");
for model in Model::all() {
let model_results: Vec<&ModelTestResult> =
test_results.iter().filter(|r| r.model == *model).collect();
if model_results.is_empty() {
continue;
}
let successes = model_results.iter().filter(|r| r.success).count();
let total = model_results.len();
println!(" {} - {}/{} providers OK", model.name(), successes, total);
for result in model_results {
let status = if result.success { "OK" } else { "FAIL" };
println!(
" {} {}: {} ({}ms)",
result.provider, result.provider_model_id, status, result.response_time_ms
);
}
}
if success_count < tested_count {
println!("\nSome tests failed!");
std::process::exit(1);
} else {
println!("\nAll tests passed!");
}
}