Skip to main content

brk_server/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    any::Any,
5    time::{Duration, Instant},
6};
7
8#[cfg(feature = "bindgen")]
9use std::path::PathBuf;
10
11use aide::axum::ApiRouter;
12use axum::{
13    Extension, ServiceExt,
14    body::Body,
15    http::{
16        Request, Response, StatusCode, Uri,
17        header::{ALLOW, CONTENT_TYPE, VARY},
18    },
19    middleware::Next,
20    response::{IntoResponse, Redirect},
21    routing::get,
22    serve,
23};
24use brk_query::AsyncQuery;
25use tokio::net::TcpListener;
26use tower_http::{
27    catch_panic::CatchPanicLayer,
28    classify::ServerErrorsFailureClass,
29    compression::{CompressionLayer, CompressionLevel},
30    cors::CorsLayer,
31    normalize_path::NormalizePathLayer,
32    timeout::TimeoutLayer,
33    trace::TraceLayer,
34};
35use tower_layer::Layer;
36use tracing::{error, info};
37
38mod api;
39mod cache;
40mod config;
41mod error;
42mod etag;
43mod extended;
44mod params;
45mod state;
46
47pub use api::ApiRoutes;
48use api::*;
49pub use brk_types::Port;
50pub use brk_website::Website;
51pub use cache::CdnCacheMode;
52use cache::{CacheParams, CacheStrategy};
53pub use config::{DEFAULT_MAX_UTXOS, DEFAULT_MAX_WEIGHT, ServerConfig};
54pub use error::{Error, Result};
55use state::*;
56
57pub const VERSION: &str = env!("CARGO_PKG_VERSION");
58
59/// Cap for buffering an upstream error body before re-wrapping it as JSON.
60/// Larger bodies are truncated; the bound only affects the message we surface.
61const MAX_ERROR_BODY_BYTES: usize = 4096;
62
63/// Per-request timeout. Hits return 504 Gateway Timeout.
64const REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
65
66/// Matches `application/json` and `application/...+json`, ignoring parameters
67/// like `; charset=utf-8`. Used to skip JSON-error rewriting for already-JSON bodies.
68fn is_json_content_type(s: &str) -> bool {
69    let mime = s.split(';').next().unwrap_or("").trim();
70    mime == "application/json" || (mime.starts_with("application/") && mime.ends_with("+json"))
71}
72
73pub struct Server(AppState);
74
75impl Server {
76    pub fn new(query: &AsyncQuery, config: ServerConfig) -> Self {
77        config.website.log();
78        cache::init(config.cdn_cache_mode);
79        Self(AppState {
80            query: query.clone(),
81            data_path: config.data_path,
82            website: config.website,
83            started_at: jiff::Timestamp::now(),
84            started_instant: Instant::now(),
85            max_weight: config.max_weight,
86            max_utxos: config.max_utxos,
87        })
88    }
89
90    pub async fn serve(self, port: Option<Port>) -> brk_error::Result<()> {
91        let state = self.0;
92
93        #[cfg(feature = "bindgen")]
94        let vecs = state.query.inner().vecs();
95
96        let compression_layer = CompressionLayer::new()
97            .br(true)
98            .gzip(true)
99            .zstd(true)
100            .quality(CompressionLevel::Precise(3));
101
102        let response_time_layer = axum::middleware::from_fn(
103            async |request: Request<Body>, next: Next| -> Response<Body> {
104                let uri = request.uri().clone();
105                let method = request.method().clone();
106                let start = Instant::now();
107                let mut response = next.run(request).await;
108                response.extensions_mut().insert(uri);
109                response.extensions_mut().insert(method);
110                response.headers_mut().insert(
111                    "X-Response-Time",
112                    format!("{}us", start.elapsed().as_micros())
113                        .parse()
114                        .unwrap(),
115                );
116                response
117            },
118        );
119
120        // Wrap non-JSON error responses in structured JSON
121        let json_error_layer = axum::middleware::from_fn(
122            async |request: Request<Body>, next: Next| -> Response<Body> {
123                let response = next.run(request).await;
124                let status = response.status();
125                if status.is_success()
126                    || status.is_redirection()
127                    || status.is_informational()
128                    || response
129                        .headers()
130                        .get(CONTENT_TYPE)
131                        .is_some_and(|v| v.to_str().is_ok_and(is_json_content_type))
132                {
133                    return response;
134                }
135
136                let (parts, body) = response.into_parts();
137                let bytes = axum::body::to_bytes(body, MAX_ERROR_BODY_BYTES)
138                    .await
139                    .unwrap_or_default();
140                let msg = String::from_utf8_lossy(&bytes);
141                let (code, msg) = match parts.status {
142                    StatusCode::NOT_FOUND => (
143                        "not_found",
144                        if msg.is_empty() {
145                            "Not found".into()
146                        } else {
147                            msg
148                        },
149                    ),
150                    StatusCode::METHOD_NOT_ALLOWED => (
151                        "method_not_allowed",
152                        "Only GET requests are supported".into(),
153                    ),
154                    StatusCode::GATEWAY_TIMEOUT => ("timeout", "Request timed out".into()),
155                    s if s.is_client_error() => (
156                        "bad_request",
157                        if msg.is_empty() {
158                            "Bad request".into()
159                        } else {
160                            msg
161                        },
162                    ),
163                    _ => (
164                        "internal_error",
165                        if msg.is_empty() {
166                            "Internal server error".into()
167                        } else {
168                            msg
169                        },
170                    ),
171                };
172                let msg = msg.into_owned();
173                let mut response = Error::new(parts.status, code, msg).into_response();
174                response.extensions_mut().extend(parts.extensions);
175                if let Some(allow) = parts.headers.get(ALLOW) {
176                    response.headers_mut().insert(ALLOW, allow.clone());
177                }
178                response
179            },
180        );
181
182        let trace_layer = TraceLayer::new_for_http()
183            .on_request(())
184            .on_response(
185                |response: &Response<Body>, latency: Duration, _: &tracing::Span| {
186                    let status = response.status().as_u16();
187                    let unknown_uri = Uri::from_static("/unknown");
188                    let unknown_method = axum::http::Method::default();
189                    let uri = response.extensions().get::<Uri>().unwrap_or(&unknown_uri);
190                    let method = response
191                        .extensions()
192                        .get::<axum::http::Method>()
193                        .unwrap_or(&unknown_method);
194                    match response.status() {
195                        StatusCode::OK => info!(%method, status, %uri, ?latency),
196                        StatusCode::NOT_MODIFIED
197                        | StatusCode::TEMPORARY_REDIRECT
198                        | StatusCode::PERMANENT_REDIRECT => info!(%method, status, %uri, ?latency),
199                        _ => error!(%method, status, %uri, ?latency),
200                    }
201                },
202            )
203            .on_body_chunk(())
204            .on_failure(
205                |error: ServerErrorsFailureClass, latency: Duration, _: &tracing::Span| {
206                    error!(?error, ?latency, "request failed");
207                },
208            )
209            .on_eos(());
210
211        let website_router = brk_website::router(state.website.clone());
212        let mut router = ApiRouter::new().add_api_routes();
213        if !state.website.is_enabled() {
214            router = router.route("/", get(Redirect::temporary("/api")));
215        }
216        let router = router
217            .with_state(state)
218            .merge(website_router)
219            .layer(TimeoutLayer::with_status_code(
220                StatusCode::GATEWAY_TIMEOUT,
221                REQUEST_TIMEOUT,
222            ))
223            .layer(json_error_layer)
224            .layer(compression_layer)
225            .layer(CorsLayer::permissive())
226            .layer(axum::middleware::from_fn(
227                async |request: Request<Body>, next: Next| -> Response<Body> {
228                    let mut response = next.run(request).await;
229                    // Consolidate multiple Vary headers into one
230                    let vary: Vec<&str> = response
231                        .headers()
232                        .get_all(VARY)
233                        .iter()
234                        .filter_map(|v| v.to_str().ok())
235                        .collect();
236                    if vary.len() > 1 {
237                        let merged = vary.join(", ");
238                        response.headers_mut().insert(VARY, merged.parse().unwrap());
239                    }
240                    response
241                },
242            ))
243            .layer(CatchPanicLayer::custom(|panic: Box<dyn Any + Send>| {
244                let msg = panic
245                    .downcast_ref::<String>()
246                    .map(|s| s.as_str())
247                    .or_else(|| panic.downcast_ref::<&str>().copied())
248                    .unwrap_or("Unknown panic");
249                Error::internal(msg).into_response()
250            }))
251            .layer(response_time_layer)
252            .layer(trace_layer);
253
254        let (listener, port) = match port {
255            Some(port) => {
256                let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
257                (listener, *port)
258            }
259            None => {
260                let base_port: u16 = *Port::DEFAULT;
261                let max_port: u16 = base_port + 100;
262                let mut port = base_port;
263                let listener = loop {
264                    match TcpListener::bind(format!("0.0.0.0:{port}")).await {
265                        Ok(l) => break l,
266                        Err(_) if port < max_port => port += 1,
267                        Err(e) => return Err(e.into()),
268                    }
269                };
270                (listener, port)
271            }
272        };
273
274        info!("Starting server on port {port}...");
275
276        let (router, openapi) = finish_openapi(router);
277
278        #[cfg(feature = "bindgen")]
279        {
280            let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
281                .parent()
282                .and_then(|p| p.parent())
283                .unwrap()
284                .to_path_buf();
285
286            let output_paths = brk_bindgen::ClientOutputPaths::new()
287                .rust(workspace_root.join("crates/brk_client/src/lib.rs"))
288                .javascript(workspace_root.join("modules/brk-client/index.js"))
289                .python(workspace_root.join("packages/brk_client/brk_client/__init__.py"));
290
291            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
292                generate_bindings(vecs, &openapi, &output_paths)
293            }));
294
295            match result {
296                Ok(Ok(())) => info!("Generated clients"),
297                Ok(Err(e)) => error!("Failed to generate clients: {e}"),
298                Err(_) => error!("Client generation panicked"),
299            }
300        }
301
302        let router = router
303            .layer(Extension(OpenApiJson::new(&openapi)))
304            .layer(Extension(ApiJson::new(&openapi)));
305
306        // NormalizePath must wrap the router (not be a layer) to run before route matching
307        let app = NormalizePathLayer::trim_trailing_slash().layer(router);
308
309        serve(
310            listener,
311            ServiceExt::<Request<Body>>::into_make_service(app),
312        )
313        .await?;
314
315        Ok(())
316    }
317}
318
319/// Finalize a router and extract the OpenAPI spec.
320pub fn finish_openapi<S: Clone + Send + Sync + 'static>(
321    router: ApiRouter<S>,
322) -> (axum::Router<S>, aide::openapi::OpenApi) {
323    let mut openapi = create_openapi();
324    let router = router.finish_api(&mut openapi);
325    (router, openapi)
326}
327
328#[cfg(feature = "bindgen")]
329pub fn generate_bindings(
330    vecs: &brk_query::Vecs,
331    openapi: &aide::openapi::OpenApi,
332    output_paths: &brk_bindgen::ClientOutputPaths,
333) -> std::io::Result<()> {
334    let openapi_json = serde_json::to_string(openapi)
335        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
336    brk_bindgen::generate_clients(vecs, &openapi_json, output_paths)
337}
338
339#[cfg(test)]
340mod tests {
341    use super::is_json_content_type;
342
343    #[test]
344    fn json_content_type_matches() {
345        assert!(is_json_content_type("application/json"));
346        assert!(is_json_content_type("application/json; charset=utf-8"));
347        assert!(is_json_content_type("  application/json  "));
348        assert!(is_json_content_type("application/problem+json"));
349        assert!(is_json_content_type(
350            "application/vnd.api+json; charset=utf-8"
351        ));
352    }
353
354    #[test]
355    fn json_content_type_rejects_non_json() {
356        assert!(!is_json_content_type("text/plain"));
357        assert!(!is_json_content_type("application/xml"));
358        assert!(!is_json_content_type("application/json+xml"));
359        assert!(!is_json_content_type(""));
360        assert!(!is_json_content_type("text/json"));
361    }
362}