barreleye_server/
lib.rs

1use axum::{
2	error_handling::HandleErrorLayer,
3	extract::State,
4	http::{header, Method, Request, StatusCode, Uri},
5	middleware::{self, Next},
6	response::Response,
7	BoxError, Router, Server as AxumServer,
8};
9use console::style;
10use eyre::{Report, Result};
11use hyper::server::{accept::Accept, conn::AddrIncoming};
12use signal::unix::SignalKind;
13use std::{
14	net::SocketAddr,
15	pin::Pin,
16	sync::Arc,
17	task::{Context, Poll},
18	time::Duration,
19};
20use tokio::signal;
21use tower::ServiceBuilder;
22use uuid::Uuid;
23
24use crate::errors::ServerError;
25use barreleye_common::{
26	models::ApiKey, quit, App, AppError, Progress, ProgressReadyType, ProgressStep, Warnings,
27};
28
29mod errors;
30mod handlers;
31
32pub type ServerResult<T> = Result<T, ServerError>;
33
34struct CombinedIncoming {
35	a: AddrIncoming,
36	b: AddrIncoming,
37}
38
39impl Accept for CombinedIncoming {
40	type Conn = <AddrIncoming as Accept>::Conn;
41	type Error = <AddrIncoming as Accept>::Error;
42
43	fn poll_accept(
44		mut self: Pin<&mut Self>,
45		cx: &mut Context<'_>,
46	) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
47		if let Poll::Ready(Some(value)) = Pin::new(&mut self.a).poll_accept(cx) {
48			return Poll::Ready(Some(value));
49		}
50
51		if let Poll::Ready(Some(value)) = Pin::new(&mut self.b).poll_accept(cx) {
52			return Poll::Ready(Some(value));
53		}
54
55		Poll::Pending
56	}
57}
58
59pub struct Server {
60	app: Arc<App>,
61}
62
63impl Server {
64	pub fn new(app: Arc<App>) -> Self {
65		Self { app }
66	}
67
68	async fn auth<B>(
69		State(app): State<Arc<App>>,
70		req: Request<B>,
71		next: Next<B>,
72	) -> ServerResult<Response> {
73		for public_endpoint in vec!["/v0/assets", "/v0/upstream", "/v0/related"].iter() {
74			if req.uri().to_string().starts_with(public_endpoint) {
75				return Ok(next.run(req).await);
76			}
77		}
78
79		let authorization = req
80			.headers()
81			.get(header::AUTHORIZATION)
82			.ok_or(ServerError::Unauthorized)?
83			.to_str()
84			.map_err(|_| ServerError::Unauthorized)?;
85
86		let token = match authorization.split_once(' ') {
87			Some((name, contents)) if name == "Bearer" => contents.to_string(),
88			_ => return Err(ServerError::Unauthorized),
89		};
90
91		let api_key = Uuid::parse_str(&token).map_err(|_| ServerError::Unauthorized)?;
92
93		match ApiKey::get_by_uuid(&app.db, &api_key).await.map_err(|_| ServerError::Unauthorized)? {
94			Some(api_key) if api_key.is_active => Ok(next.run(req).await),
95			_ => Err(ServerError::Unauthorized),
96		}
97	}
98
99	pub async fn start(&self, warnings: Warnings, progress: Progress) -> Result<()> {
100		let settings = self.app.settings.clone();
101
102		async fn handle_404() -> ServerResult<StatusCode> {
103			Err(ServerError::NotFound)
104		}
105
106		async fn handle_timeout_error(
107			method: Method,
108			uri: Uri,
109			_err: BoxError,
110		) -> ServerResult<StatusCode> {
111			Err(ServerError::Internal { error: Report::msg(format!("`{method} {uri}` timed out")) })
112		}
113
114		let app = Router::new()
115			.nest("/", handlers::get_routes())
116			.route_layer(middleware::from_fn_with_state(self.app.clone(), Self::auth))
117			.fallback(handle_404)
118			.layer(
119				ServiceBuilder::new()
120					.layer(HandleErrorLayer::new(handle_timeout_error))
121					.timeout(Duration::from_secs(30)),
122			)
123			.with_state(self.app.clone());
124
125		let ipv4 = SocketAddr::new(settings.server.ip_v4.parse()?, settings.server.port);
126
127		let show_progress = |addr: &str| {
128			progress.show(ProgressStep::Ready(
129				if self.app.is_indexer && self.app.is_server {
130					ProgressReadyType::All(addr.to_string())
131				} else {
132					ProgressReadyType::Server(addr.to_string())
133				},
134				warnings,
135			))
136		};
137
138		if settings.server.ip_v6.is_empty() {
139			show_progress(&style(ipv4).bold().to_string());
140
141			match AxumServer::try_bind(&ipv4) {
142				Err(e) => quit(AppError::ServerStartup {
143					url: ipv4.to_string(),
144					error: e.message().to_string(),
145				}),
146				Ok(server) => {
147					self.app.set_is_ready();
148					server
149						.serve(app.into_make_service())
150						.with_graceful_shutdown(Self::shutdown_signal())
151						.await?
152				}
153			}
154		} else {
155			let ipv6 = SocketAddr::new(settings.server.ip_v6.parse()?, settings.server.port);
156
157			match (AddrIncoming::bind(&ipv4), AddrIncoming::bind(&ipv6)) {
158				(Err(e), _) => quit(AppError::ServerStartup {
159					url: ipv4.to_string(),
160					error: e.message().to_string(),
161				}),
162				(_, Err(e)) => quit(AppError::ServerStartup {
163					url: ipv6.to_string(),
164					error: e.message().to_string(),
165				}),
166				(Ok(a), Ok(b)) => {
167					show_progress(&format!("{} & {}", style(ipv4).bold(), style(ipv6).bold()));
168
169					self.app.set_is_ready();
170					AxumServer::builder(CombinedIncoming { a, b })
171						.serve(app.into_make_service())
172						.with_graceful_shutdown(Self::shutdown_signal())
173						.await?;
174				}
175			}
176		}
177
178		Ok(())
179	}
180
181	async fn shutdown_signal() {
182		let ctrl_c = async {
183			if signal::ctrl_c().await.is_err() {
184				quit(AppError::SignalHandler);
185			}
186		};
187
188		#[cfg(unix)]
189		let terminate = async {
190			match signal::unix::signal(SignalKind::terminate()) {
191				Ok(mut signal) => {
192					signal.recv().await;
193				}
194				_ => quit(AppError::SignalHandler),
195			};
196		};
197
198		#[cfg(not(unix))]
199		let terminate = future::pending::<()>();
200
201		tokio::select! {
202			_ = ctrl_c => {},
203			_ = terminate => {},
204		}
205	}
206}