oauth2_broker/flows/auth_code_pkce/
session.rs1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::{Rng, distr::Alphanumeric};
4use sha2::{Digest, Sha256};
5use crate::{
7 _prelude::*,
8 auth::{PrincipalId, ScopeSet, TenantId},
9 flows::common,
10 provider::ProviderDescriptor,
11};
12
13const STATE_LEN: usize = 32;
14const PKCE_VERIFIER_LEN: usize = 64;
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum PkceCodeChallengeMethod {
19 S256,
21}
22impl PkceCodeChallengeMethod {
23 pub fn as_str(self) -> &'static str {
25 match self {
26 PkceCodeChallengeMethod::S256 => "S256",
27 }
28 }
29}
30
31#[derive(Clone)]
33pub struct AuthorizationSession {
34 pub tenant: TenantId,
36 pub principal: PrincipalId,
38 pub scope: ScopeSet,
40 pub state: String,
42 pub redirect_uri: Url,
44 pub authorize_url: Url,
46 pkce: PkcePair,
47}
48impl AuthorizationSession {
49 pub(super) fn new(
50 tenant: TenantId,
51 principal: PrincipalId,
52 scope: ScopeSet,
53 redirect_uri: Url,
54 authorize_url: Url,
55 state: String,
56 pkce: PkcePair,
57 ) -> Self {
58 Self { tenant, principal, scope, state, redirect_uri, authorize_url, pkce }
59 }
60
61 pub fn code_challenge(&self) -> &str {
63 &self.pkce.challenge
64 }
65
66 pub fn code_challenge_method(&self) -> PkceCodeChallengeMethod {
68 self.pkce.method
69 }
70
71 pub fn validate_state(&self, returned_state: &str) -> Result<()> {
73 if returned_state == self.state {
74 Ok(())
75 } else {
76 Err(Error::InvalidGrant { reason: "Authorization state mismatch.".into() })
77 }
78 }
79
80 pub(super) fn into_exchange_parts(self) -> (TenantId, PrincipalId, ScopeSet, Url, PkcePair) {
81 let AuthorizationSession { tenant, principal, scope, redirect_uri, pkce, .. } = self;
82
83 (tenant, principal, scope, redirect_uri, pkce)
84 }
85}
86impl Debug for AuthorizationSession {
87 fn fmt(&self, f: &mut Formatter) -> FmtResult {
88 f.debug_struct("AuthorizationSession")
89 .field("tenant", &self.tenant)
90 .field("principal", &self.principal)
91 .field("scope", &self.scope)
92 .field("state", &self.state)
93 .field("redirect_uri", &self.redirect_uri)
94 .field("authorize_url", &self.authorize_url)
95 .field("code_challenge", &self.pkce.challenge)
96 .field("code_challenge_method", &self.pkce.method)
97 .finish()
98 }
99}
100
101#[derive(Clone)]
102pub(super) struct PkcePair {
103 pub(super) verifier: String,
104 challenge: String,
105 method: PkceCodeChallengeMethod,
106}
107impl PkcePair {
108 pub(super) fn generate() -> Self {
109 let verifier = random_string(PKCE_VERIFIER_LEN);
110 let challenge = compute_pkce_challenge(&verifier);
111
112 Self { verifier, challenge, method: PkceCodeChallengeMethod::S256 }
113 }
114}
115
116pub(super) fn build_session(
117 descriptor: &ProviderDescriptor,
118 client_id: &str,
119 tenant: TenantId,
120 principal: PrincipalId,
121 scope: ScopeSet,
122 redirect_uri: Url,
123) -> AuthorizationSession {
124 let state = random_string(STATE_LEN);
125 let pkce = PkcePair::generate();
126 let authorize_url =
127 build_authorize_url(descriptor, client_id, &redirect_uri, &scope, &state, &pkce);
128
129 AuthorizationSession::new(tenant, principal, scope, redirect_uri, authorize_url, state, pkce)
130}
131
132fn build_authorize_url(
133 descriptor: &ProviderDescriptor,
134 client_id: &str,
135 redirect_uri: &Url,
136 scope: &ScopeSet,
137 state: &str,
138 pkce: &PkcePair,
139) -> Url {
140 let mut url = descriptor.endpoints.authorization.clone();
141 let mut pairs = url.query_pairs_mut();
142
143 pairs.append_pair("response_type", "code");
144 pairs.append_pair("client_id", client_id);
145 pairs.append_pair("redirect_uri", redirect_uri.as_str());
146
147 if let Some(scope_value) = common::format_scope(scope, descriptor.quirks.scope_delimiter) {
148 pairs.append_pair("scope", &scope_value);
149 }
150
151 pairs.append_pair("state", state);
152 pairs.append_pair("code_challenge", &pkce.challenge);
153 pairs.append_pair("code_challenge_method", pkce.method.as_str());
154
155 drop(pairs);
156
157 url
158}
159
160fn random_string(len: usize) -> String {
161 rand::rng().sample_iter(Alphanumeric).take(len).map(char::from).collect()
162}
163
164fn compute_pkce_challenge(verifier: &str) -> String {
165 let mut hasher = Sha256::new();
166 hasher.update(verifier.as_bytes());
167 let digest = hasher.finalize();
168 URL_SAFE_NO_PAD.encode(digest)
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
175
176 #[test]
177 fn state_validation_errors_on_mismatch() {
178 let session = AuthorizationSession::new(
179 TenantId::new("tenant").expect("Tenant fixture should be valid for PKCE tests."),
180 PrincipalId::new("principal")
181 .expect("Principal fixture should be valid for PKCE tests."),
182 ScopeSet::new(Vec::<&str>::new()).expect("Failed to build empty scope set for test."),
183 Url::parse("https://example.com/cb")
184 .expect("Redirect URL fixture should parse successfully."),
185 Url::parse("https://example.com/auth?state=abc")
186 .expect("Authorization URL fixture should parse successfully."),
187 "expected".into(),
188 PkcePair::generate(),
189 );
190
191 assert!(session.validate_state("expected").is_ok());
192
193 let err = session.validate_state("other").expect_err("State mismatch should fail.");
194
195 assert!(matches!(err, Error::InvalidGrant { .. }));
196 }
197}