bitrouter_api/router/
models.rs1use std::sync::Arc;
15
16use bitrouter_core::routers::routing_table::RoutingTable;
17use serde::Serialize;
18use warp::Filter;
19
20#[derive(Debug, Default)]
22pub struct ModelQuery {
23 pub provider: Option<String>,
25 pub id: Option<String>,
27 pub input_modality: Option<String>,
29 pub output_modality: Option<String>,
31}
32
33pub fn models_filter<T>(
35 table: Arc<T>,
36) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone
37where
38 T: RoutingTable + Send + Sync + 'static,
39{
40 warp::path!("v1" / "models")
41 .and(warp::get())
42 .and(optional_raw_query())
43 .and(warp::any().map(move || table.clone()))
44 .map(handle_list_models)
45}
46
47fn optional_raw_query()
50-> impl Filter<Extract = (Option<String>,), Error = std::convert::Infallible> + Clone {
51 warp::query::raw()
52 .map(Some)
53 .or(warp::any().map(|| None))
54 .unify()
55}
56
57#[derive(Serialize)]
58struct ModelResponse {
59 id: String,
60 provider: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 name: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 description: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 max_input_tokens: Option<u64>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 max_output_tokens: Option<u64>,
69 #[serde(skip_serializing_if = "Vec::is_empty")]
70 input_modalities: Vec<String>,
71 #[serde(skip_serializing_if = "Vec::is_empty")]
72 output_modalities: Vec<String>,
73}
74
75fn parse_query(raw: &str) -> ModelQuery {
76 let mut query = ModelQuery::default();
77 for pair in raw.split('&') {
78 if let Some((key, value)) = pair.split_once('=') {
79 match key {
80 "provider" => query.provider = Some(value.to_owned()),
81 "id" => query.id = Some(value.to_owned()),
82 "input_modality" => query.input_modality = Some(value.to_owned()),
83 "output_modality" => query.output_modality = Some(value.to_owned()),
84 _ => {}
85 }
86 }
87 }
88 query
89}
90
91fn handle_list_models<T: RoutingTable>(
92 raw_query: Option<String>,
93 table: Arc<T>,
94) -> impl warp::Reply {
95 let query = raw_query.as_deref().map(parse_query).unwrap_or_default();
96 let entries = table.list_models();
97 let id_lower = query.id.as_deref().map(str::to_lowercase);
98
99 let models: Vec<ModelResponse> = entries
100 .into_iter()
101 .filter(|e| {
102 if query.provider.as_deref().is_some_and(|p| e.provider != p) {
103 return false;
104 }
105 if id_lower
106 .as_deref()
107 .is_some_and(|s| !e.id.to_lowercase().contains(s))
108 {
109 return false;
110 }
111 if query
112 .input_modality
113 .as_deref()
114 .is_some_and(|m| !e.input_modalities.iter().any(|x| x == m))
115 {
116 return false;
117 }
118 if query
119 .output_modality
120 .as_deref()
121 .is_some_and(|m| !e.output_modalities.iter().any(|x| x == m))
122 {
123 return false;
124 }
125 true
126 })
127 .map(|e| ModelResponse {
128 id: e.id,
129 provider: e.provider,
130 name: e.name,
131 description: e.description,
132 max_input_tokens: e.max_input_tokens,
133 max_output_tokens: e.max_output_tokens,
134 input_modalities: e.input_modalities,
135 output_modalities: e.output_modalities,
136 })
137 .collect();
138 warp::reply::json(&serde_json::json!({ "models": models }))
139}