reddb_server/wire/redwire/
auth.rs1use crate::auth::store::AuthStore;
13use crate::auth::Role;
14use crate::serde_json::{self, Value as JsonValue};
15use reddb_wire::redwire::handshake::{
16 base64_std, base64_std_decode, build_scram_auth_ok_payload, parse_auth_response_bearer_token,
17};
18
19#[derive(Debug, Clone)]
21pub enum AuthOutcome {
22 Authenticated {
24 username: String,
25 role: Role,
26 tenant: Option<String>,
27 session_id: String,
28 },
29 Refused(String),
31}
32
33pub fn pick_auth_method(client_methods: &[String], server_anon_ok: bool) -> Option<&'static str> {
41 let priority: &[&'static str] = if server_anon_ok {
46 &["anonymous", "scram-sha-256", "oauth-jwt", "bearer"]
47 } else {
48 &["scram-sha-256", "oauth-jwt", "bearer", "anonymous"]
49 };
50 for method in priority {
51 if !client_methods.iter().any(|m| m == *method) {
52 continue;
53 }
54 if *method == "anonymous" && !server_anon_ok {
55 continue;
56 }
57 return Some(*method);
58 }
59 None
60}
61
62pub fn validate_auth_response(
64 method: &str,
65 payload: &[u8],
66 auth_store: Option<&AuthStore>,
67) -> AuthOutcome {
68 match method {
69 "anonymous" => {
70 if let Some(store) = auth_store {
73 if store.is_enabled() {
74 return AuthOutcome::Refused(
75 "anonymous auth refused — server has auth enabled".into(),
76 );
77 }
78 }
79 AuthOutcome::Authenticated {
80 username: "anonymous".to_string(),
81 role: Role::Read,
82 tenant: None,
83 session_id: new_session_id(),
84 }
85 }
86 "bearer" => {
87 let token = parse_auth_response_bearer_token(payload).unwrap_or_default();
88 let Some(store) = auth_store else {
89 return AuthOutcome::Refused(
90 "bearer auth refused — server has no auth store configured".into(),
91 );
92 };
93 match store.validate_token_full(&token) {
94 Some((user_id, role)) => AuthOutcome::Authenticated {
95 username: user_id.username,
96 role,
97 tenant: user_id.tenant,
98 session_id: new_session_id(),
99 },
100 None => AuthOutcome::Refused("bearer token invalid".into()),
101 }
102 }
103 "scram-sha-256" => AuthOutcome::Refused(
104 "scram-sha-256 must be driven through perform_scram_handshake — \
105 the 1-RTT validate_auth_response path doesn't apply"
106 .to_string(),
107 ),
108 "oauth-jwt" => {
109 AuthOutcome::Refused(
115 "oauth-jwt requires RedWireConfig.oauth to be set. Pass an \
116 OAuthValidator with the issuer + JWKS configured."
117 .to_string(),
118 )
119 }
120 other => AuthOutcome::Refused(format!("auth method '{other}' is not supported in v2.1")),
121 }
122}
123
124pub fn build_auth_ok(
127 session_id: &str,
128 username: &str,
129 role: Role,
130 server_features: u32,
131) -> Vec<u8> {
132 let role_str = role.to_string();
133 reddb_wire::redwire::handshake::build_auth_ok_payload(
134 session_id,
135 username,
136 &role_str,
137 server_features,
138 )
139}
140
141pub fn build_scram_auth_ok(
145 session_id: &str,
146 username: &str,
147 role: Role,
148 server_features: u32,
149 server_signature: &[u8],
150) -> Vec<u8> {
151 let role = role.to_string();
152 build_scram_auth_ok_payload(
153 session_id,
154 username,
155 &role,
156 server_features,
157 server_signature,
158 )
159}
160
161pub fn new_server_nonce() -> String {
165 base64_std(&crate::auth::store::random_bytes(18))
166}
167
168pub(crate) fn new_session_id_for_scram() -> String {
169 new_session_id()
170}
171
172pub fn parse_jwt(token: &str) -> Result<crate::auth::oauth::DecodedJwt, String> {
177 let parts: Vec<&str> = token.split('.').collect();
178 if parts.len() != 3 {
179 return Err(format!(
180 "expected 3 dot-separated parts, got {}",
181 parts.len()
182 ));
183 }
184 let header_bytes =
185 base64_url_decode(parts[0]).ok_or_else(|| "header is not valid base64url".to_string())?;
186 let payload_bytes =
187 base64_url_decode(parts[1]).ok_or_else(|| "payload is not valid base64url".to_string())?;
188 let signature = base64_url_decode(parts[2])
189 .ok_or_else(|| "signature is not valid base64url".to_string())?;
190
191 let header_json: JsonValue =
192 serde_json::from_slice(&header_bytes).map_err(|e| format!("header JSON: {e}"))?;
193 let payload_json: JsonValue =
194 serde_json::from_slice(&payload_bytes).map_err(|e| format!("payload JSON: {e}"))?;
195
196 let header = jwt_header_from(&header_json)?;
197 let claims = jwt_claims_from(&payload_json);
198
199 let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
200
201 Ok(crate::auth::oauth::DecodedJwt {
202 header,
203 claims,
204 signing_input,
205 signature,
206 })
207}
208
209fn jwt_header_from(v: &JsonValue) -> Result<crate::auth::oauth::JwtHeader, String> {
210 let obj = v
211 .as_object()
212 .ok_or_else(|| "JWT header must be a JSON object".to_string())?;
213 let alg = obj
214 .get("alg")
215 .and_then(|x| x.as_str())
216 .ok_or_else(|| "JWT header missing 'alg'".to_string())?
217 .to_string();
218 let kid = obj.get("kid").and_then(|x| x.as_str()).map(String::from);
219 Ok(crate::auth::oauth::JwtHeader { alg, kid })
220}
221
222fn jwt_claims_from(v: &JsonValue) -> crate::auth::oauth::JwtClaims {
223 let obj = v.as_object().cloned().unwrap_or_default();
224 let mut claims = crate::auth::oauth::JwtClaims::default();
225 if let Some(s) = obj.get("iss").and_then(|x| x.as_str()) {
226 claims.iss = Some(s.to_string());
227 }
228 if let Some(s) = obj.get("sub").and_then(|x| x.as_str()) {
229 claims.sub = Some(s.to_string());
230 }
231 if let Some(s) = obj.get("aud").and_then(|x| x.as_str()) {
232 claims.aud = vec![s.to_string()];
233 } else if let Some(arr) = obj.get("aud").and_then(|x| x.as_array()) {
234 claims.aud = arr
235 .iter()
236 .filter_map(|v| v.as_str().map(String::from))
237 .collect();
238 }
239 if let Some(n) = obj.get("exp").and_then(|x| x.as_f64()) {
240 claims.exp = Some(n as i64);
241 }
242 if let Some(n) = obj.get("nbf").and_then(|x| x.as_f64()) {
243 claims.nbf = Some(n as i64);
244 }
245 if let Some(n) = obj.get("iat").and_then(|x| x.as_f64()) {
246 claims.iat = Some(n as i64);
247 }
248 for (k, v) in obj.iter() {
249 if matches!(k.as_str(), "iss" | "sub" | "aud" | "exp" | "nbf" | "iat") {
250 continue;
251 }
252 if let Some(s) = v.as_str() {
253 claims.extra.insert(k.clone(), s.to_string());
254 }
255 }
256 claims
257}
258
259pub fn validate_oauth_jwt(
262 validator: &crate::auth::oauth::OAuthValidator,
263 raw_token: &str,
264) -> Result<(String, Role), String> {
265 validate_oauth_jwt_full(validator, raw_token).map(|(_tenant, username, role)| (username, role))
266}
267
268pub fn validate_oauth_jwt_full(
272 validator: &crate::auth::oauth::OAuthValidator,
273 raw_token: &str,
274) -> Result<(Option<String>, String, Role), String> {
275 let token = parse_jwt(raw_token).map_err(|e| format!("decode JWT: {e}"))?;
276 let now = std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .map(|d| d.as_secs() as i64)
279 .unwrap_or(0);
280 let identity = validator
286 .validate(&token, now, |sub| {
287 Some(crate::auth::User {
288 username: sub.to_string(),
289 tenant_id: token.claims.extra.get("tenant").cloned(),
290 password_hash: String::new(),
291 scram_verifier: None,
292 role: token
293 .claims
294 .extra
295 .get("role")
296 .and_then(|s| Role::from_str(s))
297 .unwrap_or(Role::Read),
298 api_keys: Vec::new(),
299 created_at: 0,
300 updated_at: 0,
301 enabled: true,
302 })
303 })
304 .map_err(|e| format!("{e}"))?;
305 Ok((identity.tenant, identity.username, identity.role))
306}
307
308fn base64_url_decode(input: &str) -> Option<Vec<u8>> {
309 let mut s = String::with_capacity(input.len() + 4);
311 for ch in input.chars() {
312 match ch {
313 '-' => s.push('+'),
314 '_' => s.push('/'),
315 _ => s.push(ch),
316 }
317 }
318 while !s.len().is_multiple_of(4) {
319 s.push('=');
320 }
321 base64_std_decode(&s)
322}
323
324fn new_session_id() -> String {
328 let now_us = std::time::SystemTime::now()
329 .duration_since(std::time::UNIX_EPOCH)
330 .map(|d| d.as_micros())
331 .unwrap_or(0);
332 let rand = crate::utils::now_unix_nanos() & 0xFFFF_FFFF;
333 format!("rwsess-{now_us}-{rand:08x}")
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use reddb_wire::redwire::handshake::{build_hello_ack, Hello};
340
341 #[test]
342 fn hello_round_trip() {
343 let payload = br#"{"versions":[1],"auth_methods":["bearer","anonymous"],"features":3,"client_name":"reddb-rs/0.1"}"#;
344 let h = Hello::from_payload(payload).unwrap();
345 assert_eq!(h.versions, vec![1]);
346 assert_eq!(h.auth_methods, vec!["bearer", "anonymous"]);
347 assert_eq!(h.features, 3);
348 assert_eq!(h.client_name.as_deref(), Some("reddb-rs/0.1"));
349 }
350
351 #[test]
352 fn hello_rejects_empty_methods() {
353 let payload = br#"{"versions":[1],"auth_methods":[]}"#;
354 assert!(Hello::from_payload(payload).is_err());
355 }
356
357 #[test]
358 fn pick_auth_prefers_anonymous_when_server_has_no_auth_store() {
359 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
362 assert_eq!(pick_auth_method(&pref, true), Some("anonymous"));
363 }
364
365 #[test]
366 fn pick_auth_picks_bearer_when_anonymous_blocked() {
367 let pref = vec!["anonymous".to_string(), "bearer".to_string()];
369 assert_eq!(pick_auth_method(&pref, false), Some("bearer"));
370 }
371
372 #[test]
373 fn pick_auth_skips_anonymous_when_server_blocks_it() {
374 let pref = vec!["anonymous".to_string()];
375 assert_eq!(pick_auth_method(&pref, false), None);
376 }
377
378 #[test]
379 fn pick_auth_returns_none_when_nothing_overlaps() {
380 let pref = vec!["kerberos".to_string(), "future-method".to_string()];
381 assert_eq!(pick_auth_method(&pref, true), None);
382 }
383
384 #[test]
385 fn anonymous_validates_only_when_store_disabled() {
386 let outcome = validate_auth_response("anonymous", &[], None);
387 assert!(matches!(outcome, AuthOutcome::Authenticated { .. }));
388 }
389
390 #[test]
391 fn bearer_without_store_refuses() {
392 let outcome = validate_auth_response("bearer", br#"{"token":"x"}"#, None);
393 assert!(matches!(outcome, AuthOutcome::Refused(_)));
394 }
395
396 #[test]
397 fn hello_ack_omits_topology_field_when_caller_passes_none() {
398 let bytes = build_hello_ack(1, "bearer", 0, None);
402 let s = std::str::from_utf8(&bytes).unwrap();
403 assert!(!s.contains("\"topology\""));
404 }
405
406 #[test]
407 fn hello_ack_embeds_topology_field_when_caller_passes_payload() {
408 let topo = reddb_wire::topology::Topology {
413 epoch: 17,
414 primary: reddb_wire::topology::Endpoint {
415 addr: "primary:5050".into(),
416 region: "us-east-1".into(),
417 },
418 replicas: Vec::new(),
419 };
420 let bytes = build_hello_ack(1, "bearer", 0, Some(&topo));
421 let s = std::str::from_utf8(&bytes).unwrap();
422 assert!(s.contains("\"topology\""), "missing topology key in {s}");
423
424 let v: JsonValue = crate::serde_json::from_slice(&bytes).unwrap();
426 let field = v
427 .as_object()
428 .and_then(|o| o.get("topology"))
429 .and_then(|t| t.as_str())
430 .expect("topology key must be present and a string");
431 let decoded = reddb_wire::topology::decode_topology_from_hello_ack(field).expect("decode");
432 assert_eq!(decoded.expect("v1 known"), topo);
433 }
434}