1#![doc = include_str!("../README.md")]
2
3use std::{
4 net::SocketAddr,
5 path::PathBuf,
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use aide::axum::ApiRouter;
11use axum::{
12 Extension,
13 body::Body,
14 http::{Request, Response, StatusCode, Uri},
15 middleware::Next,
16 response::Redirect,
17 routing::get,
18 serve,
19};
20use brk_query::AsyncQuery;
21use quick_cache::sync::Cache;
22use tokio::net::TcpListener;
23use tower_http::{
24 catch_panic::CatchPanicLayer, classify::ServerErrorsFailureClass,
25 compression::CompressionLayer, cors::CorsLayer, normalize_path::NormalizePathLayer,
26 timeout::TimeoutLayer, trace::TraceLayer,
27};
28use tracing::{error, info};
29
30mod api;
31pub mod cache;
32mod error;
33mod extended;
34mod state;
35
36use api::*;
37pub use brk_types::Port;
38pub use brk_website::Website;
39pub use cache::{CacheParams, CacheStrategy};
40pub use error::{Error, Result};
41use state::*;
42
43pub const VERSION: &str = env!("CARGO_PKG_VERSION");
44
45pub struct Server(AppState);
46
47impl Server {
48 pub fn new(query: &AsyncQuery, data_path: PathBuf, website: Website) -> Self {
49 website.log();
50 Self(AppState {
51 client: query.client().clone(),
52 query: query.clone(),
53 data_path,
54 website,
55 cache: Arc::new(Cache::new(5_000)),
56 started_at: jiff::Timestamp::now(),
57 started_instant: Instant::now(),
58 })
59 }
60
61 pub async fn serve(self, port: Option<Port>) -> brk_error::Result<()> {
62 let state = self.0;
63
64 #[cfg(feature = "bindgen")]
65 let vecs = state.query.inner().vecs();
66
67 let compression_layer = CompressionLayer::new().br(true).gzip(true).zstd(true);
68
69 let connect_info_layer = axum::middleware::from_fn(
70 async |connect_info: axum::extract::ConnectInfo<SocketAddr>,
71 mut request: Request<Body>,
72 next: Next|
73 -> Response<Body> {
74 let mut addr = connect_info.0;
75
76 if addr.ip().is_loopback()
80 && request.headers().contains_key("CF-Connecting-IP")
81 {
82 addr.set_ip(std::net::Ipv4Addr::UNSPECIFIED.into());
83 }
84
85 request.extensions_mut().insert(addr);
86 next.run(request).await
87 },
88 );
89
90 let response_uri_layer = axum::middleware::from_fn(
91 async |request: Request<Body>, next: Next| -> Response<Body> {
92 let uri = request.uri().clone();
93 let mut response = next.run(request).await;
94 response.extensions_mut().insert(uri);
95 response
96 },
97 );
98
99 let trace_layer = TraceLayer::new_for_http()
100 .on_request(())
101 .on_response(
102 |response: &Response<Body>, latency: Duration, _: &tracing::Span| {
103 let status = response.status().as_u16();
104 let uri = response.extensions().get::<Uri>().unwrap();
105 match response.status() {
106 StatusCode::OK => info!(status, %uri, ?latency),
107 StatusCode::NOT_MODIFIED
108 | StatusCode::TEMPORARY_REDIRECT
109 | StatusCode::PERMANENT_REDIRECT => info!(status, %uri, ?latency),
110 _ => error!(status, %uri, ?latency),
111 }
112 },
113 )
114 .on_body_chunk(())
115 .on_failure(
116 |error: ServerErrorsFailureClass, latency: Duration, _: &tracing::Span| {
117 error!(?error, ?latency, "request failed");
118 },
119 )
120 .on_eos(());
121
122 let website_router = brk_website::router(state.website.clone());
123 let mut router = ApiRouter::new().add_api_routes();
124 if !state.website.is_enabled() {
125 router = router.route("/", get(Redirect::temporary("/api")));
126 }
127 let router = router
128 .with_state(state)
129 .merge(website_router)
130 .layer(CatchPanicLayer::new())
131 .layer(compression_layer)
132 .layer(response_uri_layer)
133 .layer(trace_layer)
134 .layer(TimeoutLayer::with_status_code(
135 StatusCode::GATEWAY_TIMEOUT,
136 Duration::from_secs(5),
137 ))
138 .layer(CorsLayer::permissive())
139 .layer(connect_info_layer)
140 .layer(NormalizePathLayer::trim_trailing_slash());
141
142 let (listener, port) = match port {
143 Some(port) => {
144 let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
145 (listener, *port)
146 }
147 None => {
148 let base_port: u16 = *Port::DEFAULT;
149 let max_port: u16 = base_port + 100;
150 let mut port = base_port;
151 let listener = loop {
152 match TcpListener::bind(format!("0.0.0.0:{port}")).await {
153 Ok(l) => break l,
154 Err(_) if port < max_port => port += 1,
155 Err(e) => return Err(e.into()),
156 }
157 };
158 (listener, port)
159 }
160 };
161
162 info!("Starting server on port {port}...");
163
164 let mut openapi = create_openapi();
165 let router = router.finish_api(&mut openapi);
166
167 #[cfg(feature = "bindgen")]
168 {
169 let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
170 .parent()
171 .and_then(|p| p.parent())
172 .unwrap()
173 .to_path_buf();
174
175 let output_paths = brk_bindgen::ClientOutputPaths::new()
176 .rust(workspace_root.join("crates/brk_client/src/lib.rs"))
177 .javascript(workspace_root.join("modules/brk-client/index.js"))
178 .python(workspace_root.join("packages/brk_client/brk_client/__init__.py"));
179
180 let openapi_json = serde_json::to_string(&openapi).unwrap();
181
182 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
183 brk_bindgen::generate_clients(vecs, &openapi_json, &output_paths)
184 }));
185
186 match result {
187 Ok(Ok(())) => info!("Generated clients"),
188 Ok(Err(e)) => error!("Failed to generate clients: {e}"),
189 Err(_) => error!("Client generation panicked"),
190 }
191 }
192
193 let api_json = Arc::new(ApiJson::new(&openapi));
194
195 let router = router
196 .layer(Extension(Arc::new(openapi)))
197 .layer(Extension(api_json));
198
199 serve(
200 listener,
201 router.into_make_service_with_connect_info::<SocketAddr>(),
202 )
203 .await?;
204
205 Ok(())
206 }
207}