bitrouter_api/router/
models.rs1use std::sync::Arc;
15
16use bitrouter_core::routers::registry::ModelRegistry;
17use bitrouter_core::routers::routing_table::ModelPricing;
18use serde::Serialize;
19use warp::Filter;
20
21#[derive(Debug, Default)]
23pub struct ModelQuery {
24 pub provider: Option<String>,
26 pub id: Option<String>,
28 pub input_modality: Option<String>,
30 pub output_modality: Option<String>,
32}
33
34pub fn models_filter<T>(
36 table: Arc<T>,
37) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone
38where
39 T: ModelRegistry + Send + Sync + 'static,
40{
41 warp::path!("v1" / "models")
42 .and(warp::get())
43 .and(optional_raw_query())
44 .and(warp::any().map(move || table.clone()))
45 .map(handle_list_models)
46}
47
48fn optional_raw_query()
51-> impl Filter<Extract = (Option<String>,), Error = std::convert::Infallible> + Clone {
52 warp::query::raw()
53 .map(Some)
54 .or(warp::any().map(|| None))
55 .unify()
56}
57
58#[derive(Serialize)]
59struct ModelResponse {
60 id: String,
61 providers: Vec<String>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 name: Option<String>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 description: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 max_input_tokens: Option<u64>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 max_output_tokens: Option<u64>,
70 #[serde(skip_serializing_if = "Vec::is_empty")]
71 input_modalities: Vec<String>,
72 #[serde(skip_serializing_if = "Vec::is_empty")]
73 output_modalities: Vec<String>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pricing: Option<ModelPricing>,
76}
77
78fn parse_query(raw: &str) -> ModelQuery {
79 let mut query = ModelQuery::default();
80 for pair in raw.split('&') {
81 if let Some((key, value)) = pair.split_once('=') {
82 match key {
83 "provider" => query.provider = Some(value.to_owned()),
84 "id" => query.id = Some(value.to_owned()),
85 "input_modality" => query.input_modality = Some(value.to_owned()),
86 "output_modality" => query.output_modality = Some(value.to_owned()),
87 _ => {}
88 }
89 }
90 }
91 query
92}
93
94fn handle_list_models<T: ModelRegistry>(
95 raw_query: Option<String>,
96 table: Arc<T>,
97) -> impl warp::Reply {
98 let query = raw_query.as_deref().map(parse_query).unwrap_or_default();
99 let entries = table.list_models();
100 let id_lower = query.id.as_deref().map(str::to_lowercase);
101
102 let models: Vec<ModelResponse> = entries
103 .into_iter()
104 .filter(|e| {
105 if query
106 .provider
107 .as_deref()
108 .is_some_and(|p| !e.providers.iter().any(|x| x == p))
109 {
110 return false;
111 }
112 if id_lower
113 .as_deref()
114 .is_some_and(|s| !e.id.to_lowercase().contains(s))
115 {
116 return false;
117 }
118 if query
119 .input_modality
120 .as_deref()
121 .is_some_and(|m| !e.input_modalities.iter().any(|x| x == m))
122 {
123 return false;
124 }
125 if query
126 .output_modality
127 .as_deref()
128 .is_some_and(|m| !e.output_modalities.iter().any(|x| x == m))
129 {
130 return false;
131 }
132 true
133 })
134 .map(|e| ModelResponse {
135 id: e.id,
136 providers: e.providers,
137 name: e.name,
138 description: e.description,
139 max_input_tokens: e.max_input_tokens,
140 max_output_tokens: e.max_output_tokens,
141 input_modalities: e.input_modalities,
142 output_modalities: e.output_modalities,
143 pricing: e.pricing,
144 })
145 .collect();
146 warp::reply::json(&serde_json::json!({ "data": models }))
147}