1use async_trait::async_trait;
2use atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse};
3use axum::{
4 Router,
5 extract::{FromRequestParts, Request},
6 http::{self, request::Parts},
7 middleware::Next,
8 response::{IntoResponse, Response},
9 routing::{delete, get, patch, post},
10};
11use eyre::Result;
12use tower::ServiceBuilder;
13use tower_http::trace::TraceLayer;
14
15use super::handlers;
16use crate::{
17 handlers::{ErrorResponseStatus, RespExt},
18 metrics,
19 settings::Settings,
20};
21use atuin_server_database::{Database, DbError, models::User};
22
23pub struct UserAuth(pub User);
24
25#[async_trait]
26impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth
27where
28 DB: Database,
29{
30 type Rejection = ErrorResponseStatus<'static>;
31
32 async fn from_request_parts(
33 req: &mut Parts,
34 state: &AppState<DB>,
35 ) -> Result<Self, Self::Rejection> {
36 let auth_header = req
37 .headers
38 .get(http::header::AUTHORIZATION)
39 .ok_or_else(|| {
40 ErrorResponse::reply("missing authorization header")
41 .with_status(http::StatusCode::BAD_REQUEST)
42 })?;
43 let auth_header = auth_header.to_str().map_err(|_| {
44 ErrorResponse::reply("invalid authorization header encoding")
45 .with_status(http::StatusCode::BAD_REQUEST)
46 })?;
47 let (typ, token) = auth_header.split_once(' ').ok_or_else(|| {
48 ErrorResponse::reply("invalid authorization header encoding")
49 .with_status(http::StatusCode::BAD_REQUEST)
50 })?;
51
52 if typ != "Token" {
53 return Err(
54 ErrorResponse::reply("invalid authorization header encoding")
55 .with_status(http::StatusCode::BAD_REQUEST),
56 );
57 }
58
59 let user = state
60 .database
61 .get_session_user(token)
62 .await
63 .map_err(|e| match e {
64 DbError::NotFound => ErrorResponse::reply("session not found")
65 .with_status(http::StatusCode::FORBIDDEN),
66 DbError::Other(e) => {
67 tracing::error!(error = ?e, "could not query user session");
68 ErrorResponse::reply("could not query user session")
69 .with_status(http::StatusCode::INTERNAL_SERVER_ERROR)
70 }
71 })?;
72
73 Ok(UserAuth(user))
74 }
75}
76
77async fn teapot() -> impl IntoResponse {
78 (http::StatusCode::NOT_FOUND, "404 not found")
81}
82
83async fn clacks_overhead(request: Request, next: Next) -> Response {
84 let mut response = next.run(request).await;
85
86 let gnu_terry_value = "GNU Terry Pratchett, Kris Nova";
87 let gnu_terry_header = "X-Clacks-Overhead";
88
89 response
90 .headers_mut()
91 .insert(gnu_terry_header, gnu_terry_value.parse().unwrap());
92 response
93}
94
95async fn semver(request: Request, next: Next) -> Response {
97 let mut response = next.run(request).await;
98 response
99 .headers_mut()
100 .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap());
101
102 response
103}
104
105#[derive(Clone)]
106pub struct AppState<DB: Database> {
107 pub database: DB,
108 pub settings: Settings,
109}
110
111pub fn router<DB: Database>(database: DB, settings: Settings) -> Router {
112 let routes = Router::new()
113 .route("/", get(handlers::index))
114 .route("/healthz", get(handlers::health::health_check))
115 .route("/sync/count", get(handlers::history::count))
116 .route("/sync/history", get(handlers::history::list))
117 .route("/sync/calendar/:focus", get(handlers::history::calendar))
118 .route("/sync/status", get(handlers::status::status))
119 .route("/history", post(handlers::history::add))
120 .route("/history", delete(handlers::history::delete))
121 .route("/user/:username", get(handlers::user::get))
122 .route("/account", delete(handlers::user::delete))
123 .route("/account/password", patch(handlers::user::change_password))
124 .route("/register", post(handlers::user::register))
125 .route("/login", post(handlers::user::login))
126 .route("/record", post(handlers::record::post))
127 .route("/record", get(handlers::record::index))
128 .route("/record/next", get(handlers::record::next))
129 .route("/api/v0/me", get(handlers::v0::me::get))
130 .route("/api/v0/account/verify", post(handlers::user::verify_user))
131 .route(
132 "/api/v0/account/send-verification",
133 post(handlers::user::send_verification),
134 )
135 .route("/api/v0/record", post(handlers::v0::record::post))
136 .route("/api/v0/record", get(handlers::v0::record::index))
137 .route("/api/v0/record/next", get(handlers::v0::record::next))
138 .route("/api/v0/store", delete(handlers::v0::store::delete));
139
140 let path = settings.path.as_str();
141 if path.is_empty() {
142 routes
143 } else {
144 Router::new().nest(path, routes)
145 }
146 .fallback(teapot)
147 .with_state(AppState { database, settings })
148 .layer(
149 ServiceBuilder::new()
150 .layer(axum::middleware::from_fn(clacks_overhead))
151 .layer(TraceLayer::new_for_http())
152 .layer(axum::middleware::from_fn(metrics::track_metrics))
153 .layer(axum::middleware::from_fn(semver)),
154 )
155}