Skip to main content

assay_auth/
oidc.rs

1//! OIDC client — discovery, PKCE, callback, userinfo.
2//!
3//! Plan 12c task 5.1 reference. We wrap the [`openidconnect`] 4 typed
4//! `CoreClient` per upstream so callers don't have to thread its
5//! type-state generics through every handler. Each provider is
6//! discovered once at registration time (`<issuer>/.well-known/openid-configuration`)
7//! and the resulting client is cached behind a slug key.
8//!
9//! The phase-5 surface is intentionally library-only:
10//!
11//! - [`OidcRegistry`] — slug-keyed registry of discovered providers
12//! - [`OidcClient`] — wraps one upstream's discovered metadata + RP creds
13//! - [`UpstreamProvider`] — POD record (slug + issuer + client id/secret +
14//!   scopes); matches the `auth.upstream_providers` row shape that admin
15//!   CRUD will land in a later plan
16//! - [`UpstreamUserInfo`] — verified result of one login round-trip
17//!
18//! Engine boot constructs an empty registry; populated providers come
19//! from a future admin API or seed config (out of phase 5 scope).
20
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use openidconnect::core::{
25    CoreAuthenticationFlow, CoreClient, CoreProviderMetadata, CoreUserInfoClaims,
26};
27use openidconnect::reqwest as oidc_reqwest;
28use openidconnect::{
29    AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointMaybeSet, EndpointNotSet,
30    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
31    RedirectUrl, Scope, SubjectIdentifier, TokenResponse,
32};
33use parking_lot::RwLock;
34use url::Url;
35
36use crate::error::{Error, Result};
37
38/// POD record describing one upstream identity provider. Mirrors the
39/// planned `auth.upstream_providers` table shape (see plan 12d) so the
40/// admin API can `INSERT … RETURNING *` and feed the row directly into
41/// [`OidcRegistry::add`] without a translation step.
42#[derive(Clone, Debug, PartialEq, Eq)]
43pub struct UpstreamProvider {
44    /// Stable slug used in routes (`/login/{slug}`) and as the
45    /// `auth.user_upstream.provider` column value. Lower-snake-case
46    /// matches the rest of the codebase's naming.
47    pub slug: String,
48    /// Issuer URL — the value the discovery doc lives under
49    /// (`<issuer>/.well-known/openid-configuration`).
50    pub issuer: String,
51    /// RP client id registered with the upstream.
52    pub client_id: String,
53    /// RP client secret registered with the upstream. Stored as
54    /// plaintext here because phase 5 has no secret-at-rest envelope yet
55    /// — admin CRUD lands with the encryption story.
56    pub client_secret: String,
57    /// Scopes requested at authorize time. Common set:
58    /// `["openid", "email", "profile"]`. `openid` is added implicitly
59    /// by [`openidconnect`]; we forward the rest unchanged.
60    pub scopes: Vec<String>,
61}
62
63/// Verified userinfo returned by [`OidcClient::complete_login`]. Carries
64/// the canonical fields the rest of the auth stack needs to upsert into
65/// `auth.users` + `auth.user_upstream`. `raw_claims` carries the
66/// id_token's full claim set so callers can pluck custom claims (e.g.
67/// `groups`, `roles`) without a second parse.
68#[derive(Clone, Debug)]
69pub struct UpstreamUserInfo {
70    pub provider: String,
71    pub subject: String,
72    pub email: Option<String>,
73    pub email_verified: bool,
74    pub name: Option<String>,
75    pub picture: Option<String>,
76    pub raw_claims: serde_json::Value,
77}
78
79/// A single discovered upstream — wraps the [`openidconnect`] typed
80/// client and the PoD metadata used to construct it.
81///
82/// The CoreClient generic state after `from_provider_metadata +
83/// set_redirect_uri` is `<EndpointSet, EndpointNotSet, EndpointNotSet,
84/// EndpointNotSet, EndpointMaybeSet, EndpointMaybeSet>` — auth URL set
85/// (so `authorize_url` works), token + userinfo MaybeSet (we error at
86/// runtime if the upstream's discovery doc is missing one).
87pub struct OidcClient {
88    inner: CoreClient<
89        EndpointSet,
90        EndpointNotSet,
91        EndpointNotSet,
92        EndpointNotSet,
93        EndpointMaybeSet,
94        EndpointMaybeSet,
95    >,
96    /// Original PoD record for round-trip / introspection
97    /// (e.g. admin "what's configured" pages).
98    provider: UpstreamProvider,
99    /// Owned redirect URL — `set_redirect_uri` consumed it on the
100    /// builder, but operators sometimes want it back without re-parsing.
101    redirect_uri: RedirectUrl,
102}
103
104impl OidcClient {
105    /// Borrow the original PoD record.
106    pub fn provider(&self) -> &UpstreamProvider {
107        &self.provider
108    }
109
110    /// Borrow the configured redirect URI.
111    pub fn redirect_uri(&self) -> &RedirectUrl {
112        &self.redirect_uri
113    }
114
115    /// Step 1 of the authorization-code-+-PKCE flow. Generates a PKCE
116    /// pair, asks the [`openidconnect`] client for the redirect URL,
117    /// returns the URL alongside the verifier + nonce for round-trip
118    /// (callers persist them, typically in the session).
119    ///
120    /// `state` lets callers pin a known CSRF value (e.g. the session id)
121    /// rather than the library-generated random one — useful when the
122    /// callback handler uses `state` to look the in-progress login up.
123    /// Pass [`CsrfToken::new_random`] via `CsrfToken::new(...)` if you
124    /// don't have one already.
125    pub fn start_login(
126        &self,
127        state: CsrfToken,
128    ) -> StartedLogin {
129        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
130        let mut request = self.inner.authorize_url(
131            CoreAuthenticationFlow::AuthorizationCode,
132            move || state,
133            Nonce::new_random,
134        );
135        for scope in &self.provider.scopes {
136            // `openid` scope is added by openidconnect when
137            // `use_openid_scope` is true (default after
138            // `from_provider_metadata`); skip a duplicate so the URL
139            // stays clean.
140            if scope == "openid" {
141                continue;
142            }
143            request = request.add_scope(Scope::new(scope.clone()));
144        }
145        let (url, csrf_token, nonce) = request.set_pkce_challenge(pkce_challenge).url();
146        StartedLogin {
147            url,
148            csrf_token,
149            nonce,
150            pkce_verifier,
151        }
152    }
153
154    /// Step 2 — exchange the upstream's `code` for tokens, validate the
155    /// id_token against the cached JWKS + nonce, and (when the upstream
156    /// publishes a userinfo endpoint) supplement the claims with a
157    /// userinfo call.
158    ///
159    /// `pkce_verifier` and `nonce` must be the values returned from
160    /// [`OidcClient::start_login`] for the same login — callers persist
161    /// them server-side keyed by `state`.
162    pub async fn complete_login(
163        &self,
164        code: String,
165        pkce_verifier: PkceCodeVerifier,
166        nonce: Nonce,
167    ) -> Result<UpstreamUserInfo> {
168        let http = build_oidc_http_client()?;
169        let token_response = self
170            .inner
171            .exchange_code(AuthorizationCode::new(code))
172            .map_err(|e| Error::Oidc(format!("exchange_code config: {e}")))?
173            .set_pkce_verifier(pkce_verifier)
174            .request_async(&http)
175            .await
176            .map_err(|e| Error::Oidc(format!("token exchange: {e}")))?;
177
178        let id_token = token_response
179            .id_token()
180            .ok_or_else(|| Error::Oidc("upstream returned no id_token".to_string()))?;
181        let id_token_verifier = self.inner.id_token_verifier();
182        let claims = id_token
183            .claims(&id_token_verifier, &nonce)
184            .map_err(|e| Error::Oidc(format!("id_token verify: {e}")))?;
185
186        let subject = claims.subject().to_string();
187        let mut email = claims.email().map(|e| e.to_string());
188        let mut email_verified = claims.email_verified().unwrap_or(false);
189        let mut name = claims
190            .name()
191            .and_then(|map| map.get(None))
192            .map(|n| n.to_string());
193        let mut picture = claims
194            .picture()
195            .and_then(|map| map.get(None))
196            .map(|u| u.to_string());
197
198        // Best-effort userinfo fetch. Some upstreams omit email/name from
199        // the id_token and only expose them via /userinfo. If the
200        // upstream doesn't publish a userinfo endpoint or the call fails,
201        // we keep what the id_token gave us — login still works, the
202        // missing fields just show up as None.
203        let mut raw_claims = serde_json::to_value(claims)
204            .unwrap_or_else(|_| serde_json::json!({"sub": subject}));
205        if let Ok(req) = self
206            .inner
207            .user_info(token_response.access_token().clone(), Some(SubjectIdentifier::new(subject.clone())))
208            && let Ok(userinfo) = req.request_async(&http).await
209        {
210            let user_claims: CoreUserInfoClaims = userinfo;
211            if email.is_none() {
212                email = user_claims.email().map(|e| e.to_string());
213            }
214            if !email_verified {
215                email_verified = user_claims.email_verified().unwrap_or(email_verified);
216            }
217            if name.is_none() {
218                name = user_claims
219                    .name()
220                    .and_then(|map| map.get(None))
221                    .map(|n| n.to_string());
222            }
223            if picture.is_none() {
224                picture = user_claims
225                    .picture()
226                    .and_then(|map| map.get(None))
227                    .map(|u| u.to_string());
228            }
229            // Merge userinfo into raw_claims so downstream code that
230            // wants e.g. `groups` from userinfo can pluck it out.
231            if let Ok(userinfo_value) = serde_json::to_value(&user_claims) {
232                merge_json(&mut raw_claims, userinfo_value);
233            }
234        }
235
236        Ok(UpstreamUserInfo {
237            provider: self.provider.slug.clone(),
238            subject,
239            email,
240            email_verified,
241            name,
242            picture,
243            raw_claims,
244        })
245    }
246}
247
248/// Result of [`OidcClient::start_login`]. The HTTP layer redirects the
249/// user to `url` and persists the rest server-side (typically in the
250/// session payload, keyed by `csrf_token` so the callback can look the
251/// in-progress login up via the `state` query param).
252pub struct StartedLogin {
253    pub url: Url,
254    pub csrf_token: CsrfToken,
255    pub nonce: Nonce,
256    pub pkce_verifier: PkceCodeVerifier,
257}
258
259/// Slug-keyed registry of discovered upstreams.
260///
261/// Cheap to clone — interior is `Arc<RwLock<…>>` so HTTP handlers can
262/// share a single registry while admin endpoints add / remove providers
263/// at runtime.
264#[derive(Clone, Default)]
265pub struct OidcRegistry {
266    inner: Arc<RwLock<HashMap<String, Arc<OidcClient>>>>,
267}
268
269impl OidcRegistry {
270    /// Empty registry — engine boot creates one of these and feeds it to
271    /// [`crate::ctx::AuthCtx::with_oidc`]. Providers are added later via
272    /// admin CRUD or seed config.
273    pub fn new() -> Self {
274        Self::default()
275    }
276
277    /// Discover and cache one upstream. Performs a network round-trip to
278    /// `<issuer>/.well-known/openid-configuration` plus the JWKS fetch,
279    /// so call this from boot or from an admin endpoint, not from a
280    /// per-request handler.
281    ///
282    /// `redirect_uri` is the absolute URL the upstream redirects back
283    /// to after login (typically `<public_url>/login/<slug>/callback`).
284    pub async fn add(&self, provider: UpstreamProvider, redirect_uri: Url) -> Result<()> {
285        let issuer = IssuerUrl::new(provider.issuer.clone())
286            .map_err(|e| Error::Oidc(format!("issuer url {}: {e}", provider.issuer)))?;
287        let http = build_oidc_http_client()?;
288        let metadata = CoreProviderMetadata::discover_async(issuer, &http)
289            .await
290            .map_err(|e| Error::Oidc(format!("discover {}: {e}", provider.slug)))?;
291        let redirect = RedirectUrl::new(redirect_uri.to_string())
292            .map_err(|e| Error::Oidc(format!("redirect_uri {redirect_uri}: {e}")))?;
293        let client_secret = if provider.client_secret.is_empty() {
294            None
295        } else {
296            Some(ClientSecret::new(provider.client_secret.clone()))
297        };
298        let inner = CoreClient::from_provider_metadata(
299            metadata,
300            ClientId::new(provider.client_id.clone()),
301            client_secret,
302        )
303        .set_redirect_uri(redirect.clone());
304        let client = OidcClient {
305            inner,
306            provider: provider.clone(),
307            redirect_uri: redirect,
308        };
309        self.inner
310            .write()
311            .insert(provider.slug.clone(), Arc::new(client));
312        Ok(())
313    }
314
315    /// Look up a discovered upstream by slug. Returns the same `Arc`
316    /// stored at registration time so callers can hold the client for
317    /// the duration of a long-running flow.
318    pub fn client(&self, slug: &str) -> Option<Arc<OidcClient>> {
319        self.inner.read().get(slug).cloned()
320    }
321
322    /// List the slugs of every registered provider (for admin /
323    /// debugging UIs).
324    pub fn slugs(&self) -> Vec<String> {
325        self.inner.read().keys().cloned().collect()
326    }
327
328    /// Remove a provider from the registry. Returns `true` if a row was
329    /// dropped. Pending in-flight logins keep working because they hold
330    /// an `Arc<OidcClient>` from before the removal.
331    pub fn remove(&self, slug: &str) -> bool {
332        self.inner.write().remove(slug).is_some()
333    }
334
335    /// Number of registered providers — handy for tests + metrics.
336    pub fn len(&self) -> usize {
337        self.inner.read().len()
338    }
339
340    /// Whether the registry is empty.
341    pub fn is_empty(&self) -> bool {
342        self.inner.read().is_empty()
343    }
344}
345
346/// Build the reqwest client `openidconnect` uses for discovery, token
347/// exchange, JWKS fetches, and userinfo. We disable redirects on the
348/// security advice in the [`openidconnect`] crate docs (SSRF mitigation)
349/// and use rustls — matches the rest of assay's HTTP stack.
350fn build_oidc_http_client() -> Result<oidc_reqwest::Client> {
351    oidc_reqwest::ClientBuilder::new()
352        .redirect(oidc_reqwest::redirect::Policy::none())
353        .build()
354        .map_err(|e| Error::Oidc(format!("build oidc http client: {e}")))
355}
356
357/// Recursive merge of two JSON values — used so userinfo claims top up
358/// the id_token claims without overwriting them. Object fields merge
359/// recursively; everything else is replaced.
360fn merge_json(target: &mut serde_json::Value, src: serde_json::Value) {
361    match (target, src) {
362        (serde_json::Value::Object(a), serde_json::Value::Object(b)) => {
363            for (k, v) in b {
364                merge_json(a.entry(k).or_insert(serde_json::Value::Null), v);
365            }
366        }
367        (slot, src) => {
368            // Don't clobber a non-null target with a null source — the
369            // id_token's value wins when userinfo doesn't add anything.
370            if !src.is_null() {
371                *slot = src;
372            }
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn registry_starts_empty() {
383        let reg = OidcRegistry::new();
384        assert!(reg.is_empty());
385        assert_eq!(reg.len(), 0);
386        assert!(reg.client("google").is_none());
387        assert!(reg.slugs().is_empty());
388    }
389
390    #[test]
391    fn merge_json_merges_objects_and_keeps_existing_on_null() {
392        let mut a = serde_json::json!({"email": "a@x", "groups": ["a"]});
393        let b = serde_json::json!({"email": serde_json::Value::Null, "name": "Alice"});
394        merge_json(&mut a, b);
395        assert_eq!(a["email"], "a@x");
396        assert_eq!(a["name"], "Alice");
397        assert_eq!(a["groups"], serde_json::json!(["a"]));
398    }
399
400    #[test]
401    fn upstream_provider_record_is_clonable() {
402        let p = UpstreamProvider {
403            slug: "google".to_string(),
404            issuer: "https://accounts.google.com".to_string(),
405            client_id: "client".to_string(),
406            client_secret: "secret".to_string(),
407            scopes: vec!["openid".to_string(), "email".to_string()],
408        };
409        let dup = p.clone();
410        assert_eq!(p, dup);
411    }
412
413    /// Discovery against an unreachable URL should fail with `Error::Oidc`,
414    /// not panic. We don't network out from unit tests; this just exercises
415    /// the error path.
416    #[tokio::test]
417    async fn discover_unreachable_issuer_returns_oidc_error() {
418        let reg = OidcRegistry::new();
419        let provider = UpstreamProvider {
420            slug: "ghost".to_string(),
421            issuer: "http://127.0.0.1:1/oidc".to_string(),
422            client_id: "client".to_string(),
423            client_secret: "secret".to_string(),
424            scopes: vec!["openid".to_string()],
425        };
426        let redirect = Url::parse("https://example.com/login/ghost/callback").unwrap();
427        let result = reg.add(provider, redirect).await;
428        assert!(matches!(result, Err(Error::Oidc(_))));
429    }
430}