drawbridge_server/auth/
oidc.rs

1// SPDX-FileCopyrightText: 2022 Profian Inc. <opensource@profian.com>
2// SPDX-License-Identifier: Apache-2.0
3
4use super::super::{GetError, OidcConfig, Store, User};
5
6use drawbridge_type::{UserContext, UserRecord};
7
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10
11use anyhow::{anyhow, bail, Context};
12use axum::extract::rejection::{TypedHeaderRejection, TypedHeaderRejectionReason};
13use axum::extract::{Extension, FromRequest, RequestParts};
14use axum::headers::authorization::Bearer;
15use axum::headers::Authorization;
16use axum::http::StatusCode;
17use axum::response::{IntoResponse, Response};
18use axum::{async_trait, TypedHeader};
19use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
20use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
21use openidconnect::core::CoreProviderMetadata;
22use openidconnect::ureq::http_client;
23use openidconnect::IssuerUrl;
24use serde::{Deserialize, Deserializer};
25use tracing::{error, info, trace, warn};
26
27pub struct Verifier {
28    keyset: HashMap<String, DecodingKey>,
29    validator: Validation,
30}
31
32impl std::fmt::Debug for Verifier {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("Verifier")
35            .field("validator", &self.validator)
36            .finish()
37    }
38}
39
40#[derive(Clone, Debug, Deserialize)]
41struct VerifiedInfo {
42    #[serde(rename = "sub")]
43    subject: String,
44    #[serde(rename = "scope", deserialize_with = "deserialize_scopes")]
45    scopes: HashSet<String>,
46}
47
48#[allow(single_use_lifetimes)]
49fn deserialize_scopes<'de, D>(deserializer: D) -> Result<HashSet<String>, D::Error>
50where
51    D: Deserializer<'de>,
52{
53    let s: &str = Deserialize::deserialize(deserializer)?;
54    Ok(HashSet::from_iter(s.split(' ').map(|s| s.to_owned())))
55}
56
57impl Verifier {
58    pub fn new(config: OidcConfig) -> Result<Self, anyhow::Error> {
59        let mut validator = Validation::new(Algorithm::RS256);
60        validator.set_audience(&[config.audience]);
61        validator.set_issuer(&[config.issuer.as_str()]);
62        validator.set_required_spec_claims(&["exp", "iat", "scope", "aud"]);
63        validator.validate_exp = true;
64
65        let oidc_md =
66            CoreProviderMetadata::discover(&IssuerUrl::from_url(config.issuer), http_client)
67                .context("failed to discover provider metadata")?;
68        let jwks = oidc_md.jwks();
69        let jwks = serde_json::to_string(&jwks).context("failed to serialize jwks")?;
70        let keyset: JwkSet = serde_json::from_str(&jwks).context("failed to parse jwks")?;
71        let keyset = keyset
72            .keys
73            .into_iter()
74            .map(|jwk| {
75                let kid = jwk.common.key_id.ok_or_else(|| anyhow!("missing kid"))?;
76                let key = match jwk.algorithm {
77                    AlgorithmParameters::RSA(ref rsa) => {
78                        DecodingKey::from_rsa_components(&rsa.n, &rsa.e)
79                            .context("Error creating DecodingKey")
80                    }
81                    _ => bail!("Unsupported algorithm encountered: {:?}", jwk.algorithm),
82                }?;
83                Ok((kid, key))
84            })
85            .collect::<Result<HashMap<String, DecodingKey>, anyhow::Error>>()
86            .context("failed to parse jwks")?;
87
88        Ok(Self { keyset, validator })
89    }
90
91    fn verify_token(&self, token: &str) -> Result<VerifiedInfo, anyhow::Error> {
92        let header = decode_header(token).context("Error decoding header")?;
93        let kid = match header.kid {
94            Some(k) => k,
95            None => bail!("Token doesn't have a `kid` header field"),
96        };
97        let key = self
98            .keyset
99            .get(&kid)
100            .ok_or_else(|| anyhow!("No key found for kid: {}", kid))?;
101        let decoded_token =
102            decode::<VerifiedInfo>(token, key, &self.validator).context("Error decoding token")?;
103        Ok(decoded_token.claims)
104    }
105}
106
107#[derive(Debug, Clone, Copy)]
108pub enum ScopeContext {
109    User,
110    Repository,
111    Tag,
112}
113
114impl std::fmt::Display for ScopeContext {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            ScopeContext::User => write!(f, "drawbridge_users"),
118            ScopeContext::Repository => write!(f, "drawbridge_repositories"),
119            ScopeContext::Tag => write!(f, "drawbridge_tags"),
120        }
121    }
122}
123
124#[derive(Debug, Clone, Copy)]
125pub enum ScopeLevel {
126    Read,
127    Write,
128}
129
130impl std::fmt::Display for ScopeLevel {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        match self {
133            ScopeLevel::Read => write!(f, "read"),
134            ScopeLevel::Write => write!(f, "write"),
135        }
136    }
137}
138
139impl ScopeLevel {
140    fn sufficient_levels(&self) -> &[&str] {
141        match self {
142            ScopeLevel::Read => &["read", "manage"],
143            ScopeLevel::Write => &["write", "manage"],
144        }
145    }
146}
147
148#[repr(transparent)]
149#[derive(Clone, Debug)]
150pub struct Claims(VerifiedInfo);
151
152impl Claims {
153    pub fn subject(&self) -> &str {
154        &self.0.subject
155    }
156
157    fn check_scope(
158        &self,
159        context: ScopeContext,
160        level: ScopeLevel,
161    ) -> Result<(), (StatusCode, String)> {
162        for level in level.sufficient_levels() {
163            let scope = format!("{level}:{context}");
164            if self.0.scopes.contains(&scope) {
165                return Ok(());
166            }
167        }
168        Err((
169            StatusCode::UNAUTHORIZED,
170            format!("Token is missing a scope for level {level}, context {context}"),
171        ))
172    }
173
174    /// Asserts that the token has a scope that satisfies the given context and level.
175    #[allow(clippy::result_large_err)]
176    pub fn assert_scope(
177        &self,
178        context: ScopeContext,
179        level: ScopeLevel,
180    ) -> Result<(), impl IntoResponse> {
181        self.check_scope(context, level)
182            .map_err(|e| e.into_response())
183    }
184
185    /// Assert that the client is the user identified by `cx`, and that the token has a scope that
186    /// satisfies the given context and level.
187    pub async fn assert_user<'a>(
188        &self,
189        store: &'a Store,
190        cx: &UserContext,
191        scope_context: ScopeContext,
192        scope_level: ScopeLevel,
193    ) -> Result<User<'a>, impl IntoResponse> {
194        let subj = self.subject();
195        let oidc_record = UserRecord {
196            subject: subj.to_string(),
197        };
198
199        let user = store.user(cx);
200        let owner_record: UserRecord = user.get_content_json().await.map_err(|e|{
201            match e {
202                GetError::NotFound => (StatusCode::UNAUTHORIZED, format!("User `{cx}` not found")).into_response(),
203                _ => {
204            warn!(target: "app::auth::oidc", ?oidc_record, error = ?e, "failed to get user by OpenID Connect subject");
205e.into_response()
206                },
207            }})?;
208
209        if oidc_record != owner_record {
210            warn!(target: "app::auth::oidc", ?oidc_record, user = ?cx, ?owner_record, "User access not authorized");
211            return Err((
212                StatusCode::UNAUTHORIZED,
213                format!("You are logged in as `{subj}`, and not authorized for user `{cx}`"),
214            )
215                .into_response());
216        }
217
218        self.check_scope(scope_context, scope_level)
219            .map_err(|e| e.into_response())?;
220
221        Ok(user)
222    }
223}
224
225#[async_trait]
226impl<B: Send> FromRequest<B> for Claims {
227    type Rejection = Response;
228
229    async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
230        let TypedHeader(Authorization::<Bearer>(token)) =
231            req.extract()
232                .await
233                .map_err(|e: TypedHeaderRejection| match e.reason() {
234                    TypedHeaderRejectionReason::Missing => {
235                        (StatusCode::UNAUTHORIZED, "Bearer token header missing").into_response()
236                    }
237                    _ => e.into_response(),
238                })?;
239        warn!(target: "app::auth::oidc", ?token, "got token");
240
241        let Extension(verifier) = req
242            .extract::<Extension<Arc<Verifier>>>()
243            .await
244            .map_err(|e| {
245                error!(target: "app::auth::oidc", "OpenID Connect verifier extension missing");
246                e.into_response()
247            })?;
248
249        trace!(target: "app:auth::oidc", "verifying token");
250
251        let claims = verifier
252            .verify_token(token.token())
253            .map_err(|e| {
254                error!(target: "app::auth::oidc", error = ?e, "failed to verify token");
255                (StatusCode::UNAUTHORIZED, "Invalid token provided").into_response()
256            })
257            .map(Self);
258        info!(target: "app::auth::oidc", ?claims, "verified token");
259        claims
260    }
261}