Skip to main content

apr_cli/commands/serve/
auth.rs

1//! HELIX-IDEA-009 — single-key bearer-token auth for `apr serve`.
2//!
3//! Contract: `contracts/apr-serve-api-key-auth-v1.yaml`
4//! Pattern source: `helix-db/src/helix_gateway/key_verification.rs`
5//! (re-implemented; no code lift).
6//!
7//! Comparison goes through `subtle::ConstantTimeEq` over fixed-length
8//! SHA-256 digests so a probing attacker cannot leak the configured key
9//! by timing the response.
10
11use sha2::{Digest, Sha256};
12use subtle::ConstantTimeEq;
13
14/// Configured authentication state for an `apr serve` instance.
15///
16/// `expected_hash == None` means auth is disabled — every request passes.
17/// The startup path emits a stderr warning when this happens so operators
18/// see the open door at boot.
19#[derive(Clone, Debug, Default)]
20pub struct AuthGate {
21    expected_hash: Option<[u8; 32]>,
22}
23
24impl AuthGate {
25    /// Construct an explicitly-disabled gate. Test fixtures use this; the
26    /// production startup path uses [`Self::from_env`].
27    #[must_use]
28    pub fn disabled() -> Self {
29        Self {
30            expected_hash: None,
31        }
32    }
33
34    /// Construct an enabled gate from a SHA-256 digest of the expected
35    /// bearer token. Tests use this to avoid touching env vars.
36    #[must_use]
37    pub fn from_hash(expected_hash: [u8; 32]) -> Self {
38        Self {
39            expected_hash: Some(expected_hash),
40        }
41    }
42
43    /// Construct an enabled gate from a plaintext API key.
44    #[must_use]
45    pub fn from_plain_key(key: &str) -> Self {
46        Self::from_hash(sha256_32(key.as_bytes()))
47    }
48
49    /// Read configuration from env. `APR_API_KEY_HASH` (hex-encoded
50    /// SHA-256) wins over `APR_API_KEY` (plaintext, hashed in-process).
51    /// Neither set → disabled with a single stderr warning.
52    #[must_use]
53    pub fn from_env() -> Self {
54        if let Ok(hex) = std::env::var("APR_API_KEY_HASH") {
55            match decode_hex_32(&hex) {
56                Ok(bytes) => return Self::from_hash(bytes),
57                Err(reason) => {
58                    eprintln!(
59                        "[apr serve] APR_API_KEY_HASH set but {reason}; ignoring (auth disabled)",
60                    );
61                    return Self::disabled();
62                }
63            }
64        }
65        if let Ok(plain) = std::env::var("APR_API_KEY") {
66            if !plain.is_empty() {
67                return Self::from_plain_key(&plain);
68            }
69        }
70        eprintln!(
71            "[apr serve] WARNING: no APR_API_KEY or APR_API_KEY_HASH set; HTTP routes are unauthenticated",
72        );
73        Self::disabled()
74    }
75
76    /// True iff the gate has a configured hash.
77    #[must_use]
78    pub fn is_enabled(&self) -> bool {
79        self.expected_hash.is_some()
80    }
81
82    /// Verify a presented `Authorization` header value.
83    ///
84    /// Returns `true` iff the gate is disabled OR the header is exactly
85    /// `Bearer <key>` and `sha256(key) == expected_hash` (constant-time).
86    /// All other shapes — missing header, wrong scheme, empty token —
87    /// return `false`.
88    #[must_use]
89    pub fn check_bearer(&self, header: Option<&str>) -> bool {
90        let Some(expected) = self.expected_hash.as_ref() else {
91            return true;
92        };
93        let Some(value) = header else {
94            return false;
95        };
96        let Some(token) = value.strip_prefix("Bearer ") else {
97            return false;
98        };
99        let presented = sha256_32(token.as_bytes());
100        bool::from(expected.ct_eq(&presented))
101    }
102}
103
104fn sha256_32(input: &[u8]) -> [u8; 32] {
105    let digest = Sha256::digest(input);
106    let mut out = [0u8; 32];
107    out.copy_from_slice(&digest);
108    out
109}
110
111fn decode_hex_32(hex: &str) -> Result<[u8; 32], &'static str> {
112    if hex.len() != 64 {
113        return Err("APR_API_KEY_HASH must be 64 hex chars (SHA-256)");
114    }
115    let bytes = hex.as_bytes();
116    let mut out = [0u8; 32];
117    for (i, slot) in out.iter_mut().enumerate() {
118        let hi = hex_digit(bytes[i * 2])?;
119        let lo = hex_digit(bytes[i * 2 + 1])?;
120        *slot = (hi << 4) | lo;
121    }
122    Ok(out)
123}
124
125fn hex_digit(b: u8) -> Result<u8, &'static str> {
126    match b {
127        b'0'..=b'9' => Ok(b - b'0'),
128        b'a'..=b'f' => Ok(b - b'a' + 10),
129        b'A'..=b'F' => Ok(b - b'A' + 10),
130        _ => Err("APR_API_KEY_HASH must contain only [0-9a-fA-F]"),
131    }
132}
133
134/// Axum middleware closure that rejects unauthenticated requests with
135/// `401 Unauthorized` + `WWW-Authenticate: Bearer` and a JSON envelope.
136///
137/// Wired into the per-router builder via
138/// `axum::middleware::from_fn_with_state(Arc::new(gate), apply)`. The
139/// gate is shared (Arc) so all routes on a router observe the same
140/// configuration.
141#[cfg(feature = "inference")]
142pub async fn apply(
143    axum::extract::State(gate): axum::extract::State<std::sync::Arc<AuthGate>>,
144    req: axum::extract::Request,
145    next: axum::middleware::Next,
146) -> axum::response::Response {
147    use axum::http::{header, HeaderValue, StatusCode};
148    use axum::response::IntoResponse;
149
150    let header_value = req
151        .headers()
152        .get(header::AUTHORIZATION)
153        .and_then(|v| v.to_str().ok());
154
155    if gate.check_bearer(header_value) {
156        return next.run(req).await;
157    }
158
159    let body = axum::Json(serde_json::json!({
160        "error": "unauthorized",
161        "message": "Missing or invalid Authorization: Bearer <key> header"
162    }));
163    let mut resp = (StatusCode::UNAUTHORIZED, body).into_response();
164    resp.headers_mut()
165        .insert(header::WWW_AUTHENTICATE, HeaderValue::from_static("Bearer"));
166    resp
167}
168
169/// Layer the auth gate onto an axum `Router`, independent of the
170/// router's own state type. Callsites in each router builder use this
171/// to share one implementation. The gate is wrapped in `Arc` once so
172/// every route on the router observes the same snapshot, even after
173/// post-startup env-var changes.
174#[cfg(feature = "inference")]
175#[must_use]
176pub fn layer<S>(gate: AuthGate, router: axum::Router<S>) -> axum::Router<S>
177where
178    S: Clone + Send + Sync + 'static,
179{
180    router.layer(axum::middleware::from_fn_with_state(
181        std::sync::Arc::new(gate),
182        apply,
183    ))
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn disabled_gate_accepts_anything() {
192        let g = AuthGate::disabled();
193        assert!(g.check_bearer(None));
194        assert!(g.check_bearer(Some("Bearer anything")));
195        assert!(g.check_bearer(Some("garbage")));
196        assert!(!g.is_enabled());
197    }
198
199    #[test]
200    fn enabled_gate_rejects_missing_header() {
201        let g = AuthGate::from_plain_key("s3cr3t");
202        assert!(!g.check_bearer(None));
203    }
204
205    #[test]
206    fn enabled_gate_rejects_wrong_scheme() {
207        let g = AuthGate::from_plain_key("s3cr3t");
208        assert!(!g.check_bearer(Some("Basic dXNlcjpwYXNz")));
209        assert!(!g.check_bearer(Some("Bearer")));
210    }
211
212    #[test]
213    fn enabled_gate_accepts_correct_bearer() {
214        let g = AuthGate::from_plain_key("s3cr3t");
215        assert!(g.check_bearer(Some("Bearer s3cr3t")));
216    }
217
218    #[test]
219    fn enabled_gate_rejects_wrong_bearer() {
220        let g = AuthGate::from_plain_key("s3cr3t");
221        assert!(!g.check_bearer(Some("Bearer wrong")));
222    }
223
224    #[test]
225    fn from_hash_matches_from_plain_key_for_same_secret() {
226        let plain = "another-secret";
227        let g_plain = AuthGate::from_plain_key(plain);
228        let g_hash = AuthGate::from_hash(sha256_32(plain.as_bytes()));
229        assert!(g_plain.check_bearer(Some(&format!("Bearer {plain}"))));
230        assert!(g_hash.check_bearer(Some(&format!("Bearer {plain}"))));
231    }
232
233    #[test]
234    fn decode_hex_32_round_trip() {
235        let bytes = sha256_32(b"hello");
236        let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
237        let decoded = decode_hex_32(&hex).unwrap();
238        assert_eq!(decoded, bytes);
239    }
240
241    #[test]
242    fn decode_hex_32_rejects_wrong_length() {
243        assert!(decode_hex_32("deadbeef").is_err());
244        assert!(decode_hex_32(&"a".repeat(63)).is_err());
245        assert!(decode_hex_32(&"a".repeat(65)).is_err());
246    }
247
248    #[test]
249    fn decode_hex_32_rejects_non_hex_char() {
250        let mut bad = "0".repeat(64);
251        bad.replace_range(0..1, "Z");
252        assert!(decode_hex_32(&bad).is_err());
253    }
254}