Skip to main content

brk_server/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    net::SocketAddr,
5    path::PathBuf,
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use aide::axum::ApiRouter;
11use axum::{
12    Extension,
13    body::Body,
14    http::{Request, Response, StatusCode, Uri},
15    middleware::Next,
16    response::Redirect,
17    routing::get,
18    serve,
19};
20use brk_query::AsyncQuery;
21use quick_cache::sync::Cache;
22use tokio::net::TcpListener;
23use tower_http::{
24    catch_panic::CatchPanicLayer, classify::ServerErrorsFailureClass,
25    compression::CompressionLayer, cors::CorsLayer, normalize_path::NormalizePathLayer,
26    timeout::TimeoutLayer, trace::TraceLayer,
27};
28use tracing::{error, info};
29
30mod api;
31pub mod cache;
32mod error;
33mod extended;
34mod state;
35
36use api::*;
37pub use brk_types::Port;
38pub use brk_website::Website;
39pub use cache::{CacheParams, CacheStrategy};
40pub use error::{Error, Result};
41use state::*;
42
43pub const VERSION: &str = env!("CARGO_PKG_VERSION");
44
45pub struct Server(AppState);
46
47impl Server {
48    pub fn new(query: &AsyncQuery, data_path: PathBuf, website: Website) -> Self {
49        website.log();
50        Self(AppState {
51            client: query.client().clone(),
52            query: query.clone(),
53            data_path,
54            website,
55            cache: Arc::new(Cache::new(5_000)),
56            started_at: jiff::Timestamp::now(),
57            started_instant: Instant::now(),
58        })
59    }
60
61    pub async fn serve(self, port: Option<Port>) -> brk_error::Result<()> {
62        let state = self.0;
63
64        #[cfg(feature = "bindgen")]
65        let vecs = state.query.inner().vecs();
66
67        let compression_layer = CompressionLayer::new().br(true).gzip(true).zstd(true);
68
69        let connect_info_layer = axum::middleware::from_fn(
70            async |connect_info: axum::extract::ConnectInfo<SocketAddr>,
71                   mut request: Request<Body>,
72                   next: Next|
73                   -> Response<Body> {
74                let mut addr = connect_info.0;
75
76                // When behind a reverse proxy (e.g. cloudflared), the direct
77                // connection comes from loopback but the request is external.
78                // Mark it as non-loopback so it gets the stricter limit.
79                if addr.ip().is_loopback()
80                    && request.headers().contains_key("CF-Connecting-IP")
81                {
82                    addr.set_ip(std::net::Ipv4Addr::UNSPECIFIED.into());
83                }
84
85                request.extensions_mut().insert(addr);
86                next.run(request).await
87            },
88        );
89
90        let response_uri_layer = axum::middleware::from_fn(
91            async |request: Request<Body>, next: Next| -> Response<Body> {
92                let uri = request.uri().clone();
93                let mut response = next.run(request).await;
94                response.extensions_mut().insert(uri);
95                response
96            },
97        );
98
99        let trace_layer = TraceLayer::new_for_http()
100            .on_request(())
101            .on_response(
102                |response: &Response<Body>, latency: Duration, _: &tracing::Span| {
103                    let status = response.status().as_u16();
104                    let uri = response.extensions().get::<Uri>().unwrap();
105                    match response.status() {
106                        StatusCode::OK => info!(status, %uri, ?latency),
107                        StatusCode::NOT_MODIFIED
108                        | StatusCode::TEMPORARY_REDIRECT
109                        | StatusCode::PERMANENT_REDIRECT => info!(status, %uri, ?latency),
110                        _ => error!(status, %uri, ?latency),
111                    }
112                },
113            )
114            .on_body_chunk(())
115            .on_failure(
116                |error: ServerErrorsFailureClass, latency: Duration, _: &tracing::Span| {
117                    error!(?error, ?latency, "request failed");
118                },
119            )
120            .on_eos(());
121
122        let website_router = brk_website::router(state.website.clone());
123        let mut router = ApiRouter::new().add_api_routes();
124        if !state.website.is_enabled() {
125            router = router.route("/", get(Redirect::temporary("/api")));
126        }
127        let router = router
128            .with_state(state)
129            .merge(website_router)
130            .layer(CatchPanicLayer::new())
131            .layer(compression_layer)
132            .layer(response_uri_layer)
133            .layer(trace_layer)
134            .layer(TimeoutLayer::with_status_code(
135                StatusCode::GATEWAY_TIMEOUT,
136                Duration::from_secs(5),
137            ))
138            .layer(CorsLayer::permissive())
139            .layer(connect_info_layer)
140            .layer(NormalizePathLayer::trim_trailing_slash());
141
142        let (listener, port) = match port {
143            Some(port) => {
144                let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
145                (listener, *port)
146            }
147            None => {
148                let base_port: u16 = *Port::DEFAULT;
149                let max_port: u16 = base_port + 100;
150                let mut port = base_port;
151                let listener = loop {
152                    match TcpListener::bind(format!("0.0.0.0:{port}")).await {
153                        Ok(l) => break l,
154                        Err(_) if port < max_port => port += 1,
155                        Err(e) => return Err(e.into()),
156                    }
157                };
158                (listener, port)
159            }
160        };
161
162        info!("Starting server on port {port}...");
163
164        let mut openapi = create_openapi();
165        let router = router.finish_api(&mut openapi);
166
167        #[cfg(feature = "bindgen")]
168        {
169            let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
170                .parent()
171                .and_then(|p| p.parent())
172                .unwrap()
173                .to_path_buf();
174
175            let output_paths = brk_bindgen::ClientOutputPaths::new()
176                .rust(workspace_root.join("crates/brk_client/src/lib.rs"))
177                .javascript(workspace_root.join("modules/brk-client/index.js"))
178                .python(workspace_root.join("packages/brk_client/brk_client/__init__.py"));
179
180            let openapi_json = serde_json::to_string(&openapi).unwrap();
181
182            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
183                brk_bindgen::generate_clients(vecs, &openapi_json, &output_paths)
184            }));
185
186            match result {
187                Ok(Ok(())) => info!("Generated clients"),
188                Ok(Err(e)) => error!("Failed to generate clients: {e}"),
189                Err(_) => error!("Client generation panicked"),
190            }
191        }
192
193        let api_json = Arc::new(ApiJson::new(&openapi));
194
195        let router = router
196            .layer(Extension(Arc::new(openapi)))
197            .layer(Extension(api_json));
198
199        serve(
200            listener,
201            router.into_make_service_with_connect_info::<SocketAddr>(),
202        )
203        .await?;
204
205        Ok(())
206    }
207}