Skip to main content

bark_rest/
auth.rs

1
2use std::fmt;
3
4use axum::Router;
5use axum::body::Body;
6use axum::extract::State;
7use axum::http::Request;
8use axum::middleware::Next;
9use axum::response::{IntoResponse, Response};
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use base64::Engine;
12
13use crate::ServerState;
14use crate::error::{ErrorResponse, unauthorized};
15
16const BEARER_PREFIX: &str = "Bearer ";
17
18pub fn authed_router(state: &ServerState, router: Router<ServerState>) -> Router<ServerState> {
19	router.route_layer(axum::middleware::from_fn_with_state(state.clone(), guard_auth))
20}
21
22/// A bearer token that is the 32-byte secret itself.
23///
24/// The token grants full access when it matches any registered secret.
25#[derive(Clone, PartialEq, Eq)]
26pub struct AuthToken {
27	secret: [u8; 32],
28}
29
30#[derive(Debug)]
31pub struct TokenDecodeError(String);
32
33impl fmt::Display for TokenDecodeError {
34	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35		write!(f, "{}", self.0)
36	}
37}
38
39impl std::error::Error for TokenDecodeError {}
40
41impl AuthToken {
42	// One byte for the version, 32 bytes for the secret.
43	pub const ENCODED_SIZE: usize = 33;
44
45	/// Create an auth token from a 32-byte secret.
46	pub fn new(secret: [u8; 32]) -> Self {
47		AuthToken { secret }
48	}
49
50	/// Base64url-encode the token for transmission.
51	///
52	/// Wire format: `<version byte><32-byte secret>`
53	pub fn encode(&self) -> String {
54		let mut buf = Vec::with_capacity(AuthToken::ENCODED_SIZE);
55		buf.push(0); // version byte
56		buf.extend_from_slice(&self.secret);
57		URL_SAFE_NO_PAD.encode(&buf)
58	}
59
60	/// Decode a base64url-encoded auth token.
61	pub fn decode(encoded: &str) -> Result<Self, TokenDecodeError> {
62		let bytes = URL_SAFE_NO_PAD.decode(encoded.trim())
63			.map_err(|e| TokenDecodeError(format!("invalid base64: {}", e)))?;
64
65		if bytes.is_empty() {
66			return Err(TokenDecodeError("invalid format".into()));
67		}
68
69		let version = bytes[0];
70		if version == 0 {
71			if bytes.len() != AuthToken::ENCODED_SIZE {
72				return Err(TokenDecodeError("invalid format".into()));
73			}
74
75			let secret = &bytes[1..];
76			let secret = secret.try_into()
77				.map_err(|e| TokenDecodeError(format!("invalid secret: {}", e)))?;
78
79			return Ok(AuthToken { secret });
80		}
81
82		return Err(TokenDecodeError("unknown version".into()));
83	}
84}
85
86/// Extract the auth token from the `Authorization: Bearer <token>` header
87/// per RFC 6750.
88///
89/// The `Bearer` prefix is matched case sensitively. Non-Bearer
90/// authorization headers are silently ignored (returns `Ok(None)`).
91///
92/// Returns `Ok(Some(token))` when a valid Bearer token is found,
93/// `Ok(None)` when no auth header is present or the scheme is not Bearer,
94/// or `Err(msg)` when headers are malformed (non-UTF-8 or duplicated).
95fn extract_auth_token(req: &Request<Body>) -> Result<Option<String>, &'static str> {
96	let auth_headers = req.headers().get_all("authorization");
97
98	let mut authorization_header = None;
99	for header in auth_headers {
100		if authorization_header.is_some() {
101			return Err("multiple authorization headers are not allowed");
102		}
103
104		let header_str = header.to_str()
105			.map_err(|_| "authorization header is not valid UTF-8")?;
106
107		if let Some(token) = header_str.strip_prefix(BEARER_PREFIX) {
108			authorization_header = Some(token.to_string());
109		}
110	}
111
112	Ok(authorization_header)
113}
114
115pub fn authenticate_request(
116	State(state): State<ServerState>,
117	req: &Request<Body>,
118) -> Result<(), ErrorResponse> {
119	// If no auth token is configured, allow unauthenticated access.
120	let expected = match state.auth_token() {
121		Some(t) => t,
122		None => return Ok(()),
123	};
124
125	let token_str = match extract_auth_token(req) {
126		Ok(Some(t)) => t,
127		Ok(None) => unauthorized!("missing auth token"),
128		Err(msg) => unauthorized!("{}", msg),
129	};
130
131	let token = match AuthToken::decode(&token_str) {
132		Ok(r) => r,
133		Err(_) => unauthorized!("invalid auth token"),
134	};
135
136	if token != *expected {
137		unauthorized!("invalid auth token");
138	}
139
140	Ok(())
141}
142
143pub(crate) async fn guard_auth(
144	state: State<ServerState>,
145	req: Request<Body>,
146	next: Next,
147) -> Response {
148	match authenticate_request(state, &req) {
149		Ok(()) => next.run(req).await,
150		Err(e) => e.into_response(),
151	}
152}
153
154#[cfg(test)]
155mod tests {
156	use std::collections::HashMap;
157	use std::sync::Arc;
158
159	use tokio::sync::RwLock;
160
161	use super::*;
162
163	fn test_token() -> AuthToken {
164		AuthToken::new([42u8; 32])
165	}
166
167	fn make_state(token: AuthToken) -> State<ServerState> {
168		State(ServerState {
169			wallet: Arc::new(parking_lot::RwLock::new(None)),
170			on_wallet_create: None,
171			auth_token: Some(token),
172			on_wallet_delete: None,
173
174			websocket_tickets: Arc::new(RwLock::new(HashMap::new())),
175		})
176	}
177
178	#[test]
179	fn roundtrip_and_whitespace() {
180		let token = test_token();
181		let encoded = token.encode();
182
183		let decoded = AuthToken::decode(&encoded).unwrap();
184		assert_eq!(token.secret, decoded.secret);
185
186		// decode trims surrounding whitespace (important for file-loaded tokens)
187		let padded = format!("  {} \n", encoded);
188		assert_eq!(AuthToken::decode(&padded).unwrap().secret, token.secret);
189	}
190
191	#[test]
192	fn decode_rejects_malformed_input() {
193		assert!(AuthToken::decode("").is_err(), "empty string");
194		assert!(AuthToken::decode("not!valid!base64").is_err(), "invalid base64");
195		assert!(AuthToken::decode("AAAAAA").is_err(), "wrong length");
196
197		// unknown version byte
198		let mut raw = URL_SAFE_NO_PAD.decode(test_token().encode()).unwrap();
199		raw[0] = 1;
200		assert!(AuthToken::decode(&URL_SAFE_NO_PAD.encode(&raw)).is_err(), "unknown version");
201	}
202
203	#[test]
204	fn extract_auth_token_from_headers() {
205		let req = |name: &str, val: &str| Request::builder()
206			.header(name, val).body(Body::empty()).unwrap();
207
208		// Authorization: Bearer header (RFC 6750)
209		assert_eq!(extract_auth_token(&req("authorization", "Bearer tok")).unwrap(), Some("tok".into()));
210
211		// no auth headers
212		let empty = Request::builder().body(Body::empty()).unwrap();
213		assert_eq!(extract_auth_token(&empty).unwrap(), None);
214
215		// unsupported scheme
216		assert_eq!(extract_auth_token(&req("authorization", "Basic dXNlcjpwYXNz")).unwrap(), None);
217	}
218
219	#[test]
220	fn guard_auth_accepts_and_rejects() {
221		let token = test_token();
222		let req = |hdr: Option<&str>| {
223			let mut b = Request::builder();
224			if let Some(v) = hdr { b = b.header("authorization", format!("Bearer {}", v)); }
225			b.body(Body::empty()).unwrap()
226		};
227
228		// valid token passes
229		let res = authenticate_request(make_state(token.clone()), &req(Some(&token.encode())));
230		assert!(res.is_ok(), "valid token should pass: {:?}", res);
231
232		// missing, wrong, and garbage tokens all fail
233		let state = make_state(token);
234		let no_hdr = Request::builder().body(Body::empty()).unwrap();
235		assert!(authenticate_request(state.clone(), &no_hdr).is_err());
236		assert!(authenticate_request(state.clone(), &req(Some(&AuthToken::new([0u8; 32]).encode()))).is_err());
237		assert!(authenticate_request(state, &req(Some("not-a-valid-token"))).is_err());
238	}
239}