use actix_web::{web, HttpResponse, Result};
use serde::Deserialize;
use std::collections::HashMap;
use vllora_core::metadata::pool::DbPool;
use vllora_core::metadata::services::provider_credential::ProviderCredentialsServiceImpl;
use vllora_core::types::metadata::project::Project;
use vllora_core::types::metadata::services::model::ModelService;
use vllora_core::types::metadata::services::provider_credential::ProviderCredentialsService;
use vllora_core::GatewayApiError;
use vllora_llm::types::models::group_models_by_name_with_endpoints;
use vllora_llm::types::models::ModelMetadata;
#[derive(Deserialize)]
pub struct PricingQueryParams {
#[serde(default)]
pub include_parameters: bool,
#[serde(default)]
pub include_benchmark: bool,
}
pub async fn list_gateway_pricing(
query: web::Query<PricingQueryParams>,
model_service: web::Data<Box<dyn ModelService>>,
db_pool: web::Data<DbPool>,
project: web::ReqData<Project>,
) -> Result<HttpResponse, GatewayApiError> {
let project_id = None;
let result = model_service.list(project_id).map(|db_models| {
let models: Vec<ModelMetadata> = db_models
.into_iter()
.map(|m| {
let mut model_metadata: ModelMetadata = m.into();
if !query.include_parameters {
model_metadata.parameters = None;
}
if !query.include_benchmark {
model_metadata.benchmark_info = None;
}
model_metadata
})
.collect();
let provider_credentials_map =
get_provider_credentials_map(&db_pool, Some(&project.id.to_string()));
group_models_by_name_with_endpoints(models, &provider_credentials_map)
});
match result {
Ok(grouped_models) => Ok(HttpResponse::Ok().json(grouped_models)),
Err(e) => Err(GatewayApiError::CustomError(format!(
"Failed to fetch models: {}",
e
))),
}
}
fn get_provider_credentials_map(
db_pool: &DbPool,
project_id: Option<&str>,
) -> HashMap<String, bool> {
let provider_credentials_service = ProviderCredentialsServiceImpl::new(db_pool.clone());
let provider_credentials_map = provider_credentials_service
.list_providers(project_id)
.unwrap_or_default();
provider_credentials_map
.into_iter()
.map(|provider_info| (provider_info.name, provider_info.has_credentials))
.collect::<HashMap<String, bool>>()
}