1#![doc = include_str!("../README.md")]
2
3use std::{
4 any::Any,
5 net::SocketAddr,
6 path::PathBuf,
7 sync::Arc,
8 time::{Duration, Instant},
9};
10
11use aide::axum::ApiRouter;
12use axum::{
13 Extension, ServiceExt,
14 body::Body,
15 http::{
16 Request, Response, StatusCode, Uri,
17 header::{CONTENT_TYPE, VARY},
18 },
19 middleware::Next,
20 response::{IntoResponse, Redirect},
21 routing::get,
22 serve,
23};
24use brk_query::AsyncQuery;
25use quick_cache::sync::Cache;
26use tokio::net::TcpListener;
27use tower_http::{
28 catch_panic::CatchPanicLayer, classify::ServerErrorsFailureClass,
29 compression::CompressionLayer, cors::CorsLayer, normalize_path::NormalizePathLayer,
30 timeout::TimeoutLayer, trace::TraceLayer,
31};
32use tower_layer::Layer;
33use tracing::{error, info};
34
35mod api;
36pub mod cache;
37mod error;
38mod extended;
39mod state;
40
41pub use api::ApiRoutes;
42use api::*;
43pub use brk_types::Port;
44pub use brk_website::Website;
45pub use cache::{CacheParams, CacheStrategy};
46pub use error::{Error, Result};
47use state::*;
48
49pub const VERSION: &str = env!("CARGO_PKG_VERSION");
50
51pub struct Server(AppState);
52
53impl Server {
54 pub fn new(query: &AsyncQuery, data_path: PathBuf, website: Website) -> Self {
55 website.log();
56 Self(AppState {
57 query: query.clone(),
58 data_path,
59 website,
60 cache: Arc::new(Cache::new(1_000)),
61 started_at: jiff::Timestamp::now(),
62 started_instant: Instant::now(),
63 })
64 }
65
66 pub async fn serve(self, port: Option<Port>) -> brk_error::Result<()> {
67 let state = self.0;
68
69 #[cfg(feature = "bindgen")]
70 let vecs = state.query.inner().vecs();
71
72 let compression_layer = CompressionLayer::new().br(true).gzip(true).zstd(true);
73
74 let connect_info_layer = axum::middleware::from_fn(
75 async |connect_info: axum::extract::ConnectInfo<SocketAddr>,
76 mut request: Request<Body>,
77 next: Next|
78 -> Response<Body> {
79 let mut addr = connect_info.0;
80
81 if addr.ip().is_loopback() && request.headers().contains_key("CF-Connecting-IP") {
85 addr.set_ip(std::net::Ipv4Addr::UNSPECIFIED.into());
86 }
87
88 request.extensions_mut().insert(addr);
89 next.run(request).await
90 },
91 );
92
93 let response_time_layer = axum::middleware::from_fn(
94 async |request: Request<Body>, next: Next| -> Response<Body> {
95 let uri = request.uri().clone();
96 let start = Instant::now();
97 let mut response = next.run(request).await;
98 response.extensions_mut().insert(uri);
99 response.headers_mut().insert(
100 "X-Response-Time",
101 format!("{}us", start.elapsed().as_micros())
102 .parse()
103 .unwrap(),
104 );
105 response
106 },
107 );
108
109 let json_error_layer = axum::middleware::from_fn(
111 async |request: Request<Body>, next: Next| -> Response<Body> {
112 let response = next.run(request).await;
113 let status = response.status();
114 if status.is_success()
115 || status.is_redirection()
116 || status.is_informational()
117 || response.headers().get(CONTENT_TYPE).is_some_and(|v| {
118 let b = v.as_bytes();
119 b.starts_with(b"application/") && b.ends_with(b"json")
120 })
121 {
122 return response;
123 }
124
125 let (parts, body) = response.into_parts();
126 let bytes = axum::body::to_bytes(body, 4096).await.unwrap_or_default();
127 let msg = String::from_utf8_lossy(&bytes);
128 let (code, msg) = match parts.status {
129 StatusCode::NOT_FOUND => (
130 "not_found",
131 if msg.is_empty() {
132 "Not found".into()
133 } else {
134 msg
135 },
136 ),
137 StatusCode::METHOD_NOT_ALLOWED => (
138 "method_not_allowed",
139 "Only GET requests are supported".into(),
140 ),
141 StatusCode::GATEWAY_TIMEOUT => ("timeout", "Request timed out".into()),
142 s if s.is_client_error() => (
143 "bad_request",
144 if msg.is_empty() {
145 "Bad request".into()
146 } else {
147 msg
148 },
149 ),
150 _ => (
151 "internal_error",
152 if msg.is_empty() {
153 "Internal server error".into()
154 } else {
155 msg
156 },
157 ),
158 };
159 let msg = msg.into_owned();
160 let mut response = Error::new(parts.status, code, msg).into_response();
161 response.extensions_mut().extend(parts.extensions);
162 response
163 },
164 );
165
166 let trace_layer = TraceLayer::new_for_http()
167 .on_request(())
168 .on_response(
169 |response: &Response<Body>, latency: Duration, _: &tracing::Span| {
170 let status = response.status().as_u16();
171 let unknown = Uri::from_static("/unknown");
172 let uri = response.extensions().get::<Uri>().unwrap_or(&unknown);
173 match response.status() {
174 StatusCode::OK => info!(status, %uri, ?latency),
175 StatusCode::NOT_MODIFIED
176 | StatusCode::TEMPORARY_REDIRECT
177 | StatusCode::PERMANENT_REDIRECT => info!(status, %uri, ?latency),
178 _ => error!(status, %uri, ?latency),
179 }
180 },
181 )
182 .on_body_chunk(())
183 .on_failure(
184 |error: ServerErrorsFailureClass, latency: Duration, _: &tracing::Span| {
185 error!(?error, ?latency, "request failed");
186 },
187 )
188 .on_eos(());
189
190 let website_router = brk_website::router(state.website.clone());
191 let mut router = ApiRouter::new().add_api_routes();
192 if !state.website.is_enabled() {
193 router = router.route("/", get(Redirect::temporary("/api")));
194 }
195 let router = router
196 .with_state(state)
197 .merge(website_router)
198 .layer(compression_layer)
199 .layer(response_time_layer)
200 .layer(trace_layer)
201 .layer(CatchPanicLayer::custom(|panic: Box<dyn Any + Send>| {
202 let msg = panic
203 .downcast_ref::<String>()
204 .map(|s| s.as_str())
205 .or_else(|| panic.downcast_ref::<&str>().copied())
206 .unwrap_or("Unknown panic");
207 Error::internal(msg).into_response()
208 }))
209 .layer(TimeoutLayer::with_status_code(
210 StatusCode::GATEWAY_TIMEOUT,
211 Duration::from_secs(5),
212 ))
213 .layer(json_error_layer)
214 .layer(CorsLayer::permissive())
215 .layer(axum::middleware::from_fn(
216 async |request: Request<Body>, next: Next| -> Response<Body> {
217 let mut response = next.run(request).await;
218 let vary: Vec<&str> = response
220 .headers()
221 .get_all(VARY)
222 .iter()
223 .filter_map(|v| v.to_str().ok())
224 .collect();
225 if vary.len() > 1 {
226 let merged = vary.join(", ");
227 response.headers_mut().insert(VARY, merged.parse().unwrap());
228 }
229 response
230 },
231 ))
232 .layer(connect_info_layer);
233
234 let (listener, port) = match port {
235 Some(port) => {
236 let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
237 (listener, *port)
238 }
239 None => {
240 let base_port: u16 = *Port::DEFAULT;
241 let max_port: u16 = base_port + 100;
242 let mut port = base_port;
243 let listener = loop {
244 match TcpListener::bind(format!("0.0.0.0:{port}")).await {
245 Ok(l) => break l,
246 Err(_) if port < max_port => port += 1,
247 Err(e) => return Err(e.into()),
248 }
249 };
250 (listener, port)
251 }
252 };
253
254 info!("Starting server on port {port}...");
255
256 let (router, openapi) = finish_openapi(router);
257
258 #[cfg(feature = "bindgen")]
259 {
260 let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
261 .parent()
262 .and_then(|p| p.parent())
263 .unwrap()
264 .to_path_buf();
265
266 let output_paths = brk_bindgen::ClientOutputPaths::new()
267 .rust(workspace_root.join("crates/brk_client/src/lib.rs"))
268 .javascript(workspace_root.join("modules/brk-client/index.js"))
269 .python(workspace_root.join("packages/brk_client/brk_client/__init__.py"));
270
271 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
272 generate_bindings(vecs, &openapi, &output_paths)
273 }));
274
275 match result {
276 Ok(Ok(())) => info!("Generated clients"),
277 Ok(Err(e)) => error!("Failed to generate clients: {e}"),
278 Err(_) => error!("Client generation panicked"),
279 }
280 }
281
282 let api_json = Arc::new(ApiJson::new(&openapi));
283
284 let router = router
285 .layer(Extension(Arc::new(openapi)))
286 .layer(Extension(api_json));
287
288 let app = NormalizePathLayer::trim_trailing_slash().layer(router);
290
291 serve(
292 listener,
293 ServiceExt::<Request<Body>>::into_make_service_with_connect_info::<SocketAddr>(app),
294 )
295 .await?;
296
297 Ok(())
298 }
299}
300
301pub fn finish_openapi<S: Clone + Send + Sync + 'static>(
303 router: ApiRouter<S>,
304) -> (axum::Router<S>, aide::openapi::OpenApi) {
305 let mut openapi = create_openapi();
306 let router = router.finish_api(&mut openapi);
307 (router, openapi)
308}
309
310#[cfg(feature = "bindgen")]
311pub fn generate_bindings(
312 vecs: &brk_query::Vecs,
313 openapi: &aide::openapi::OpenApi,
314 output_paths: &brk_bindgen::ClientOutputPaths,
315) -> std::io::Result<()> {
316 let openapi_json = serde_json::to_string(openapi)
317 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
318 brk_bindgen::generate_clients(vecs, &openapi_json, output_paths)
319}