Skip to main content

brk_server/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    any::Any,
5    net::SocketAddr,
6    path::PathBuf,
7    sync::Arc,
8    time::{Duration, Instant},
9};
10
11use aide::axum::ApiRouter;
12use axum::{
13    Extension, ServiceExt,
14    body::Body,
15    http::{
16        Request, Response, StatusCode, Uri,
17        header::{CONTENT_TYPE, VARY},
18    },
19    middleware::Next,
20    response::{IntoResponse, Redirect},
21    routing::get,
22    serve,
23};
24use brk_query::AsyncQuery;
25use quick_cache::sync::Cache;
26use tokio::net::TcpListener;
27use tower_http::{
28    catch_panic::CatchPanicLayer, classify::ServerErrorsFailureClass,
29    compression::CompressionLayer, cors::CorsLayer, normalize_path::NormalizePathLayer,
30    timeout::TimeoutLayer, trace::TraceLayer,
31};
32use tower_layer::Layer;
33use tracing::{error, info};
34
35mod api;
36pub mod cache;
37mod error;
38mod extended;
39mod state;
40
41pub use api::ApiRoutes;
42use api::*;
43pub use brk_types::Port;
44pub use brk_website::Website;
45pub use cache::{CacheParams, CacheStrategy};
46pub use error::{Error, Result};
47use state::*;
48
49pub const VERSION: &str = env!("CARGO_PKG_VERSION");
50
51pub struct Server(AppState);
52
53impl Server {
54    pub fn new(query: &AsyncQuery, data_path: PathBuf, website: Website) -> Self {
55        website.log();
56        Self(AppState {
57            query: query.clone(),
58            data_path,
59            website,
60            cache: Arc::new(Cache::new(1_000)),
61            started_at: jiff::Timestamp::now(),
62            started_instant: Instant::now(),
63        })
64    }
65
66    pub async fn serve(self, port: Option<Port>) -> brk_error::Result<()> {
67        let state = self.0;
68
69        #[cfg(feature = "bindgen")]
70        let vecs = state.query.inner().vecs();
71
72        let compression_layer = CompressionLayer::new().br(true).gzip(true).zstd(true);
73
74        let connect_info_layer = axum::middleware::from_fn(
75            async |connect_info: axum::extract::ConnectInfo<SocketAddr>,
76                   mut request: Request<Body>,
77                   next: Next|
78                   -> Response<Body> {
79                let mut addr = connect_info.0;
80
81                // When behind a reverse proxy (e.g. cloudflared), the direct
82                // connection comes from loopback but the request is external.
83                // Mark it as non-loopback so it gets the stricter limit.
84                if addr.ip().is_loopback() && request.headers().contains_key("CF-Connecting-IP") {
85                    addr.set_ip(std::net::Ipv4Addr::UNSPECIFIED.into());
86                }
87
88                request.extensions_mut().insert(addr);
89                next.run(request).await
90            },
91        );
92
93        let response_time_layer = axum::middleware::from_fn(
94            async |request: Request<Body>, next: Next| -> Response<Body> {
95                let uri = request.uri().clone();
96                let start = Instant::now();
97                let mut response = next.run(request).await;
98                response.extensions_mut().insert(uri);
99                response.headers_mut().insert(
100                    "X-Response-Time",
101                    format!("{}us", start.elapsed().as_micros())
102                        .parse()
103                        .unwrap(),
104                );
105                response
106            },
107        );
108
109        // Wrap non-JSON error responses in structured JSON
110        let json_error_layer = axum::middleware::from_fn(
111            async |request: Request<Body>, next: Next| -> Response<Body> {
112                let response = next.run(request).await;
113                let status = response.status();
114                if status.is_success()
115                    || status.is_redirection()
116                    || status.is_informational()
117                    || response.headers().get(CONTENT_TYPE).is_some_and(|v| {
118                        let b = v.as_bytes();
119                        b.starts_with(b"application/") && b.ends_with(b"json")
120                    })
121                {
122                    return response;
123                }
124
125                let (parts, body) = response.into_parts();
126                let bytes = axum::body::to_bytes(body, 4096).await.unwrap_or_default();
127                let msg = String::from_utf8_lossy(&bytes);
128                let (code, msg) = match parts.status {
129                    StatusCode::NOT_FOUND => (
130                        "not_found",
131                        if msg.is_empty() {
132                            "Not found".into()
133                        } else {
134                            msg
135                        },
136                    ),
137                    StatusCode::METHOD_NOT_ALLOWED => (
138                        "method_not_allowed",
139                        "Only GET requests are supported".into(),
140                    ),
141                    StatusCode::GATEWAY_TIMEOUT => ("timeout", "Request timed out".into()),
142                    s if s.is_client_error() => (
143                        "bad_request",
144                        if msg.is_empty() {
145                            "Bad request".into()
146                        } else {
147                            msg
148                        },
149                    ),
150                    _ => (
151                        "internal_error",
152                        if msg.is_empty() {
153                            "Internal server error".into()
154                        } else {
155                            msg
156                        },
157                    ),
158                };
159                let msg = msg.into_owned();
160                let mut response = Error::new(parts.status, code, msg).into_response();
161                response.extensions_mut().extend(parts.extensions);
162                response
163            },
164        );
165
166        let trace_layer = TraceLayer::new_for_http()
167            .on_request(())
168            .on_response(
169                |response: &Response<Body>, latency: Duration, _: &tracing::Span| {
170                    let status = response.status().as_u16();
171                    let unknown = Uri::from_static("/unknown");
172                    let uri = response.extensions().get::<Uri>().unwrap_or(&unknown);
173                    match response.status() {
174                        StatusCode::OK => info!(status, %uri, ?latency),
175                        StatusCode::NOT_MODIFIED
176                        | StatusCode::TEMPORARY_REDIRECT
177                        | StatusCode::PERMANENT_REDIRECT => info!(status, %uri, ?latency),
178                        _ => error!(status, %uri, ?latency),
179                    }
180                },
181            )
182            .on_body_chunk(())
183            .on_failure(
184                |error: ServerErrorsFailureClass, latency: Duration, _: &tracing::Span| {
185                    error!(?error, ?latency, "request failed");
186                },
187            )
188            .on_eos(());
189
190        let website_router = brk_website::router(state.website.clone());
191        let mut router = ApiRouter::new().add_api_routes();
192        if !state.website.is_enabled() {
193            router = router.route("/", get(Redirect::temporary("/api")));
194        }
195        let router = router
196            .with_state(state)
197            .merge(website_router)
198            .layer(compression_layer)
199            .layer(response_time_layer)
200            .layer(trace_layer)
201            .layer(CatchPanicLayer::custom(|panic: Box<dyn Any + Send>| {
202                let msg = panic
203                    .downcast_ref::<String>()
204                    .map(|s| s.as_str())
205                    .or_else(|| panic.downcast_ref::<&str>().copied())
206                    .unwrap_or("Unknown panic");
207                Error::internal(msg).into_response()
208            }))
209            .layer(TimeoutLayer::with_status_code(
210                StatusCode::GATEWAY_TIMEOUT,
211                Duration::from_secs(5),
212            ))
213            .layer(json_error_layer)
214            .layer(CorsLayer::permissive())
215            .layer(axum::middleware::from_fn(
216                async |request: Request<Body>, next: Next| -> Response<Body> {
217                    let mut response = next.run(request).await;
218                    // Consolidate multiple Vary headers into one
219                    let vary: Vec<&str> = response
220                        .headers()
221                        .get_all(VARY)
222                        .iter()
223                        .filter_map(|v| v.to_str().ok())
224                        .collect();
225                    if vary.len() > 1 {
226                        let merged = vary.join(", ");
227                        response.headers_mut().insert(VARY, merged.parse().unwrap());
228                    }
229                    response
230                },
231            ))
232            .layer(connect_info_layer);
233
234        let (listener, port) = match port {
235            Some(port) => {
236                let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
237                (listener, *port)
238            }
239            None => {
240                let base_port: u16 = *Port::DEFAULT;
241                let max_port: u16 = base_port + 100;
242                let mut port = base_port;
243                let listener = loop {
244                    match TcpListener::bind(format!("0.0.0.0:{port}")).await {
245                        Ok(l) => break l,
246                        Err(_) if port < max_port => port += 1,
247                        Err(e) => return Err(e.into()),
248                    }
249                };
250                (listener, port)
251            }
252        };
253
254        info!("Starting server on port {port}...");
255
256        let (router, openapi) = finish_openapi(router);
257
258        #[cfg(feature = "bindgen")]
259        {
260            let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
261                .parent()
262                .and_then(|p| p.parent())
263                .unwrap()
264                .to_path_buf();
265
266            let output_paths = brk_bindgen::ClientOutputPaths::new()
267                .rust(workspace_root.join("crates/brk_client/src/lib.rs"))
268                .javascript(workspace_root.join("modules/brk-client/index.js"))
269                .python(workspace_root.join("packages/brk_client/brk_client/__init__.py"));
270
271            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
272                generate_bindings(vecs, &openapi, &output_paths)
273            }));
274
275            match result {
276                Ok(Ok(())) => info!("Generated clients"),
277                Ok(Err(e)) => error!("Failed to generate clients: {e}"),
278                Err(_) => error!("Client generation panicked"),
279            }
280        }
281
282        let api_json = Arc::new(ApiJson::new(&openapi));
283
284        let router = router
285            .layer(Extension(Arc::new(openapi)))
286            .layer(Extension(api_json));
287
288        // NormalizePath must wrap the router (not be a layer) to run before route matching
289        let app = NormalizePathLayer::trim_trailing_slash().layer(router);
290
291        serve(
292            listener,
293            ServiceExt::<Request<Body>>::into_make_service_with_connect_info::<SocketAddr>(app),
294        )
295        .await?;
296
297        Ok(())
298    }
299}
300
301/// Finalize a router and extract the OpenAPI spec.
302pub fn finish_openapi<S: Clone + Send + Sync + 'static>(
303    router: ApiRouter<S>,
304) -> (axum::Router<S>, aide::openapi::OpenApi) {
305    let mut openapi = create_openapi();
306    let router = router.finish_api(&mut openapi);
307    (router, openapi)
308}
309
310#[cfg(feature = "bindgen")]
311pub fn generate_bindings(
312    vecs: &brk_query::Vecs,
313    openapi: &aide::openapi::OpenApi,
314    output_paths: &brk_bindgen::ClientOutputPaths,
315) -> std::io::Result<()> {
316    let openapi_json = serde_json::to_string(openapi)
317        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
318    brk_bindgen::generate_clients(vecs, &openapi_json, output_paths)
319}