1use axum::{
2 error_handling::HandleErrorLayer,
3 extract::State,
4 http::{header, Method, Request, StatusCode, Uri},
5 middleware::{self, Next},
6 response::Response,
7 BoxError, Router, Server as AxumServer,
8};
9use console::style;
10use eyre::{Report, Result};
11use hyper::server::{accept::Accept, conn::AddrIncoming};
12use signal::unix::SignalKind;
13use std::{
14 net::SocketAddr,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18 time::Duration,
19};
20use tokio::signal;
21use tower::ServiceBuilder;
22use uuid::Uuid;
23
24use crate::errors::ServerError;
25use barreleye_common::{
26 models::ApiKey, quit, App, AppError, Progress, ProgressReadyType, ProgressStep, Warnings,
27};
28
29mod errors;
30mod handlers;
31
32pub type ServerResult<T> = Result<T, ServerError>;
33
34struct CombinedIncoming {
35 a: AddrIncoming,
36 b: AddrIncoming,
37}
38
39impl Accept for CombinedIncoming {
40 type Conn = <AddrIncoming as Accept>::Conn;
41 type Error = <AddrIncoming as Accept>::Error;
42
43 fn poll_accept(
44 mut self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
47 if let Poll::Ready(Some(value)) = Pin::new(&mut self.a).poll_accept(cx) {
48 return Poll::Ready(Some(value));
49 }
50
51 if let Poll::Ready(Some(value)) = Pin::new(&mut self.b).poll_accept(cx) {
52 return Poll::Ready(Some(value));
53 }
54
55 Poll::Pending
56 }
57}
58
59pub struct Server {
60 app: Arc<App>,
61}
62
63impl Server {
64 pub fn new(app: Arc<App>) -> Self {
65 Self { app }
66 }
67
68 async fn auth<B>(
69 State(app): State<Arc<App>>,
70 req: Request<B>,
71 next: Next<B>,
72 ) -> ServerResult<Response> {
73 for public_endpoint in vec!["/v0/assets", "/v0/upstream", "/v0/related"].iter() {
74 if req.uri().to_string().starts_with(public_endpoint) {
75 return Ok(next.run(req).await);
76 }
77 }
78
79 let authorization = req
80 .headers()
81 .get(header::AUTHORIZATION)
82 .ok_or(ServerError::Unauthorized)?
83 .to_str()
84 .map_err(|_| ServerError::Unauthorized)?;
85
86 let token = match authorization.split_once(' ') {
87 Some((name, contents)) if name == "Bearer" => contents.to_string(),
88 _ => return Err(ServerError::Unauthorized),
89 };
90
91 let api_key = Uuid::parse_str(&token).map_err(|_| ServerError::Unauthorized)?;
92
93 match ApiKey::get_by_uuid(&app.db, &api_key).await.map_err(|_| ServerError::Unauthorized)? {
94 Some(api_key) if api_key.is_active => Ok(next.run(req).await),
95 _ => Err(ServerError::Unauthorized),
96 }
97 }
98
99 pub async fn start(&self, warnings: Warnings, progress: Progress) -> Result<()> {
100 let settings = self.app.settings.clone();
101
102 async fn handle_404() -> ServerResult<StatusCode> {
103 Err(ServerError::NotFound)
104 }
105
106 async fn handle_timeout_error(
107 method: Method,
108 uri: Uri,
109 _err: BoxError,
110 ) -> ServerResult<StatusCode> {
111 Err(ServerError::Internal { error: Report::msg(format!("`{method} {uri}` timed out")) })
112 }
113
114 let app = Router::new()
115 .nest("/", handlers::get_routes())
116 .route_layer(middleware::from_fn_with_state(self.app.clone(), Self::auth))
117 .fallback(handle_404)
118 .layer(
119 ServiceBuilder::new()
120 .layer(HandleErrorLayer::new(handle_timeout_error))
121 .timeout(Duration::from_secs(30)),
122 )
123 .with_state(self.app.clone());
124
125 let ipv4 = SocketAddr::new(settings.server.ip_v4.parse()?, settings.server.port);
126
127 let show_progress = |addr: &str| {
128 progress.show(ProgressStep::Ready(
129 if self.app.is_indexer && self.app.is_server {
130 ProgressReadyType::All(addr.to_string())
131 } else {
132 ProgressReadyType::Server(addr.to_string())
133 },
134 warnings,
135 ))
136 };
137
138 if settings.server.ip_v6.is_empty() {
139 show_progress(&style(ipv4).bold().to_string());
140
141 match AxumServer::try_bind(&ipv4) {
142 Err(e) => quit(AppError::ServerStartup {
143 url: ipv4.to_string(),
144 error: e.message().to_string(),
145 }),
146 Ok(server) => {
147 self.app.set_is_ready();
148 server
149 .serve(app.into_make_service())
150 .with_graceful_shutdown(Self::shutdown_signal())
151 .await?
152 }
153 }
154 } else {
155 let ipv6 = SocketAddr::new(settings.server.ip_v6.parse()?, settings.server.port);
156
157 match (AddrIncoming::bind(&ipv4), AddrIncoming::bind(&ipv6)) {
158 (Err(e), _) => quit(AppError::ServerStartup {
159 url: ipv4.to_string(),
160 error: e.message().to_string(),
161 }),
162 (_, Err(e)) => quit(AppError::ServerStartup {
163 url: ipv6.to_string(),
164 error: e.message().to_string(),
165 }),
166 (Ok(a), Ok(b)) => {
167 show_progress(&format!("{} & {}", style(ipv4).bold(), style(ipv6).bold()));
168
169 self.app.set_is_ready();
170 AxumServer::builder(CombinedIncoming { a, b })
171 .serve(app.into_make_service())
172 .with_graceful_shutdown(Self::shutdown_signal())
173 .await?;
174 }
175 }
176 }
177
178 Ok(())
179 }
180
181 async fn shutdown_signal() {
182 let ctrl_c = async {
183 if signal::ctrl_c().await.is_err() {
184 quit(AppError::SignalHandler);
185 }
186 };
187
188 #[cfg(unix)]
189 let terminate = async {
190 match signal::unix::signal(SignalKind::terminate()) {
191 Ok(mut signal) => {
192 signal.recv().await;
193 }
194 _ => quit(AppError::SignalHandler),
195 };
196 };
197
198 #[cfg(not(unix))]
199 let terminate = future::pending::<()>();
200
201 tokio::select! {
202 _ = ctrl_c => {},
203 _ = terminate => {},
204 }
205 }
206}