use crate::client::Client;
use crate::error::Error;
use crate::models::
{
GenerateContentRequest,
GenerateContentResponse,
};
use std::time::Instant;
#[ derive( Debug, Clone ) ]
pub struct ModelComparisonResult
{
pub model_name : String,
pub response : GenerateContentResponse,
pub response_time_ms : u64,
pub success : bool,
pub error_message : Option< String >,
pub input_tokens : Option< i32 >,
pub output_tokens : Option< i32 >,
}
#[ derive( Debug, Clone ) ]
pub struct ComparisonResults
{
pub results : Vec< ModelComparisonResult >,
pub total_time_ms : u64,
pub fastest_model : Option< String >,
pub slowest_model : Option< String >,
}
impl ComparisonResults
{
#[ must_use ]
pub fn get_fastest( &self ) -> Option< &ModelComparisonResult >
{
self.results
.iter()
.filter( | r | r.success )
.min_by_key( | r | r.response_time_ms )
}
#[ must_use ]
pub fn get_slowest( &self ) -> Option< &ModelComparisonResult >
{
self.results
.iter()
.filter( | r | r.success )
.max_by_key( | r | r.response_time_ms )
}
#[ must_use ]
pub fn average_response_time( &self ) -> f64
{
let successful : Vec< _ > = self.results.iter().filter( | r | r.success ).collect();
if successful.is_empty()
{
return 0.0;
}
let total : u64 = successful.iter().map( | r | r.response_time_ms ).sum();
total as f64 / successful.len() as f64
}
#[ must_use ]
pub fn success_rate( &self ) -> f64
{
if self.results.is_empty()
{
return 0.0;
}
let successful = self.results.iter().filter( | r | r.success ).count();
successful as f64 / self.results.len() as f64
}
}
#[ derive( Debug ) ]
pub struct ModelComparator< 'a >
{
client : &'a Client,
}
impl< 'a > ModelComparator< 'a >
{
#[ must_use ]
#[ inline ]
pub fn new( client : &'a Client ) -> Self
{
Self { client }
}
pub async fn compare_models(
&self,
model_names : &[ &str ],
request : &GenerateContentRequest,
) -> Result< ComparisonResults, Error >
{
let start = Instant::now();
let mut results = Vec::with_capacity( model_names.len() );
for model_name in model_names
{
let model_start = Instant::now();
match self.client.models().by_name( model_name ).generate_content( request ).await
{
Ok( response ) =>
{
let elapsed = model_start.elapsed().as_millis() as u64;
let input_tokens = response.usage_metadata.as_ref().and_then( | u | u.prompt_token_count );
let output_tokens = response.usage_metadata.as_ref().and_then( | u | u.candidates_token_count );
results.push( ModelComparisonResult
{
model_name : model_name.to_string(),
response,
response_time_ms : elapsed,
success : true,
error_message : None,
input_tokens,
output_tokens,
} );
}
Err( err ) =>
{
let elapsed = model_start.elapsed().as_millis() as u64;
let empty_response = GenerateContentResponse
{
candidates : vec![],
prompt_feedback : None,
usage_metadata : None,
grounding_metadata : None,
};
results.push( ModelComparisonResult
{
model_name : model_name.to_string(),
response : empty_response,
response_time_ms : elapsed,
success : false,
error_message : Some( err.to_string() ),
input_tokens : None,
output_tokens : None,
} );
}
}
}
let total_time_ms = start.elapsed().as_millis() as u64;
let fastest_model = results
.iter()
.filter( | r | r.success )
.min_by_key( | r | r.response_time_ms )
.map( | r | r.model_name.clone() );
let slowest_model = results
.iter()
.filter( | r | r.success )
.max_by_key( | r | r.response_time_ms )
.map( | r | r.model_name.clone() );
Ok( ComparisonResults
{
results,
total_time_ms,
fastest_model,
slowest_model,
} )
}
pub async fn compare_models_parallel(
&self,
model_names : &[ &str ],
request : &GenerateContentRequest,
) -> Result< ComparisonResults, Error >
{
let start = Instant::now();
let futures : Vec< _ > = model_names
.iter()
.map( | model_name |
{
let request = request.clone();
async move
{
let model_start = Instant::now();
let result = self.client.models().by_name( model_name ).generate_content( &request ).await;
let elapsed = model_start.elapsed().as_millis() as u64;
match result
{
Ok( response ) =>
{
let input_tokens = response.usage_metadata.as_ref().and_then( | u | u.prompt_token_count );
let output_tokens = response.usage_metadata.as_ref().and_then( | u | u.candidates_token_count );
ModelComparisonResult
{
model_name : model_name.to_string(),
response,
response_time_ms : elapsed,
success : true,
error_message : None,
input_tokens,
output_tokens,
}
}
Err( err ) =>
{
let empty_response = GenerateContentResponse
{
candidates : vec![],
prompt_feedback : None,
usage_metadata : None,
grounding_metadata : None,
};
ModelComparisonResult
{
model_name : model_name.to_string(),
response : empty_response,
response_time_ms : elapsed,
success : false,
error_message : Some( err.to_string() ),
input_tokens : None,
output_tokens : None,
}
}
}
}
} )
.collect();
let results = futures::future::join_all( futures ).await;
let total_time_ms = start.elapsed().as_millis() as u64;
let fastest_model = results
.iter()
.filter( | r | r.success )
.min_by_key( | r | r.response_time_ms )
.map( | r | r.model_name.clone() );
let slowest_model = results
.iter()
.filter( | r | r.success )
.max_by_key( | r | r.response_time_ms )
.map( | r | r.model_name.clone() );
Ok( ComparisonResults
{
results,
total_time_ms,
fastest_model,
slowest_model,
} )
}
}
impl Client
{
#[ must_use ]
#[ inline ]
pub fn comparator( &self ) -> ModelComparator< '_ >
{
ModelComparator::new( self )
}
}