atuin_server/
router.rs

1use async_trait::async_trait;
2use atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse};
3use axum::{
4    Router,
5    extract::{FromRequestParts, Request},
6    http::{self, request::Parts},
7    middleware::Next,
8    response::{IntoResponse, Response},
9    routing::{delete, get, patch, post},
10};
11use eyre::Result;
12use tower::ServiceBuilder;
13use tower_http::trace::TraceLayer;
14
15use super::handlers;
16use crate::{
17    handlers::{ErrorResponseStatus, RespExt},
18    metrics,
19    settings::Settings,
20};
21use atuin_server_database::{Database, DbError, models::User};
22
23pub struct UserAuth(pub User);
24
25#[async_trait]
26impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth
27where
28    DB: Database,
29{
30    type Rejection = ErrorResponseStatus<'static>;
31
32    async fn from_request_parts(
33        req: &mut Parts,
34        state: &AppState<DB>,
35    ) -> Result<Self, Self::Rejection> {
36        let auth_header = req
37            .headers
38            .get(http::header::AUTHORIZATION)
39            .ok_or_else(|| {
40                ErrorResponse::reply("missing authorization header")
41                    .with_status(http::StatusCode::BAD_REQUEST)
42            })?;
43        let auth_header = auth_header.to_str().map_err(|_| {
44            ErrorResponse::reply("invalid authorization header encoding")
45                .with_status(http::StatusCode::BAD_REQUEST)
46        })?;
47        let (typ, token) = auth_header.split_once(' ').ok_or_else(|| {
48            ErrorResponse::reply("invalid authorization header encoding")
49                .with_status(http::StatusCode::BAD_REQUEST)
50        })?;
51
52        if typ != "Token" {
53            return Err(
54                ErrorResponse::reply("invalid authorization header encoding")
55                    .with_status(http::StatusCode::BAD_REQUEST),
56            );
57        }
58
59        let user = state
60            .database
61            .get_session_user(token)
62            .await
63            .map_err(|e| match e {
64                DbError::NotFound => ErrorResponse::reply("session not found")
65                    .with_status(http::StatusCode::FORBIDDEN),
66                DbError::Other(e) => {
67                    tracing::error!(error = ?e, "could not query user session");
68                    ErrorResponse::reply("could not query user session")
69                        .with_status(http::StatusCode::INTERNAL_SERVER_ERROR)
70                }
71            })?;
72
73        Ok(UserAuth(user))
74    }
75}
76
77async fn teapot() -> impl IntoResponse {
78    // This used to return 418: 🫖
79    // Much as it was fun, it wasn't as useful or informative as it should be
80    (http::StatusCode::NOT_FOUND, "404 not found")
81}
82
83async fn clacks_overhead(request: Request, next: Next) -> Response {
84    let mut response = next.run(request).await;
85
86    let gnu_terry_value = "GNU Terry Pratchett, Kris Nova";
87    let gnu_terry_header = "X-Clacks-Overhead";
88
89    response
90        .headers_mut()
91        .insert(gnu_terry_header, gnu_terry_value.parse().unwrap());
92    response
93}
94
95/// Ensure that we only try and sync with clients on the same major version
96async fn semver(request: Request, next: Next) -> Response {
97    let mut response = next.run(request).await;
98    response
99        .headers_mut()
100        .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap());
101
102    response
103}
104
105#[derive(Clone)]
106pub struct AppState<DB: Database> {
107    pub database: DB,
108    pub settings: Settings,
109}
110
111pub fn router<DB: Database>(database: DB, settings: Settings) -> Router {
112    let routes = Router::new()
113        .route("/", get(handlers::index))
114        .route("/healthz", get(handlers::health::health_check))
115        .route("/sync/count", get(handlers::history::count))
116        .route("/sync/history", get(handlers::history::list))
117        .route("/sync/calendar/:focus", get(handlers::history::calendar))
118        .route("/sync/status", get(handlers::status::status))
119        .route("/history", post(handlers::history::add))
120        .route("/history", delete(handlers::history::delete))
121        .route("/user/:username", get(handlers::user::get))
122        .route("/account", delete(handlers::user::delete))
123        .route("/account/password", patch(handlers::user::change_password))
124        .route("/register", post(handlers::user::register))
125        .route("/login", post(handlers::user::login))
126        .route("/record", post(handlers::record::post))
127        .route("/record", get(handlers::record::index))
128        .route("/record/next", get(handlers::record::next))
129        .route("/api/v0/me", get(handlers::v0::me::get))
130        .route("/api/v0/account/verify", post(handlers::user::verify_user))
131        .route(
132            "/api/v0/account/send-verification",
133            post(handlers::user::send_verification),
134        )
135        .route("/api/v0/record", post(handlers::v0::record::post))
136        .route("/api/v0/record", get(handlers::v0::record::index))
137        .route("/api/v0/record/next", get(handlers::v0::record::next))
138        .route("/api/v0/store", delete(handlers::v0::store::delete));
139
140    let path = settings.path.as_str();
141    if path.is_empty() {
142        routes
143    } else {
144        Router::new().nest(path, routes)
145    }
146    .fallback(teapot)
147    .with_state(AppState { database, settings })
148    .layer(
149        ServiceBuilder::new()
150            .layer(axum::middleware::from_fn(clacks_overhead))
151            .layer(TraceLayer::new_for_http())
152            .layer(axum::middleware::from_fn(metrics::track_metrics))
153            .layer(axum::middleware::from_fn(semver)),
154    )
155}