Skip to main content

reddb_wire/redwire/
handshake.rs

1//! RedWire handshake payload contracts.
2//!
3//! Authentication policy and credential validation belong in the
4//! server. This module owns only the wire-visible JSON shapes used by
5//! Hello, HelloAck, AuthResponse, AuthOk, and AuthFail.
6
7use serde_json::Value as JsonValue;
8use std::fmt;
9
10use super::{BuildError, Frame, FrameBuilder, MessageKind, MAX_KNOWN_MINOR_VERSION};
11
12/// Methods RedWire v2.1 knows how to negotiate.
13pub const SUPPORTED_METHODS: &[&str] = &["bearer", "anonymous", "scram-sha-256", "oauth-jwt"];
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Hello {
17    pub versions: Vec<u8>,
18    pub auth_methods: Vec<String>,
19    pub features: u32,
20    pub client_name: Option<String>,
21}
22
23impl Hello {
24    pub fn to_payload(&self) -> Vec<u8> {
25        build_hello_payload(
26            &self.versions,
27            self.auth_methods.iter().map(String::as_str),
28            self.features,
29            self.client_name.as_deref(),
30        )
31    }
32
33    pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
34        let v: JsonValue =
35            serde_json::from_slice(bytes).map_err(|e| format!("Hello: invalid JSON: {e}"))?;
36        let obj = match v {
37            JsonValue::Object(o) => o,
38            _ => return Err("Hello: payload must be a JSON object".into()),
39        };
40        let versions: Vec<u8> = obj
41            .get("versions")
42            .and_then(|v| v.as_array())
43            .map(|arr| {
44                arr.iter()
45                    .filter_map(|n| n.as_u64().map(|u| u as u8))
46                    .collect()
47            })
48            .unwrap_or_default();
49        let auth_methods: Vec<String> = obj
50            .get("auth_methods")
51            .and_then(|v| v.as_array())
52            .map(|arr| {
53                arr.iter()
54                    .filter_map(|s| s.as_str().map(String::from))
55                    .collect()
56            })
57            .unwrap_or_default();
58        let features = obj
59            .get("features")
60            .and_then(|v| v.as_u64())
61            .map(|u| u as u32)
62            .unwrap_or(0);
63        let client_name = obj
64            .get("client_name")
65            .and_then(|v| v.as_str())
66            .map(String::from);
67        if versions.is_empty() {
68            return Err("Hello: versions[] is empty".into());
69        }
70        if auth_methods.is_empty() {
71            return Err("Hello: auth_methods[] is empty".into());
72        }
73        Ok(Self {
74            versions,
75            auth_methods,
76            features,
77            client_name,
78        })
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct HelloAck {
84    pub version: u8,
85    pub auth: String,
86    pub features: u32,
87    pub server: Option<String>,
88    pub topology: Option<String>,
89}
90
91impl HelloAck {
92    pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
93        let obj = object_from_payload("HelloAck", bytes)?;
94        let version = required_u8(&obj, "HelloAck", "version")?;
95        let auth = required_string(&obj, "HelloAck", "auth")?;
96        let features = optional_u32(&obj, "features").unwrap_or(0);
97        let server = optional_string(&obj, "server");
98        let topology = optional_string(&obj, "topology");
99        Ok(Self {
100            version,
101            auth,
102            features,
103            server,
104            topology,
105        })
106    }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct AuthOk {
111    pub session_id: String,
112    pub username: Option<String>,
113    pub role: Option<String>,
114    pub features: u32,
115    pub server_signature: Option<String>,
116}
117
118impl AuthOk {
119    pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
120        let obj = object_from_payload("AuthOk", bytes)?;
121        let session_id = required_string(&obj, "AuthOk", "session_id")?;
122        let username = optional_string(&obj, "username");
123        let role = optional_string(&obj, "role");
124        let features = optional_u32(&obj, "features").unwrap_or(0);
125        let server_signature = optional_string(&obj, "v");
126        Ok(Self {
127            session_id,
128            username,
129            role,
130            features,
131            server_signature,
132        })
133    }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct AuthFail {
138    pub reason: String,
139}
140
141impl AuthFail {
142    pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
143        let obj = object_from_payload("AuthFail", bytes)?;
144        Ok(Self {
145            reason: required_string(&obj, "AuthFail", "reason")?,
146        })
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub struct AuthResponseKindError {
152    pub expected: &'static str,
153    pub actual: MessageKind,
154}
155
156impl fmt::Display for AuthResponseKindError {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        write!(f, "expected {}", self.expected)
159    }
160}
161
162impl std::error::Error for AuthResponseKindError {}
163
164pub fn build_hello_payload<'a, I>(
165    versions: &[u8],
166    auth_methods: I,
167    features: u32,
168    client_name: Option<&str>,
169) -> Vec<u8>
170where
171    I: IntoIterator<Item = &'a str>,
172{
173    let mut obj = serde_json::Map::new();
174    obj.insert(
175        "versions".to_string(),
176        JsonValue::Array(
177            versions
178                .iter()
179                .map(|version| JsonValue::Number((*version).into()))
180                .collect(),
181        ),
182    );
183    obj.insert(
184        "auth_methods".to_string(),
185        JsonValue::Array(
186            auth_methods
187                .into_iter()
188                .map(|method| JsonValue::String(method.to_string()))
189                .collect(),
190        ),
191    );
192    obj.insert("features".to_string(), JsonValue::Number(features.into()));
193    if let Some(name) = client_name {
194        obj.insert(
195            "client_name".to_string(),
196            JsonValue::String(name.to_string()),
197        );
198    }
199    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
200}
201
202pub fn build_client_hello_payload<'a, I>(
203    auth_methods: I,
204    features: u32,
205    client_name: Option<&str>,
206) -> Vec<u8>
207where
208    I: IntoIterator<Item = &'a str>,
209{
210    build_hello_payload(
211        &[MAX_KNOWN_MINOR_VERSION],
212        auth_methods,
213        features,
214        client_name,
215    )
216}
217
218pub fn build_client_hello_frame<'a, I>(
219    correlation_id: u64,
220    auth_methods: I,
221    features: u32,
222    client_name: Option<&str>,
223) -> Result<Frame, BuildError>
224where
225    I: IntoIterator<Item = &'a str>,
226{
227    FrameBuilder::request(correlation_id)
228        .kind(MessageKind::Hello)
229        .payload(build_client_hello_payload(
230            auth_methods,
231            features,
232            client_name,
233        ))
234        .build()
235}
236
237pub fn choose_hello_minor_version(client_versions: &[u8]) -> Option<u8> {
238    client_versions
239        .iter()
240        .copied()
241        .filter(|version| *version > 0 && *version <= MAX_KNOWN_MINOR_VERSION)
242        .max()
243}
244
245pub fn build_hello_ack(
246    chosen_version: u8,
247    chosen_auth: &str,
248    server_features: u32,
249    topology: Option<&crate::topology::Topology>,
250) -> Vec<u8> {
251    let mut obj = serde_json::Map::new();
252    obj.insert(
253        "version".to_string(),
254        JsonValue::Number(chosen_version.into()),
255    );
256    obj.insert(
257        "auth".to_string(),
258        JsonValue::String(chosen_auth.to_string()),
259    );
260    obj.insert(
261        "features".to_string(),
262        JsonValue::Number(server_features.into()),
263    );
264    obj.insert(
265        "server".to_string(),
266        JsonValue::String(format!("reddb/{}", env!("CARGO_PKG_VERSION"))),
267    );
268    if let Some(topo) = topology {
269        obj.insert(
270            "topology".to_string(),
271            JsonValue::String(crate::topology::encode_topology_for_hello_ack(topo)),
272        );
273    }
274    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
275}
276
277pub fn build_hello_ack_frame(
278    correlation_id: u64,
279    chosen_version: u8,
280    chosen_auth: &str,
281    server_features: u32,
282    topology: Option<&crate::topology::Topology>,
283) -> Result<Frame, BuildError> {
284    FrameBuilder::reply_to(correlation_id)
285        .kind(MessageKind::HelloAck)
286        .payload(build_hello_ack(
287            chosen_version,
288            chosen_auth,
289            server_features,
290            topology,
291        ))
292        .build()
293}
294
295pub fn build_auth_response_anonymous_payload() -> Vec<u8> {
296    Vec::new()
297}
298
299pub fn build_auth_response_bearer_payload(token: &str) -> Vec<u8> {
300    let mut obj = serde_json::Map::new();
301    obj.insert("token".to_string(), JsonValue::String(token.to_string()));
302    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
303}
304
305pub fn parse_auth_response_bearer_token(payload: &[u8]) -> Result<String, String> {
306    let obj = object_from_payload("AuthResponse", payload)?;
307    required_string(&obj, "AuthResponse", "token")
308}
309
310pub fn build_auth_response_oauth_jwt_payload(jwt: &str) -> Vec<u8> {
311    let mut obj = serde_json::Map::new();
312    obj.insert("jwt".to_string(), JsonValue::String(jwt.to_string()));
313    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
314}
315
316pub fn build_auth_response_frame(
317    correlation_id: u64,
318    payload: Vec<u8>,
319) -> Result<Frame, BuildError> {
320    FrameBuilder::request(correlation_id)
321        .kind(MessageKind::AuthResponse)
322        .payload(payload)
323        .build()
324}
325
326pub fn parse_auth_response_oauth_jwt(payload: &[u8]) -> Result<String, String> {
327    let obj = object_from_payload("AuthResponse", payload)?;
328    required_string(&obj, "AuthResponse", "jwt")
329}
330
331pub fn expect_auth_response_payload<'a>(
332    kind: MessageKind,
333    payload: &'a [u8],
334    expected: &'static str,
335) -> Result<&'a [u8], AuthResponseKindError> {
336    if kind == MessageKind::AuthResponse {
337        Ok(payload)
338    } else {
339        Err(AuthResponseKindError {
340            expected,
341            actual: kind,
342        })
343    }
344}
345
346pub fn build_auth_ok_payload(
347    session_id: &str,
348    username: &str,
349    role: &str,
350    server_features: u32,
351) -> Vec<u8> {
352    let mut obj = serde_json::Map::new();
353    obj.insert(
354        "session_id".to_string(),
355        JsonValue::String(session_id.to_string()),
356    );
357    obj.insert(
358        "username".to_string(),
359        JsonValue::String(username.to_string()),
360    );
361    obj.insert("role".to_string(), JsonValue::String(role.to_string()));
362    obj.insert(
363        "features".to_string(),
364        JsonValue::Number(server_features.into()),
365    );
366    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
367}
368
369pub fn build_auth_ok_frame_from_payload(
370    correlation_id: u64,
371    payload: Vec<u8>,
372) -> Result<Frame, BuildError> {
373    FrameBuilder::reply_to(correlation_id)
374        .kind(MessageKind::AuthOk)
375        .payload(payload)
376        .build()
377}
378
379pub fn build_auth_fail_frame(correlation_id: u64, reason: &str) -> Result<Frame, BuildError> {
380    FrameBuilder::reply_to(correlation_id)
381        .kind(MessageKind::AuthFail)
382        .payload(build_auth_fail_payload(reason))
383        .build()
384}
385
386pub fn build_scram_auth_ok_payload(
387    session_id: &str,
388    username: &str,
389    role: &str,
390    server_features: u32,
391    server_signature: &[u8],
392) -> Vec<u8> {
393    let mut obj = serde_json::Map::new();
394    obj.insert(
395        "session_id".to_string(),
396        JsonValue::String(session_id.to_string()),
397    );
398    obj.insert(
399        "username".to_string(),
400        JsonValue::String(username.to_string()),
401    );
402    obj.insert("role".to_string(), JsonValue::String(role.to_string()));
403    obj.insert(
404        "features".to_string(),
405        JsonValue::Number(server_features.into()),
406    );
407    obj.insert(
408        "v".to_string(),
409        JsonValue::String(base64_std(server_signature)),
410    );
411    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
412}
413
414pub fn build_auth_fail_payload(reason: &str) -> Vec<u8> {
415    let mut obj = serde_json::Map::new();
416    obj.insert("reason".to_string(), JsonValue::String(reason.to_string()));
417    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
418}
419
420/// Parse a SCRAM client-first-message.
421///
422/// Format: `n,,n=<user>,r=<client_nonce>` (no channel binding, no authzid).
423/// Returns `(username, client_nonce, bare_message)`.
424pub fn parse_scram_client_first(payload: &[u8]) -> Result<(String, String, String), String> {
425    let s = std::str::from_utf8(payload).map_err(|_| "client-first not UTF-8".to_string())?;
426    let bare = s
427        .strip_prefix("n,,")
428        .ok_or_else(|| "client-first must start with 'n,,' (no channel binding)".to_string())?;
429    let mut user = None;
430    let mut nonce = None;
431    for part in bare.split(',') {
432        if let Some(v) = part.strip_prefix("n=") {
433            user = Some(v.to_string());
434        } else if let Some(v) = part.strip_prefix("r=") {
435            nonce = Some(v.to_string());
436        }
437    }
438    let user = user.ok_or_else(|| "missing n=<user>".to_string())?;
439    let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
440    Ok((user, nonce, bare.to_string()))
441}
442
443/// Build the SCRAM server-first-message.
444///
445/// Format: `r=<client_nonce><server_nonce>,s=<salt_b64>,i=<iter>`.
446pub fn build_scram_server_first(
447    client_nonce: &str,
448    server_nonce: &str,
449    salt: &[u8],
450    iter: u32,
451) -> String {
452    format!(
453        "r={client_nonce}{server_nonce},s={},i={iter}",
454        base64_std(salt)
455    )
456}
457
458/// Parse SCRAM client-final-message.
459///
460/// Format: `c=<channel_binding_b64>,r=<combined_nonce>,p=<proof_b64>`.
461pub fn parse_scram_client_final(payload: &[u8]) -> Result<(String, Vec<u8>, String), String> {
462    let s = std::str::from_utf8(payload).map_err(|_| "client-final not UTF-8".to_string())?;
463    let mut channel_binding = None;
464    let mut nonce = None;
465    let mut proof_b64 = None;
466    for part in s.split(',') {
467        if let Some(v) = part.strip_prefix("c=") {
468            channel_binding = Some(v.to_string());
469        } else if let Some(v) = part.strip_prefix("r=") {
470            nonce = Some(v.to_string());
471        } else if let Some(v) = part.strip_prefix("p=") {
472            proof_b64 = Some(v.to_string());
473        }
474    }
475    let channel_binding =
476        channel_binding.ok_or_else(|| "missing c=<channel-binding>".to_string())?;
477    let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
478    let proof_b64 = proof_b64.ok_or_else(|| "missing p=<proof>".to_string())?;
479    let proof = base64_std_decode(&proof_b64)
480        .ok_or_else(|| "client proof is not valid base64".to_string())?;
481    if channel_binding != "biws" {
482        return Err(format!(
483            "channel binding must be 'biws' (n,,), got '{channel_binding}'"
484        ));
485    }
486    let no_proof = format!("c={channel_binding},r={nonce}");
487    Ok((nonce, proof, no_proof))
488}
489
490const B64_ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
491
492pub fn base64_std(input: &[u8]) -> String {
493    let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
494    let chunks = input.chunks_exact(3);
495    let rem = chunks.remainder();
496    for c in chunks {
497        let n = ((c[0] as u32) << 16) | ((c[1] as u32) << 8) | (c[2] as u32);
498        out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
499        out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
500        out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
501        out.push(B64_ALPHA[(n & 0x3F) as usize] as char);
502    }
503    match rem {
504        [a] => {
505            let n = (*a as u32) << 16;
506            out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
507            out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
508            out.push('=');
509            out.push('=');
510        }
511        [a, b] => {
512            let n = ((*a as u32) << 16) | ((*b as u32) << 8);
513            out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
514            out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
515            out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
516            out.push('=');
517        }
518        _ => {}
519    }
520    out
521}
522
523pub fn base64_std_decode(input: &str) -> Option<Vec<u8>> {
524    let trimmed = input.trim_end_matches('=');
525    let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
526    let mut buf = 0u32;
527    let mut bits = 0u8;
528    for ch in trimmed.bytes() {
529        let v: u32 = match ch {
530            b'A'..=b'Z' => (ch - b'A') as u32,
531            b'a'..=b'z' => (ch - b'a' + 26) as u32,
532            b'0'..=b'9' => (ch - b'0' + 52) as u32,
533            b'+' => 62,
534            b'/' => 63,
535            _ => return None,
536        };
537        buf = (buf << 6) | v;
538        bits += 6;
539        if bits >= 8 {
540            bits -= 8;
541            out.push(((buf >> bits) & 0xFF) as u8);
542        }
543    }
544    Some(out)
545}
546
547fn object_from_payload(
548    name: &str,
549    bytes: &[u8],
550) -> Result<serde_json::Map<String, JsonValue>, String> {
551    let v: JsonValue =
552        serde_json::from_slice(bytes).map_err(|e| format!("{name}: invalid JSON: {e}"))?;
553    match v {
554        JsonValue::Object(o) => Ok(o),
555        _ => Err(format!("{name}: payload must be a JSON object")),
556    }
557}
558
559fn required_string(
560    obj: &serde_json::Map<String, JsonValue>,
561    name: &str,
562    field: &str,
563) -> Result<String, String> {
564    obj.get(field)
565        .and_then(JsonValue::as_str)
566        .map(String::from)
567        .ok_or_else(|| format!("{name}: missing {field} string"))
568}
569
570fn optional_string(obj: &serde_json::Map<String, JsonValue>, field: &str) -> Option<String> {
571    obj.get(field).and_then(JsonValue::as_str).map(String::from)
572}
573
574fn required_u8(
575    obj: &serde_json::Map<String, JsonValue>,
576    name: &str,
577    field: &str,
578) -> Result<u8, String> {
579    let n = obj
580        .get(field)
581        .and_then(JsonValue::as_u64)
582        .ok_or_else(|| format!("{name}: missing {field} number"))?;
583    u8::try_from(n).map_err(|_| format!("{name}: {field} out of range for u8"))
584}
585
586fn optional_u32(obj: &serde_json::Map<String, JsonValue>, field: &str) -> Option<u32> {
587    obj.get(field)
588        .and_then(JsonValue::as_u64)
589        .and_then(|n| u32::try_from(n).ok())
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use crate::topology::{Endpoint, ReplicaInfo, Topology};
596
597    #[test]
598    fn hello_parses_client_payload() {
599        let payload =
600            br#"{"versions":[1],"auth_methods":["bearer"],"features":1,"client_name":"x"}"#;
601        let hello = Hello::from_payload(payload).unwrap();
602        assert_eq!(hello.versions, vec![1]);
603        assert_eq!(hello.auth_methods, vec!["bearer"]);
604        assert_eq!(hello.features, 1);
605        assert_eq!(hello.client_name.as_deref(), Some("x"));
606    }
607
608    #[test]
609    fn hello_builds_client_payload() {
610        let bytes = build_hello_payload(&[1], ["anonymous", "bearer"], 7, Some("client"));
611        let hello = Hello::from_payload(&bytes).unwrap();
612        assert_eq!(hello.versions, vec![1]);
613        assert_eq!(hello.auth_methods, vec!["anonymous", "bearer"]);
614        assert_eq!(hello.features, 7);
615        assert_eq!(hello.client_name.as_deref(), Some("client"));
616    }
617
618    #[test]
619    fn client_hello_payload_uses_current_minor_version() {
620        let bytes = build_client_hello_payload(["anonymous"], 0, Some("client"));
621        let hello = Hello::from_payload(&bytes).unwrap();
622        assert_eq!(hello.versions, vec![MAX_KNOWN_MINOR_VERSION]);
623        assert_eq!(hello.auth_methods, vec!["anonymous"]);
624        assert_eq!(hello.client_name.as_deref(), Some("client"));
625    }
626
627    #[test]
628    fn hello_minor_version_negotiation_picks_highest_supported_nonzero_version() {
629        assert_eq!(
630            choose_hello_minor_version(&[0, MAX_KNOWN_MINOR_VERSION]),
631            Some(MAX_KNOWN_MINOR_VERSION)
632        );
633        assert_eq!(
634            choose_hello_minor_version(&[
635                MAX_KNOWN_MINOR_VERSION.saturating_add(1),
636                MAX_KNOWN_MINOR_VERSION,
637                1,
638            ]),
639            Some(MAX_KNOWN_MINOR_VERSION)
640        );
641        assert_eq!(choose_hello_minor_version(&[]), None);
642        assert_eq!(choose_hello_minor_version(&[0]), None);
643        assert_eq!(
644            choose_hello_minor_version(&[MAX_KNOWN_MINOR_VERSION.saturating_add(1)]),
645            None
646        );
647    }
648
649    #[test]
650    fn hello_requires_versions_and_auth_methods() {
651        assert!(Hello::from_payload(br#"{"auth_methods":["bearer"]}"#).is_err());
652        assert!(Hello::from_payload(br#"{"versions":[1]}"#).is_err());
653    }
654
655    #[test]
656    fn hello_ack_can_embed_topology() {
657        let topology = Topology {
658            epoch: 7,
659            primary: Endpoint {
660                addr: "127.0.0.1:5050".to_string(),
661                region: "local".to_string(),
662            },
663            replicas: vec![ReplicaInfo {
664                addr: "127.0.0.1:5051".to_string(),
665                region: "local".to_string(),
666                healthy: true,
667                lag_ms: 3,
668                last_applied_lsn: 9,
669                rebootstrapping: false,
670            }],
671        };
672        let bytes = build_hello_ack(1, "bearer", 0, Some(&topology));
673        let json: JsonValue = serde_json::from_slice(&bytes).unwrap();
674        assert_eq!(json["version"], 1);
675        assert!(json["topology"].as_str().is_some());
676        let ack = HelloAck::from_payload(&bytes).unwrap();
677        assert_eq!(ack.version, 1);
678        assert_eq!(ack.auth, "bearer");
679        assert_eq!(ack.features, 0);
680        assert!(ack.topology.is_some());
681    }
682
683    #[test]
684    fn auth_response_builders_are_pinned() {
685        assert!(build_auth_response_anonymous_payload().is_empty());
686
687        let bearer: JsonValue =
688            serde_json::from_slice(&build_auth_response_bearer_payload("token")).unwrap();
689        assert_eq!(bearer["token"], "token");
690
691        let oauth: JsonValue =
692            serde_json::from_slice(&build_auth_response_oauth_jwt_payload("jwt")).unwrap();
693        assert_eq!(oauth["jwt"], "jwt");
694    }
695
696    #[test]
697    fn auth_response_kind_expectation_is_pinned() {
698        assert_eq!(
699            expect_auth_response_payload(MessageKind::AuthResponse, b"proof", "AuthResponse")
700                .unwrap(),
701            b"proof"
702        );
703
704        let err =
705            expect_auth_response_payload(MessageKind::Hello, b"{}", "AuthResponse").unwrap_err();
706        assert_eq!(err.actual, MessageKind::Hello);
707        assert_eq!(err.to_string(), "expected AuthResponse");
708    }
709
710    #[test]
711    fn auth_ok_and_fail_parse_payloads() {
712        let ok = AuthOk::from_payload(&build_auth_ok_payload("s1", "alice", "admin", 3)).unwrap();
713        assert_eq!(ok.session_id, "s1");
714        assert_eq!(ok.username.as_deref(), Some("alice"));
715        assert_eq!(ok.role.as_deref(), Some("admin"));
716        assert_eq!(ok.features, 3);
717        assert_eq!(ok.server_signature.as_deref(), None);
718
719        let scram_ok = AuthOk::from_payload(&build_scram_auth_ok_payload(
720            "s1", "alice", "admin", 3, b"sig",
721        ))
722        .unwrap();
723        assert_eq!(scram_ok.server_signature.as_deref(), Some("c2ln"));
724
725        let fail = AuthFail::from_payload(&build_auth_fail_payload("nope")).unwrap();
726        assert_eq!(fail.reason, "nope");
727    }
728
729    #[test]
730    fn handshake_frame_builders_pin_message_kinds() {
731        let hello_ack = build_hello_ack_frame(7, 1, "anonymous", 3, None).unwrap();
732        assert_eq!(hello_ack.kind, MessageKind::HelloAck);
733        assert_eq!(hello_ack.correlation_id, 7);
734        assert_eq!(
735            HelloAck::from_payload(&hello_ack.payload).unwrap().auth,
736            "anonymous"
737        );
738
739        let auth_ok =
740            build_auth_ok_frame_from_payload(8, build_auth_ok_payload("s1", "alice", "admin", 3))
741                .unwrap();
742        assert_eq!(auth_ok.kind, MessageKind::AuthOk);
743        assert_eq!(auth_ok.correlation_id, 8);
744        assert_eq!(
745            AuthOk::from_payload(&auth_ok.payload)
746                .unwrap()
747                .username
748                .as_deref(),
749            Some("alice")
750        );
751
752        let auth_fail = build_auth_fail_frame(9, "nope").unwrap();
753        assert_eq!(auth_fail.kind, MessageKind::AuthFail);
754        assert_eq!(auth_fail.correlation_id, 9);
755        assert_eq!(
756            AuthFail::from_payload(&auth_fail.payload).unwrap().reason,
757            "nope"
758        );
759    }
760
761    #[test]
762    fn auth_response_parsers_are_pinned() {
763        assert_eq!(
764            parse_auth_response_bearer_token(&build_auth_response_bearer_payload("token")).unwrap(),
765            "token"
766        );
767        assert_eq!(
768            parse_auth_response_oauth_jwt(&build_auth_response_oauth_jwt_payload("jwt")).unwrap(),
769            "jwt"
770        );
771        assert!(parse_auth_response_bearer_token(br#"{"jwt":"x"}"#).is_err());
772    }
773
774    #[test]
775    fn scram_wire_messages_round_trip() {
776        let (user, nonce, bare) = parse_scram_client_first(b"n,,n=alice,r=client").unwrap();
777        assert_eq!(user, "alice");
778        assert_eq!(nonce, "client");
779        assert_eq!(bare, "n=alice,r=client");
780
781        let server_first = build_scram_server_first("client", "server", b"salt", 4096);
782        assert_eq!(server_first, "r=clientserver,s=c2FsdA==,i=4096");
783
784        let proof = base64_std(b"proof");
785        let final_msg = format!("c=biws,r=clientserver,p={proof}");
786        let (combined, decoded_proof, without_proof) =
787            parse_scram_client_final(final_msg.as_bytes()).unwrap();
788        assert_eq!(combined, "clientserver");
789        assert_eq!(decoded_proof, b"proof");
790        assert_eq!(without_proof, "c=biws,r=clientserver");
791    }
792}