1#![forbid(unsafe_code)]
2
3use std::future::Future;
4use std::net::SocketAddr;
5
6use atuin_server_database::Database;
7use axum::{Router, serve};
8use axum_server::Handle;
9use axum_server::tls_rustls::RustlsConfig;
10use eyre::{Context, Result, eyre};
11
12mod handlers;
13mod metrics;
14mod router;
15mod utils;
16
17pub use settings::Settings;
18pub use settings::example_config;
19
20pub mod settings;
21
22use tokio::net::TcpListener;
23use tokio::signal;
24
25#[cfg(target_family = "unix")]
26async fn shutdown_signal() {
27 let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
28 .expect("failed to register signal handler");
29 let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
30 .expect("failed to register signal handler");
31
32 tokio::select! {
33 _ = term.recv() => {},
34 _ = interrupt.recv() => {},
35 };
36 eprintln!("Shutting down gracefully...");
37}
38
39#[cfg(target_family = "windows")]
40async fn shutdown_signal() {
41 signal::windows::ctrl_c()
42 .expect("failed to register signal handler")
43 .recv()
44 .await;
45 eprintln!("Shutting down gracefully...");
46}
47
48pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> {
49 if settings.tls.enable {
50 launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
51 } else {
52 launch_with_tcp_listener::<Db>(
53 settings,
54 TcpListener::bind(addr)
55 .await
56 .context("could not connect to socket")?,
57 shutdown_signal(),
58 )
59 .await
60 }
61}
62
63pub async fn launch_with_tcp_listener<Db: Database>(
64 settings: Settings,
65 listener: TcpListener,
66 shutdown: impl Future<Output = ()> + Send + 'static,
67) -> Result<()> {
68 let r = make_router::<Db>(settings).await?;
69
70 serve(listener, r.into_make_service())
71 .with_graceful_shutdown(shutdown)
72 .await?;
73
74 Ok(())
75}
76
77async fn launch_with_tls<Db: Database>(
78 settings: Settings,
79 addr: SocketAddr,
80 shutdown: impl Future<Output = ()>,
81) -> Result<()> {
82 let crypto_provider = rustls::crypto::ring::default_provider().install_default();
83 if crypto_provider.is_err() {
84 return Err(eyre!("Failed to install default crypto provider"));
85 }
86 let rustls_config = RustlsConfig::from_pem_file(
87 settings.tls.cert_path.clone(),
88 settings.tls.pkey_path.clone(),
89 )
90 .await;
91 if rustls_config.is_err() {
92 return Err(eyre!("Failed to load TLS key and/or certificate"));
93 }
94 let rustls_config = rustls_config.unwrap();
95
96 let r = make_router::<Db>(settings).await?;
97
98 let handle = Handle::new();
99
100 let server = axum_server::bind_rustls(addr, rustls_config)
101 .handle(handle.clone())
102 .serve(r.into_make_service());
103
104 tokio::select! {
105 _ = server => {}
106 _ = shutdown => {
107 handle.graceful_shutdown(None);
108 }
109 }
110
111 Ok(())
112}
113
114pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
117 let listener = TcpListener::bind((host, port))
118 .await
119 .context("failed to bind metrics tcp")?;
120
121 let recorder_handle = metrics::setup_metrics_recorder();
122
123 let router = Router::new().route(
124 "/metrics",
125 axum::routing::get(move || std::future::ready(recorder_handle.render())),
126 );
127
128 serve(listener, router.into_make_service())
129 .with_graceful_shutdown(shutdown_signal())
130 .await?;
131
132 Ok(())
133}
134
135async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> {
136 let db = Db::new(&settings.db_settings)
137 .await
138 .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
139 let r = router::router(db, settings);
140 Ok(r)
141}