1use crate::auth::store::AuthStore;
13use crate::auth::Role;
14use crate::serde_json::{self, Value as JsonValue};
15
16pub const SUPPORTED_METHODS: &[&str] = &["bearer", "anonymous", "scram-sha-256", "oauth-jwt"];
27
28#[derive(Debug, Clone)]
30pub enum AuthOutcome {
31 Authenticated {
33 username: String,
34 role: Role,
35 tenant: Option<String>,
36 session_id: String,
37 },
38 Refused(String),
40}
41
42#[derive(Debug, Clone)]
44pub struct Hello {
45 pub versions: Vec<u8>,
46 pub auth_methods: Vec<String>,
47 pub features: u32,
48 pub client_name: Option<String>,
49}
50
51impl Hello {
52 pub fn from_payload(bytes: &[u8]) -> Result<Self, String> {
53 let v: JsonValue =
54 serde_json::from_slice(bytes).map_err(|e| format!("Hello: invalid JSON: {e}"))?;
55 let obj = match v {
56 JsonValue::Object(o) => o,
57 _ => return Err("Hello: payload must be a JSON object".into()),
58 };
59 let versions: Vec<u8> = obj
60 .get("versions")
61 .and_then(|v| v.as_array())
62 .map(|arr| {
63 arr.iter()
64 .filter_map(|n| n.as_f64().map(|f| f as u8))
65 .collect()
66 })
67 .unwrap_or_default();
68 let auth_methods: Vec<String> = obj
69 .get("auth_methods")
70 .and_then(|v| v.as_array())
71 .map(|arr| {
72 arr.iter()
73 .filter_map(|s| s.as_str().map(String::from))
74 .collect()
75 })
76 .unwrap_or_default();
77 let features = obj
78 .get("features")
79 .and_then(|v| v.as_f64())
80 .map(|f| f as u32)
81 .unwrap_or(0);
82 let client_name = obj
83 .get("client_name")
84 .and_then(|v| v.as_str())
85 .map(String::from);
86 if versions.is_empty() {
87 return Err("Hello: versions[] is empty".into());
88 }
89 if auth_methods.is_empty() {
90 return Err("Hello: auth_methods[] is empty".into());
91 }
92 Ok(Self {
93 versions,
94 auth_methods,
95 features,
96 client_name,
97 })
98 }
99}
100
101pub fn build_hello_ack(
118 chosen_version: u8,
119 chosen_auth: &str,
120 server_features: u32,
121 topology: Option<&reddb_wire::topology::Topology>,
122) -> Vec<u8> {
123 use crate::json_field::SerializedJsonField;
124 let mut obj = crate::serde_json::Map::new();
137 obj.insert(
138 "version".to_string(),
139 JsonValue::Number(chosen_version as f64),
140 );
141 obj.insert(
142 "auth".to_string(),
143 SerializedJsonField::tainted(chosen_auth),
144 );
145 obj.insert(
146 "features".to_string(),
147 JsonValue::Number(server_features as f64),
148 );
149 let server_field = format!("reddb/{}", env!("CARGO_PKG_VERSION"));
150 obj.insert(
151 "server".to_string(),
152 SerializedJsonField::tainted(&server_field),
153 );
154 if let Some(topo) = topology {
155 obj.insert(
156 "topology".to_string(),
157 SerializedJsonField::tainted(&reddb_wire::topology::encode_topology_for_hello_ack(
158 topo,
159 )),
160 );
161 }
162 JsonValue::Object(obj).to_string_compact().into_bytes()
163}
164
165pub fn pick_auth_method(client_methods: &[String], server_anon_ok: bool) -> Option<&'static str> {
173 let priority: &[&'static str] = if server_anon_ok {
178 &["anonymous", "scram-sha-256", "oauth-jwt", "bearer"]
179 } else {
180 &["scram-sha-256", "oauth-jwt", "bearer", "anonymous"]
181 };
182 for method in priority {
183 if !client_methods.iter().any(|m| m == *method) {
184 continue;
185 }
186 if *method == "anonymous" && !server_anon_ok {
187 continue;
188 }
189 return Some(*method);
190 }
191 None
192}
193
194pub fn validate_auth_response(
196 method: &str,
197 payload: &[u8],
198 auth_store: Option<&AuthStore>,
199) -> AuthOutcome {
200 match method {
201 "anonymous" => {
202 if let Some(store) = auth_store {
205 if store.is_enabled() {
206 return AuthOutcome::Refused(
207 "anonymous auth refused — server has auth enabled".into(),
208 );
209 }
210 }
211 AuthOutcome::Authenticated {
212 username: "anonymous".to_string(),
213 role: Role::Read,
214 tenant: None,
215 session_id: new_session_id(),
216 }
217 }
218 "bearer" => {
219 let token = parse_bearer_response(payload).unwrap_or_default();
220 let Some(store) = auth_store else {
221 return AuthOutcome::Refused(
222 "bearer auth refused — server has no auth store configured".into(),
223 );
224 };
225 match store.validate_token_full(&token) {
226 Some((user_id, role)) => AuthOutcome::Authenticated {
227 username: user_id.username,
228 role,
229 tenant: user_id.tenant,
230 session_id: new_session_id(),
231 },
232 None => AuthOutcome::Refused("bearer token invalid".into()),
233 }
234 }
235 "scram-sha-256" => AuthOutcome::Refused(
236 "scram-sha-256 must be driven through perform_scram_handshake — \
237 the 1-RTT validate_auth_response path doesn't apply"
238 .to_string(),
239 ),
240 "oauth-jwt" => {
241 AuthOutcome::Refused(
247 "oauth-jwt requires RedWireConfig.oauth to be set. Pass an \
248 OAuthValidator with the issuer + JWKS configured."
249 .to_string(),
250 )
251 }
252 other => AuthOutcome::Refused(format!("auth method '{other}' is not supported in v2.1")),
253 }
254}
255
256fn parse_bearer_response(payload: &[u8]) -> Option<String> {
257 let v: JsonValue = serde_json::from_slice(payload).ok()?;
258 let token = v.as_object()?.get("token")?.as_str()?;
259 Some(token.to_string())
260}
261
262pub fn build_auth_ok(
265 session_id: &str,
266 username: &str,
267 role: Role,
268 server_features: u32,
269) -> Vec<u8> {
270 use crate::json_field::SerializedJsonField;
271 let mut obj = crate::serde_json::Map::new();
275 obj.insert(
276 "session_id".to_string(),
277 SerializedJsonField::tainted(session_id),
278 );
279 obj.insert(
280 "username".to_string(),
281 SerializedJsonField::tainted(username),
282 );
283 let role_str = role.to_string();
284 obj.insert("role".to_string(), SerializedJsonField::tainted(&role_str));
285 obj.insert(
286 "features".to_string(),
287 JsonValue::Number(server_features as f64),
288 );
289 JsonValue::Object(obj).to_string_compact().into_bytes()
290}
291
292pub fn build_auth_fail(reason: &str) -> Vec<u8> {
293 use crate::json_field::SerializedJsonField;
294 let mut obj = crate::serde_json::Map::new();
298 obj.insert("reason".to_string(), SerializedJsonField::tainted(reason));
299 JsonValue::Object(obj).to_string_compact().into_bytes()
300}
301
302pub fn parse_scram_client_first(payload: &[u8]) -> Result<(String, String, String), String> {
306 let s = std::str::from_utf8(payload).map_err(|_| "client-first not UTF-8".to_string())?;
307 let bare = s
310 .strip_prefix("n,,")
311 .ok_or_else(|| "client-first must start with 'n,,' (no channel binding)".to_string())?;
312 let mut user = None;
313 let mut nonce = None;
314 for part in bare.split(',') {
315 if let Some(v) = part.strip_prefix("n=") {
316 user = Some(v.to_string());
317 } else if let Some(v) = part.strip_prefix("r=") {
318 nonce = Some(v.to_string());
319 }
320 }
321 let user = user.ok_or_else(|| "missing n=<user>".to_string())?;
322 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
323 Ok((user, nonce, bare.to_string()))
324}
325
326pub fn build_scram_server_first(
329 client_nonce: &str,
330 server_nonce: &str,
331 salt: &[u8],
332 iter: u32,
333) -> String {
334 format!(
335 "r={client_nonce}{server_nonce},s={},i={iter}",
336 base64_std(salt)
337 )
338}
339
340pub fn parse_scram_client_final(payload: &[u8]) -> Result<(String, Vec<u8>, String), String> {
343 let s = std::str::from_utf8(payload).map_err(|_| "client-final not UTF-8".to_string())?;
344 let mut channel_binding = None;
345 let mut nonce = None;
346 let mut proof_b64 = None;
347 for part in s.split(',') {
348 if let Some(v) = part.strip_prefix("c=") {
349 channel_binding = Some(v.to_string());
350 } else if let Some(v) = part.strip_prefix("r=") {
351 nonce = Some(v.to_string());
352 } else if let Some(v) = part.strip_prefix("p=") {
353 proof_b64 = Some(v.to_string());
354 }
355 }
356 let channel_binding =
357 channel_binding.ok_or_else(|| "missing c=<channel-binding>".to_string())?;
358 let nonce = nonce.ok_or_else(|| "missing r=<nonce>".to_string())?;
359 let proof_b64 = proof_b64.ok_or_else(|| "missing p=<proof>".to_string())?;
360 let proof = base64_std_decode(&proof_b64)
361 .ok_or_else(|| "client proof is not valid base64".to_string())?;
362 if channel_binding != "biws" {
364 return Err(format!(
365 "channel binding must be 'biws' (n,,), got '{channel_binding}'"
366 ));
367 }
368 let no_proof = format!("c={channel_binding},r={nonce}");
369 Ok((nonce, proof, no_proof))
370}
371
372pub fn build_scram_auth_ok(
376 session_id: &str,
377 username: &str,
378 role: Role,
379 server_features: u32,
380 server_signature: &[u8],
381) -> Vec<u8> {
382 let mut obj = crate::serde_json::Map::new();
383 obj.insert(
384 "session_id".to_string(),
385 JsonValue::String(session_id.to_string()),
386 );
387 obj.insert(
388 "username".to_string(),
389 JsonValue::String(username.to_string()),
390 );
391 obj.insert("role".to_string(), JsonValue::String(role.to_string()));
392 obj.insert(
393 "features".to_string(),
394 JsonValue::Number(server_features as f64),
395 );
396 obj.insert(
397 "v".to_string(),
398 JsonValue::String(base64_std(server_signature)),
399 );
400 JsonValue::Object(obj).to_string_compact().into_bytes()
401}
402
403pub fn new_server_nonce() -> String {
407 base64_std(&crate::auth::store::random_bytes(18))
408}
409
410pub(crate) fn new_session_id_for_scram() -> String {
411 new_session_id()
412}
413
414const B64_ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
421
422pub fn base64_std(input: &[u8]) -> String {
423 let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
424 let chunks = input.chunks_exact(3);
425 let rem = chunks.remainder();
426 for c in chunks {
427 let n = ((c[0] as u32) << 16) | ((c[1] as u32) << 8) | (c[2] as u32);
428 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
429 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
430 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
431 out.push(B64_ALPHA[(n & 0x3F) as usize] as char);
432 }
433 match rem {
434 [a] => {
435 let n = (*a as u32) << 16;
436 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
437 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
438 out.push('=');
439 out.push('=');
440 }
441 [a, b] => {
442 let n = ((*a as u32) << 16) | ((*b as u32) << 8);
443 out.push(B64_ALPHA[((n >> 18) & 0x3F) as usize] as char);
444 out.push(B64_ALPHA[((n >> 12) & 0x3F) as usize] as char);
445 out.push(B64_ALPHA[((n >> 6) & 0x3F) as usize] as char);
446 out.push('=');
447 }
448 _ => {}
449 }
450 out
451}
452
453pub fn base64_std_decode(input: &str) -> Option<Vec<u8>> {
454 let trimmed = input.trim_end_matches('=');
455 let mut out = Vec::with_capacity(trimmed.len() * 3 / 4);
456 let mut buf = 0u32;
457 let mut bits = 0u8;
458 for ch in trimmed.bytes() {
459 let v: u32 = match ch {
460 b'A'..=b'Z' => (ch - b'A') as u32,
461 b'a'..=b'z' => (ch - b'a' + 26) as u32,
462 b'0'..=b'9' => (ch - b'0' + 52) as u32,
463 b'+' => 62,
464 b'/' => 63,
465 _ => return None,
466 };
467 buf = (buf << 6) | v;
468 bits += 6;
469 if bits >= 8 {
470 bits -= 8;
471 out.push(((buf >> bits) & 0xFF) as u8);
472 }
473 }
474 Some(out)
475}
476
477pub fn parse_jwt(token: &str) -> Result<crate::auth::oauth::DecodedJwt, String> {
482 let parts: Vec<&str> = token.split('.').collect();
483 if parts.len() != 3 {
484 return Err(format!(
485 "expected 3 dot-separated parts, got {}",
486 parts.len()
487 ));
488 }
489 let header_bytes =
490 base64_url_decode(parts[0]).ok_or_else(|| "header is not valid base64url".to_string())?;
491 let payload_bytes =
492 base64_url_decode(parts[1]).ok_or_else(|| "payload is not valid base64url".to_string())?;
493 let signature = base64_url_decode(parts[2])
494 .ok_or_else(|| "signature is not valid base64url".to_string())?;
495
496 let header_json: JsonValue =
497 serde_json::from_slice(&header_bytes).map_err(|e| format!("header JSON: {e}"))?;
498 let payload_json: JsonValue =
499 serde_json::from_slice(&payload_bytes).map_err(|e| format!("payload JSON: {e}"))?;
500
501 let header = jwt_header_from(&header_json)?;
502 let claims = jwt_claims_from(&payload_json);
503
504 let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
505
506 Ok(crate::auth::oauth::DecodedJwt {
507 header,
508 claims,
509 signing_input,
510 signature,
511 })
512}
513
514fn jwt_header_from(v: &JsonValue) -> Result<crate::auth::oauth::JwtHeader, String> {
515 let obj = v
516 .as_object()
517 .ok_or_else(|| "JWT header must be a JSON object".to_string())?;
518 let alg = obj
519 .get("alg")
520 .and_then(|x| x.as_str())
521 .ok_or_else(|| "JWT header missing 'alg'".to_string())?
522 .to_string();
523 let kid = obj.get("kid").and_then(|x| x.as_str()).map(String::from);
524 Ok(crate::auth::oauth::JwtHeader { alg, kid })
525}
526
527fn jwt_claims_from(v: &JsonValue) -> crate::auth::oauth::JwtClaims {
528 let obj = v.as_object().cloned().unwrap_or_default();
529 let mut claims = crate::auth::oauth::JwtClaims::default();
530 if let Some(s) = obj.get("iss").and_then(|x| x.as_str()) {
531 claims.iss = Some(s.to_string());
532 }
533 if let Some(s) = obj.get("sub").and_then(|x| x.as_str()) {
534 claims.sub = Some(s.to_string());
535 }
536 if let Some(s) = obj.get("aud").and_then(|x| x.as_str()) {
537 claims.aud = vec![s.to_string()];
538 } else if let Some(arr) = obj.get("aud").and_then(|x| x.as_array()) {
539 claims.aud = arr
540 .iter()
541 .filter_map(|v| v.as_str().map(String::from))
542 .collect();
543 }
544 if let Some(n) = obj.get("exp").and_then(|x| x.as_f64()) {
545 claims.exp = Some(n as i64);
546 }
547 if let Some(n) = obj.get("nbf").and_then(|x| x.as_f64()) {
548 claims.nbf = Some(n as i64);
549 }
550 if let Some(n) = obj.get("iat").and_then(|x| x.as_f64()) {
551 claims.iat = Some(n as i64);
552 }
553 for (k, v) in obj.iter() {
554 if matches!(k.as_str(), "iss" | "sub" | "aud" | "exp" | "nbf" | "iat") {
555 continue;
556 }
557 if let Some(s) = v.as_str() {
558 claims.extra.insert(k.clone(), s.to_string());
559 }
560 }
561 claims
562}
563
564pub fn validate_oauth_jwt(
567 validator: &crate::auth::oauth::OAuthValidator,
568 raw_token: &str,
569) -> Result<(String, Role), String> {
570 validate_oauth_jwt_full(validator, raw_token).map(|(_tenant, username, role)| (username, role))
571}
572
573pub fn validate_oauth_jwt_full(
577 validator: &crate::auth::oauth::OAuthValidator,
578 raw_token: &str,
579) -> Result<(Option<String>, String, Role), String> {
580 let token = parse_jwt(raw_token).map_err(|e| format!("decode JWT: {e}"))?;
581 let now = std::time::SystemTime::now()
582 .duration_since(std::time::UNIX_EPOCH)
583 .map(|d| d.as_secs() as i64)
584 .unwrap_or(0);
585 let identity = validator
591 .validate(&token, now, |sub| {
592 Some(crate::auth::User {
593 username: sub.to_string(),
594 tenant_id: token.claims.extra.get("tenant").cloned(),
595 password_hash: String::new(),
596 scram_verifier: None,
597 role: token
598 .claims
599 .extra
600 .get("role")
601 .and_then(|s| Role::from_str(s))
602 .unwrap_or(Role::Read),
603 api_keys: Vec::new(),
604 created_at: 0,
605 updated_at: 0,
606 enabled: true,
607 system_owned: false,
608 })
609 })
610 .map_err(|e| format!("{e}"))?;
611 Ok((identity.tenant, identity.username, identity.role))
612}
613
614fn base64_url_decode(input: &str) -> Option<Vec<u8>> {
615 let mut s = String::with_capacity(input.len() + 4);
617 for ch in input.chars() {
618 match ch {
619 '-' => s.push('+'),
620 '_' => s.push('/'),
621 _ => s.push(ch),
622 }
623 }
624 while !s.len().is_multiple_of(4) {
625 s.push('=');
626 }
627 base64_std_decode(&s)
628}
629
630fn new_session_id() -> String {
634 let now_us = std::time::SystemTime::now()
635 .duration_since(std::time::UNIX_EPOCH)
636 .map(|d| d.as_micros())
637 .unwrap_or(0);
638 let rand = crate::utils::now_unix_nanos() & 0xFFFF_FFFF;
639 format!("rwsess-{now_us}-{rand:08x}")
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[test]
647 fn hello_round_trip() {
648 let payload = br#"{"versions":[1],"auth_methods":["bearer","anonymous"],"features":3,"client_name":"reddb-rs/0.1"}"#;
649 let h = Hello::from_payload(payload).unwrap();
650 assert_eq!(h.versions, vec![1]);
651 assert_eq!(h.auth_methods, vec!["bearer", "anonymous"]);
652 assert_eq!(h.features, 3);
653 assert_eq!(h.client_name.as_deref(), Some("reddb-rs/0.1"));
654 }
655
656 #[test]
657 fn hello_rejects_empty_methods() {
658 let payload = br#"{"versions":[1],"auth_methods":[]}"#;
659 assert!(Hello::from_payload(payload).is_err());
660 }
661
662 #[test]
663 fn pick_auth_prefers_anonymous_when_server_has_no_auth_store() {
664 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
667 assert_eq!(pick_auth_method(&pref, true), Some("anonymous"));
668 }
669
670 #[test]
671 fn pick_auth_picks_bearer_when_anonymous_blocked() {
672 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
674 assert_eq!(pick_auth_method(&pref, false), Some("bearer"));
675 }
676
677 #[test]
678 fn pick_auth_skips_anonymous_when_server_blocks_it() {
679 let pref = vec!["anonymous".to_string()];
680 assert_eq!(pick_auth_method(&pref, false), None);
681 }
682
683 #[test]
684 fn pick_auth_returns_none_when_nothing_overlaps() {
685 let pref = vec!["kerberos".to_string(), "future-method".to_string()];
686 assert_eq!(pick_auth_method(&pref, true), None);
687 }
688
689 #[test]
690 fn anonymous_validates_only_when_store_disabled() {
691 let outcome = validate_auth_response("anonymous", &[], None);
692 assert!(matches!(outcome, AuthOutcome::Authenticated { .. }));
693 }
694
695 #[test]
696 fn bearer_without_store_refuses() {
697 let outcome = validate_auth_response("bearer", br#"{"token":"x"}"#, None);
698 assert!(matches!(outcome, AuthOutcome::Refused(_)));
699 }
700
701 #[test]
702 fn hello_ack_omits_topology_field_when_caller_passes_none() {
703 let bytes = build_hello_ack(1, "bearer", 0, None);
707 let s = std::str::from_utf8(&bytes).unwrap();
708 assert!(!s.contains("\"topology\""));
709 }
710
711 #[test]
712 fn hello_ack_embeds_topology_field_when_caller_passes_payload() {
713 let topo = reddb_wire::topology::Topology {
718 epoch: 17,
719 primary: reddb_wire::topology::Endpoint {
720 addr: "primary:5050".into(),
721 region: "us-east-1".into(),
722 },
723 replicas: Vec::new(),
724 };
725 let bytes = build_hello_ack(1, "bearer", 0, Some(&topo));
726 let s = std::str::from_utf8(&bytes).unwrap();
727 assert!(s.contains("\"topology\""), "missing topology key in {s}");
728
729 let v: JsonValue = crate::serde_json::from_slice(&bytes).unwrap();
731 let field = v
732 .as_object()
733 .and_then(|o| o.get("topology"))
734 .and_then(|t| t.as_str())
735 .expect("topology key must be present and a string");
736 let decoded = reddb_wire::topology::decode_topology_from_hello_ack(field).expect("decode");
737 assert_eq!(decoded.expect("v1 known"), topo);
738 }
739}