1
2use std::fmt;
3
4use axum::Router;
5use axum::body::Body;
6use axum::extract::State;
7use axum::http::Request;
8use axum::middleware::Next;
9use axum::response::{IntoResponse, Response};
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use base64::Engine;
12
13use crate::ServerState;
14use crate::error::{ErrorResponse, unauthorized};
15
16const BEARER_PREFIX: &str = "Bearer ";
17
18pub fn authed_router(state: &ServerState, router: Router<ServerState>) -> Router<ServerState> {
19 router.route_layer(axum::middleware::from_fn_with_state(state.clone(), guard_auth))
20}
21
22#[derive(Clone, PartialEq, Eq)]
26pub struct AuthToken {
27 secret: [u8; 32],
28}
29
30#[derive(Debug)]
31pub struct TokenDecodeError(String);
32
33impl fmt::Display for TokenDecodeError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39impl std::error::Error for TokenDecodeError {}
40
41impl AuthToken {
42 pub const ENCODED_SIZE: usize = 33;
44
45 pub fn new(secret: [u8; 32]) -> Self {
47 AuthToken { secret }
48 }
49
50 pub fn encode(&self) -> String {
54 let mut buf = Vec::with_capacity(AuthToken::ENCODED_SIZE);
55 buf.push(0); buf.extend_from_slice(&self.secret);
57 URL_SAFE_NO_PAD.encode(&buf)
58 }
59
60 pub fn decode(encoded: &str) -> Result<Self, TokenDecodeError> {
62 let bytes = URL_SAFE_NO_PAD.decode(encoded.trim())
63 .map_err(|e| TokenDecodeError(format!("invalid base64: {}", e)))?;
64
65 if bytes.is_empty() {
66 return Err(TokenDecodeError("invalid format".into()));
67 }
68
69 let version = bytes[0];
70 if version == 0 {
71 if bytes.len() != AuthToken::ENCODED_SIZE {
72 return Err(TokenDecodeError("invalid format".into()));
73 }
74
75 let secret = &bytes[1..];
76 let secret = secret.try_into()
77 .map_err(|e| TokenDecodeError(format!("invalid secret: {}", e)))?;
78
79 return Ok(AuthToken { secret });
80 }
81
82 return Err(TokenDecodeError("unknown version".into()));
83 }
84}
85
86fn extract_auth_token(req: &Request<Body>) -> Result<Option<String>, &'static str> {
96 let auth_headers = req.headers().get_all("authorization");
97
98 let mut authorization_header = None;
99 for header in auth_headers {
100 if authorization_header.is_some() {
101 return Err("multiple authorization headers are not allowed");
102 }
103
104 let header_str = header.to_str()
105 .map_err(|_| "authorization header is not valid UTF-8")?;
106
107 if let Some(token) = header_str.strip_prefix(BEARER_PREFIX) {
108 authorization_header = Some(token.to_string());
109 }
110 }
111
112 Ok(authorization_header)
113}
114
115pub fn authenticate_request(
116 State(state): State<ServerState>,
117 req: &Request<Body>,
118) -> Result<(), ErrorResponse> {
119 let expected = match state.auth_token() {
121 Some(t) => t,
122 None => return Ok(()),
123 };
124
125 let token_str = match extract_auth_token(req) {
126 Ok(Some(t)) => t,
127 Ok(None) => unauthorized!("missing auth token"),
128 Err(msg) => unauthorized!("{}", msg),
129 };
130
131 let token = match AuthToken::decode(&token_str) {
132 Ok(r) => r,
133 Err(_) => unauthorized!("invalid auth token"),
134 };
135
136 if token != *expected {
137 unauthorized!("invalid auth token");
138 }
139
140 Ok(())
141}
142
143pub(crate) async fn guard_auth(
144 state: State<ServerState>,
145 req: Request<Body>,
146 next: Next,
147) -> Response {
148 match authenticate_request(state, &req) {
149 Ok(()) => next.run(req).await,
150 Err(e) => e.into_response(),
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use std::collections::HashMap;
157 use std::sync::Arc;
158
159 use tokio::sync::RwLock;
160
161 use super::*;
162
163 fn test_token() -> AuthToken {
164 AuthToken::new([42u8; 32])
165 }
166
167 fn make_state(token: AuthToken) -> State<ServerState> {
168 State(ServerState {
169 wallet: Arc::new(parking_lot::RwLock::new(None)),
170 on_wallet_create: None,
171 auth_token: Some(token),
172 on_wallet_delete: None,
173
174 websocket_tickets: Arc::new(RwLock::new(HashMap::new())),
175 })
176 }
177
178 #[test]
179 fn roundtrip_and_whitespace() {
180 let token = test_token();
181 let encoded = token.encode();
182
183 let decoded = AuthToken::decode(&encoded).unwrap();
184 assert_eq!(token.secret, decoded.secret);
185
186 let padded = format!(" {} \n", encoded);
188 assert_eq!(AuthToken::decode(&padded).unwrap().secret, token.secret);
189 }
190
191 #[test]
192 fn decode_rejects_malformed_input() {
193 assert!(AuthToken::decode("").is_err(), "empty string");
194 assert!(AuthToken::decode("not!valid!base64").is_err(), "invalid base64");
195 assert!(AuthToken::decode("AAAAAA").is_err(), "wrong length");
196
197 let mut raw = URL_SAFE_NO_PAD.decode(test_token().encode()).unwrap();
199 raw[0] = 1;
200 assert!(AuthToken::decode(&URL_SAFE_NO_PAD.encode(&raw)).is_err(), "unknown version");
201 }
202
203 #[test]
204 fn extract_auth_token_from_headers() {
205 let req = |name: &str, val: &str| Request::builder()
206 .header(name, val).body(Body::empty()).unwrap();
207
208 assert_eq!(extract_auth_token(&req("authorization", "Bearer tok")).unwrap(), Some("tok".into()));
210
211 let empty = Request::builder().body(Body::empty()).unwrap();
213 assert_eq!(extract_auth_token(&empty).unwrap(), None);
214
215 assert_eq!(extract_auth_token(&req("authorization", "Basic dXNlcjpwYXNz")).unwrap(), None);
217 }
218
219 #[test]
220 fn guard_auth_accepts_and_rejects() {
221 let token = test_token();
222 let req = |hdr: Option<&str>| {
223 let mut b = Request::builder();
224 if let Some(v) = hdr { b = b.header("authorization", format!("Bearer {}", v)); }
225 b.body(Body::empty()).unwrap()
226 };
227
228 let res = authenticate_request(make_state(token.clone()), &req(Some(&token.encode())));
230 assert!(res.is_ok(), "valid token should pass: {:?}", res);
231
232 let state = make_state(token);
234 let no_hdr = Request::builder().body(Body::empty()).unwrap();
235 assert!(authenticate_request(state.clone(), &no_hdr).is_err());
236 assert!(authenticate_request(state.clone(), &req(Some(&AuthToken::new([0u8; 32]).encode()))).is_err());
237 assert!(authenticate_request(state, &req(Some("not-a-valid-token"))).is_err());
238 }
239}