use std::sync::Arc;
use bitrouter_core::routers::registry::ModelRegistry;
use bitrouter_core::routers::routing_table::ModelPricing;
use serde::Serialize;
use warp::Filter;
#[derive(Debug, Default)]
pub struct ModelQuery {
pub provider: Option<String>,
pub id: Option<String>,
pub input_modality: Option<String>,
pub output_modality: Option<String>,
}
pub fn models_filter<T>(
table: Arc<T>,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone
where
T: ModelRegistry + Send + Sync + 'static,
{
warp::path!("v1" / "models")
.and(warp::get())
.and(optional_raw_query())
.and(warp::any().map(move || table.clone()))
.map(handle_list_models)
}
fn optional_raw_query()
-> impl Filter<Extract = (Option<String>,), Error = std::convert::Infallible> + Clone {
warp::query::raw()
.map(Some)
.or(warp::any().map(|| None))
.unify()
}
#[derive(Serialize)]
struct ModelResponse {
id: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
max_input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
input_modalities: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
output_modalities: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pricing: Option<ModelPricing>,
}
fn parse_query(raw: &str) -> ModelQuery {
let mut query = ModelQuery::default();
for pair in raw.split('&') {
if let Some((key, value)) = pair.split_once('=') {
match key {
"provider" => query.provider = Some(value.to_owned()),
"id" => query.id = Some(value.to_owned()),
"input_modality" => query.input_modality = Some(value.to_owned()),
"output_modality" => query.output_modality = Some(value.to_owned()),
_ => {}
}
}
}
query
}
fn handle_list_models<T: ModelRegistry>(
raw_query: Option<String>,
table: Arc<T>,
) -> impl warp::Reply {
let query = raw_query.as_deref().map(parse_query).unwrap_or_default();
let entries = table.list_models();
let id_lower = query.id.as_deref().map(str::to_lowercase);
let models: Vec<ModelResponse> = entries
.into_iter()
.filter(|e| {
if query
.provider
.as_deref()
.is_some_and(|p| !e.providers.iter().any(|x| x == p))
{
return false;
}
if id_lower
.as_deref()
.is_some_and(|s| !e.id.to_lowercase().contains(s))
{
return false;
}
if query
.input_modality
.as_deref()
.is_some_and(|m| !e.input_modalities.iter().any(|x| x == m))
{
return false;
}
if query
.output_modality
.as_deref()
.is_some_and(|m| !e.output_modalities.iter().any(|x| x == m))
{
return false;
}
true
})
.map(|e| ModelResponse {
id: e.id,
name: e.name,
description: e.description,
max_input_tokens: e.max_input_tokens,
max_output_tokens: e.max_output_tokens,
input_modalities: e.input_modalities,
output_modalities: e.output_modalities,
pricing: e.pricing,
})
.collect();
warp::reply::json(&serde_json::json!({ "data": models }))
}