1use crate::error::AuthenticationFailedError;
3use crate::{Error, IdToken, Provider};
4
5#[derive(Clone, Debug)]
9pub enum OidcResponseMode {
10 Query,
13 FormPost,
20 Fragment,
24}
25
26impl std::ops::Deref for OidcResponseMode {
28 type Target = str;
29 fn deref(&self) -> &str {
30 match self {
31 Self::Query => "query",
32 Self::FormPost => "form_post",
33 Self::Fragment => "fragment",
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
40pub enum OidcPrompt {
41 NoPrompt, Login,
43 Consent,
44 SelectAccount,
45}
46
47impl std::ops::Deref for OidcPrompt {
49 type Target = str;
50 fn deref(&self) -> &str {
51 match self {
52 Self::NoPrompt => "none",
53 Self::Login => "login",
54 Self::Consent => "consent",
55 Self::SelectAccount => "select_account",
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct Client<P: Provider> {
63 client_id: String,
64 client_secret: String,
65 redirect_uri: String,
66 response_mode: OidcResponseMode,
67 provider: P,
68}
69
70impl<P: Provider> Client<P> {
71 pub fn auth_url(&self, session: &Session, prompt: Option<OidcPrompt>) -> url::Url {
75 let mut authurl = self.provider.authorization_endpoint();
77 authurl
78 .query_pairs_mut()
79 .append_pair("scope", "openid profile email")
80 .append_pair("response_type", "code")
81 .append_pair("client_id", &self.client_id)
82 .append_pair("nonce", &session.nonce())
83 .append_pair("state", &session.state())
84 .append_pair("response_mode", &self.response_mode)
85 .append_pair("redirect_uri", &self.redirect_uri)
86 .append_pair("code_challenge_method", "S256")
87 .append_pair("code_challenge", &session.pkce_challenge());
88
89 if let Some(prompt) = prompt {
90 authurl.query_pairs_mut().append_pair("prompt", &prompt);
91 }
92
93 authurl
94 }
95
96 pub async fn authenticate<T>(
112 &self,
113 state: &str,
114 code: &str,
115 session: &Session,
116 ) -> Result<IdToken<T>, Error>
117 where
118 T: serde::de::DeserializeOwned,
119 {
120 if state != session.state() {
122 log::warn!("state mismatch");
123 return Err(Error::BadRequest);
124 }
125
126 let code_verifier = session.pkce_verifier();
128 let params = vec![
129 ("grant_type", "authorization_code"),
130 ("code", code),
131 ("client_id", &self.client_id),
132 ("client_secret", &self.client_secret),
133 ("redirect_uri", &self.redirect_uri),
134 ("code_verifier", &code_verifier),
135 ];
136
137 let response = reqwest::Client::new()
139 .post(self.provider.token_endpoint().clone())
140 .form(¶ms)
141 .send()
142 .await?;
143
144 if let Err(err) = response.error_for_status_ref() {
145 let err_body = response.text().await?;
147 log::warn!("Token endpoint returns error {}", err_body);
148
149 Err(err.into())
150 } else {
151 let token_response = response.json::<OidcTokenEndpointResponse>().await?;
153 log::debug!("Token endpoint returns {:?}", token_response);
154
155 let id_token = IdToken::<T>::decode_without_jws_validation(&token_response.id_token)?;
159
160 self.validate_claims(&id_token, session)?;
161 Ok(id_token)
162 }
163 }
164
165 fn validate_claims<T>(
168 &self,
169 id_token: &IdToken<T>,
170 session: &Session,
171 ) -> Result<(), AuthenticationFailedError> {
172 use std::time::SystemTime;
173
174 if !self.provider.validate_iss(&id_token.iss) {
175 log::info!("Invalid iss {}", id_token.iss);
176 return Err(AuthenticationFailedError::ClaimValidationError);
177 }
178
179 if id_token.aud != self.client_id {
180 log::info!("Invalid aud {}", id_token.aud);
181 return Err(AuthenticationFailedError::ClaimValidationError);
182 }
183
184 if &id_token.nonce != &session.nonce() {
185 log::info!("Invalid nonce {}", id_token.nonce);
186 return Err(AuthenticationFailedError::ClaimValidationError);
187 }
188
189 let now = SystemTime::now()
190 .duration_since(SystemTime::UNIX_EPOCH)
191 .map_or(0, |t| t.as_secs());
192 if id_token.iat > now + 60 || now > id_token.exp {
193 log::info!(
195 "Invalid iat {} or exp {} : now = {}",
196 id_token.iat,
197 id_token.exp,
198 now
199 );
200 return Err(AuthenticationFailedError::ClaimValidationError);
201 }
202
203 Ok(())
204 }
205}
206
207pub struct ClientBuilder<P: Provider> {
209 client_id: Option<String>,
210 client_secret: Option<String>,
211 redirect_uri: Option<String>,
212 response_mode: OidcResponseMode,
213 provider: P,
214}
215
216impl<P: Provider> ClientBuilder<P> {
217 pub(crate) fn from_provider(provider: P) -> Self {
219 Self {
220 client_id: None,
221 client_secret: None,
222 redirect_uri: None,
223 response_mode: OidcResponseMode::Query,
224 provider,
225 }
226 }
227
228 pub fn build(self) -> Option<Client<P>> {
230 match self {
231 Self {
232 client_id: Some(client_id),
233 client_secret: Some(client_secret),
234 redirect_uri: Some(redirect_uri),
235 response_mode,
236 provider,
237 } => Some(Client {
238 client_id,
239 client_secret,
240 redirect_uri,
241 response_mode,
242 provider,
243 }),
244 _ => {
245 None
247 }
248 }
249 }
250
251 pub fn client_id(self, client_id: &str) -> Self {
253 let mut builder = self;
254 builder.client_id = Some(client_id.to_string());
255 builder
256 }
257
258 pub fn client_secret(self, client_secret: &str) -> Self {
260 let mut builder = self;
261 builder.client_secret = Some(client_secret.to_string());
262 builder
263 }
264
265 pub fn redirect_uri(self, redirect_uri: &str) -> Self {
267 let mut builder = self;
268 builder.redirect_uri = Some(redirect_uri.to_string());
269 builder
270 }
271
272 pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
274 let mut builder = self;
275 builder.response_mode = response_mode;
276 builder
277 }
278}
279
280pub struct Session {
282 rand_bytes: [u8; 144],
284}
285
286impl Session {
287 pub fn new_session() -> Result<Session, crate::Error> {
289 let mut rand_bytes = [0u8; 144];
291 getrandom::fill(&mut rand_bytes).map_err(|e| {
292 log::error!("getrandom() failed with {:?}", e);
293 crate::Error::InternalError
294 })?;
295 Ok(Session { rand_bytes })
296 }
297
298 pub fn save_session(&self) -> (String, String) {
303 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
304 return (self.key(), URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..]));
305 }
306
307 pub fn load_session(
311 session_key: &str,
312 session_value: &str,
313 ) -> Result<Self, base64::DecodeSliceError> {
314 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
315 let mut rand_bytes = [0u8; 144];
316
317 URL_SAFE_NO_PAD.decode_slice(session_key, &mut rand_bytes[..36])?;
319 URL_SAFE_NO_PAD.decode_slice(session_value, &mut rand_bytes[36..])?;
320
321 Ok(Self { rand_bytes })
322 }
323
324 pub fn key(&self) -> String {
326 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
327 URL_SAFE_NO_PAD.encode(&self.rand_bytes[..36])
328 }
329
330 fn state(&self) -> String {
332 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
333 URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..72])
334 }
335
336 fn nonce(&self) -> String {
338 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
339 URL_SAFE_NO_PAD.encode(&self.rand_bytes[72..108])
340 }
341
342 fn pkce_challenge(&self) -> String {
344 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
345 use sha2::{Digest, Sha256};
346
347 let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
349
350 URL_SAFE_NO_PAD.encode(&challenge_byte)
351 }
352
353 fn pkce_verifier(&self) -> String {
355 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
356 URL_SAFE_NO_PAD.encode(&self.rand_bytes[108..144])
358 }
359}
360
361#[derive(Debug, serde::Deserialize)]
363struct OidcTokenEndpointResponse {
364 id_token: String,
366}