1use serde_json::Value as JsonValue;
8use std::fmt;
9
10use super::{BuildError, Frame, FrameBuilder, MessageKind, MAX_KNOWN_MINOR_VERSION};
11
12pub const SUPPORTED_METHODS: &[&str] = &["bearer", "anonymous", "scram-sha-256", "oauth-jwt"];
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Hello {
17 pub versions: Vec<u8>,
18 pub auth_methods: Vec<String>,
19 pub features: u32,
20 pub client_name: Option<String>,
21}
22
23impl Hello {
24 pub fn to_payload(&self) -> Vec<u8> {
25 build_hello_payload(
26 &self.versions,
27 self.auth_methods.iter().map(String::as_str),
28 self.features,
29 self.client_name.as_deref(),
30 )
31 }
32
33 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
34 let v: JsonValue =
35 serde_json::from_slice(bytes).map_err(|e| format!("Hello: invalid JSON: {e}"))?;
36 let obj = match v {
37 JsonValue::Object(o) => o,
38 _ => return Err("Hello: payload must be a JSON object".into()),
39 };
40 let versions: Vec<u8> = obj
41 .get("versions")
42 .and_then(|v| v.as_array())
43 .map(|arr| {
44 arr.iter()
45 .filter_map(|n| n.as_u64().map(|u| u as u8))
46 .collect()
47 })
48 .unwrap_or_default();
49 let auth_methods: Vec<String> = obj
50 .get("auth_methods")
51 .and_then(|v| v.as_array())
52 .map(|arr| {
53 arr.iter()
54 .filter_map(|s| s.as_str().map(String::from))
55 .collect()
56 })
57 .unwrap_or_default();
58 let features = obj
59 .get("features")
60 .and_then(|v| v.as_u64())
61 .map(|u| u as u32)
62 .unwrap_or(0);
63 let client_name = obj
64 .get("client_name")
65 .and_then(|v| v.as_str())
66 .map(String::from);
67 if versions.is_empty() {
68 return Err("Hello: versions[] is empty".into());
69 }
70 if auth_methods.is_empty() {
71 return Err("Hello: auth_methods[] is empty".into());
72 }
73 Ok(Self {
74 versions,
75 auth_methods,
76 features,
77 client_name,
78 })
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct HelloAck {
84 pub version: u8,
85 pub auth: String,
86 pub features: u32,
87 pub server: Option<String>,
88 pub topology: Option<String>,
89}
90
91impl HelloAck {
92 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
93 let obj = object_from_payload("HelloAck", bytes)?;
94 let version = required_u8(&obj, "HelloAck", "version")?;
95 let auth = required_string(&obj, "HelloAck", "auth")?;
96 let features = optional_u32(&obj, "features").unwrap_or(0);
97 let server = optional_string(&obj, "server");
98 let topology = optional_string(&obj, "topology");
99 Ok(Self {
100 version,
101 auth,
102 features,
103 server,
104 topology,
105 })
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct AuthOk {
111 pub session_id: String,
112 pub username: Option<String>,
113 pub role: Option<String>,
114 pub features: u32,
115 pub server_signature: Option<String>,
116}
117
118impl AuthOk {
119 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
120 let obj = object_from_payload("AuthOk", bytes)?;
121 let session_id = required_string(&obj, "AuthOk", "session_id")?;
122 let username = optional_string(&obj, "username");
123 let role = optional_string(&obj, "role");
124 let features = optional_u32(&obj, "features").unwrap_or(0);
125 let server_signature = optional_string(&obj, "v");
126 Ok(Self {
127 session_id,
128 username,
129 role,
130 features,
131 server_signature,
132 })
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct AuthFail {
138 pub reason: String,
139}
140
141impl AuthFail {
142 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
143 let obj = object_from_payload("AuthFail", bytes)?;
144 Ok(Self {
145 reason: required_string(&obj, "AuthFail", "reason")?,
146 })
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub struct AuthResponseKindError {
152 pub expected: &'static str,
153 pub actual: MessageKind,
154}
155
156impl fmt::Display for AuthResponseKindError {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(f, "expected {}", self.expected)
159 }
160}
161
162impl std::error::Error for AuthResponseKindError {}
163
164pub fn build_hello_payload<'a, I>(
165 versions: &[u8],
166 auth_methods: I,
167 features: u32,
168 client_name: Option<&str>,
169) -> Vec<u8>
170where
171 I: IntoIterator<Item = &'a str>,
172{
173 let mut obj = serde_json::Map::new();
174 obj.insert(
175 "versions".to_string(),
176 JsonValue::Array(
177 versions
178 .iter()
179 .map(|version| JsonValue::Number((*version).into()))
180 .collect(),
181 ),
182 );
183 obj.insert(
184 "auth_methods".to_string(),
185 JsonValue::Array(
186 auth_methods
187 .into_iter()
188 .map(|method| JsonValue::String(method.to_string()))
189 .collect(),
190 ),
191 );
192 obj.insert("features".to_string(), JsonValue::Number(features.into()));
193 if let Some(name) = client_name {
194 obj.insert(
195 "client_name".to_string(),
196 JsonValue::String(name.to_string()),
197 );
198 }
199 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
200}
201
202pub fn build_client_hello_payload<'a, I>(
203 auth_methods: I,
204 features: u32,
205 client_name: Option<&str>,
206) -> Vec<u8>
207where
208 I: IntoIterator<Item = &'a str>,
209{
210 build_hello_payload(
211 &[MAX_KNOWN_MINOR_VERSION],
212 auth_methods,
213 features,
214 client_name,
215 )
216}
217
218pub fn build_client_hello_frame<'a, I>(
219 correlation_id: u64,
220 auth_methods: I,
221 features: u32,
222 client_name: Option<&str>,
223) -> Result<Frame, BuildError>
224where
225 I: IntoIterator<Item = &'a str>,
226{
227 FrameBuilder::request(correlation_id)
228 .kind(MessageKind::Hello)
229 .payload(build_client_hello_payload(
230 auth_methods,
231 features,
232 client_name,
233 ))
234 .build()
235}
236
237pub fn choose_hello_minor_version(client_versions: &[u8]) -> Option<u8> {
238 client_versions
239 .iter()
240 .copied()
241 .filter(|version| *version > 0 && *version <= MAX_KNOWN_MINOR_VERSION)
242 .max()
243}
244
245pub fn build_hello_ack(
246 chosen_version: u8,
247 chosen_auth: &str,
248 server_features: u32,
249 topology: Option<&crate::topology::Topology>,
250) -> Vec<u8> {
251 let mut obj = serde_json::Map::new();
252 obj.insert(
253 "version".to_string(),
254 JsonValue::Number(chosen_version.into()),
255 );
256 obj.insert(
257 "auth".to_string(),
258 JsonValue::String(chosen_auth.to_string()),
259 );
260 obj.insert(
261 "features".to_string(),
262 JsonValue::Number(server_features.into()),
263 );
264 obj.insert(
265 "server".to_string(),
266 JsonValue::String(format!("reddb/{}", env!("CARGO_PKG_VERSION"))),
267 );
268 if let Some(topo) = topology {
269 obj.insert(
270 "topology".to_string(),
271 JsonValue::String(crate::topology::encode_topology_for_hello_ack(topo)),
272 );
273 }
274 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
275}
276
277pub fn build_hello_ack_frame(
278 correlation_id: u64,
279 chosen_version: u8,
280 chosen_auth: &str,
281 server_features: u32,
282 topology: Option<&crate::topology::Topology>,
283) -> Result<Frame, BuildError> {
284 FrameBuilder::reply_to(correlation_id)
285 .kind(MessageKind::HelloAck)
286 .payload(build_hello_ack(
287 chosen_version,
288 chosen_auth,
289 server_features,
290 topology,
291 ))
292 .build()
293}
294
295pub fn build_auth_response_anonymous_payload() -> Vec<u8> {
296 Vec::new()
297}
298
299pub fn build_auth_response_bearer_payload(token: &str) -> Vec<u8> {
300 let mut obj = serde_json::Map::new();
301 obj.insert("token".to_string(), JsonValue::String(token.to_string()));
302 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
303}
304
305pub fn parse_auth_response_bearer_token(payload: &[u8]) -> Result<String, String> {
306 let obj = object_from_payload("AuthResponse", payload)?;
307 required_string(&obj, "AuthResponse", "token")
308}
309
310pub fn build_auth_response_oauth_jwt_payload(jwt: &str) -> Vec<u8> {
311 let mut obj = serde_json::Map::new();
312 obj.insert("jwt".to_string(), JsonValue::String(jwt.to_string()));
313 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
314}
315
316pub fn build_auth_response_frame(
317 correlation_id: u64,
318 payload: Vec<u8>,
319) -> Result<Frame, BuildError> {
320 FrameBuilder::request(correlation_id)
321 .kind(MessageKind::AuthResponse)
322 .payload(payload)
323 .build()
324}
325
326pub fn parse_auth_response_oauth_jwt(payload: &[u8]) -> Result<String, String> {
327 let obj = object_from_payload("AuthResponse", payload)?;
328 required_string(&obj, "AuthResponse", "jwt")
329}
330
331pub fn expect_auth_response_payload<'a>(
332 kind: MessageKind,
333 payload: &'a [u8],
334 expected: &'static str,
335) -> Result<&'a [u8], AuthResponseKindError> {
336 if kind == MessageKind::AuthResponse {
337 Ok(payload)
338 } else {
339 Err(AuthResponseKindError {
340 expected,
341 actual: kind,
342 })
343 }
344}
345
346pub fn build_auth_ok_payload(
347 session_id: &str,
348 username: &str,
349 role: &str,
350 server_features: u32,
351) -> Vec<u8> {
352 let mut obj = serde_json::Map::new();
353 obj.insert(
354 "session_id".to_string(),
355 JsonValue::String(session_id.to_string()),
356 );
357 obj.insert(
358 "username".to_string(),
359 JsonValue::String(username.to_string()),
360 );
361 obj.insert("role".to_string(), JsonValue::String(role.to_string()));
362 obj.insert(
363 "features".to_string(),
364 JsonValue::Number(server_features.into()),
365 );
366 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
367}
368
369pub fn build_auth_ok_frame_from_payload(
370 correlation_id: u64,
371 payload: Vec<u8>,
372) -> Result<Frame, BuildError> {
373 FrameBuilder::reply_to(correlation_id)
374 .kind(MessageKind::AuthOk)
375 .payload(payload)
376 .build()
377}
378
379pub fn build_auth_fail_frame(correlation_id: u64, reason: &str) -> Result<Frame, BuildError> {
380 FrameBuilder::reply_to(correlation_id)
381 .kind(MessageKind::AuthFail)
382 .payload(build_auth_fail_payload(reason))
383 .build()
384}
385
386pub fn build_scram_auth_ok_payload(
387 session_id: &str,
388 username: &str,
389 role: &str,
390 server_features: u32,
391 server_signature: &[u8],
392) -> Vec<u8> {
393 let mut obj = serde_json::Map::new();
394 obj.insert(
395 "session_id".to_string(),
396 JsonValue::String(session_id.to_string()),
397 );
398 obj.insert(
399 "username".to_string(),
400 JsonValue::String(username.to_string()),
401 );
402 obj.insert("role".to_string(), JsonValue::String(role.to_string()));
403 obj.insert(
404 "features".to_string(),
405 JsonValue::Number(server_features.into()),
406 );
407 obj.insert(
408 "v".to_string(),
409 JsonValue::String(base64_std(server_signature)),
410 );
411 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
412}
413
414pub fn build_auth_fail_payload(reason: &str) -> Vec<u8> {
415 let mut obj = serde_json::Map::new();
416 obj.insert("reason".to_string(), JsonValue::String(reason.to_string()));
417 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
418}
419
420pub fn parse_scram_client_first(payload: &[u8]) -> Result<(String, String, String), String> {
425 let s = std::str::from_utf8(payload).map_err(|_| "client-first not UTF-8".to_string())?;
426 let bare = s
427 .strip_prefix("n,,")
428 .ok_or_else(|| "client-first must start with 'n,,' (no channel binding)".to_string())?;
429 let mut user = None;
430 let mut nonce = None;
431 for part in bare.split(',') {
432 if let Some(v) = part.strip_prefix("n=") {
433 user = Some(v.to_string());
434 } else if let Some(v) = part.strip_prefix("r=") {
435 nonce = Some(v.to_string());
436 }
437 }
438 let user = user.ok_or_else(|| "missing n=<user>".to_string())?;
439 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
440 Ok((user, nonce, bare.to_string()))
441}
442
443pub fn build_scram_server_first(
447 client_nonce: &str,
448 server_nonce: &str,
449 salt: &[u8],
450 iter: u32,
451) -> String {
452 format!(
453 "r={client_nonce}{server_nonce},s={},i={iter}",
454 base64_std(salt)
455 )
456}
457
458pub fn parse_scram_client_final(payload: &[u8]) -> Result<(String, Vec<u8>, String), String> {
462 let s = std::str::from_utf8(payload).map_err(|_| "client-final not UTF-8".to_string())?;
463 let mut channel_binding = None;
464 let mut nonce = None;
465 let mut proof_b64 = None;
466 for part in s.split(',') {
467 if let Some(v) = part.strip_prefix("c=") {
468 channel_binding = Some(v.to_string());
469 } else if let Some(v) = part.strip_prefix("r=") {
470 nonce = Some(v.to_string());
471 } else if let Some(v) = part.strip_prefix("p=") {
472 proof_b64 = Some(v.to_string());
473 }
474 }
475 let channel_binding =
476 channel_binding.ok_or_else(|| "missing c=<channel-binding>".to_string())?;
477 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
478 let proof_b64 = proof_b64.ok_or_else(|| "missing p=<proof>".to_string())?;
479 let proof = base64_std_decode(&proof_b64)
480 .ok_or_else(|| "client proof is not valid base64".to_string())?;
481 if channel_binding != "biws" {
482 return Err(format!(
483 "channel binding must be 'biws' (n,,), got '{channel_binding}'"
484 ));
485 }
486 let no_proof = format!("c={channel_binding},r={nonce}");
487 Ok((nonce, proof, no_proof))
488}
489
490const B64_ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
491
492pub fn base64_std(input: &[u8]) -> String {
493 let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
494 let chunks = input.chunks_exact(3);
495 let rem = chunks.remainder();
496 for c in chunks {
497 let n = ((c[0] as u32) << 16) | ((c[1] as u32) << 8) | (c[2] as u32);
498 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
499 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
500 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
501 out.push(B64_ALPHA[(n & 0x3F) as usize] as char);
502 }
503 match rem {
504 [a] => {
505 let n = (*a as u32) << 16;
506 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
507 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
508 out.push('=');
509 out.push('=');
510 }
511 [a, b] => {
512 let n = ((*a as u32) << 16) | ((*b as u32) << 8);
513 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
514 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
515 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
516 out.push('=');
517 }
518 _ => {}
519 }
520 out
521}
522
523pub fn base64_std_decode(input: &str) -> Option<Vec<u8>> {
524 let trimmed = input.trim_end_matches('=');
525 let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
526 let mut buf = 0u32;
527 let mut bits = 0u8;
528 for ch in trimmed.bytes() {
529 let v: u32 = match ch {
530 b'A'..=b'Z' => (ch - b'A') as u32,
531 b'a'..=b'z' => (ch - b'a' + 26) as u32,
532 b'0'..=b'9' => (ch - b'0' + 52) as u32,
533 b'+' => 62,
534 b'/' => 63,
535 _ => return None,
536 };
537 buf = (buf << 6) | v;
538 bits += 6;
539 if bits >= 8 {
540 bits -= 8;
541 out.push(((buf >> bits) & 0xFF) as u8);
542 }
543 }
544 Some(out)
545}
546
547fn object_from_payload(
548 name: &str,
549 bytes: &[u8],
550) -> Result<serde_json::Map<String, JsonValue>, String> {
551 let v: JsonValue =
552 serde_json::from_slice(bytes).map_err(|e| format!("{name}: invalid JSON: {e}"))?;
553 match v {
554 JsonValue::Object(o) => Ok(o),
555 _ => Err(format!("{name}: payload must be a JSON object")),
556 }
557}
558
559fn required_string(
560 obj: &serde_json::Map<String, JsonValue>,
561 name: &str,
562 field: &str,
563) -> Result<String, String> {
564 obj.get(field)
565 .and_then(JsonValue::as_str)
566 .map(String::from)
567 .ok_or_else(|| format!("{name}: missing {field} string"))
568}
569
570fn optional_string(obj: &serde_json::Map<String, JsonValue>, field: &str) -> Option<String> {
571 obj.get(field).and_then(JsonValue::as_str).map(String::from)
572}
573
574fn required_u8(
575 obj: &serde_json::Map<String, JsonValue>,
576 name: &str,
577 field: &str,
578) -> Result<u8, String> {
579 let n = obj
580 .get(field)
581 .and_then(JsonValue::as_u64)
582 .ok_or_else(|| format!("{name}: missing {field} number"))?;
583 u8::try_from(n).map_err(|_| format!("{name}: {field} out of range for u8"))
584}
585
586fn optional_u32(obj: &serde_json::Map<String, JsonValue>, field: &str) -> Option<u32> {
587 obj.get(field)
588 .and_then(JsonValue::as_u64)
589 .and_then(|n| u32::try_from(n).ok())
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use crate::topology::{Endpoint, ReplicaInfo, Topology};
596
597 #[test]
598 fn hello_parses_client_payload() {
599 let payload =
600 br#"{"versions":[1],"auth_methods":["bearer"],"features":1,"client_name":"x"}"#;
601 let hello = Hello::from_payload(payload).unwrap();
602 assert_eq!(hello.versions, vec![1]);
603 assert_eq!(hello.auth_methods, vec!["bearer"]);
604 assert_eq!(hello.features, 1);
605 assert_eq!(hello.client_name.as_deref(), Some("x"));
606 }
607
608 #[test]
609 fn hello_builds_client_payload() {
610 let bytes = build_hello_payload(&[1], ["anonymous", "bearer"], 7, Some("client"));
611 let hello = Hello::from_payload(&bytes).unwrap();
612 assert_eq!(hello.versions, vec![1]);
613 assert_eq!(hello.auth_methods, vec!["anonymous", "bearer"]);
614 assert_eq!(hello.features, 7);
615 assert_eq!(hello.client_name.as_deref(), Some("client"));
616 }
617
618 #[test]
619 fn client_hello_payload_uses_current_minor_version() {
620 let bytes = build_client_hello_payload(["anonymous"], 0, Some("client"));
621 let hello = Hello::from_payload(&bytes).unwrap();
622 assert_eq!(hello.versions, vec![MAX_KNOWN_MINOR_VERSION]);
623 assert_eq!(hello.auth_methods, vec!["anonymous"]);
624 assert_eq!(hello.client_name.as_deref(), Some("client"));
625 }
626
627 #[test]
628 fn hello_minor_version_negotiation_picks_highest_supported_nonzero_version() {
629 assert_eq!(
630 choose_hello_minor_version(&[0, MAX_KNOWN_MINOR_VERSION]),
631 Some(MAX_KNOWN_MINOR_VERSION)
632 );
633 assert_eq!(
634 choose_hello_minor_version(&[
635 MAX_KNOWN_MINOR_VERSION.saturating_add(1),
636 MAX_KNOWN_MINOR_VERSION,
637 1,
638 ]),
639 Some(MAX_KNOWN_MINOR_VERSION)
640 );
641 assert_eq!(choose_hello_minor_version(&[]), None);
642 assert_eq!(choose_hello_minor_version(&[0]), None);
643 assert_eq!(
644 choose_hello_minor_version(&[MAX_KNOWN_MINOR_VERSION.saturating_add(1)]),
645 None
646 );
647 }
648
649 #[test]
650 fn hello_requires_versions_and_auth_methods() {
651 assert!(Hello::from_payload(br#"{"auth_methods":["bearer"]}"#).is_err());
652 assert!(Hello::from_payload(br#"{"versions":[1]}"#).is_err());
653 }
654
655 #[test]
656 fn hello_ack_can_embed_topology() {
657 let topology = Topology {
658 epoch: 7,
659 primary: Endpoint {
660 addr: "127.0.0.1:5050".to_string(),
661 region: "local".to_string(),
662 },
663 replicas: vec![ReplicaInfo {
664 addr: "127.0.0.1:5051".to_string(),
665 region: "local".to_string(),
666 healthy: true,
667 lag_ms: 3,
668 last_applied_lsn: 9,
669 rebootstrapping: false,
670 }],
671 };
672 let bytes = build_hello_ack(1, "bearer", 0, Some(&topology));
673 let json: JsonValue = serde_json::from_slice(&bytes).unwrap();
674 assert_eq!(json["version"], 1);
675 assert!(json["topology"].as_str().is_some());
676 let ack = HelloAck::from_payload(&bytes).unwrap();
677 assert_eq!(ack.version, 1);
678 assert_eq!(ack.auth, "bearer");
679 assert_eq!(ack.features, 0);
680 assert!(ack.topology.is_some());
681 }
682
683 #[test]
684 fn auth_response_builders_are_pinned() {
685 assert!(build_auth_response_anonymous_payload().is_empty());
686
687 let bearer: JsonValue =
688 serde_json::from_slice(&build_auth_response_bearer_payload("token")).unwrap();
689 assert_eq!(bearer["token"], "token");
690
691 let oauth: JsonValue =
692 serde_json::from_slice(&build_auth_response_oauth_jwt_payload("jwt")).unwrap();
693 assert_eq!(oauth["jwt"], "jwt");
694 }
695
696 #[test]
697 fn auth_response_kind_expectation_is_pinned() {
698 assert_eq!(
699 expect_auth_response_payload(MessageKind::AuthResponse, b"proof", "AuthResponse")
700 .unwrap(),
701 b"proof"
702 );
703
704 let err =
705 expect_auth_response_payload(MessageKind::Hello, b"{}", "AuthResponse").unwrap_err();
706 assert_eq!(err.actual, MessageKind::Hello);
707 assert_eq!(err.to_string(), "expected AuthResponse");
708 }
709
710 #[test]
711 fn auth_ok_and_fail_parse_payloads() {
712 let ok = AuthOk::from_payload(&build_auth_ok_payload("s1", "alice", "admin", 3)).unwrap();
713 assert_eq!(ok.session_id, "s1");
714 assert_eq!(ok.username.as_deref(), Some("alice"));
715 assert_eq!(ok.role.as_deref(), Some("admin"));
716 assert_eq!(ok.features, 3);
717 assert_eq!(ok.server_signature.as_deref(), None);
718
719 let scram_ok = AuthOk::from_payload(&build_scram_auth_ok_payload(
720 "s1", "alice", "admin", 3, b"sig",
721 ))
722 .unwrap();
723 assert_eq!(scram_ok.server_signature.as_deref(), Some("c2ln"));
724
725 let fail = AuthFail::from_payload(&build_auth_fail_payload("nope")).unwrap();
726 assert_eq!(fail.reason, "nope");
727 }
728
729 #[test]
730 fn handshake_frame_builders_pin_message_kinds() {
731 let hello_ack = build_hello_ack_frame(7, 1, "anonymous", 3, None).unwrap();
732 assert_eq!(hello_ack.kind, MessageKind::HelloAck);
733 assert_eq!(hello_ack.correlation_id, 7);
734 assert_eq!(
735 HelloAck::from_payload(&hello_ack.payload).unwrap().auth,
736 "anonymous"
737 );
738
739 let auth_ok =
740 build_auth_ok_frame_from_payload(8, build_auth_ok_payload("s1", "alice", "admin", 3))
741 .unwrap();
742 assert_eq!(auth_ok.kind, MessageKind::AuthOk);
743 assert_eq!(auth_ok.correlation_id, 8);
744 assert_eq!(
745 AuthOk::from_payload(&auth_ok.payload)
746 .unwrap()
747 .username
748 .as_deref(),
749 Some("alice")
750 );
751
752 let auth_fail = build_auth_fail_frame(9, "nope").unwrap();
753 assert_eq!(auth_fail.kind, MessageKind::AuthFail);
754 assert_eq!(auth_fail.correlation_id, 9);
755 assert_eq!(
756 AuthFail::from_payload(&auth_fail.payload).unwrap().reason,
757 "nope"
758 );
759 }
760
761 #[test]
762 fn auth_response_parsers_are_pinned() {
763 assert_eq!(
764 parse_auth_response_bearer_token(&build_auth_response_bearer_payload("token")).unwrap(),
765 "token"
766 );
767 assert_eq!(
768 parse_auth_response_oauth_jwt(&build_auth_response_oauth_jwt_payload("jwt")).unwrap(),
769 "jwt"
770 );
771 assert!(parse_auth_response_bearer_token(br#"{"jwt":"x"}"#).is_err());
772 }
773
774 #[test]
775 fn scram_wire_messages_round_trip() {
776 let (user, nonce, bare) = parse_scram_client_first(b"n,,n=alice,r=client").unwrap();
777 assert_eq!(user, "alice");
778 assert_eq!(nonce, "client");
779 assert_eq!(bare, "n=alice,r=client");
780
781 let server_first = build_scram_server_first("client", "server", b"salt", 4096);
782 assert_eq!(server_first, "r=clientserver,s=c2FsdA==,i=4096");
783
784 let proof = base64_std(b"proof");
785 let final_msg = format!("c=biws,r=clientserver,p={proof}");
786 let (combined, decoded_proof, without_proof) =
787 parse_scram_client_final(final_msg.as_bytes()).unwrap();
788 assert_eq!(combined, "clientserver");
789 assert_eq!(decoded_proof, b"proof");
790 assert_eq!(without_proof, "c=biws,r=clientserver");
791 }
792}