1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#![forbid(unsafe_code)]

use std::net::SocketAddr;
use std::sync::Arc;
use std::{future::Future, net::TcpListener};

use atuin_server_database::Database;
use axum::Router;
use axum::Server;
use axum_server::Handle;
use eyre::{Context, Result};

mod handlers;
mod metrics;
mod router;
mod utils;

use rustls::ServerConfig;
pub use settings::example_config;
pub use settings::Settings;

pub mod settings;

use tokio::signal;

#[cfg(target_family = "unix")]
async fn shutdown_signal() {
    let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
        .expect("failed to register signal handler");
    let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
        .expect("failed to register signal handler");

    tokio::select! {
        _ = term.recv() => {},
        _ = interrupt.recv() => {},
    };
    eprintln!("Shutting down gracefully...");
}

#[cfg(target_family = "windows")]
async fn shutdown_signal() {
    signal::windows::ctrl_c()
        .expect("failed to register signal handler")
        .recv()
        .await;
    eprintln!("Shutting down gracefully...");
}

pub async fn launch<Db: Database>(
    settings: Settings<Db::Settings>,
    addr: SocketAddr,
) -> Result<()> {
    if settings.tls.enable {
        launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
    } else {
        launch_with_tcp_listener::<Db>(
            settings,
            TcpListener::bind(addr).context("could not connect to socket")?,
            shutdown_signal(),
        )
        .await
    }
}

pub async fn launch_with_tcp_listener<Db: Database>(
    settings: Settings<Db::Settings>,
    listener: TcpListener,
    shutdown: impl Future<Output = ()>,
) -> Result<()> {
    let r = make_router::<Db>(settings).await?;

    Server::from_tcp(listener)
        .context("could not launch server")?
        .serve(r.into_make_service())
        .with_graceful_shutdown(shutdown)
        .await?;

    Ok(())
}

async fn launch_with_tls<Db: Database>(
    settings: Settings<Db::Settings>,
    addr: SocketAddr,
    shutdown: impl Future<Output = ()>,
) -> Result<()> {
    let certificates = settings.tls.certificates()?;
    let pkey = settings.tls.private_key()?;

    let server_config = ServerConfig::builder()
        .with_safe_defaults()
        .with_no_client_auth()
        .with_single_cert(certificates, pkey)?;

    let server_config = Arc::new(server_config);
    let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);

    let r = make_router::<Db>(settings).await?;

    let handle = Handle::new();

    let server = axum_server::bind_rustls(addr, rustls_config)
        .handle(handle.clone())
        .serve(r.into_make_service());

    tokio::select! {
        _ = server => {}
        _ = shutdown => {
            handle.graceful_shutdown(None);
        }
    }

    Ok(())
}

// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
// the public.
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
    let listener = TcpListener::bind((host, port)).context("failed to bind metrics tcp")?;

    let recorder_handle = metrics::setup_metrics_recorder();

    let router = Router::new().route(
        "/metrics",
        axum::routing::get(move || std::future::ready(recorder_handle.render())),
    );

    Server::from_tcp(listener)
        .context("could not launch server")?
        .serve(router.into_make_service())
        .with_graceful_shutdown(shutdown_signal())
        .await?;

    Ok(())
}

async fn make_router<Db: Database>(
    settings: Settings<<Db as Database>::Settings>,
) -> Result<Router, eyre::Error> {
    let db = Db::new(&settings.db_settings)
        .await
        .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
    let r = router::router(db, settings);
    Ok(r)
}