1use super::super::{GetError, OidcConfig, Store, User};
5
6use drawbridge_type::{UserContext, UserRecord};
7
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10
11use anyhow::{anyhow, bail, Context};
12use axum::extract::rejection::{TypedHeaderRejection, TypedHeaderRejectionReason};
13use axum::extract::{Extension, FromRequest, RequestParts};
14use axum::headers::authorization::Bearer;
15use axum::headers::Authorization;
16use axum::http::StatusCode;
17use axum::response::{IntoResponse, Response};
18use axum::{async_trait, TypedHeader};
19use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
20use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
21use openidconnect::core::CoreProviderMetadata;
22use openidconnect::ureq::http_client;
23use openidconnect::IssuerUrl;
24use serde::{Deserialize, Deserializer};
25use tracing::{error, info, trace, warn};
26
27pub struct Verifier {
28 keyset: HashMap<String, DecodingKey>,
29 validator: Validation,
30}
31
32impl std::fmt::Debug for Verifier {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("Verifier")
35 .field("validator", &self.validator)
36 .finish()
37 }
38}
39
40#[derive(Clone, Debug, Deserialize)]
41struct VerifiedInfo {
42 #[serde(rename = "sub")]
43 subject: String,
44 #[serde(rename = "scope", deserialize_with = "deserialize_scopes")]
45 scopes: HashSet<String>,
46}
47
48#[allow(single_use_lifetimes)]
49fn deserialize_scopes<'de, D>(deserializer: D) -> Result<HashSet<String>, D::Error>
50where
51 D: Deserializer<'de>,
52{
53 let s: &str = Deserialize::deserialize(deserializer)?;
54 Ok(HashSet::from_iter(s.split(' ').map(|s| s.to_owned())))
55}
56
57impl Verifier {
58 pub fn new(config: OidcConfig) -> Result<Self, anyhow::Error> {
59 let mut validator = Validation::new(Algorithm::RS256);
60 validator.set_audience(&[config.audience]);
61 validator.set_issuer(&[config.issuer.as_str()]);
62 validator.set_required_spec_claims(&["exp", "iat", "scope", "aud"]);
63 validator.validate_exp = true;
64
65 let oidc_md =
66 CoreProviderMetadata::discover(&IssuerUrl::from_url(config.issuer), http_client)
67 .context("failed to discover provider metadata")?;
68 let jwks = oidc_md.jwks();
69 let jwks = serde_json::to_string(&jwks).context("failed to serialize jwks")?;
70 let keyset: JwkSet = serde_json::from_str(&jwks).context("failed to parse jwks")?;
71 let keyset = keyset
72 .keys
73 .into_iter()
74 .map(|jwk| {
75 let kid = jwk.common.key_id.ok_or_else(|| anyhow!("missing kid"))?;
76 let key = match jwk.algorithm {
77 AlgorithmParameters::RSA(ref rsa) => {
78 DecodingKey::from_rsa_components(&rsa.n, &rsa.e)
79 .context("Error creating DecodingKey")
80 }
81 _ => bail!("Unsupported algorithm encountered: {:?}", jwk.algorithm),
82 }?;
83 Ok((kid, key))
84 })
85 .collect::<Result<HashMap<String, DecodingKey>, anyhow::Error>>()
86 .context("failed to parse jwks")?;
87
88 Ok(Self { keyset, validator })
89 }
90
91 fn verify_token(&self, token: &str) -> Result<VerifiedInfo, anyhow::Error> {
92 let header = decode_header(token).context("Error decoding header")?;
93 let kid = match header.kid {
94 Some(k) => k,
95 None => bail!("Token doesn't have a `kid` header field"),
96 };
97 let key = self
98 .keyset
99 .get(&kid)
100 .ok_or_else(|| anyhow!("No key found for kid: {}", kid))?;
101 let decoded_token =
102 decode::<VerifiedInfo>(token, key, &self.validator).context("Error decoding token")?;
103 Ok(decoded_token.claims)
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
108pub enum ScopeContext {
109 User,
110 Repository,
111 Tag,
112}
113
114impl std::fmt::Display for ScopeContext {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 match self {
117 ScopeContext::User => write!(f, "drawbridge_users"),
118 ScopeContext::Repository => write!(f, "drawbridge_repositories"),
119 ScopeContext::Tag => write!(f, "drawbridge_tags"),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
125pub enum ScopeLevel {
126 Read,
127 Write,
128}
129
130impl std::fmt::Display for ScopeLevel {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 ScopeLevel::Read => write!(f, "read"),
134 ScopeLevel::Write => write!(f, "write"),
135 }
136 }
137}
138
139impl ScopeLevel {
140 fn sufficient_levels(&self) -> &[&str] {
141 match self {
142 ScopeLevel::Read => &["read", "manage"],
143 ScopeLevel::Write => &["write", "manage"],
144 }
145 }
146}
147
148#[repr(transparent)]
149#[derive(Clone, Debug)]
150pub struct Claims(VerifiedInfo);
151
152impl Claims {
153 pub fn subject(&self) -> &str {
154 &self.0.subject
155 }
156
157 fn check_scope(
158 &self,
159 context: ScopeContext,
160 level: ScopeLevel,
161 ) -> Result<(), (StatusCode, String)> {
162 for level in level.sufficient_levels() {
163 let scope = format!("{level}:{context}");
164 if self.0.scopes.contains(&scope) {
165 return Ok(());
166 }
167 }
168 Err((
169 StatusCode::UNAUTHORIZED,
170 format!("Token is missing a scope for level {level}, context {context}"),
171 ))
172 }
173
174 #[allow(clippy::result_large_err)]
176 pub fn assert_scope(
177 &self,
178 context: ScopeContext,
179 level: ScopeLevel,
180 ) -> Result<(), impl IntoResponse> {
181 self.check_scope(context, level)
182 .map_err(|e| e.into_response())
183 }
184
185 pub async fn assert_user<'a>(
188 &self,
189 store: &'a Store,
190 cx: &UserContext,
191 scope_context: ScopeContext,
192 scope_level: ScopeLevel,
193 ) -> Result<User<'a>, impl IntoResponse> {
194 let subj = self.subject();
195 let oidc_record = UserRecord {
196 subject: subj.to_string(),
197 };
198
199 let user = store.user(cx);
200 let owner_record: UserRecord = user.get_content_json().await.map_err(|e|{
201 match e {
202 GetError::NotFound => (StatusCode::UNAUTHORIZED, format!("User `{cx}` not found")).into_response(),
203 _ => {
204 warn!(target: "app::auth::oidc", ?oidc_record, error = ?e, "failed to get user by OpenID Connect subject");
205e.into_response()
206 },
207 }})?;
208
209 if oidc_record != owner_record {
210 warn!(target: "app::auth::oidc", ?oidc_record, user = ?cx, ?owner_record, "User access not authorized");
211 return Err((
212 StatusCode::UNAUTHORIZED,
213 format!("You are logged in as `{subj}`, and not authorized for user `{cx}`"),
214 )
215 .into_response());
216 }
217
218 self.check_scope(scope_context, scope_level)
219 .map_err(|e| e.into_response())?;
220
221 Ok(user)
222 }
223}
224
225#[async_trait]
226impl<B: Send> FromRequest<B> for Claims {
227 type Rejection = Response;
228
229 async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
230 let TypedHeader(Authorization::<Bearer>(token)) =
231 req.extract()
232 .await
233 .map_err(|e: TypedHeaderRejection| match e.reason() {
234 TypedHeaderRejectionReason::Missing => {
235 (StatusCode::UNAUTHORIZED, "Bearer token header missing").into_response()
236 }
237 _ => e.into_response(),
238 })?;
239 warn!(target: "app::auth::oidc", ?token, "got token");
240
241 let Extension(verifier) = req
242 .extract::<Extension<Arc<Verifier>>>()
243 .await
244 .map_err(|e| {
245 error!(target: "app::auth::oidc", "OpenID Connect verifier extension missing");
246 e.into_response()
247 })?;
248
249 trace!(target: "app:auth::oidc", "verifying token");
250
251 let claims = verifier
252 .verify_token(token.token())
253 .map_err(|e| {
254 error!(target: "app::auth::oidc", error = ?e, "failed to verify token");
255 (StatusCode::UNAUTHORIZED, "Invalid token provided").into_response()
256 })
257 .map(Self);
258 info!(target: "app::auth::oidc", ?claims, "verified token");
259 claims
260 }
261}