blueprint_auth/oauth/
mod.rs

1use crate::db::{RocksDb, cf};
2use crate::types::ServiceId;
3use serde::{Deserialize, Serialize};
4
5pub mod token;
6
7pub fn rate_limit_check(
8    db: &RocksDb,
9    headers: &axum::http::HeaderMap,
10    service_id: ServiceId,
11    window_secs: u64,
12    max_requests: u64,
13) -> Result<(), String> {
14    use std::time::{SystemTime, UNIX_EPOCH};
15    let now = SystemTime::now()
16        .duration_since(UNIX_EPOCH)
17        .unwrap_or_default()
18        .as_secs();
19    let window = now / window_secs;
20    let ip = headers
21        .get(axum::http::header::FORWARDED)
22        .and_then(|h| h.to_str().ok())
23        .unwrap_or("unknown");
24    let key = format!("rl:{service_id}:{ip}");
25
26    let cf = db
27        .cf_handle(cf::OAUTH_RL_CF)
28        .ok_or_else(|| "rl store unavailable".to_string())?;
29
30    let bucket_key = format!("{key}:{window}");
31
32    let current = db
33        .get_cf(&cf, bucket_key.as_bytes())
34        .map_err(|_| "rl read error".to_string())?
35        .and_then(|b| {
36            String::from_utf8(b)
37                .ok()
38                .and_then(|s| s.parse::<u64>().ok())
39        })
40        .unwrap_or(0);
41
42    if current >= max_requests {
43        return Err("rate limit exceeded".into());
44    }
45
46    let new_val = (current + 1).to_string();
47    db.put_cf(&cf, bucket_key.as_bytes(), new_val.as_bytes())
48        .map_err(|_| "rl write error".to_string())?;
49    Ok(())
50}
51
52/// Per-service OAuth policy
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct ServiceOAuthPolicy {
55    /// Allowed assertion issuers (iss)
56    pub allowed_issuers: Vec<String>,
57    /// Required audience (aud); if empty, defaults to proxy origin policy
58    pub required_audiences: Vec<String>,
59    /// Allowed public keys in PEM (RSA/EC) used to verify assertions for this service
60    /// At least one key must validate the signature
61    pub public_keys_pem: Vec<String>,
62    /// Allowed scopes; if None, scopes are not permitted and will be dropped
63    pub allowed_scopes: Option<Vec<String>>,
64    /// Whether DPoP is required for access token use at the proxy
65    pub require_dpop: bool,
66    /// Maximum access token TTL in seconds
67    pub max_access_token_ttl_secs: u64,
68    /// Maximum assertion age in seconds (exp - iat)
69    pub max_assertion_ttl_secs: u64,
70}
71
72impl ServiceOAuthPolicy {
73    pub fn load(service_id: ServiceId, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
74        let cf =
75            db.cf_handle(cf::SERVICES_OAUTH_POLICY_CF)
76                .ok_or(crate::Error::UnknownColumnFamily(
77                    cf::SERVICES_OAUTH_POLICY_CF,
78                ))?;
79        if let Some(bytes) = db.get_cf(&cf, service_id.to_be_bytes())? {
80            let policy: Self =
81                serde_json::from_slice(&bytes).map_err(|_| crate::Error::UnknownKeyType)?;
82            Ok(Some(policy))
83        } else {
84            Ok(None)
85        }
86    }
87
88    pub fn save(&self, service_id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
89        let cf =
90            db.cf_handle(cf::SERVICES_OAUTH_POLICY_CF)
91                .ok_or(crate::Error::UnknownColumnFamily(
92                    cf::SERVICES_OAUTH_POLICY_CF,
93                ))?;
94        let bytes = serde_json::to_vec(self).map_err(|_| crate::Error::UnknownKeyType)?;
95        db.put_cf(&cf, service_id.to_be_bytes(), bytes)?;
96        Ok(())
97    }
98}
99
100/// OAuth assertion claims (subset)
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct AssertionClaims {
103    pub iss: String,
104    pub sub: String,
105    pub aud: Option<String>,
106    pub iat: Option<u64>,
107    pub exp: Option<u64>,
108    pub jti: Option<String>,
109    #[serde(default)]
110    pub scope: Option<String>,
111}
112
113/// Result of assertion verification
114#[derive(Debug, Clone)]
115pub struct VerifiedAssertion {
116    pub issuer: String,
117    pub subject: String,
118    pub audience: Option<String>,
119    pub scopes: Option<Vec<String>>,
120}
121
122/// Assertion verifier
123pub struct AssertionVerifier<'a> {
124    db: &'a RocksDb,
125}
126
127impl<'a> AssertionVerifier<'a> {
128    pub fn new(db: &'a RocksDb) -> Self {
129        Self { db }
130    }
131
132    /// Verify JWT assertion against policy and return verified data
133    pub fn verify(
134        &self,
135        jwt: &str,
136        policy: &ServiceOAuthPolicy,
137    ) -> Result<VerifiedAssertion, VerificationError> {
138        if policy.allowed_issuers.is_empty() {
139            return Err(VerificationError::PolicyViolation(
140                "no issuers configured".into(),
141            ));
142        }
143
144        // Decode header to inspect alg/kid
145        let header = jsonwebtoken::decode_header(jwt)
146            .map_err(|e| VerificationError::InvalidRequest(format!("invalid jwt header: {e}")))?;
147        let alg = header.alg;
148        if !matches!(
149            alg,
150            jsonwebtoken::Algorithm::RS256 | jsonwebtoken::Algorithm::ES256
151        ) {
152            return Err(VerificationError::InvalidRequest("unsupported alg".into()));
153        }
154
155        // Build validation rules
156        let mut validation = jsonwebtoken::Validation::new(alg);
157        validation.set_required_spec_claims(&["iss", "sub", "exp", "iat"]);
158        if !policy.required_audiences.is_empty() {
159            validation.set_audience(&policy.required_audiences);
160        } else {
161            validation.validate_aud = false;
162        }
163
164        if policy.public_keys_pem.is_empty() {
165            return Err(VerificationError::NotConfigured);
166        }
167
168        // Try each provided public key until signature validates
169        let mut last_err: Option<jsonwebtoken::errors::Error> = None;
170        let mut decoded: Option<jsonwebtoken::TokenData<AssertionClaims>> = None;
171        for pem in &policy.public_keys_pem {
172            let pem_str = pem.trim();
173            let key_res = pem::parse(pem_str)
174                .ok()
175                .and_then(|p| match (alg, p.tag.as_str()) {
176                    (jsonwebtoken::Algorithm::RS256, "PUBLIC KEY") => {
177                        Some(jsonwebtoken::DecodingKey::from_rsa_der(&p.contents))
178                    }
179                    (jsonwebtoken::Algorithm::RS256, "RSA PUBLIC KEY") => {
180                        Some(jsonwebtoken::DecodingKey::from_rsa_der(&p.contents))
181                    }
182                    (jsonwebtoken::Algorithm::ES256, "PUBLIC KEY") => {
183                        Some(jsonwebtoken::DecodingKey::from_ec_der(&p.contents))
184                    }
185                    _ => None,
186                });
187
188            if let Some(k) = key_res {
189                match jsonwebtoken::decode::<AssertionClaims>(jwt, &k, &validation) {
190                    Ok(t) => {
191                        decoded = Some(t);
192                        break;
193                    }
194                    Err(e) => {
195                        last_err = Some(e);
196                    }
197                }
198            }
199        }
200
201        let token = decoded.ok_or_else(|| {
202            VerificationError::InvalidGrant(format!(
203                "signature verification failed: {}",
204                last_err
205                    .map(|e| e.to_string())
206                    .unwrap_or_else(|| "unknown".into())
207            ))
208        })?;
209
210        let claims = token.claims;
211
212        // iss
213        if !policy.allowed_issuers.iter().any(|i| i == &claims.iss) {
214            return Err(VerificationError::InvalidGrant("issuer not allowed".into()));
215        }
216
217        // Validate iat/exp and assertion TTL
218        let now = std::time::SystemTime::now()
219            .duration_since(std::time::UNIX_EPOCH)
220            .unwrap_or_default()
221            .as_secs();
222        let skew: u64 = 60;
223        let iat = claims.iat.unwrap_or(0);
224        let exp = claims.exp.unwrap_or(0);
225        if iat > now + skew {
226            return Err(VerificationError::InvalidGrant(
227                "token used before issued (iat)".into(),
228            ));
229        }
230        if exp + skew <= now {
231            return Err(VerificationError::InvalidGrant("token expired".into()));
232        }
233        if exp > iat && (exp - iat) > policy.max_assertion_ttl_secs {
234            return Err(VerificationError::InvalidGrant(
235                "assertion TTL exceeds policy".into(),
236            ));
237        }
238
239        // Replay protection
240        let jti = claims
241            .jti
242            .clone()
243            .ok_or_else(|| VerificationError::InvalidRequest("missing jti".into()))?;
244        if self.is_replayed(&jti, exp)? {
245            return Err(VerificationError::InvalidGrant("assertion replayed".into()));
246        }
247        self.remember_jti(&jti, exp)?;
248
249        // Scopes
250        let scopes = claims
251            .scope
252            .as_ref()
253            .map(|s| {
254                s.split(' ')
255                    .filter(|t| !t.is_empty())
256                    .map(|t| t.to_string())
257                    .collect::<Vec<_>>()
258            })
259            .filter(|v| !v.is_empty());
260
261        Ok(VerifiedAssertion {
262            issuer: claims.iss,
263            subject: claims.sub,
264            audience: claims.aud.clone(),
265            scopes,
266        })
267    }
268
269    fn is_replayed(&self, jti: &str, now_or_exp: u64) -> Result<bool, VerificationError> {
270        let cf = self
271            .db
272            .cf_handle(cf::OAUTH_JTI_CF)
273            .ok_or_else(|| VerificationError::InvalidRequest("replay store unavailable".into()))?;
274        if let Some(bytes) = self
275            .db
276            .get_cf(&cf, jti.as_bytes())
277            .map_err(|_| VerificationError::InvalidRequest("replay read error".into()))?
278        {
279            if let Some(stored) = std::str::from_utf8(&bytes)
280                .ok()
281                .and_then(|s| s.parse::<u64>().ok())
282            {
283                return Ok(stored >= now_or_exp);
284            }
285        }
286        Ok(false)
287    }
288
289    fn remember_jti(&self, jti: &str, exp: u64) -> Result<(), VerificationError> {
290        let cf = self
291            .db
292            .cf_handle(cf::OAUTH_JTI_CF)
293            .ok_or_else(|| VerificationError::InvalidRequest("replay store unavailable".into()))?;
294        let val = exp.to_string();
295        self.db
296            .put_cf(&cf, jti.as_bytes(), val.as_bytes())
297            .map_err(|_| VerificationError::InvalidRequest("replay write error".into()))?;
298        Ok(())
299    }
300}
301
302#[derive(Debug, thiserror::Error)]
303pub enum VerificationError {
304    #[error("invalid_request: {0}")]
305    InvalidRequest(String),
306    #[error("invalid_grant: {0}")]
307    InvalidGrant(String),
308    #[error("policy_violation: {0}")]
309    PolicyViolation(String),
310    #[error("not_configured")]
311    NotConfigured,
312}