atuin_server/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::future::Future;
4use std::net::SocketAddr;
5
6use atuin_server_database::Database;
7use axum::{Router, serve};
8use axum_server::Handle;
9use axum_server::tls_rustls::RustlsConfig;
10use eyre::{Context, Result, eyre};
11
12mod handlers;
13mod metrics;
14mod router;
15mod utils;
16
17pub use settings::Settings;
18pub use settings::example_config;
19
20pub mod settings;
21
22use tokio::net::TcpListener;
23use tokio::signal;
24
25#[cfg(target_family = "unix")]
26async fn shutdown_signal() {
27    let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
28        .expect("failed to register signal handler");
29    let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
30        .expect("failed to register signal handler");
31
32    tokio::select! {
33        _ = term.recv() => {},
34        _ = interrupt.recv() => {},
35    };
36    eprintln!("Shutting down gracefully...");
37}
38
39#[cfg(target_family = "windows")]
40async fn shutdown_signal() {
41    signal::windows::ctrl_c()
42        .expect("failed to register signal handler")
43        .recv()
44        .await;
45    eprintln!("Shutting down gracefully...");
46}
47
48pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> {
49    if settings.tls.enable {
50        launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
51    } else {
52        launch_with_tcp_listener::<Db>(
53            settings,
54            TcpListener::bind(addr)
55                .await
56                .context("could not connect to socket")?,
57            shutdown_signal(),
58        )
59        .await
60    }
61}
62
63pub async fn launch_with_tcp_listener<Db: Database>(
64    settings: Settings,
65    listener: TcpListener,
66    shutdown: impl Future<Output = ()> + Send + 'static,
67) -> Result<()> {
68    let r = make_router::<Db>(settings).await?;
69
70    serve(listener, r.into_make_service())
71        .with_graceful_shutdown(shutdown)
72        .await?;
73
74    Ok(())
75}
76
77async fn launch_with_tls<Db: Database>(
78    settings: Settings,
79    addr: SocketAddr,
80    shutdown: impl Future<Output = ()>,
81) -> Result<()> {
82    let crypto_provider = rustls::crypto::ring::default_provider().install_default();
83    if crypto_provider.is_err() {
84        return Err(eyre!("Failed to install default crypto provider"));
85    }
86    let rustls_config = RustlsConfig::from_pem_file(
87        settings.tls.cert_path.clone(),
88        settings.tls.pkey_path.clone(),
89    )
90    .await;
91    if rustls_config.is_err() {
92        return Err(eyre!("Failed to load TLS key and/or certificate"));
93    }
94    let rustls_config = rustls_config.unwrap();
95
96    let r = make_router::<Db>(settings).await?;
97
98    let handle = Handle::new();
99
100    let server = axum_server::bind_rustls(addr, rustls_config)
101        .handle(handle.clone())
102        .serve(r.into_make_service());
103
104    tokio::select! {
105        _ = server => {}
106        _ = shutdown => {
107            handle.graceful_shutdown(None);
108        }
109    }
110
111    Ok(())
112}
113
114// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
115// the public.
116pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
117    let listener = TcpListener::bind((host, port))
118        .await
119        .context("failed to bind metrics tcp")?;
120
121    let recorder_handle = metrics::setup_metrics_recorder();
122
123    let router = Router::new().route(
124        "/metrics",
125        axum::routing::get(move || std::future::ready(recorder_handle.render())),
126    );
127
128    serve(listener, router.into_make_service())
129        .with_graceful_shutdown(shutdown_signal())
130        .await?;
131
132    Ok(())
133}
134
135async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> {
136    let db = Db::new(&settings.db_settings)
137        .await
138        .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
139    let r = router::router(db, settings);
140    Ok(r)
141}