oauth2_broker/flows/auth_code_pkce/
session.rs

1// crates.io
2use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::{Rng, distr::Alphanumeric};
4use sha2::{Digest, Sha256};
5// self
6use crate::{
7	_prelude::*,
8	auth::{PrincipalId, ScopeSet, TenantId},
9	flows::common,
10	provider::ProviderDescriptor,
11};
12
13const STATE_LEN: usize = 32;
14const PKCE_VERIFIER_LEN: usize = 64;
15
16/// Supported PKCE challenge methods surfaced via [`AuthorizationSession`].
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum PkceCodeChallengeMethod {
19	/// SHA-256 based PKCE (RFC 7636 S256).
20	S256,
21}
22impl PkceCodeChallengeMethod {
23	/// Returns the RFC 7636 identifier for the challenge method.
24	pub fn as_str(self) -> &'static str {
25		match self {
26			PkceCodeChallengeMethod::S256 => "S256",
27		}
28	}
29}
30
31/// Authorization Code + PKCE handshake metadata returned by [`Broker::start_authorization`].
32#[derive(Clone)]
33pub struct AuthorizationSession {
34	/// Tenant identifier tied to the session.
35	pub tenant: TenantId,
36	/// Principal identifier tied to the session.
37	pub principal: PrincipalId,
38	/// Requested scope set (prior to any provider overrides during exchange).
39	pub scope: ScopeSet,
40	/// Opaque state value that must round-trip via the redirect handler.
41	pub state: String,
42	/// Redirect URI supplied when constructing the authorize URL.
43	pub redirect_uri: Url,
44	/// Fully-formed HTTPS authorize URL that callers should send end-users to.
45	pub authorize_url: Url,
46	pkce: PkcePair,
47}
48impl AuthorizationSession {
49	pub(super) fn new(
50		tenant: TenantId,
51		principal: PrincipalId,
52		scope: ScopeSet,
53		redirect_uri: Url,
54		authorize_url: Url,
55		state: String,
56		pkce: PkcePair,
57	) -> Self {
58		Self { tenant, principal, scope, state, redirect_uri, authorize_url, pkce }
59	}
60
61	/// PKCE code challenge derived from the secret verifier.
62	pub fn code_challenge(&self) -> &str {
63		&self.pkce.challenge
64	}
65
66	/// PKCE challenge method (currently always `S256`).
67	pub fn code_challenge_method(&self) -> PkceCodeChallengeMethod {
68		self.pkce.method
69	}
70
71	/// Validates the returned `state` parameter after the authorization redirect.
72	pub fn validate_state(&self, returned_state: &str) -> Result<()> {
73		if returned_state == self.state {
74			Ok(())
75		} else {
76			Err(Error::InvalidGrant { reason: "Authorization state mismatch.".into() })
77		}
78	}
79
80	pub(super) fn into_exchange_parts(self) -> (TenantId, PrincipalId, ScopeSet, Url, PkcePair) {
81		let AuthorizationSession { tenant, principal, scope, redirect_uri, pkce, .. } = self;
82
83		(tenant, principal, scope, redirect_uri, pkce)
84	}
85}
86impl Debug for AuthorizationSession {
87	fn fmt(&self, f: &mut Formatter) -> FmtResult {
88		f.debug_struct("AuthorizationSession")
89			.field("tenant", &self.tenant)
90			.field("principal", &self.principal)
91			.field("scope", &self.scope)
92			.field("state", &self.state)
93			.field("redirect_uri", &self.redirect_uri)
94			.field("authorize_url", &self.authorize_url)
95			.field("code_challenge", &self.pkce.challenge)
96			.field("code_challenge_method", &self.pkce.method)
97			.finish()
98	}
99}
100
101#[derive(Clone)]
102pub(super) struct PkcePair {
103	pub(super) verifier: String,
104	challenge: String,
105	method: PkceCodeChallengeMethod,
106}
107impl PkcePair {
108	pub(super) fn generate() -> Self {
109		let verifier = random_string(PKCE_VERIFIER_LEN);
110		let challenge = compute_pkce_challenge(&verifier);
111
112		Self { verifier, challenge, method: PkceCodeChallengeMethod::S256 }
113	}
114}
115
116pub(super) fn build_session(
117	descriptor: &ProviderDescriptor,
118	client_id: &str,
119	tenant: TenantId,
120	principal: PrincipalId,
121	scope: ScopeSet,
122	redirect_uri: Url,
123) -> AuthorizationSession {
124	let state = random_string(STATE_LEN);
125	let pkce = PkcePair::generate();
126	let authorize_url =
127		build_authorize_url(descriptor, client_id, &redirect_uri, &scope, &state, &pkce);
128
129	AuthorizationSession::new(tenant, principal, scope, redirect_uri, authorize_url, state, pkce)
130}
131
132fn build_authorize_url(
133	descriptor: &ProviderDescriptor,
134	client_id: &str,
135	redirect_uri: &Url,
136	scope: &ScopeSet,
137	state: &str,
138	pkce: &PkcePair,
139) -> Url {
140	let mut url = descriptor.endpoints.authorization.clone();
141	let mut pairs = url.query_pairs_mut();
142
143	pairs.append_pair("response_type", "code");
144	pairs.append_pair("client_id", client_id);
145	pairs.append_pair("redirect_uri", redirect_uri.as_str());
146
147	if let Some(scope_value) = common::format_scope(scope, descriptor.quirks.scope_delimiter) {
148		pairs.append_pair("scope", &scope_value);
149	}
150
151	pairs.append_pair("state", state);
152	pairs.append_pair("code_challenge", &pkce.challenge);
153	pairs.append_pair("code_challenge_method", pkce.method.as_str());
154
155	drop(pairs);
156
157	url
158}
159
160fn random_string(len: usize) -> String {
161	rand::rng().sample_iter(Alphanumeric).take(len).map(char::from).collect()
162}
163
164fn compute_pkce_challenge(verifier: &str) -> String {
165	let mut hasher = Sha256::new();
166	hasher.update(verifier.as_bytes());
167	let digest = hasher.finalize();
168	URL_SAFE_NO_PAD.encode(digest)
169}
170
171#[cfg(test)]
172mod tests {
173	// self
174	use super::*;
175
176	#[test]
177	fn state_validation_errors_on_mismatch() {
178		let session = AuthorizationSession::new(
179			TenantId::new("tenant").expect("Tenant fixture should be valid for PKCE tests."),
180			PrincipalId::new("principal")
181				.expect("Principal fixture should be valid for PKCE tests."),
182			ScopeSet::new(Vec::<&str>::new()).expect("Failed to build empty scope set for test."),
183			Url::parse("https://example.com/cb")
184				.expect("Redirect URL fixture should parse successfully."),
185			Url::parse("https://example.com/auth?state=abc")
186				.expect("Authorization URL fixture should parse successfully."),
187			"expected".into(),
188			PkcePair::generate(),
189		);
190
191		assert!(session.validate_state("expected").is_ok());
192
193		let err = session.validate_state("other").expect_err("State mismatch should fail.");
194
195		assert!(matches!(err, Error::InvalidGrant { .. }));
196	}
197}