1use crate::db::{RocksDb, cf};
2use crate::types::ServiceId;
3use serde::{Deserialize, Serialize};
4
5pub mod token;
6
7pub fn rate_limit_check(
8 db: &RocksDb,
9 headers: &axum::http::HeaderMap,
10 service_id: ServiceId,
11 window_secs: u64,
12 max_requests: u64,
13) -> Result<(), String> {
14 use std::time::{SystemTime, UNIX_EPOCH};
15 let now = SystemTime::now()
16 .duration_since(UNIX_EPOCH)
17 .unwrap_or_default()
18 .as_secs();
19 let window = now / window_secs;
20 let ip = headers
21 .get(axum::http::header::FORWARDED)
22 .and_then(|h| h.to_str().ok())
23 .unwrap_or("unknown");
24 let key = format!("rl:{service_id}:{ip}");
25
26 let cf = db
27 .cf_handle(cf::OAUTH_RL_CF)
28 .ok_or_else(|| "rl store unavailable".to_string())?;
29
30 let bucket_key = format!("{key}:{window}");
31
32 let current = db
33 .get_cf(&cf, bucket_key.as_bytes())
34 .map_err(|_| "rl read error".to_string())?
35 .and_then(|b| {
36 String::from_utf8(b)
37 .ok()
38 .and_then(|s| s.parse::<u64>().ok())
39 })
40 .unwrap_or(0);
41
42 if current >= max_requests {
43 return Err("rate limit exceeded".into());
44 }
45
46 let new_val = (current + 1).to_string();
47 db.put_cf(&cf, bucket_key.as_bytes(), new_val.as_bytes())
48 .map_err(|_| "rl write error".to_string())?;
49 Ok(())
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct ServiceOAuthPolicy {
55 pub allowed_issuers: Vec<String>,
57 pub required_audiences: Vec<String>,
59 pub public_keys_pem: Vec<String>,
62 pub allowed_scopes: Option<Vec<String>>,
64 pub require_dpop: bool,
66 pub max_access_token_ttl_secs: u64,
68 pub max_assertion_ttl_secs: u64,
70}
71
72impl ServiceOAuthPolicy {
73 pub fn load(service_id: ServiceId, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
74 let cf =
75 db.cf_handle(cf::SERVICES_OAUTH_POLICY_CF)
76 .ok_or(crate::Error::UnknownColumnFamily(
77 cf::SERVICES_OAUTH_POLICY_CF,
78 ))?;
79 if let Some(bytes) = db.get_cf(&cf, service_id.to_be_bytes())? {
80 let policy: Self =
81 serde_json::from_slice(&bytes).map_err(|_| crate::Error::UnknownKeyType)?;
82 Ok(Some(policy))
83 } else {
84 Ok(None)
85 }
86 }
87
88 pub fn save(&self, service_id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
89 let cf =
90 db.cf_handle(cf::SERVICES_OAUTH_POLICY_CF)
91 .ok_or(crate::Error::UnknownColumnFamily(
92 cf::SERVICES_OAUTH_POLICY_CF,
93 ))?;
94 let bytes = serde_json::to_vec(self).map_err(|_| crate::Error::UnknownKeyType)?;
95 db.put_cf(&cf, service_id.to_be_bytes(), bytes)?;
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct AssertionClaims {
103 pub iss: String,
104 pub sub: String,
105 pub aud: Option<String>,
106 pub iat: Option<u64>,
107 pub exp: Option<u64>,
108 pub jti: Option<String>,
109 #[serde(default)]
110 pub scope: Option<String>,
111}
112
113#[derive(Debug, Clone)]
115pub struct VerifiedAssertion {
116 pub issuer: String,
117 pub subject: String,
118 pub audience: Option<String>,
119 pub scopes: Option<Vec<String>>,
120}
121
122pub struct AssertionVerifier<'a> {
124 db: &'a RocksDb,
125}
126
127impl<'a> AssertionVerifier<'a> {
128 pub fn new(db: &'a RocksDb) -> Self {
129 Self { db }
130 }
131
132 pub fn verify(
134 &self,
135 jwt: &str,
136 policy: &ServiceOAuthPolicy,
137 ) -> Result<VerifiedAssertion, VerificationError> {
138 if policy.allowed_issuers.is_empty() {
139 return Err(VerificationError::PolicyViolation(
140 "no issuers configured".into(),
141 ));
142 }
143
144 let header = jsonwebtoken::decode_header(jwt)
146 .map_err(|e| VerificationError::InvalidRequest(format!("invalid jwt header: {e}")))?;
147 let alg = header.alg;
148 if !matches!(
149 alg,
150 jsonwebtoken::Algorithm::RS256 | jsonwebtoken::Algorithm::ES256
151 ) {
152 return Err(VerificationError::InvalidRequest("unsupported alg".into()));
153 }
154
155 let mut validation = jsonwebtoken::Validation::new(alg);
157 validation.set_required_spec_claims(&["iss", "sub", "exp", "iat"]);
158 if !policy.required_audiences.is_empty() {
159 validation.set_audience(&policy.required_audiences);
160 } else {
161 validation.validate_aud = false;
162 }
163
164 if policy.public_keys_pem.is_empty() {
165 return Err(VerificationError::NotConfigured);
166 }
167
168 let mut last_err: Option<jsonwebtoken::errors::Error> = None;
170 let mut decoded: Option<jsonwebtoken::TokenData<AssertionClaims>> = None;
171 for pem in &policy.public_keys_pem {
172 let pem_str = pem.trim();
173 let key_res = pem::parse(pem_str)
174 .ok()
175 .and_then(|p| match (alg, p.tag.as_str()) {
176 (jsonwebtoken::Algorithm::RS256, "PUBLIC KEY") => {
177 Some(jsonwebtoken::DecodingKey::from_rsa_der(&p.contents))
178 }
179 (jsonwebtoken::Algorithm::RS256, "RSA PUBLIC KEY") => {
180 Some(jsonwebtoken::DecodingKey::from_rsa_der(&p.contents))
181 }
182 (jsonwebtoken::Algorithm::ES256, "PUBLIC KEY") => {
183 Some(jsonwebtoken::DecodingKey::from_ec_der(&p.contents))
184 }
185 _ => None,
186 });
187
188 if let Some(k) = key_res {
189 match jsonwebtoken::decode::<AssertionClaims>(jwt, &k, &validation) {
190 Ok(t) => {
191 decoded = Some(t);
192 break;
193 }
194 Err(e) => {
195 last_err = Some(e);
196 }
197 }
198 }
199 }
200
201 let token = decoded.ok_or_else(|| {
202 VerificationError::InvalidGrant(format!(
203 "signature verification failed: {}",
204 last_err
205 .map(|e| e.to_string())
206 .unwrap_or_else(|| "unknown".into())
207 ))
208 })?;
209
210 let claims = token.claims;
211
212 if !policy.allowed_issuers.iter().any(|i| i == &claims.iss) {
214 return Err(VerificationError::InvalidGrant("issuer not allowed".into()));
215 }
216
217 let now = std::time::SystemTime::now()
219 .duration_since(std::time::UNIX_EPOCH)
220 .unwrap_or_default()
221 .as_secs();
222 let skew: u64 = 60;
223 let iat = claims.iat.unwrap_or(0);
224 let exp = claims.exp.unwrap_or(0);
225 if iat > now + skew {
226 return Err(VerificationError::InvalidGrant(
227 "token used before issued (iat)".into(),
228 ));
229 }
230 if exp + skew <= now {
231 return Err(VerificationError::InvalidGrant("token expired".into()));
232 }
233 if exp > iat && (exp - iat) > policy.max_assertion_ttl_secs {
234 return Err(VerificationError::InvalidGrant(
235 "assertion TTL exceeds policy".into(),
236 ));
237 }
238
239 let jti = claims
241 .jti
242 .clone()
243 .ok_or_else(|| VerificationError::InvalidRequest("missing jti".into()))?;
244 if self.is_replayed(&jti, exp)? {
245 return Err(VerificationError::InvalidGrant("assertion replayed".into()));
246 }
247 self.remember_jti(&jti, exp)?;
248
249 let scopes = claims
251 .scope
252 .as_ref()
253 .map(|s| {
254 s.split(' ')
255 .filter(|t| !t.is_empty())
256 .map(|t| t.to_string())
257 .collect::<Vec<_>>()
258 })
259 .filter(|v| !v.is_empty());
260
261 Ok(VerifiedAssertion {
262 issuer: claims.iss,
263 subject: claims.sub,
264 audience: claims.aud.clone(),
265 scopes,
266 })
267 }
268
269 fn is_replayed(&self, jti: &str, now_or_exp: u64) -> Result<bool, VerificationError> {
270 let cf = self
271 .db
272 .cf_handle(cf::OAUTH_JTI_CF)
273 .ok_or_else(|| VerificationError::InvalidRequest("replay store unavailable".into()))?;
274 if let Some(bytes) = self
275 .db
276 .get_cf(&cf, jti.as_bytes())
277 .map_err(|_| VerificationError::InvalidRequest("replay read error".into()))?
278 {
279 if let Some(stored) = std::str::from_utf8(&bytes)
280 .ok()
281 .and_then(|s| s.parse::<u64>().ok())
282 {
283 return Ok(stored >= now_or_exp);
284 }
285 }
286 Ok(false)
287 }
288
289 fn remember_jti(&self, jti: &str, exp: u64) -> Result<(), VerificationError> {
290 let cf = self
291 .db
292 .cf_handle(cf::OAUTH_JTI_CF)
293 .ok_or_else(|| VerificationError::InvalidRequest("replay store unavailable".into()))?;
294 let val = exp.to_string();
295 self.db
296 .put_cf(&cf, jti.as_bytes(), val.as_bytes())
297 .map_err(|_| VerificationError::InvalidRequest("replay write error".into()))?;
298 Ok(())
299 }
300}
301
302#[derive(Debug, thiserror::Error)]
303pub enum VerificationError {
304 #[error("invalid_request: {0}")]
305 InvalidRequest(String),
306 #[error("invalid_grant: {0}")]
307 InvalidGrant(String),
308 #[error("policy_violation: {0}")]
309 PolicyViolation(String),
310 #[error("not_configured")]
311 NotConfigured,
312}