1use crate::auth::store::AuthStore;
13use crate::auth::Role;
14use crate::serde_json::{self, Value as JsonValue};
15
16pub const SUPPORTED_METHODS: &[&str] = &["bearer", "anonymous", "scram-sha-256", "oauth-jwt"];
27
28#[derive(Debug, Clone)]
30pub enum AuthOutcome {
31 Authenticated {
33 username: String,
34 role: Role,
35 session_id: String,
36 },
37 Refused(String),
39}
40
41#[derive(Debug, Clone)]
43pub struct Hello {
44 pub versions: Vec<u8>,
45 pub auth_methods: Vec<String>,
46 pub features: u32,
47 pub client_name: Option<String>,
48}
49
50impl Hello {
51 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
52 let v: JsonValue =
53 serde_json::from_slice(bytes).map_err(|e| format!("Hello: invalid JSON: {e}"))?;
54 let obj = match v {
55 JsonValue::Object(o) => o,
56 _ => return Err("Hello: payload must be a JSON object".into()),
57 };
58 let versions: Vec<u8> = obj
59 .get("versions")
60 .and_then(|v| v.as_array())
61 .map(|arr| {
62 arr.iter()
63 .filter_map(|n| n.as_f64().map(|f| f as u8))
64 .collect()
65 })
66 .unwrap_or_default();
67 let auth_methods: Vec<String> = obj
68 .get("auth_methods")
69 .and_then(|v| v.as_array())
70 .map(|arr| {
71 arr.iter()
72 .filter_map(|s| s.as_str().map(String::from))
73 .collect()
74 })
75 .unwrap_or_default();
76 let features = obj
77 .get("features")
78 .and_then(|v| v.as_f64())
79 .map(|f| f as u32)
80 .unwrap_or(0);
81 let client_name = obj
82 .get("client_name")
83 .and_then(|v| v.as_str())
84 .map(String::from);
85 if versions.is_empty() {
86 return Err("Hello: versions[] is empty".into());
87 }
88 if auth_methods.is_empty() {
89 return Err("Hello: auth_methods[] is empty".into());
90 }
91 Ok(Self {
92 versions,
93 auth_methods,
94 features,
95 client_name,
96 })
97 }
98}
99
100pub fn build_hello_ack(
117 chosen_version: u8,
118 chosen_auth: &str,
119 server_features: u32,
120 topology: Option<&reddb_wire::topology::Topology>,
121) -> Vec<u8> {
122 use crate::json_field::SerializedJsonField;
123 let mut obj = crate::serde_json::Map::new();
136 obj.insert(
137 "version".to_string(),
138 JsonValue::Number(chosen_version as f64),
139 );
140 obj.insert(
141 "auth".to_string(),
142 SerializedJsonField::tainted(chosen_auth),
143 );
144 obj.insert(
145 "features".to_string(),
146 JsonValue::Number(server_features as f64),
147 );
148 let server_field = format!("reddb/{}", env!("CARGO_PKG_VERSION"));
149 obj.insert(
150 "server".to_string(),
151 SerializedJsonField::tainted(&server_field),
152 );
153 if let Some(topo) = topology {
154 obj.insert(
155 "topology".to_string(),
156 SerializedJsonField::tainted(&reddb_wire::topology::encode_topology_for_hello_ack(
157 topo,
158 )),
159 );
160 }
161 JsonValue::Object(obj).to_string_compact().into_bytes()
162}
163
164pub fn pick_auth_method(client_methods: &[String], server_anon_ok: bool) -> Option<&'static str> {
172 let priority: &[&'static str] = if server_anon_ok {
177 &["anonymous", "scram-sha-256", "oauth-jwt", "bearer"]
178 } else {
179 &["scram-sha-256", "oauth-jwt", "bearer", "anonymous"]
180 };
181 for method in priority {
182 if !client_methods.iter().any(|m| m == *method) {
183 continue;
184 }
185 if *method == "anonymous" && !server_anon_ok {
186 continue;
187 }
188 return Some(*method);
189 }
190 None
191}
192
193pub fn validate_auth_response(
195 method: &str,
196 payload: &[u8],
197 auth_store: Option<&AuthStore>,
198) -> AuthOutcome {
199 match method {
200 "anonymous" => {
201 if let Some(store) = auth_store {
204 if store.is_enabled() {
205 return AuthOutcome::Refused(
206 "anonymous auth refused — server has auth enabled".into(),
207 );
208 }
209 }
210 AuthOutcome::Authenticated {
211 username: "anonymous".to_string(),
212 role: Role::Read,
213 session_id: new_session_id(),
214 }
215 }
216 "bearer" => {
217 let token = parse_bearer_response(payload).unwrap_or_default();
218 let Some(store) = auth_store else {
219 return AuthOutcome::Refused(
220 "bearer auth refused — server has no auth store configured".into(),
221 );
222 };
223 match store.validate_token(&token) {
224 Some((username, role)) => AuthOutcome::Authenticated {
225 username,
226 role,
227 session_id: new_session_id(),
228 },
229 None => AuthOutcome::Refused("bearer token invalid".into()),
230 }
231 }
232 "scram-sha-256" => AuthOutcome::Refused(
233 "scram-sha-256 must be driven through perform_scram_handshake — \
234 the 1-RTT validate_auth_response path doesn't apply"
235 .to_string(),
236 ),
237 "oauth-jwt" => {
238 AuthOutcome::Refused(
244 "oauth-jwt requires RedWireConfig.oauth to be set. Pass an \
245 OAuthValidator with the issuer + JWKS configured."
246 .to_string(),
247 )
248 }
249 other => AuthOutcome::Refused(format!("auth method '{other}' is not supported in v2.1")),
250 }
251}
252
253fn parse_bearer_response(payload: &[u8]) -> Option<String> {
254 let v: JsonValue = serde_json::from_slice(payload).ok()?;
255 let token = v.as_object()?.get("token")?.as_str()?;
256 Some(token.to_string())
257}
258
259pub fn build_auth_ok(
262 session_id: &str,
263 username: &str,
264 role: Role,
265 server_features: u32,
266) -> Vec<u8> {
267 use crate::json_field::SerializedJsonField;
268 let mut obj = crate::serde_json::Map::new();
272 obj.insert(
273 "session_id".to_string(),
274 SerializedJsonField::tainted(session_id),
275 );
276 obj.insert(
277 "username".to_string(),
278 SerializedJsonField::tainted(username),
279 );
280 let role_str = role.to_string();
281 obj.insert("role".to_string(), SerializedJsonField::tainted(&role_str));
282 obj.insert(
283 "features".to_string(),
284 JsonValue::Number(server_features as f64),
285 );
286 JsonValue::Object(obj).to_string_compact().into_bytes()
287}
288
289pub fn build_auth_fail(reason: &str) -> Vec<u8> {
290 use crate::json_field::SerializedJsonField;
291 let mut obj = crate::serde_json::Map::new();
295 obj.insert("reason".to_string(), SerializedJsonField::tainted(reason));
296 JsonValue::Object(obj).to_string_compact().into_bytes()
297}
298
299pub fn parse_scram_client_first(payload: &[u8]) -> Result<(String, String, String), String> {
303 let s = std::str::from_utf8(payload).map_err(|_| "client-first not UTF-8".to_string())?;
304 let bare = s
307 .strip_prefix("n,,")
308 .ok_or_else(|| "client-first must start with 'n,,' (no channel binding)".to_string())?;
309 let mut user = None;
310 let mut nonce = None;
311 for part in bare.split(',') {
312 if let Some(v) = part.strip_prefix("n=") {
313 user = Some(v.to_string());
314 } else if let Some(v) = part.strip_prefix("r=") {
315 nonce = Some(v.to_string());
316 }
317 }
318 let user = user.ok_or_else(|| "missing n=<user>".to_string())?;
319 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
320 Ok((user, nonce, bare.to_string()))
321}
322
323pub fn build_scram_server_first(
326 client_nonce: &str,
327 server_nonce: &str,
328 salt: &[u8],
329 iter: u32,
330) -> String {
331 format!(
332 "r={client_nonce}{server_nonce},s={},i={iter}",
333 base64_std(salt)
334 )
335}
336
337pub fn parse_scram_client_final(payload: &[u8]) -> Result<(String, Vec<u8>, String), String> {
340 let s = std::str::from_utf8(payload).map_err(|_| "client-final not UTF-8".to_string())?;
341 let mut channel_binding = None;
342 let mut nonce = None;
343 let mut proof_b64 = None;
344 for part in s.split(',') {
345 if let Some(v) = part.strip_prefix("c=") {
346 channel_binding = Some(v.to_string());
347 } else if let Some(v) = part.strip_prefix("r=") {
348 nonce = Some(v.to_string());
349 } else if let Some(v) = part.strip_prefix("p=") {
350 proof_b64 = Some(v.to_string());
351 }
352 }
353 let channel_binding =
354 channel_binding.ok_or_else(|| "missing c=<channel-binding>".to_string())?;
355 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
356 let proof_b64 = proof_b64.ok_or_else(|| "missing p=<proof>".to_string())?;
357 let proof = base64_std_decode(&proof_b64)
358 .ok_or_else(|| "client proof is not valid base64".to_string())?;
359 if channel_binding != "biws" {
361 return Err(format!(
362 "channel binding must be 'biws' (n,,), got '{channel_binding}'"
363 ));
364 }
365 let no_proof = format!("c={channel_binding},r={nonce}");
366 Ok((nonce, proof, no_proof))
367}
368
369pub fn build_scram_auth_ok(
373 session_id: &str,
374 username: &str,
375 role: Role,
376 server_features: u32,
377 server_signature: &[u8],
378) -> Vec<u8> {
379 let mut obj = crate::serde_json::Map::new();
380 obj.insert(
381 "session_id".to_string(),
382 JsonValue::String(session_id.to_string()),
383 );
384 obj.insert(
385 "username".to_string(),
386 JsonValue::String(username.to_string()),
387 );
388 obj.insert("role".to_string(), JsonValue::String(role.to_string()));
389 obj.insert(
390 "features".to_string(),
391 JsonValue::Number(server_features as f64),
392 );
393 obj.insert(
394 "v".to_string(),
395 JsonValue::String(base64_std(server_signature)),
396 );
397 JsonValue::Object(obj).to_string_compact().into_bytes()
398}
399
400pub fn new_server_nonce() -> String {
404 base64_std(&crate::auth::store::random_bytes(18))
405}
406
407pub(crate) fn new_session_id_for_scram() -> String {
408 new_session_id()
409}
410
411const B64_ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
418
419pub fn base64_std(input: &[u8]) -> String {
420 let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
421 let chunks = input.chunks_exact(3);
422 let rem = chunks.remainder();
423 for c in chunks {
424 let n = ((c[0] as u32) << 16) | ((c[1] as u32) << 8) | (c[2] as u32);
425 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
426 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
427 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
428 out.push(B64_ALPHA[(n & 0x3F) as usize] as char);
429 }
430 match rem {
431 [a] => {
432 let n = (*a as u32) << 16;
433 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
434 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
435 out.push('=');
436 out.push('=');
437 }
438 [a, b] => {
439 let n = ((*a as u32) << 16) | ((*b as u32) << 8);
440 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
441 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
442 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
443 out.push('=');
444 }
445 _ => {}
446 }
447 out
448}
449
450pub fn base64_std_decode(input: &str) -> Option<Vec<u8>> {
451 let trimmed = input.trim_end_matches('=');
452 let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
453 let mut buf = 0u32;
454 let mut bits = 0u8;
455 for ch in trimmed.bytes() {
456 let v: u32 = match ch {
457 b'A'..=b'Z' => (ch - b'A') as u32,
458 b'a'..=b'z' => (ch - b'a' + 26) as u32,
459 b'0'..=b'9' => (ch - b'0' + 52) as u32,
460 b'+' => 62,
461 b'/' => 63,
462 _ => return None,
463 };
464 buf = (buf << 6) | v;
465 bits += 6;
466 if bits >= 8 {
467 bits -= 8;
468 out.push(((buf >> bits) & 0xFF) as u8);
469 }
470 }
471 Some(out)
472}
473
474pub fn parse_jwt(token: &str) -> Result<crate::auth::oauth::DecodedJwt, String> {
479 let parts: Vec<&str> = token.split('.').collect();
480 if parts.len() != 3 {
481 return Err(format!(
482 "expected 3 dot-separated parts, got {}",
483 parts.len()
484 ));
485 }
486 let header_bytes =
487 base64_url_decode(parts[0]).ok_or_else(|| "header is not valid base64url".to_string())?;
488 let payload_bytes =
489 base64_url_decode(parts[1]).ok_or_else(|| "payload is not valid base64url".to_string())?;
490 let signature = base64_url_decode(parts[2])
491 .ok_or_else(|| "signature is not valid base64url".to_string())?;
492
493 let header_json: JsonValue =
494 serde_json::from_slice(&header_bytes).map_err(|e| format!("header JSON: {e}"))?;
495 let payload_json: JsonValue =
496 serde_json::from_slice(&payload_bytes).map_err(|e| format!("payload JSON: {e}"))?;
497
498 let header = jwt_header_from(&header_json)?;
499 let claims = jwt_claims_from(&payload_json);
500
501 let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
502
503 Ok(crate::auth::oauth::DecodedJwt {
504 header,
505 claims,
506 signing_input,
507 signature,
508 })
509}
510
511fn jwt_header_from(v: &JsonValue) -> Result<crate::auth::oauth::JwtHeader, String> {
512 let obj = v
513 .as_object()
514 .ok_or_else(|| "JWT header must be a JSON object".to_string())?;
515 let alg = obj
516 .get("alg")
517 .and_then(|x| x.as_str())
518 .ok_or_else(|| "JWT header missing 'alg'".to_string())?
519 .to_string();
520 let kid = obj.get("kid").and_then(|x| x.as_str()).map(String::from);
521 Ok(crate::auth::oauth::JwtHeader { alg, kid })
522}
523
524fn jwt_claims_from(v: &JsonValue) -> crate::auth::oauth::JwtClaims {
525 let obj = v.as_object().cloned().unwrap_or_default();
526 let mut claims = crate::auth::oauth::JwtClaims::default();
527 if let Some(s) = obj.get("iss").and_then(|x| x.as_str()) {
528 claims.iss = Some(s.to_string());
529 }
530 if let Some(s) = obj.get("sub").and_then(|x| x.as_str()) {
531 claims.sub = Some(s.to_string());
532 }
533 if let Some(s) = obj.get("aud").and_then(|x| x.as_str()) {
534 claims.aud = vec![s.to_string()];
535 } else if let Some(arr) = obj.get("aud").and_then(|x| x.as_array()) {
536 claims.aud = arr
537 .iter()
538 .filter_map(|v| v.as_str().map(String::from))
539 .collect();
540 }
541 if let Some(n) = obj.get("exp").and_then(|x| x.as_f64()) {
542 claims.exp = Some(n as i64);
543 }
544 if let Some(n) = obj.get("nbf").and_then(|x| x.as_f64()) {
545 claims.nbf = Some(n as i64);
546 }
547 if let Some(n) = obj.get("iat").and_then(|x| x.as_f64()) {
548 claims.iat = Some(n as i64);
549 }
550 for (k, v) in obj.iter() {
551 if matches!(k.as_str(), "iss" | "sub" | "aud" | "exp" | "nbf" | "iat") {
552 continue;
553 }
554 if let Some(s) = v.as_str() {
555 claims.extra.insert(k.clone(), s.to_string());
556 }
557 }
558 claims
559}
560
561pub fn validate_oauth_jwt(
564 validator: &crate::auth::oauth::OAuthValidator,
565 raw_token: &str,
566) -> Result<(String, Role), String> {
567 validate_oauth_jwt_full(validator, raw_token).map(|(_tenant, username, role)| (username, role))
568}
569
570pub fn validate_oauth_jwt_full(
574 validator: &crate::auth::oauth::OAuthValidator,
575 raw_token: &str,
576) -> Result<(Option<String>, String, Role), String> {
577 let token = parse_jwt(raw_token).map_err(|e| format!("decode JWT: {e}"))?;
578 let now = std::time::SystemTime::now()
579 .duration_since(std::time::UNIX_EPOCH)
580 .map(|d| d.as_secs() as i64)
581 .unwrap_or(0);
582 let identity = validator
588 .validate(&token, now, |sub| {
589 Some(crate::auth::User {
590 username: sub.to_string(),
591 tenant_id: token.claims.extra.get("tenant").cloned(),
592 password_hash: String::new(),
593 scram_verifier: None,
594 role: token
595 .claims
596 .extra
597 .get("role")
598 .and_then(|s| Role::from_str(s))
599 .unwrap_or(Role::Read),
600 api_keys: Vec::new(),
601 created_at: 0,
602 updated_at: 0,
603 enabled: true,
604 })
605 })
606 .map_err(|e| format!("{e}"))?;
607 Ok((identity.tenant, identity.username, identity.role))
608}
609
610fn base64_url_decode(input: &str) -> Option<Vec<u8>> {
611 let mut s = String::with_capacity(input.len() + 4);
613 for ch in input.chars() {
614 match ch {
615 '-' => s.push('+'),
616 '_' => s.push('/'),
617 _ => s.push(ch),
618 }
619 }
620 while !s.len().is_multiple_of(4) {
621 s.push('=');
622 }
623 base64_std_decode(&s)
624}
625
626fn new_session_id() -> String {
630 let now_us = std::time::SystemTime::now()
631 .duration_since(std::time::UNIX_EPOCH)
632 .map(|d| d.as_micros())
633 .unwrap_or(0);
634 let rand = crate::utils::now_unix_nanos() & 0xFFFF_FFFF;
635 format!("rwsess-{now_us}-{rand:08x}")
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn hello_round_trip() {
644 let payload = br#"{"versions":[1],"auth_methods":["bearer","anonymous"],"features":3,"client_name":"reddb-rs/0.1"}"#;
645 let h = Hello::from_payload(payload).unwrap();
646 assert_eq!(h.versions, vec![1]);
647 assert_eq!(h.auth_methods, vec!["bearer", "anonymous"]);
648 assert_eq!(h.features, 3);
649 assert_eq!(h.client_name.as_deref(), Some("reddb-rs/0.1"));
650 }
651
652 #[test]
653 fn hello_rejects_empty_methods() {
654 let payload = br#"{"versions":[1],"auth_methods":[]}"#;
655 assert!(Hello::from_payload(payload).is_err());
656 }
657
658 #[test]
659 fn pick_auth_prefers_anonymous_when_server_has_no_auth_store() {
660 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
663 assert_eq!(pick_auth_method(&pref, true), Some("anonymous"));
664 }
665
666 #[test]
667 fn pick_auth_picks_bearer_when_anonymous_blocked() {
668 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
670 assert_eq!(pick_auth_method(&pref, false), Some("bearer"));
671 }
672
673 #[test]
674 fn pick_auth_skips_anonymous_when_server_blocks_it() {
675 let pref = vec!["anonymous".to_string()];
676 assert_eq!(pick_auth_method(&pref, false), None);
677 }
678
679 #[test]
680 fn pick_auth_returns_none_when_nothing_overlaps() {
681 let pref = vec!["kerberos".to_string(), "future-method".to_string()];
682 assert_eq!(pick_auth_method(&pref, true), None);
683 }
684
685 #[test]
686 fn anonymous_validates_only_when_store_disabled() {
687 let outcome = validate_auth_response("anonymous", &[], None);
688 assert!(matches!(outcome, AuthOutcome::Authenticated { .. }));
689 }
690
691 #[test]
692 fn bearer_without_store_refuses() {
693 let outcome = validate_auth_response("bearer", br#"{"token":"x"}"#, None);
694 assert!(matches!(outcome, AuthOutcome::Refused(_)));
695 }
696
697 #[test]
698 fn hello_ack_omits_topology_field_when_caller_passes_none() {
699 let bytes = build_hello_ack(1, "bearer", 0, None);
703 let s = std::str::from_utf8(&bytes).unwrap();
704 assert!(!s.contains("\"topology\""));
705 }
706
707 #[test]
708 fn hello_ack_embeds_topology_field_when_caller_passes_payload() {
709 let topo = reddb_wire::topology::Topology {
714 epoch: 17,
715 primary: reddb_wire::topology::Endpoint {
716 addr: "primary:5050".into(),
717 region: "us-east-1".into(),
718 },
719 replicas: Vec::new(),
720 };
721 let bytes = build_hello_ack(1, "bearer", 0, Some(&topo));
722 let s = std::str::from_utf8(&bytes).unwrap();
723 assert!(s.contains("\"topology\""), "missing topology key in {s}");
724
725 let v: JsonValue = crate::serde_json::from_slice(&bytes).unwrap();
727 let field = v
728 .as_object()
729 .and_then(|o| o.get("topology"))
730 .and_then(|t| t.as_str())
731 .expect("topology key must be present and a string");
732 let decoded = reddb_wire::topology::decode_topology_from_hello_ack(field).expect("decode");
733 assert_eq!(decoded.expect("v1 known"), topo);
734 }
735}