bitrouter_api/router/
models.rs1use std::sync::Arc;
15
16use bitrouter_core::routers::registry::ModelRegistry;
17use bitrouter_core::routers::routing_table::ModelPricing;
18use serde::Serialize;
19use warp::Filter;
20
21#[derive(Debug, Default)]
23pub struct ModelQuery {
24 pub provider: Option<String>,
26 pub id: Option<String>,
28 pub input_modality: Option<String>,
30 pub output_modality: Option<String>,
32}
33
34pub fn models_filter<T>(
36 table: Arc<T>,
37) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone
38where
39 T: ModelRegistry + Send + Sync + 'static,
40{
41 warp::path!("v1" / "models")
42 .and(warp::get())
43 .and(optional_raw_query())
44 .and(warp::any().map(move || table.clone()))
45 .map(handle_list_models)
46}
47
48fn optional_raw_query()
51-> impl Filter<Extract = (Option<String>,), Error = std::convert::Infallible> + Clone {
52 warp::query::raw()
53 .map(Some)
54 .or(warp::any().map(|| None))
55 .unify()
56}
57
58#[derive(Serialize)]
59struct ModelResponse {
60 id: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 name: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 description: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 max_input_tokens: Option<u64>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 max_output_tokens: Option<u64>,
69 #[serde(skip_serializing_if = "Vec::is_empty")]
70 input_modalities: Vec<String>,
71 #[serde(skip_serializing_if = "Vec::is_empty")]
72 output_modalities: Vec<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pricing: Option<ModelPricing>,
75}
76
77fn parse_query(raw: &str) -> ModelQuery {
78 let mut query = ModelQuery::default();
79 for pair in raw.split('&') {
80 if let Some((key, value)) = pair.split_once('=') {
81 match key {
82 "provider" => query.provider = Some(value.to_owned()),
83 "id" => query.id = Some(value.to_owned()),
84 "input_modality" => query.input_modality = Some(value.to_owned()),
85 "output_modality" => query.output_modality = Some(value.to_owned()),
86 _ => {}
87 }
88 }
89 }
90 query
91}
92
93fn handle_list_models<T: ModelRegistry>(
94 raw_query: Option<String>,
95 table: Arc<T>,
96) -> impl warp::Reply {
97 let query = raw_query.as_deref().map(parse_query).unwrap_or_default();
98 let entries = table.list_models();
99 let id_lower = query.id.as_deref().map(str::to_lowercase);
100
101 let models: Vec<ModelResponse> = entries
102 .into_iter()
103 .filter(|e| {
104 if query
105 .provider
106 .as_deref()
107 .is_some_and(|p| !e.providers.iter().any(|x| x == p))
108 {
109 return false;
110 }
111 if id_lower
112 .as_deref()
113 .is_some_and(|s| !e.id.to_lowercase().contains(s))
114 {
115 return false;
116 }
117 if query
118 .input_modality
119 .as_deref()
120 .is_some_and(|m| !e.input_modalities.iter().any(|x| x == m))
121 {
122 return false;
123 }
124 if query
125 .output_modality
126 .as_deref()
127 .is_some_and(|m| !e.output_modalities.iter().any(|x| x == m))
128 {
129 return false;
130 }
131 true
132 })
133 .map(|e| ModelResponse {
134 id: e.id,
135 name: e.name,
136 description: e.description,
137 max_input_tokens: e.max_input_tokens,
138 max_output_tokens: e.max_output_tokens,
139 input_modalities: e.input_modalities,
140 output_modalities: e.output_modalities,
141 pricing: e.pricing,
142 })
143 .collect();
144 warp::reply::json(&serde_json::json!({ "data": models }))
145}