Skip to main content

reddb_server/wire/redwire/
auth.rs

1//! Handshake state machine + auth method dispatch.
2//!
3//! Hello / HelloAck payloads are JSON for the initial cut. CBOR
4//! migration tracked as a follow-up — JSON keeps the v2 wire
5//! debuggable from a hex dump. The payload contract lives in
6//! `reddb-wire`; this module keeps server auth policy and validation.
7//!
8//! Auth methods supported in v2.1:
9//!   - `bearer`     — token in AuthResponse, validated against AuthStore
10//!   - `anonymous`  — only when AuthStore is disabled; no challenge
11
12use crate::auth::store::AuthStore;
13use crate::auth::Role;
14use crate::serde_json::{self, Value as JsonValue};
15use reddb_wire::redwire::handshake::{
16    base64_std, base64_std_decode, build_scram_auth_ok_payload, parse_auth_response_bearer_token,
17};
18
19/// Outcome of `validate_auth_response`.
20#[derive(Debug, Clone)]
21pub enum AuthOutcome {
22    /// Auth succeeded; session id + role for downstream dispatch.
23    Authenticated {
24        username: String,
25        role: Role,
26        tenant: Option<String>,
27        session_id: String,
28    },
29    /// Auth refused; the message is operator-readable.
30    Refused(String),
31}
32
33/// Server's policy for picking an auth method given the client's
34/// preferences. Strongest-first ordering — but when the server
35/// has no auth backend configured (`server_anon_ok = true`),
36/// `anonymous` wins over `bearer` because bearer validation
37/// would fail anyway. v2.1 supports bearer + anonymous; future
38/// versions prepend scram-sha-256, mtls, oauth-jwt to the
39/// priority list.
40pub fn pick_auth_method(client_methods: &[String], server_anon_ok: bool) -> Option<&'static str> {
41    // SCRAM (no-plaintext-on-the-wire) > OAuth-JWT (federated)
42    // > bearer (session token / API key) > anonymous.
43    // No-auth servers prefer anonymous so the handshake succeeds
44    // without an AuthStore lookup.
45    let priority: &[&'static str] = if server_anon_ok {
46        &["anonymous", "scram-sha-256", "oauth-jwt", "bearer"]
47    } else {
48        &["scram-sha-256", "oauth-jwt", "bearer", "anonymous"]
49    };
50    for method in priority {
51        if !client_methods.iter().any(|m| m == *method) {
52            continue;
53        }
54        if *method == "anonymous" && !server_anon_ok {
55            continue;
56        }
57        return Some(*method);
58    }
59    None
60}
61
62/// Validate the AuthResponse payload for the chosen method.
63pub fn validate_auth_response(
64    method: &str,
65    payload: &[u8],
66    auth_store: Option<&AuthStore>,
67) -> AuthOutcome {
68    match method {
69        "anonymous" => {
70            // Only legitimate when auth is disabled. Caller already
71            // gated this in `pick_auth_method`; double-check here.
72            if let Some(store) = auth_store {
73                if store.is_enabled() {
74                    return AuthOutcome::Refused(
75                        "anonymous auth refused — server has auth enabled".into(),
76                    );
77                }
78            }
79            AuthOutcome::Authenticated {
80                username: "anonymous".to_string(),
81                role: Role::Read,
82                tenant: None,
83                session_id: new_session_id(),
84            }
85        }
86        "bearer" => {
87            let token = parse_auth_response_bearer_token(payload).unwrap_or_default();
88            let Some(store) = auth_store else {
89                return AuthOutcome::Refused(
90                    "bearer auth refused — server has no auth store configured".into(),
91                );
92            };
93            match store.validate_token_full(&token) {
94                Some((user_id, role)) => AuthOutcome::Authenticated {
95                    username: user_id.username,
96                    role,
97                    tenant: user_id.tenant,
98                    session_id: new_session_id(),
99                },
100                None => AuthOutcome::Refused("bearer token invalid".into()),
101            }
102        }
103        "scram-sha-256" => AuthOutcome::Refused(
104            "scram-sha-256 must be driven through perform_scram_handshake — \
105             the 1-RTT validate_auth_response path doesn't apply"
106                .to_string(),
107        ),
108        "oauth-jwt" => {
109            // The OAuthValidator handle is expected via the
110            // RedWireConfig.oauth slot — plumbing happens in
111            // session::handle_session. When called here without
112            // it (e.g. test paths that don't set the handle),
113            // the v2 handshake refuses cleanly.
114            AuthOutcome::Refused(
115                "oauth-jwt requires RedWireConfig.oauth to be set. Pass an \
116                 OAuthValidator with the issuer + JWKS configured."
117                    .to_string(),
118            )
119        }
120        other => AuthOutcome::Refused(format!("auth method '{other}' is not supported in v2.1")),
121    }
122}
123
124/// Build the AuthOk payload the server sends after a successful
125/// auth.
126pub fn build_auth_ok(
127    session_id: &str,
128    username: &str,
129    role: Role,
130    server_features: u32,
131) -> Vec<u8> {
132    let role_str = role.to_string();
133    reddb_wire::redwire::handshake::build_auth_ok_payload(
134        session_id,
135        username,
136        &role_str,
137        server_features,
138    )
139}
140
141/// Build the AuthOk payload for a successful SCRAM completion.
142/// Carries the server signature so the client can verify the
143/// server also knew the verifier.
144pub fn build_scram_auth_ok(
145    session_id: &str,
146    username: &str,
147    role: Role,
148    server_features: u32,
149    server_signature: &[u8],
150) -> Vec<u8> {
151    let role = role.to_string();
152    build_scram_auth_ok_payload(
153        session_id,
154        username,
155        &role,
156        server_features,
157        server_signature,
158    )
159}
160
161/// Generate a 24-byte server nonce, base64-encoded. Cryptographic
162/// randomness sourced from the engine's existing `random_bytes`
163/// helper so SCRAM doesn't introduce a new RNG path.
164pub fn new_server_nonce() -> String {
165    base64_std(&crate::auth::store::random_bytes(18))
166}
167
168pub(crate) fn new_session_id_for_scram() -> String {
169    new_session_id()
170}
171
172/// Parse a compact-serialized JWT into a `DecodedJwt`. RFC 7519
173/// shape: `<base64url(header)>.<base64url(payload)>.<base64url(signature)>`.
174/// The validator does the heavy lifting (signature, claims,
175/// expiry); this function just splits + decodes.
176pub fn parse_jwt(token: &str) -> Result<crate::auth::oauth::DecodedJwt, String> {
177    let parts: Vec<&str> = token.split('.').collect();
178    if parts.len() != 3 {
179        return Err(format!(
180            "expected 3 dot-separated parts, got {}",
181            parts.len()
182        ));
183    }
184    let header_bytes =
185        base64_url_decode(parts[0]).ok_or_else(|| "header is not valid base64url".to_string())?;
186    let payload_bytes =
187        base64_url_decode(parts[1]).ok_or_else(|| "payload is not valid base64url".to_string())?;
188    let signature = base64_url_decode(parts[2])
189        .ok_or_else(|| "signature is not valid base64url".to_string())?;
190
191    let header_json: JsonValue =
192        serde_json::from_slice(&header_bytes).map_err(|e| format!("header JSON: {e}"))?;
193    let payload_json: JsonValue =
194        serde_json::from_slice(&payload_bytes).map_err(|e| format!("payload JSON: {e}"))?;
195
196    let header = jwt_header_from(&header_json)?;
197    let claims = jwt_claims_from(&payload_json);
198
199    let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
200
201    Ok(crate::auth::oauth::DecodedJwt {
202        header,
203        claims,
204        signing_input,
205        signature,
206    })
207}
208
209fn jwt_header_from(v: &JsonValue) -> Result<crate::auth::oauth::JwtHeader, String> {
210    let obj = v
211        .as_object()
212        .ok_or_else(|| "JWT header must be a JSON object".to_string())?;
213    let alg = obj
214        .get("alg")
215        .and_then(|x| x.as_str())
216        .ok_or_else(|| "JWT header missing 'alg'".to_string())?
217        .to_string();
218    let kid = obj.get("kid").and_then(|x| x.as_str()).map(String::from);
219    Ok(crate::auth::oauth::JwtHeader { alg, kid })
220}
221
222fn jwt_claims_from(v: &JsonValue) -> crate::auth::oauth::JwtClaims {
223    let obj = v.as_object().cloned().unwrap_or_default();
224    let mut claims = crate::auth::oauth::JwtClaims::default();
225    if let Some(s) = obj.get("iss").and_then(|x| x.as_str()) {
226        claims.iss = Some(s.to_string());
227    }
228    if let Some(s) = obj.get("sub").and_then(|x| x.as_str()) {
229        claims.sub = Some(s.to_string());
230    }
231    if let Some(s) = obj.get("aud").and_then(|x| x.as_str()) {
232        claims.aud = vec![s.to_string()];
233    } else if let Some(arr) = obj.get("aud").and_then(|x| x.as_array()) {
234        claims.aud = arr
235            .iter()
236            .filter_map(|v| v.as_str().map(String::from))
237            .collect();
238    }
239    if let Some(n) = obj.get("exp").and_then(|x| x.as_f64()) {
240        claims.exp = Some(n as i64);
241    }
242    if let Some(n) = obj.get("nbf").and_then(|x| x.as_f64()) {
243        claims.nbf = Some(n as i64);
244    }
245    if let Some(n) = obj.get("iat").and_then(|x| x.as_f64()) {
246        claims.iat = Some(n as i64);
247    }
248    for (k, v) in obj.iter() {
249        if matches!(k.as_str(), "iss" | "sub" | "aud" | "exp" | "nbf" | "iat") {
250            continue;
251        }
252        if let Some(s) = v.as_str() {
253            claims.extra.insert(k.clone(), s.to_string());
254        }
255    }
256    claims
257}
258
259/// Validate a JWT through the supplied `OAuthValidator`. Returns
260/// `(username, role)` on success, or a refusal reason.
261pub fn validate_oauth_jwt(
262    validator: &crate::auth::oauth::OAuthValidator,
263    raw_token: &str,
264) -> Result<(String, Role), String> {
265    validate_oauth_jwt_full(validator, raw_token).map(|(_tenant, username, role)| (username, role))
266}
267
268/// Tenant-aware variant of [`validate_oauth_jwt`]. Returns
269/// `(tenant, username, role)` so the caller can mint a session pinned
270/// to the tenant carried by the configured `tenant_claim`.
271pub fn validate_oauth_jwt_full(
272    validator: &crate::auth::oauth::OAuthValidator,
273    raw_token: &str,
274) -> Result<(Option<String>, String, Role), String> {
275    let token = parse_jwt(raw_token).map_err(|e| format!("decode JWT: {e}"))?;
276    let now = std::time::SystemTime::now()
277        .duration_since(std::time::UNIX_EPOCH)
278        .map(|d| d.as_secs() as i64)
279        .unwrap_or(0);
280    // sub-claim mode: the JWT subject IS the RedDB username. Roles map
281    // from a `role` custom claim; tenant from the configured tenant
282    // claim (default "tenant"). The lookup closure mirrors the same
283    // claims so `map_to_existing_users=false` deployments still get a
284    // tenant-tagged identity.
285    let identity = validator
286        .validate(&token, now, |sub| {
287            Some(crate::auth::User {
288                username: sub.to_string(),
289                tenant_id: token.claims.extra.get("tenant").cloned(),
290                password_hash: String::new(),
291                scram_verifier: None,
292                role: token
293                    .claims
294                    .extra
295                    .get("role")
296                    .and_then(|s| Role::from_str(s))
297                    .unwrap_or(Role::Read),
298                api_keys: Vec::new(),
299                created_at: 0,
300                updated_at: 0,
301                enabled: true,
302            })
303        })
304        .map_err(|e| format!("{e}"))?;
305    Ok((identity.tenant, identity.username, identity.role))
306}
307
308fn base64_url_decode(input: &str) -> Option<Vec<u8>> {
309    // base64url = '+' → '-', '/' → '_', stripped padding.
310    let mut s = String::with_capacity(input.len() + 4);
311    for ch in input.chars() {
312        match ch {
313            '-' => s.push('+'),
314            '_' => s.push('/'),
315            _ => s.push(ch),
316        }
317    }
318    while !s.len().is_multiple_of(4) {
319        s.push('=');
320    }
321    base64_std_decode(&s)
322}
323
324/// Generate a session id. Format: `rwsess-<unix_micros>-<rand>`.
325/// Not cryptographically random; the security boundary is the
326/// auth method, not session-id unguessability.
327fn new_session_id() -> String {
328    let now_us = std::time::SystemTime::now()
329        .duration_since(std::time::UNIX_EPOCH)
330        .map(|d| d.as_micros())
331        .unwrap_or(0);
332    let rand = crate::utils::now_unix_nanos() & 0xFFFF_FFFF;
333    format!("rwsess-{now_us}-{rand:08x}")
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use reddb_wire::redwire::handshake::{build_hello_ack, Hello};
340
341    #[test]
342    fn hello_round_trip() {
343        let payload = br#"{"versions":[1],"auth_methods":["bearer","anonymous"],"features":3,"client_name":"reddb-rs/0.1"}"#;
344        let h = Hello::from_payload(payload).unwrap();
345        assert_eq!(h.versions, vec![1]);
346        assert_eq!(h.auth_methods, vec!["bearer", "anonymous"]);
347        assert_eq!(h.features, 3);
348        assert_eq!(h.client_name.as_deref(), Some("reddb-rs/0.1"));
349    }
350
351    #[test]
352    fn hello_rejects_empty_methods() {
353        let payload = br#"{"versions":[1],"auth_methods":[]}"#;
354        assert!(Hello::from_payload(payload).is_err());
355    }
356
357    #[test]
358    fn pick_auth_prefers_anonymous_when_server_has_no_auth_store() {
359        // Without an auth store, bearer validation can't succeed.
360        // Picker should prefer anonymous so the handshake works.
361        let pref = vec!["anonymous".to_string(), "bearer".to_string()];
362        assert_eq!(pick_auth_method(&pref, true), Some("anonymous"));
363    }
364
365    #[test]
366    fn pick_auth_picks_bearer_when_anonymous_blocked() {
367        // Server has auth enabled (no anonymous) — bearer wins.
368        let pref = vec!["anonymous".to_string(), "bearer".to_string()];
369        assert_eq!(pick_auth_method(&pref, false), Some("bearer"));
370    }
371
372    #[test]
373    fn pick_auth_skips_anonymous_when_server_blocks_it() {
374        let pref = vec!["anonymous".to_string()];
375        assert_eq!(pick_auth_method(&pref, false), None);
376    }
377
378    #[test]
379    fn pick_auth_returns_none_when_nothing_overlaps() {
380        let pref = vec!["kerberos".to_string(), "future-method".to_string()];
381        assert_eq!(pick_auth_method(&pref, true), None);
382    }
383
384    #[test]
385    fn anonymous_validates_only_when_store_disabled() {
386        let outcome = validate_auth_response("anonymous", &[], None);
387        assert!(matches!(outcome, AuthOutcome::Authenticated { .. }));
388    }
389
390    #[test]
391    fn bearer_without_store_refuses() {
392        let outcome = validate_auth_response("bearer", br#"{"token":"x"}"#, None);
393        assert!(matches!(outcome, AuthOutcome::Refused(_)));
394    }
395
396    #[test]
397    fn hello_ack_omits_topology_field_when_caller_passes_none() {
398        // Backwards-compat: callers that haven't picked up the
399        // advertiser yet pass `None` and the JSON envelope keeps
400        // the same shape as pre-#167.
401        let bytes = build_hello_ack(1, "bearer", 0, None);
402        let s = std::str::from_utf8(&bytes).unwrap();
403        assert!(!s.contains("\"topology\""));
404    }
405
406    #[test]
407    fn hello_ack_embeds_topology_field_when_caller_passes_payload() {
408        // Issue #167: HelloAck builder inserts the canonical bytes
409        // base64-wrapped under JSON key `topology`. Round-trip via
410        // the wire decoder pins byte-for-byte equivalence with the
411        // canonical encoder (#166).
412        let topo = reddb_wire::topology::Topology {
413            epoch: 17,
414            primary: reddb_wire::topology::Endpoint {
415                addr: "primary:5050".into(),
416                region: "us-east-1".into(),
417            },
418            replicas: Vec::new(),
419        };
420        let bytes = build_hello_ack(1, "bearer", 0, Some(&topo));
421        let s = std::str::from_utf8(&bytes).unwrap();
422        assert!(s.contains("\"topology\""), "missing topology key in {s}");
423
424        // Extract and round-trip the field through the wire decoder.
425        let v: JsonValue = crate::serde_json::from_slice(&bytes).unwrap();
426        let field = v
427            .as_object()
428            .and_then(|o| o.get("topology"))
429            .and_then(|t| t.as_str())
430            .expect("topology key must be present and a string");
431        let decoded = reddb_wire::topology::decode_topology_from_hello_ack(field).expect("decode");
432        assert_eq!(decoded.expect("v1 known"), topo);
433    }
434}