1use crate::{
2 config::{self, TlsConfig},
3 handlers::{account, api, home, websocket::WebSocketAccount},
4 Backend, Result, ServerConfig, SslConfig, StorageConfig,
5};
6use axum::{
7 extract::Extension,
8 http::{
9 header::{AUTHORIZATION, CONTENT_TYPE},
10 HeaderValue, Method,
11 },
12 middleware,
13 response::{IntoResponse, Json},
14 routing::{get, post, put},
15 Router,
16};
17use axum_server::{tls_rustls::RustlsConfig, Handle};
18use colored::Colorize;
19use futures::StreamExt;
20use sos_core::{AccountId, UtcDateTime};
21use std::{
22 collections::{HashMap, HashSet},
23 net::SocketAddr,
24 path::PathBuf,
25 sync::Arc,
26};
27use tokio::sync::{Mutex, RwLock, RwLockReadGuard};
28use tower_http::{
29 cors::CorsLayer,
30 trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
31};
32use tracing::Level;
33
34#[cfg(feature = "acme")]
35use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
36
37#[cfg(feature = "listen")]
38use super::handlers::websocket::upgrade;
39
40use sos_core::ExternalFile;
41
42#[cfg(feature = "pairing")]
43use super::handlers::relay::{upgrade as relay_upgrade, RelayState};
44
45pub struct State {
47 pub config: ServerConfig,
49 pub(crate) sockets: HashMap<AccountId, WebSocketAccount>,
51}
52
53impl State {
54 pub fn new(config: ServerConfig) -> Self {
56 Self {
57 config,
58 sockets: Default::default(),
59 }
60 }
61}
62
63pub type ServerState = Arc<RwLock<State>>;
65
66pub type ServerBackend = Arc<RwLock<Backend>>;
68
69pub type TransferOperations = HashSet<ExternalFile>;
71
72pub type ServerTransfer = Arc<RwLock<TransferOperations>>;
74
75pub struct Server {}
77
78impl Server {
79 pub async fn new() -> Result<Self> {
86 Ok(Self {})
87 }
88
89 pub async fn start(
91 &self,
92 state: ServerState,
93 backend: ServerBackend,
94 handle: Handle,
95 ) -> Result<()> {
96 let reader = state.read().await;
97 let origins = Server::read_origins(&reader)?;
98 let ssl = reader.config.net.ssl.clone();
99 let addr = reader.config.bind_address().clone();
100 drop(reader);
101
102 match ssl {
103 Some(SslConfig::Tls(tls)) => {
104 self.run_tls(addr, state, backend, handle, origins, tls)
105 .await
106 }
107 #[cfg(feature = "acme")]
108 Some(SslConfig::Acme(acme)) => {
109 self.run_acme(addr, state, backend, handle, origins, acme)
110 .await
111 }
112 None => self.run(addr, state, backend, handle, origins).await,
113 }
114 }
115
116 async fn run_tls(
118 &self,
119 addr: SocketAddr,
120 state: ServerState,
121 backend: ServerBackend,
122 handle: Handle,
123 origins: Vec<HeaderValue>,
124 tls: TlsConfig,
125 ) -> Result<()> {
126 let storage = {
127 let state = state.read().await;
128 let backend = backend.read().await;
129 (
130 state.config.storage.clone(),
131 backend.paths().documents_dir().to_owned(),
132 )
133 };
134
135 let tls = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
136 let app = Server::router(Arc::clone(&state), backend, origins)?;
137
138 self.startup_message(state, &addr, true, storage).await;
139
140 axum_server::bind_rustls(addr, tls)
141 .handle(handle)
142 .serve(app.into_make_service())
143 .await?;
144 Ok(())
145 }
146
147 #[cfg(feature = "acme")]
149 async fn run_acme(
150 &self,
151 addr: SocketAddr,
152 state: ServerState,
153 backend: ServerBackend,
154 handle: Handle,
155 origins: Vec<HeaderValue>,
156 acme: config::AcmeConfig,
157 ) -> Result<()> {
158 let storage = {
159 let state = state.read().await;
160 let backend = backend.read().await;
161 (
162 state.config.storage.clone(),
163 backend.paths().documents_dir().to_owned(),
164 )
165 };
166
167 let mut acme_state = AcmeConfig::new(acme.domains)
168 .contact(acme.email.iter().map(|e| format!("mailto:{}", e)))
169 .cache_option(Some(DirCache::new(acme.cache)))
170 .directory_lets_encrypt(acme.production)
171 .state();
172
173 let app = Server::router(Arc::clone(&state), backend, origins)?;
174
175 self.startup_message(state, &addr, true, storage).await;
176
177 let rustls_config = rustls::ServerConfig::builder()
178 .with_no_client_auth()
179 .with_cert_resolver(acme_state.resolver());
180 let acceptor = acme_state.axum_acceptor(Arc::new(rustls_config));
181
182 tokio::spawn(async move {
183 loop {
184 match acme_state.next().await.unwrap() {
185 Ok(res) => tracing::info!(result = ?res, "acme"),
186 Err(err) => tracing::error!(error = ?err, "acme"),
187 }
188 }
189 });
190
191 axum_server::bind(addr)
192 .acceptor(acceptor)
193 .handle(handle)
194 .serve(app.into_make_service())
195 .await?;
196
197 Ok(())
198 }
199
200 async fn run(
202 &self,
203 addr: SocketAddr,
204 state: ServerState,
205 backend: ServerBackend,
206 handle: Handle,
207 origins: Vec<HeaderValue>,
208 ) -> Result<()> {
209 let storage = {
210 let state = state.read().await;
211 let backend = backend.read().await;
212 (
213 state.config.storage.clone(),
214 backend.paths().documents_dir().to_owned(),
215 )
216 };
217
218 let app = Server::router(Arc::clone(&state), backend, origins)?;
219 self.startup_message(state, &addr, false, storage).await;
220
221 axum_server::bind(addr)
222 .handle(handle)
223 .serve(app.into_make_service())
224 .await?;
225 Ok(())
226 }
227
228 async fn startup_message(
229 &self,
230 state: ServerState,
231 addr: &SocketAddr,
232 tls: bool,
233 storage: (StorageConfig, PathBuf),
234 ) {
235 let now = UtcDateTime::now().to_rfc3339().unwrap();
236
237 let mut columns = vec![
238 ("Started", now),
239 ("Listen", addr.to_string()),
240 ("TLS enabled", tls.to_string()),
241 ("Directory", storage.1.display().to_string()),
242 ];
243
244 if let Some(db_file) = &storage.0.database_uri {
245 columns.push(("Database", db_file.as_uri_string()));
246 }
247
248 let max_length = columns.iter().map(|s| s.0.len()).max().unwrap();
249 let col_size = max_length + 4;
250 for (key, value) in columns {
251 let padding = col_size - key.len();
252 println!("{}{}{}", key, " ".repeat(padding), value.yellow());
253 }
254
255 {
256 let reader = state.read().await;
257 if let Some(access) = &reader.config.access {
258 if let Some(allow) = &access.allow {
259 for address in allow {
260 println!(
261 "Allow {}",
262 address.to_string().green()
263 );
264 }
265 }
266 if let Some(deny) = &access.deny {
267 for address in deny {
268 println!(
269 "Deny {}",
270 address.to_string().red()
271 );
272 }
273 }
274 }
275 }
276 }
277
278 fn read_origins(
279 reader: &RwLockReadGuard<'_, State>,
280 ) -> Result<Vec<HeaderValue>> {
281 let mut origins = Vec::new();
282 let cors = reader.config.net.cors.as_ref();
283 if let Some(cors) = cors {
284 for url in cors.origins.iter() {
285 origins.push(HeaderValue::from_str(
286 url.as_str().trim_end_matches('/'),
287 )?);
288 }
289 }
290 Ok(origins)
291 }
292
293 fn router(
294 state: ServerState,
295 backend: ServerBackend,
296 origins: Vec<HeaderValue>,
297 ) -> Result<Router> {
298 let cors = CorsLayer::new()
299 .allow_methods(vec![
300 Method::GET,
301 Method::POST,
302 Method::PUT,
303 Method::PATCH,
304 Method::DELETE,
305 ])
306 .allow_credentials(true)
307 .allow_headers(vec![AUTHORIZATION, CONTENT_TYPE])
308 .expose_headers(vec![])
309 .allow_origin(origins);
310
311 let v1 = {
312 let mut router = Router::new()
313 .route("/", get(api))
314 .route("/docs", get(apidocs))
315 .route("/docs/", get(apidocs))
316 .route("/docs/openapi.json", get(openapi))
317 .route(
318 "/sync/account",
319 put(account::create_account)
320 .post(account::update_account)
321 .patch(account::sync_account)
322 .get(account::fetch_account)
323 .head(account::account_exists)
324 .delete(account::delete_account),
325 )
326 .route("/sync/account/status", get(account::sync_status))
327 .route(
328 "/sync/account/events",
329 get(account::event_scan)
330 .post(account::event_diff)
331 .patch(account::event_patch),
332 );
333
334 {
335 use super::handlers::files::{self, file_operation_lock};
336 router = router
337 .route("/sync/files", post(files::compare_files))
338 .route(
339 "/sync/file/{vault_id}/{secret_id}/{file_name}",
340 put(files::receive_file)
341 .post(files::move_file)
342 .get(files::send_file)
343 .delete(files::delete_file)
344 .route_layer(middleware::from_fn(
345 file_operation_lock,
346 )),
347 );
348 }
349
350 #[cfg(feature = "listen")]
351 {
352 use super::handlers::connections;
353 router = router
354 .route("/sync/connections", get(connections))
355 .route("/sync/changes", get(upgrade));
356 }
357
358 #[cfg(feature = "pairing")]
359 {
360 router = router.route("/relay", get(relay_upgrade));
361 }
362
363 router
364 };
365
366 let mut v1 = v1.layer(cors).layer(
367 TraceLayer::new_for_http()
368 .on_request(DefaultOnRequest::new().level(Level::TRACE))
369 .on_response(DefaultOnResponse::new().level(Level::TRACE)),
370 );
371
372 #[cfg(feature = "pairing")]
373 {
374 let relay: RelayState = Arc::new(Mutex::new(HashMap::new()));
375 v1 = v1.layer(Extension(relay));
376 }
377
378 v1 = v1.layer(Extension(backend)).layer(Extension(state));
379
380 {
381 let file_operations: ServerTransfer =
382 Arc::new(RwLock::new(HashSet::new()));
383 v1 = v1.layer(Extension(file_operations));
384 }
385
386 #[allow(unused_mut)]
387 let mut app = Router::new()
388 .route("/", get(home))
389 .nest_service("/api/v1", v1);
390
391 #[cfg(feature = "prometheus")]
392 {
393 let (prometheus_layer, metric_handle) =
394 axum_prometheus::PrometheusMetricLayerBuilder::new()
395 .with_default_metrics()
396 .enable_response_body_size(true)
397 .build_pair();
398
399 app = app
400 .route(
401 "/metrics",
402 get(|| async move { metric_handle.render() }),
403 )
404 .layer(prometheus_layer);
405 }
406
407 Ok(app)
408 }
409}
410
411#[utoipa::path(
413 get,
414 path = "/docs/openapi.json",
415 responses(
416 (
417 status = StatusCode::OK,
418 description = "OpenAPI definition",
419 ),
420 ),
421)]
422pub async fn openapi() -> impl IntoResponse {
423 let value = crate::api_docs::openapi();
424 Json(serde_json::json!(&value))
425}
426
427#[utoipa::path(
429 get,
430 path = "/docs",
431 responses(
432 (
433 status = StatusCode::OK,
434 description = "Render OpenAPI documentation",
435 ),
436 ),
437)]
438pub async fn apidocs() -> impl IntoResponse {
439 use utoipa_rapidoc::RapiDoc;
440 let rapidoc = RapiDoc::new("/api/v1/docs/openapi.json");
441 let html = rapidoc.to_html();
442 ([(CONTENT_TYPE, "text/html")], html)
443}