apr_cli/commands/serve/
auth.rs1use sha2::{Digest, Sha256};
12use subtle::ConstantTimeEq;
13
14#[derive(Clone, Debug, Default)]
20pub struct AuthGate {
21 expected_hash: Option<[u8; 32]>,
22}
23
24impl AuthGate {
25 #[must_use]
28 pub fn disabled() -> Self {
29 Self {
30 expected_hash: None,
31 }
32 }
33
34 #[must_use]
37 pub fn from_hash(expected_hash: [u8; 32]) -> Self {
38 Self {
39 expected_hash: Some(expected_hash),
40 }
41 }
42
43 #[must_use]
45 pub fn from_plain_key(key: &str) -> Self {
46 Self::from_hash(sha256_32(key.as_bytes()))
47 }
48
49 #[must_use]
53 pub fn from_env() -> Self {
54 if let Ok(hex) = std::env::var("APR_API_KEY_HASH") {
55 match decode_hex_32(&hex) {
56 Ok(bytes) => return Self::from_hash(bytes),
57 Err(reason) => {
58 eprintln!(
59 "[apr serve] APR_API_KEY_HASH set but {reason}; ignoring (auth disabled)",
60 );
61 return Self::disabled();
62 }
63 }
64 }
65 if let Ok(plain) = std::env::var("APR_API_KEY") {
66 if !plain.is_empty() {
67 return Self::from_plain_key(&plain);
68 }
69 }
70 eprintln!(
71 "[apr serve] WARNING: no APR_API_KEY or APR_API_KEY_HASH set; HTTP routes are unauthenticated",
72 );
73 Self::disabled()
74 }
75
76 #[must_use]
78 pub fn is_enabled(&self) -> bool {
79 self.expected_hash.is_some()
80 }
81
82 #[must_use]
89 pub fn check_bearer(&self, header: Option<&str>) -> bool {
90 let Some(expected) = self.expected_hash.as_ref() else {
91 return true;
92 };
93 let Some(value) = header else {
94 return false;
95 };
96 let Some(token) = value.strip_prefix("Bearer ") else {
97 return false;
98 };
99 let presented = sha256_32(token.as_bytes());
100 bool::from(expected.ct_eq(&presented))
101 }
102}
103
104fn sha256_32(input: &[u8]) -> [u8; 32] {
105 let digest = Sha256::digest(input);
106 let mut out = [0u8; 32];
107 out.copy_from_slice(&digest);
108 out
109}
110
111fn decode_hex_32(hex: &str) -> Result<[u8; 32], &'static str> {
112 if hex.len() != 64 {
113 return Err("APR_API_KEY_HASH must be 64 hex chars (SHA-256)");
114 }
115 let bytes = hex.as_bytes();
116 let mut out = [0u8; 32];
117 for (i, slot) in out.iter_mut().enumerate() {
118 let hi = hex_digit(bytes[i * 2])?;
119 let lo = hex_digit(bytes[i * 2 + 1])?;
120 *slot = (hi << 4) | lo;
121 }
122 Ok(out)
123}
124
125fn hex_digit(b: u8) -> Result<u8, &'static str> {
126 match b {
127 b'0'..=b'9' => Ok(b - b'0'),
128 b'a'..=b'f' => Ok(b - b'a' + 10),
129 b'A'..=b'F' => Ok(b - b'A' + 10),
130 _ => Err("APR_API_KEY_HASH must contain only [0-9a-fA-F]"),
131 }
132}
133
134#[cfg(feature = "inference")]
142pub async fn apply(
143 axum::extract::State(gate): axum::extract::State<std::sync::Arc<AuthGate>>,
144 req: axum::extract::Request,
145 next: axum::middleware::Next,
146) -> axum::response::Response {
147 use axum::http::{header, HeaderValue, StatusCode};
148 use axum::response::IntoResponse;
149
150 let header_value = req
151 .headers()
152 .get(header::AUTHORIZATION)
153 .and_then(|v| v.to_str().ok());
154
155 if gate.check_bearer(header_value) {
156 return next.run(req).await;
157 }
158
159 let body = axum::Json(serde_json::json!({
160 "error": "unauthorized",
161 "message": "Missing or invalid Authorization: Bearer <key> header"
162 }));
163 let mut resp = (StatusCode::UNAUTHORIZED, body).into_response();
164 resp.headers_mut()
165 .insert(header::WWW_AUTHENTICATE, HeaderValue::from_static("Bearer"));
166 resp
167}
168
169#[cfg(feature = "inference")]
175#[must_use]
176pub fn layer<S>(gate: AuthGate, router: axum::Router<S>) -> axum::Router<S>
177where
178 S: Clone + Send + Sync + 'static,
179{
180 router.layer(axum::middleware::from_fn_with_state(
181 std::sync::Arc::new(gate),
182 apply,
183 ))
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn disabled_gate_accepts_anything() {
192 let g = AuthGate::disabled();
193 assert!(g.check_bearer(None));
194 assert!(g.check_bearer(Some("Bearer anything")));
195 assert!(g.check_bearer(Some("garbage")));
196 assert!(!g.is_enabled());
197 }
198
199 #[test]
200 fn enabled_gate_rejects_missing_header() {
201 let g = AuthGate::from_plain_key("s3cr3t");
202 assert!(!g.check_bearer(None));
203 }
204
205 #[test]
206 fn enabled_gate_rejects_wrong_scheme() {
207 let g = AuthGate::from_plain_key("s3cr3t");
208 assert!(!g.check_bearer(Some("Basic dXNlcjpwYXNz")));
209 assert!(!g.check_bearer(Some("Bearer")));
210 }
211
212 #[test]
213 fn enabled_gate_accepts_correct_bearer() {
214 let g = AuthGate::from_plain_key("s3cr3t");
215 assert!(g.check_bearer(Some("Bearer s3cr3t")));
216 }
217
218 #[test]
219 fn enabled_gate_rejects_wrong_bearer() {
220 let g = AuthGate::from_plain_key("s3cr3t");
221 assert!(!g.check_bearer(Some("Bearer wrong")));
222 }
223
224 #[test]
225 fn from_hash_matches_from_plain_key_for_same_secret() {
226 let plain = "another-secret";
227 let g_plain = AuthGate::from_plain_key(plain);
228 let g_hash = AuthGate::from_hash(sha256_32(plain.as_bytes()));
229 assert!(g_plain.check_bearer(Some(&format!("Bearer {plain}"))));
230 assert!(g_hash.check_bearer(Some(&format!("Bearer {plain}"))));
231 }
232
233 #[test]
234 fn decode_hex_32_round_trip() {
235 let bytes = sha256_32(b"hello");
236 let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
237 let decoded = decode_hex_32(&hex).unwrap();
238 assert_eq!(decoded, bytes);
239 }
240
241 #[test]
242 fn decode_hex_32_rejects_wrong_length() {
243 assert!(decode_hex_32("deadbeef").is_err());
244 assert!(decode_hex_32(&"a".repeat(63)).is_err());
245 assert!(decode_hex_32(&"a".repeat(65)).is_err());
246 }
247
248 #[test]
249 fn decode_hex_32_rejects_non_hex_char() {
250 let mut bad = "0".repeat(64);
251 bad.replace_range(0..1, "Z");
252 assert!(decode_hex_32(&bad).is_err());
253 }
254}