use reqwest;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
use std::path::Path;
use uuid::Uuid;
use vllora_core::metadata::error::DatabaseError;
use vllora_core::metadata::models::model::DbNewModel;
use vllora_core::metadata::pool::DbPool;
use vllora_core::metadata::services::model::ModelServiceImpl;
use vllora_core::types::metadata::services::model::ModelService;
use vllora_core::types::LANGDB_API_URL;
use vllora_llm::types::models::ModelMetadata;
#[derive(Debug, thiserror::Error)]
pub enum ModelsLoadError {
#[error("Failed to fetch models: {0}")]
Fetch(#[from] reqwest::Error),
#[error("Database error: {0}")]
Database(#[from] DatabaseError),
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
#[error("JSON serialization error: {0}")]
Json(#[from] serde_json::Error),
}
pub async fn fetch_and_store_models(
db_pool: DbPool,
) -> Result<Vec<ModelMetadata>, ModelsLoadError> {
let langdb_api_url = std::env::var("LANGDB_API_URL")
.ok()
.unwrap_or(LANGDB_API_URL.to_string());
let client = reqwest::Client::new();
let models: Vec<ModelMetadata> = client
.get(format!(
"{langdb_api_url}/pricing?include_parameters=true&include_benchmark=true"
))
.send()
.await?
.json()
.await?;
let mut synced_model_identifiers: HashSet<(String, String)> = models
.iter()
.map(|m| (m.model.clone(), m.inference_provider.provider.to_string()))
.collect();
let mut db_models: Vec<DbNewModel> =
models.iter().map(|m| DbNewModel::from(m.clone())).collect();
let mut langdb_models = HashMap::<String, DbNewModel>::new();
for model in &db_models {
let mut new_model = model.clone();
if langdb_models.contains_key(&new_model.model_name) {
continue;
}
synced_model_identifiers.insert((new_model.model_name.clone(), "langdb".to_string()));
new_model.id = Some(Uuid::new_v4().to_string());
new_model.endpoint = Some(langdb_api_url.clone());
new_model.provider_name = "langdb".to_string();
new_model.model_name_in_provider = Some(new_model.model_name.clone());
langdb_models.insert(new_model.model_name.clone(), new_model);
}
db_models.extend(langdb_models.values().cloned());
let model_service = ModelServiceImpl::new(db_pool);
model_service.insert_many(db_models)?;
let db_models = model_service.list(None)?;
let models_to_delete: Vec<String> = db_models
.iter()
.filter(|db_model| {
let identifier = (db_model.model_name.clone(), db_model.provider_name.clone());
!synced_model_identifiers.contains(&identifier)
})
.filter_map(|db_model| db_model.id.clone())
.collect();
if !models_to_delete.is_empty() {
model_service.mark_models_as_deleted(models_to_delete)?;
}
Ok(models)
}
pub async fn fetch_and_save_models_json(
output_path: &Path,
) -> Result<Vec<ModelMetadata>, ModelsLoadError> {
let langdb_api_url = std::env::var("LANGDB_API_URL")
.ok()
.unwrap_or(LANGDB_API_URL.to_string());
let client = reqwest::Client::new();
let models: Vec<ModelMetadata> = client
.get(format!(
"{langdb_api_url}/pricing?include_parameters=true&include_benchmark=true"
))
.send()
.await?
.json()
.await?;
let json_content = serde_json::to_string_pretty(&models)?;
fs::write(output_path, json_content)?;
println!(
"Successfully saved {} models to {}",
models.len(),
output_path.display()
);
Ok(models)
}
pub fn load_models_from_json(json_content: &str) -> Result<Vec<ModelMetadata>, ModelsLoadError> {
let models: Vec<ModelMetadata> = serde_json::from_str(json_content)?;
Ok(models)
}