Skip to main content

reddb_server/wire/redwire/
auth.rs

1//! Handshake state machine + auth method dispatch.
2//!
3//! Hello / HelloAck payloads are JSON for the initial cut. CBOR
4//! migration tracked as a follow-up — JSON keeps the v2 wire
5//! debuggable from a hex dump and reuses the engine's existing
6//! `crate::serde_json` codec without a new dep.
7//!
8//! Auth methods supported in v2.1:
9//!   - `bearer`     — token in AuthResponse, validated against AuthStore
10//!   - `anonymous`  — only when AuthStore is disabled; no challenge
11
12use crate::auth::store::AuthStore;
13use crate::auth::Role;
14use crate::serde_json::{self, Value as JsonValue};
15
16/// Methods we know how to handle today.
17///
18/// `bearer` + `anonymous` are 1-RTT and fully wired.
19/// `scram-sha-256` and `oauth-jwt` are advertised but the
20/// validate_auth_response side returns AuthFail until the
21/// AuthStore migration (Phase 3b/4) lands the verifier
22/// storage + OAuth authenticator handle. Listing them keeps
23/// Hello/HelloAck stable while the server-side wiring catches
24/// up — clients can probe for the method without churning the
25/// negotiation surface later.
26pub const SUPPORTED_METHODS: &[&str] = &["bearer", "anonymous", "scram-sha-256", "oauth-jwt"];
27
28/// Outcome of `validate_auth_response`.
29#[derive(Debug, Clone)]
30pub enum AuthOutcome {
31    /// Auth succeeded; session id + role for downstream dispatch.
32    Authenticated {
33        username: String,
34        role: Role,
35        session_id: String,
36    },
37    /// Auth refused; the message is operator-readable.
38    Refused(String),
39}
40
41/// Decode the JSON-shaped Hello payload sent by a v2 client.
42#[derive(Debug, Clone)]
43pub struct Hello {
44    pub versions: Vec<u8>,
45    pub auth_methods: Vec<String>,
46    pub features: u32,
47    pub client_name: Option<String>,
48}
49
50impl Hello {
51    pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
52        let v: JsonValue =
53            serde_json::from_slice(bytes).map_err(|e| format!("Hello: invalid JSON: {e}"))?;
54        let obj = match v {
55            JsonValue::Object(o) => o,
56            _ => return Err("Hello: payload must be a JSON object".into()),
57        };
58        let versions: Vec<u8> = obj
59            .get("versions")
60            .and_then(|v| v.as_array())
61            .map(|arr| {
62                arr.iter()
63                    .filter_map(|n| n.as_f64().map(|f| f as u8))
64                    .collect()
65            })
66            .unwrap_or_default();
67        let auth_methods: Vec<String> = obj
68            .get("auth_methods")
69            .and_then(|v| v.as_array())
70            .map(|arr| {
71                arr.iter()
72                    .filter_map(|s| s.as_str().map(String::from))
73                    .collect()
74            })
75            .unwrap_or_default();
76        let features = obj
77            .get("features")
78            .and_then(|v| v.as_f64())
79            .map(|f| f as u32)
80            .unwrap_or(0);
81        let client_name = obj
82            .get("client_name")
83            .and_then(|v| v.as_str())
84            .map(String::from);
85        if versions.is_empty() {
86            return Err("Hello: versions[] is empty".into());
87        }
88        if auth_methods.is_empty() {
89            return Err("Hello: auth_methods[] is empty".into());
90        }
91        Ok(Self {
92            versions,
93            auth_methods,
94            features,
95            client_name,
96        })
97    }
98}
99
100/// Build the HelloAck the server sends back. `chosen_auth` is the
101/// strongest method both sides support; `chosen_version` is
102/// `min(client_max, server_max)`.
103///
104/// When `topology` is `Some(_)`, the canonical bytes are
105/// base64-wrapped via `encode_topology_for_hello_ack` and embedded
106/// under the JSON key `"topology"` per issue #166's HelloAck
107/// embedding shape. Old clients that do not understand the key
108/// ignore it cleanly (ADR 0008 §4).
109///
110/// HelloAck travels *before* the AuthResponse, so the caller is
111/// expected to thread an *anonymous* auth context through
112/// `TopologyAdvertiser::advertise` — which collapses the payload
113/// to primary-only per ADR 0008 §3. A post-handshake
114/// re-advertisement (full replica list for an authenticated
115/// principal) rides the gRPC `Topology` RPC.
116pub fn build_hello_ack(
117    chosen_version: u8,
118    chosen_auth: &str,
119    server_features: u32,
120    topology: Option<&reddb_wire::topology::Topology>,
121) -> Vec<u8> {
122    use crate::json_field::SerializedJsonField;
123    // Every caller-influenced or composed string field is wired
124    // through the JSON-envelope guard so the field round-trips
125    // through the canonical RFC-8259 encoder rather than being
126    // string-concatenated. See ADR 0010 §3 and issue #178.
127    //
128    // `chosen_auth` is sourced from the client's Hello (an
129    // `auth_methods[]` entry the server picked), so it is caller-
130    // influenced. `server` is server-owned but composed via
131    // `format!` — wiring through the guard keeps the discipline
132    // uniform. `topology` is base64 over canonical bytes (#166)
133    // and structurally cannot contain delimiters, but the same
134    // guard applies for consistency.
135    let mut obj = crate::serde_json::Map::new();
136    obj.insert(
137        "version".to_string(),
138        JsonValue::Number(chosen_version as f64),
139    );
140    obj.insert(
141        "auth".to_string(),
142        SerializedJsonField::tainted(chosen_auth),
143    );
144    obj.insert(
145        "features".to_string(),
146        JsonValue::Number(server_features as f64),
147    );
148    let server_field = format!("reddb/{}", env!("CARGO_PKG_VERSION"));
149    obj.insert(
150        "server".to_string(),
151        SerializedJsonField::tainted(&server_field),
152    );
153    if let Some(topo) = topology {
154        obj.insert(
155            "topology".to_string(),
156            SerializedJsonField::tainted(&reddb_wire::topology::encode_topology_for_hello_ack(
157                topo,
158            )),
159        );
160    }
161    JsonValue::Object(obj).to_string_compact().into_bytes()
162}
163
164/// Server's policy for picking an auth method given the client's
165/// preferences. Strongest-first ordering — but when the server
166/// has no auth backend configured (`server_anon_ok = true`),
167/// `anonymous` wins over `bearer` because bearer validation
168/// would fail anyway. v2.1 supports bearer + anonymous; future
169/// versions prepend scram-sha-256, mtls, oauth-jwt to the
170/// priority list.
171pub fn pick_auth_method(client_methods: &[String], server_anon_ok: bool) -> Option<&'static str> {
172    // SCRAM (no-plaintext-on-the-wire) > OAuth-JWT (federated)
173    // > bearer (session token / API key) > anonymous.
174    // No-auth servers prefer anonymous so the handshake succeeds
175    // without an AuthStore lookup.
176    let priority: &[&'static str] = if server_anon_ok {
177        &["anonymous", "scram-sha-256", "oauth-jwt", "bearer"]
178    } else {
179        &["scram-sha-256", "oauth-jwt", "bearer", "anonymous"]
180    };
181    for method in priority {
182        if !client_methods.iter().any(|m| m == *method) {
183            continue;
184        }
185        if *method == "anonymous" && !server_anon_ok {
186            continue;
187        }
188        return Some(*method);
189    }
190    None
191}
192
193/// Validate the AuthResponse payload for the chosen method.
194pub fn validate_auth_response(
195    method: &str,
196    payload: &[u8],
197    auth_store: Option<&AuthStore>,
198) -> AuthOutcome {
199    match method {
200        "anonymous" => {
201            // Only legitimate when auth is disabled. Caller already
202            // gated this in `pick_auth_method`; double-check here.
203            if let Some(store) = auth_store {
204                if store.is_enabled() {
205                    return AuthOutcome::Refused(
206                        "anonymous auth refused — server has auth enabled".into(),
207                    );
208                }
209            }
210            AuthOutcome::Authenticated {
211                username: "anonymous".to_string(),
212                role: Role::Read,
213                session_id: new_session_id(),
214            }
215        }
216        "bearer" => {
217            let token = parse_bearer_response(payload).unwrap_or_default();
218            let Some(store) = auth_store else {
219                return AuthOutcome::Refused(
220                    "bearer auth refused — server has no auth store configured".into(),
221                );
222            };
223            match store.validate_token(&token) {
224                Some((username, role)) => AuthOutcome::Authenticated {
225                    username,
226                    role,
227                    session_id: new_session_id(),
228                },
229                None => AuthOutcome::Refused("bearer token invalid".into()),
230            }
231        }
232        "scram-sha-256" => AuthOutcome::Refused(
233            "scram-sha-256 must be driven through perform_scram_handshake — \
234             the 1-RTT validate_auth_response path doesn't apply"
235                .to_string(),
236        ),
237        "oauth-jwt" => {
238            // The OAuthValidator handle is expected via the
239            // RedWireConfig.oauth slot — plumbing happens in
240            // session::handle_session. When called here without
241            // it (e.g. test paths that don't set the handle),
242            // the v2 handshake refuses cleanly.
243            AuthOutcome::Refused(
244                "oauth-jwt requires RedWireConfig.oauth to be set. Pass an \
245                 OAuthValidator with the issuer + JWKS configured."
246                    .to_string(),
247            )
248        }
249        other => AuthOutcome::Refused(format!("auth method '{other}' is not supported in v2.1")),
250    }
251}
252
253fn parse_bearer_response(payload: &[u8]) -> Option<String> {
254    let v: JsonValue = serde_json::from_slice(payload).ok()?;
255    let token = v.as_object()?.get("token")?.as_str()?;
256    Some(token.to_string())
257}
258
259/// Build the AuthOk payload the server sends after a successful
260/// auth.
261pub fn build_auth_ok(
262    session_id: &str,
263    username: &str,
264    role: Role,
265    server_features: u32,
266) -> Vec<u8> {
267    use crate::json_field::SerializedJsonField;
268    // `username` is caller-influenced (the client claimed it during
269    // bearer / SCRAM); `session_id` is server-issued but routed
270    // through the guard so the discipline is uniform. ADR 0010 §3 / #178.
271    let mut obj = crate::serde_json::Map::new();
272    obj.insert(
273        "session_id".to_string(),
274        SerializedJsonField::tainted(session_id),
275    );
276    obj.insert(
277        "username".to_string(),
278        SerializedJsonField::tainted(username),
279    );
280    let role_str = role.to_string();
281    obj.insert("role".to_string(), SerializedJsonField::tainted(&role_str));
282    obj.insert(
283        "features".to_string(),
284        JsonValue::Number(server_features as f64),
285    );
286    JsonValue::Object(obj).to_string_compact().into_bytes()
287}
288
289pub fn build_auth_fail(reason: &str) -> Vec<u8> {
290    use crate::json_field::SerializedJsonField;
291    // `reason` is composed from validator output that may include
292    // user-controlled fragments (e.g. token text, JWT claim names);
293    // wire it through the guard. ADR 0010 §3 / #178.
294    let mut obj = crate::serde_json::Map::new();
295    obj.insert("reason".to_string(), SerializedJsonField::tainted(reason));
296    JsonValue::Object(obj).to_string_compact().into_bytes()
297}
298
299/// Parse a SCRAM client-first-message.
300/// Format: `n,,n=<user>,r=<client_nonce>` (no channel binding,
301/// no authzid). Returns `(username, client_nonce, bare_message)`.
302pub fn parse_scram_client_first(payload: &[u8]) -> Result<(String, String, String), String> {
303    let s = std::str::from_utf8(payload).map_err(|_| "client-first not UTF-8".to_string())?;
304    // Strip the GS2 header `n,,` (or `y,,` / `p=...,`). v2.1 only
305    // accepts `n,,` — explicit no-channel-binding.
306    let bare = s
307        .strip_prefix("n,,")
308        .ok_or_else(|| "client-first must start with 'n,,' (no channel binding)".to_string())?;
309    let mut user = None;
310    let mut nonce = None;
311    for part in bare.split(',') {
312        if let Some(v) = part.strip_prefix("n=") {
313            user = Some(v.to_string());
314        } else if let Some(v) = part.strip_prefix("r=") {
315            nonce = Some(v.to_string());
316        }
317    }
318    let user = user.ok_or_else(|| "missing n=<user>".to_string())?;
319    let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
320    Ok((user, nonce, bare.to_string()))
321}
322
323/// Build the SCRAM server-first-message. Sent in `AuthRequest`.
324/// Format: `r=<client_nonce><server_nonce>,s=<salt_b64>,i=<iter>`.
325pub fn build_scram_server_first(
326    client_nonce: &str,
327    server_nonce: &str,
328    salt: &[u8],
329    iter: u32,
330) -> String {
331    format!(
332        "r={client_nonce}{server_nonce},s={},i={iter}",
333        base64_std(salt)
334    )
335}
336
337/// Parse SCRAM client-final-message.
338/// Format: `c=<channel_binding_b64>,r=<combined_nonce>,p=<proof_b64>`.
339pub fn parse_scram_client_final(payload: &[u8]) -> Result<(String, Vec<u8>, String), String> {
340    let s = std::str::from_utf8(payload).map_err(|_| "client-final not UTF-8".to_string())?;
341    let mut channel_binding = None;
342    let mut nonce = None;
343    let mut proof_b64 = None;
344    for part in s.split(',') {
345        if let Some(v) = part.strip_prefix("c=") {
346            channel_binding = Some(v.to_string());
347        } else if let Some(v) = part.strip_prefix("r=") {
348            nonce = Some(v.to_string());
349        } else if let Some(v) = part.strip_prefix("p=") {
350            proof_b64 = Some(v.to_string());
351        }
352    }
353    let channel_binding =
354        channel_binding.ok_or_else(|| "missing c=<channel-binding>".to_string())?;
355    let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
356    let proof_b64 = proof_b64.ok_or_else(|| "missing p=<proof>".to_string())?;
357    let proof = base64_std_decode(&proof_b64)
358        .ok_or_else(|| "client proof is not valid base64".to_string())?;
359    // c=biws is base64("n,,") — the canonical no-channel-binding GS2 header.
360    if channel_binding != "biws" {
361        return Err(format!(
362            "channel binding must be 'biws' (n,,), got '{channel_binding}'"
363        ));
364    }
365    let no_proof = format!("c={channel_binding},r={nonce}");
366    Ok((nonce, proof, no_proof))
367}
368
369/// Build the AuthOk payload for a successful SCRAM completion.
370/// Carries the server signature so the client can verify the
371/// server also knew the verifier.
372pub fn build_scram_auth_ok(
373    session_id: &str,
374    username: &str,
375    role: Role,
376    server_features: u32,
377    server_signature: &[u8],
378) -> Vec<u8> {
379    let mut obj = crate::serde_json::Map::new();
380    obj.insert(
381        "session_id".to_string(),
382        JsonValue::String(session_id.to_string()),
383    );
384    obj.insert(
385        "username".to_string(),
386        JsonValue::String(username.to_string()),
387    );
388    obj.insert("role".to_string(), JsonValue::String(role.to_string()));
389    obj.insert(
390        "features".to_string(),
391        JsonValue::Number(server_features as f64),
392    );
393    obj.insert(
394        "v".to_string(),
395        JsonValue::String(base64_std(server_signature)),
396    );
397    JsonValue::Object(obj).to_string_compact().into_bytes()
398}
399
400/// Generate a 24-byte server nonce, base64-encoded. Cryptographic
401/// randomness sourced from the engine's existing `random_bytes`
402/// helper so SCRAM doesn't introduce a new RNG path.
403pub fn new_server_nonce() -> String {
404    base64_std(&crate::auth::store::random_bytes(18))
405}
406
407pub(crate) fn new_session_id_for_scram() -> String {
408    new_session_id()
409}
410
411// ---------------------------------------------------------------
412// Tiny base64 — RFC 4648 standard alphabet. Only used for SCRAM
413// payloads + AuthOk signature, low-frequency so a hand-rolled
414// codec is fine and avoids pulling another crate.
415// ---------------------------------------------------------------
416
417const B64_ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
418
419pub fn base64_std(input: &[u8]) -> String {
420    let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
421    let chunks = input.chunks_exact(3);
422    let rem = chunks.remainder();
423    for c in chunks {
424        let n = ((c[0] as u32) << 16) | ((c[1] as u32) << 8) | (c[2] as u32);
425        out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
426        out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
427        out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
428        out.push(B64_ALPHA[(n & 0x3F) as usize] as char);
429    }
430    match rem {
431        [a] => {
432            let n = (*a as u32) << 16;
433            out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
434            out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
435            out.push('=');
436            out.push('=');
437        }
438        [a, b] => {
439            let n = ((*a as u32) << 16) | ((*b as u32) << 8);
440            out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
441            out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
442            out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
443            out.push('=');
444        }
445        _ => {}
446    }
447    out
448}
449
450pub fn base64_std_decode(input: &str) -> Option<Vec<u8>> {
451    let trimmed = input.trim_end_matches('=');
452    let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
453    let mut buf = 0u32;
454    let mut bits = 0u8;
455    for ch in trimmed.bytes() {
456        let v: u32 = match ch {
457            b'A'..=b'Z' => (ch - b'A') as u32,
458            b'a'..=b'z' => (ch - b'a' + 26) as u32,
459            b'0'..=b'9' => (ch - b'0' + 52) as u32,
460            b'+' => 62,
461            b'/' => 63,
462            _ => return None,
463        };
464        buf = (buf << 6) | v;
465        bits += 6;
466        if bits >= 8 {
467            bits -= 8;
468            out.push(((buf >> bits) & 0xFF) as u8);
469        }
470    }
471    Some(out)
472}
473
474/// Parse a compact-serialized JWT into a `DecodedJwt`. RFC 7519
475/// shape: `<base64url(header)>.<base64url(payload)>.<base64url(signature)>`.
476/// The validator does the heavy lifting (signature, claims,
477/// expiry); this function just splits + decodes.
478pub fn parse_jwt(token: &str) -> Result<crate::auth::oauth::DecodedJwt, String> {
479    let parts: Vec<&str> = token.split('.').collect();
480    if parts.len() != 3 {
481        return Err(format!(
482            "expected 3 dot-separated parts, got {}",
483            parts.len()
484        ));
485    }
486    let header_bytes =
487        base64_url_decode(parts[0]).ok_or_else(|| "header is not valid base64url".to_string())?;
488    let payload_bytes =
489        base64_url_decode(parts[1]).ok_or_else(|| "payload is not valid base64url".to_string())?;
490    let signature = base64_url_decode(parts[2])
491        .ok_or_else(|| "signature is not valid base64url".to_string())?;
492
493    let header_json: JsonValue =
494        serde_json::from_slice(&header_bytes).map_err(|e| format!("header JSON: {e}"))?;
495    let payload_json: JsonValue =
496        serde_json::from_slice(&payload_bytes).map_err(|e| format!("payload JSON: {e}"))?;
497
498    let header = jwt_header_from(&header_json)?;
499    let claims = jwt_claims_from(&payload_json);
500
501    let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
502
503    Ok(crate::auth::oauth::DecodedJwt {
504        header,
505        claims,
506        signing_input,
507        signature,
508    })
509}
510
511fn jwt_header_from(v: &JsonValue) -> Result<crate::auth::oauth::JwtHeader, String> {
512    let obj = v
513        .as_object()
514        .ok_or_else(|| "JWT header must be a JSON object".to_string())?;
515    let alg = obj
516        .get("alg")
517        .and_then(|x| x.as_str())
518        .ok_or_else(|| "JWT header missing 'alg'".to_string())?
519        .to_string();
520    let kid = obj.get("kid").and_then(|x| x.as_str()).map(String::from);
521    Ok(crate::auth::oauth::JwtHeader { alg, kid })
522}
523
524fn jwt_claims_from(v: &JsonValue) -> crate::auth::oauth::JwtClaims {
525    let obj = v.as_object().cloned().unwrap_or_default();
526    let mut claims = crate::auth::oauth::JwtClaims::default();
527    if let Some(s) = obj.get("iss").and_then(|x| x.as_str()) {
528        claims.iss = Some(s.to_string());
529    }
530    if let Some(s) = obj.get("sub").and_then(|x| x.as_str()) {
531        claims.sub = Some(s.to_string());
532    }
533    if let Some(s) = obj.get("aud").and_then(|x| x.as_str()) {
534        claims.aud = vec![s.to_string()];
535    } else if let Some(arr) = obj.get("aud").and_then(|x| x.as_array()) {
536        claims.aud = arr
537            .iter()
538            .filter_map(|v| v.as_str().map(String::from))
539            .collect();
540    }
541    if let Some(n) = obj.get("exp").and_then(|x| x.as_f64()) {
542        claims.exp = Some(n as i64);
543    }
544    if let Some(n) = obj.get("nbf").and_then(|x| x.as_f64()) {
545        claims.nbf = Some(n as i64);
546    }
547    if let Some(n) = obj.get("iat").and_then(|x| x.as_f64()) {
548        claims.iat = Some(n as i64);
549    }
550    for (k, v) in obj.iter() {
551        if matches!(k.as_str(), "iss" | "sub" | "aud" | "exp" | "nbf" | "iat") {
552            continue;
553        }
554        if let Some(s) = v.as_str() {
555            claims.extra.insert(k.clone(), s.to_string());
556        }
557    }
558    claims
559}
560
561/// Validate a JWT through the supplied `OAuthValidator`. Returns
562/// `(username, role)` on success, or a refusal reason.
563pub fn validate_oauth_jwt(
564    validator: &crate::auth::oauth::OAuthValidator,
565    raw_token: &str,
566) -> Result<(String, Role), String> {
567    validate_oauth_jwt_full(validator, raw_token).map(|(_tenant, username, role)| (username, role))
568}
569
570/// Tenant-aware variant of [`validate_oauth_jwt`]. Returns
571/// `(tenant, username, role)` so the caller can mint a session pinned
572/// to the tenant carried by the configured `tenant_claim`.
573pub fn validate_oauth_jwt_full(
574    validator: &crate::auth::oauth::OAuthValidator,
575    raw_token: &str,
576) -> Result<(Option<String>, String, Role), String> {
577    let token = parse_jwt(raw_token).map_err(|e| format!("decode JWT: {e}"))?;
578    let now = std::time::SystemTime::now()
579        .duration_since(std::time::UNIX_EPOCH)
580        .map(|d| d.as_secs() as i64)
581        .unwrap_or(0);
582    // sub-claim mode: the JWT subject IS the RedDB username. Roles map
583    // from a `role` custom claim; tenant from the configured tenant
584    // claim (default "tenant"). The lookup closure mirrors the same
585    // claims so `map_to_existing_users=false` deployments still get a
586    // tenant-tagged identity.
587    let identity = validator
588        .validate(&token, now, |sub| {
589            Some(crate::auth::User {
590                username: sub.to_string(),
591                tenant_id: token.claims.extra.get("tenant").cloned(),
592                password_hash: String::new(),
593                scram_verifier: None,
594                role: token
595                    .claims
596                    .extra
597                    .get("role")
598                    .and_then(|s| Role::from_str(s))
599                    .unwrap_or(Role::Read),
600                api_keys: Vec::new(),
601                created_at: 0,
602                updated_at: 0,
603                enabled: true,
604            })
605        })
606        .map_err(|e| format!("{e}"))?;
607    Ok((identity.tenant, identity.username, identity.role))
608}
609
610fn base64_url_decode(input: &str) -> Option<Vec<u8>> {
611    // base64url = '+' → '-', '/' → '_', stripped padding.
612    let mut s = String::with_capacity(input.len() + 4);
613    for ch in input.chars() {
614        match ch {
615            '-' => s.push('+'),
616            '_' => s.push('/'),
617            _ => s.push(ch),
618        }
619    }
620    while !s.len().is_multiple_of(4) {
621        s.push('=');
622    }
623    base64_std_decode(&s)
624}
625
626/// Generate a session id. Format: `rwsess-<unix_micros>-<rand>`.
627/// Not cryptographically random; the security boundary is the
628/// auth method, not session-id unguessability.
629fn new_session_id() -> String {
630    let now_us = std::time::SystemTime::now()
631        .duration_since(std::time::UNIX_EPOCH)
632        .map(|d| d.as_micros())
633        .unwrap_or(0);
634    let rand = crate::utils::now_unix_nanos() & 0xFFFF_FFFF;
635    format!("rwsess-{now_us}-{rand:08x}")
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    #[test]
643    fn hello_round_trip() {
644        let payload = br#"{"versions":[1],"auth_methods":["bearer","anonymous"],"features":3,"client_name":"reddb-rs/0.1"}"#;
645        let h = Hello::from_payload(payload).unwrap();
646        assert_eq!(h.versions, vec![1]);
647        assert_eq!(h.auth_methods, vec!["bearer", "anonymous"]);
648        assert_eq!(h.features, 3);
649        assert_eq!(h.client_name.as_deref(), Some("reddb-rs/0.1"));
650    }
651
652    #[test]
653    fn hello_rejects_empty_methods() {
654        let payload = br#"{"versions":[1],"auth_methods":[]}"#;
655        assert!(Hello::from_payload(payload).is_err());
656    }
657
658    #[test]
659    fn pick_auth_prefers_anonymous_when_server_has_no_auth_store() {
660        // Without an auth store, bearer validation can't succeed.
661        // Picker should prefer anonymous so the handshake works.
662        let pref = vec!["anonymous".to_string(), "bearer".to_string()];
663        assert_eq!(pick_auth_method(&pref, true), Some("anonymous"));
664    }
665
666    #[test]
667    fn pick_auth_picks_bearer_when_anonymous_blocked() {
668        // Server has auth enabled (no anonymous) — bearer wins.
669        let pref = vec!["anonymous".to_string(), "bearer".to_string()];
670        assert_eq!(pick_auth_method(&pref, false), Some("bearer"));
671    }
672
673    #[test]
674    fn pick_auth_skips_anonymous_when_server_blocks_it() {
675        let pref = vec!["anonymous".to_string()];
676        assert_eq!(pick_auth_method(&pref, false), None);
677    }
678
679    #[test]
680    fn pick_auth_returns_none_when_nothing_overlaps() {
681        let pref = vec!["kerberos".to_string(), "future-method".to_string()];
682        assert_eq!(pick_auth_method(&pref, true), None);
683    }
684
685    #[test]
686    fn anonymous_validates_only_when_store_disabled() {
687        let outcome = validate_auth_response("anonymous", &[], None);
688        assert!(matches!(outcome, AuthOutcome::Authenticated { .. }));
689    }
690
691    #[test]
692    fn bearer_without_store_refuses() {
693        let outcome = validate_auth_response("bearer", br#"{"token":"x"}"#, None);
694        assert!(matches!(outcome, AuthOutcome::Refused(_)));
695    }
696
697    #[test]
698    fn hello_ack_omits_topology_field_when_caller_passes_none() {
699        // Backwards-compat: callers that haven't picked up the
700        // advertiser yet pass `None` and the JSON envelope keeps
701        // the same shape as pre-#167.
702        let bytes = build_hello_ack(1, "bearer", 0, None);
703        let s = std::str::from_utf8(&bytes).unwrap();
704        assert!(!s.contains("\"topology\""));
705    }
706
707    #[test]
708    fn hello_ack_embeds_topology_field_when_caller_passes_payload() {
709        // Issue #167: HelloAck builder inserts the canonical bytes
710        // base64-wrapped under JSON key `topology`. Round-trip via
711        // the wire decoder pins byte-for-byte equivalence with the
712        // canonical encoder (#166).
713        let topo = reddb_wire::topology::Topology {
714            epoch: 17,
715            primary: reddb_wire::topology::Endpoint {
716                addr: "primary:5050".into(),
717                region: "us-east-1".into(),
718            },
719            replicas: Vec::new(),
720        };
721        let bytes = build_hello_ack(1, "bearer", 0, Some(&topo));
722        let s = std::str::from_utf8(&bytes).unwrap();
723        assert!(s.contains("\"topology\""), "missing topology key in {s}");
724
725        // Extract and round-trip the field through the wire decoder.
726        let v: JsonValue = crate::serde_json::from_slice(&bytes).unwrap();
727        let field = v
728            .as_object()
729            .and_then(|o| o.get("topology"))
730            .and_then(|t| t.as_str())
731            .expect("topology key must be present and a string");
732        let decoded = reddb_wire::topology::decode_topology_from_hello_ack(field).expect("decode");
733        assert_eq!(decoded.expect("v1 known"), topo);
734    }
735}