1pub mod discovery;
65pub mod error;
66pub mod issuer;
67pub mod token;
68
69pub use crate::error::Error;
70
71use biscuit::jwa::{self, SignatureAlgorithm};
72use biscuit::jwk::{AlgorithmParameters, JWKSet};
73use biscuit::jws::{Compact, Secret};
74use biscuit::{Empty, SingleOrMultiple};
75use chrono::{Duration, NaiveDate, Utc};
76use inth_oauth2::token::Token as _t;
77use reqwest::Url;
78use serde::{Deserialize, Serialize};
79use validator::Validate;
80use validator_derive::Validate;
81
82use crate::discovery::{Config, Discovered};
83use crate::error::{Decode, Expiry, Mismatch, Missing, Validation};
84use crate::token::{Claims, Token};
85
86type IdToken = Compact<Claims, Empty>;
87
88pub struct Client {
90 oauth: inth_oauth2::Client<Discovered>,
91 jwks: JWKSet<Empty>,
92}
93
94macro_rules! wrong_key {
96 ($expected:expr, $actual:expr) => {
97 Err(error::Jose::WrongKeyType {
98 expected: format!("{:?}", $expected),
99 actual: format!("{:?}", $actual),
100 }
101 .into())
102 };
103}
104
105impl Client {
106 pub fn discover(id: String, secret: String, redirect: Url, issuer: Url) -> Result<Self, Error> {
108 discovery::secure(&redirect)?;
109 let client = reqwest::Client::new();
110 let config = discovery::discover(&client, issuer)?;
111 let jwks = discovery::jwks(&client, config.jwks_uri.clone())?;
112 let provider = Discovered(config);
113 Ok(Self::new(id, secret, redirect, provider, jwks))
114 }
115
116 pub fn new(
119 id: String,
120 secret: String,
121 redirect: Url,
122 provider: Discovered,
123 jwks: JWKSet<Empty>,
124 ) -> Self {
125 Client {
126 oauth: inth_oauth2::Client::new(provider, id, secret, Some(redirect.into_string())),
127 jwks,
128 }
129 }
130
131 pub fn redirect_url(&self) -> &str {
133 self.oauth
134 .redirect_uri
135 .as_ref()
136 .expect("We always require a redirect to construct client!")
137 }
138
139 pub fn request_token(&self, client: &reqwest::Client, auth_code: &str) -> Result<Token, Error> {
141 self.oauth
142 .request_token(client, auth_code)
143 .map_err(Error::from)
144 }
145
146 pub fn config(&self) -> &Config {
148 &self.oauth.provider.0
149 }
150
151 pub fn auth_url(&self, options: &Options) -> Url {
155 let scope = match options.scope {
156 Some(ref scope) => {
157 if !scope.contains("openid") {
158 String::from("openid ") + scope
159 } else {
160 scope.clone()
161 }
162 }
163 None => String::from("openid"),
165 };
166
167 let mut url = self
168 .oauth
169 .auth_uri(Some(&scope), options.state.as_ref().map(String::as_str));
170 {
171 let mut query = url.query_pairs_mut();
172 if let Some(ref nonce) = options.nonce {
173 query.append_pair("nonce", nonce.as_str());
174 }
175 if let Some(ref display) = options.display {
176 query.append_pair("display", display.as_str());
177 }
178 if let Some(ref prompt) = options.prompt {
179 let s = prompt
180 .iter()
181 .map(|s| s.as_str())
182 .collect::<Vec<_>>()
183 .join(" ");
184 query.append_pair("prompt", s.as_str());
185 }
186 if let Some(max_age) = options.max_age {
187 query.append_pair("max_age", max_age.num_seconds().to_string().as_str());
188 }
189 if let Some(ref ui_locales) = options.ui_locales {
190 query.append_pair("ui_locales", ui_locales.as_str());
191 }
192 if let Some(ref claims_locales) = options.claims_locales {
193 query.append_pair("claims_locales", claims_locales.as_str());
194 }
195 if let Some(ref id_token_hint) = options.id_token_hint {
196 query.append_pair("id_token_hint", id_token_hint.as_str());
197 }
198 if let Some(ref login_hint) = options.login_hint {
199 query.append_pair("login_hint", login_hint.as_str());
200 }
201 if let Some(ref acr_values) = options.acr_values {
202 query.append_pair("acr_values", acr_values.as_str());
203 }
204 }
205 url
206 }
207
208 pub fn authenticate(
210 &self,
211 auth_code: &str,
212 nonce: Option<&str>,
213 max_age: Option<&Duration>,
214 ) -> Result<Token, Error> {
215 let client = reqwest::Client::new();
216 let mut token = self.request_token(&client, auth_code)?;
217 self.decode_token(&mut token.id_token)?;
218 self.validate_token(&token.id_token, nonce, max_age)?;
219 Ok(token)
220 }
221
222 pub fn decode_token(&self, token: &mut IdToken) -> Result<(), Error> {
231 if let Compact::Decoded { .. } = *token {
233 return Ok(());
234 }
235
236 let header = token.unverified_header()?;
237 let key = if self.jwks.keys.len() > 1 {
239 let token_kid = header.registered.key_id.ok_or(Decode::MissingKid)?;
240 self.jwks
241 .find(&token_kid)
242 .ok_or(Decode::MissingKey(token_kid))?
243 } else {
244 self.jwks.keys.first().as_ref().ok_or(Decode::EmptySet)?
247 };
248
249 if let Some(alg) = key.common.algorithm.as_ref() {
250 if let &jwa::Algorithm::Signature(sig) = alg {
251 if header.registered.algorithm != sig {
252 return wrong_key!(sig, header.registered.algorithm);
253 }
254 } else {
255 return wrong_key!(SignatureAlgorithm::default(), alg);
256 }
257 }
258
259 let alg = header.registered.algorithm;
260 match key.algorithm {
261 AlgorithmParameters::OctectKey { ref value, .. } => match alg {
263 SignatureAlgorithm::HS256
264 | SignatureAlgorithm::HS384
265 | SignatureAlgorithm::HS512 => {
266 *token = token.decode(&Secret::Bytes(value.clone()), alg)?;
267 Ok(())
268 }
269 _ => wrong_key!("HS256 | HS384 | HS512", alg),
270 },
271 AlgorithmParameters::RSA(ref params) => match alg {
272 SignatureAlgorithm::RS256
273 | SignatureAlgorithm::RS384
274 | SignatureAlgorithm::RS512 => {
275 let pkcs = Secret::RSAModulusExponent {
276 n: params.n.clone(),
277 e: params.e.clone(),
278 };
279 *token = token.decode(&pkcs, alg)?;
280 Ok(())
281 }
282 _ => wrong_key!("RS256 | RS384 | RS512", alg),
283 },
284 AlgorithmParameters::EllipticCurve(_) => unimplemented!("No support for EC keys yet"),
285 }
286 }
287
288 pub fn validate_token(
302 &self,
303 token: &IdToken,
304 nonce: Option<&str>,
305 max_age: Option<&Duration>,
306 ) -> Result<(), Error> {
307 let claims = token.payload()?;
308
309 if claims.iss != self.config().issuer {
310 let expected = self.config().issuer.as_str().to_string();
311 let actual = claims.iss.as_str().to_string();
312 return Err(Validation::Mismatch(Mismatch::Issuer { expected, actual }).into());
313 }
314
315 match nonce {
316 Some(expected) => match claims.nonce {
317 Some(ref actual) => {
318 if expected != actual {
319 let expected = expected.to_string();
320 let actual = actual.to_string();
321 return Err(
322 Validation::Mismatch(Mismatch::Nonce { expected, actual }).into()
323 );
324 }
325 }
326 None => return Err(Validation::Missing(Missing::Nonce).into()),
327 },
328 None => {
329 if claims.nonce.is_some() {
330 return Err(Validation::Missing(Missing::Nonce).into());
331 }
332 }
333 }
334
335 if !claims.aud.contains(&self.oauth.client_id) {
336 return Err(Validation::Missing(Missing::Audience).into());
337 }
338 if let SingleOrMultiple::Multiple(_) = claims.aud {
340 if let None = claims.azp {
341 return Err(Validation::Missing(Missing::AuthorizedParty).into());
342 }
343 }
344 if let Some(ref actual) = claims.azp {
346 if actual != &self.oauth.client_id {
347 let expected = self.oauth.client_id.to_string();
348 let actual = actual.to_string();
349 return Err(
350 Validation::Mismatch(Mismatch::AuthorizedParty { expected, actual }).into(),
351 );
352 }
353 }
354
355 let now = Utc::now();
356 if now.timestamp() < 1504758600 {
358 panic!("chrono::Utc::now() can never be before this was written!")
359 }
360 if claims.exp <= now.timestamp() {
361 return Err(Validation::Expired(Expiry::Expires(
362 chrono::naive::NaiveDateTime::from_timestamp(claims.exp, 0),
363 ))
364 .into());
365 }
366
367 if let Some(max) = max_age {
368 match claims.auth_time {
369 Some(time) => {
370 let age = chrono::Duration::seconds(now.timestamp() - time);
371 if age >= *max {
372 return Err(error::Validation::Expired(Expiry::MaxAge(age)).into());
373 }
374 }
375 None => return Err(Validation::Missing(Missing::AuthTime).into()),
376 }
377 }
378
379 Ok(())
380 }
381
382 pub fn request_userinfo(
392 &self,
393 client: &reqwest::Client,
394 token: &Token,
395 ) -> Result<Userinfo, Error> {
396 match self.config().userinfo_endpoint {
397 Some(ref url) => {
398 discovery::secure(&url)?;
399 let claims = token.id_token.payload()?;
400 let auth_code = token.access_token().to_string();
401 let mut resp = client
402 .get(url.clone())
403 .header_011(reqwest::hyper_011::header::Authorization(
406 reqwest::hyper_011::header::Bearer { token: auth_code },
407 ))
408 .send()?;
409 let info: Userinfo = resp.json()?;
410 if claims.sub != info.sub {
411 let expected = info.sub.clone();
412 let actual = claims.sub.clone();
413 return Err(error::Userinfo::MismatchSubject { expected, actual }.into());
414 }
415 Ok(info)
416 }
417 None => Err(error::Userinfo::NoUrl.into()),
418 }
419 }
420}
421
422#[derive(Default)]
425pub struct Options {
426 pub scope: Option<String>,
430 pub state: Option<String>,
431 pub nonce: Option<String>,
432 pub display: Option<Display>,
433 pub prompt: Option<std::collections::HashSet<Prompt>>,
434 pub max_age: Option<Duration>,
435 pub ui_locales: Option<String>,
436 pub claims_locales: Option<String>,
437 pub id_token_hint: Option<String>,
438 pub login_hint: Option<String>,
439 pub acr_values: Option<String>,
440}
441
442#[derive(Debug, Deserialize, Serialize, Validate)]
445pub struct Userinfo {
446 pub sub: String,
447 #[serde(default)]
448 pub name: Option<String>,
449 #[serde(default)]
450 pub given_name: Option<String>,
451 #[serde(default)]
452 pub family_name: Option<String>,
453 #[serde(default)]
454 pub middle_name: Option<String>,
455 #[serde(default)]
456 pub nickname: Option<String>,
457 #[serde(default)]
458 pub preferred_username: Option<String>,
459 #[serde(default)]
460 #[serde(with = "url_serde")]
461 pub profile: Option<Url>,
462 #[serde(default)]
463 #[serde(with = "url_serde")]
464 pub picture: Option<Url>,
465 #[serde(default)]
466 #[serde(with = "url_serde")]
467 pub website: Option<Url>,
468 #[serde(default)]
469 #[validate(email)]
470 pub email: Option<String>,
471 #[serde(default)]
472 pub email_verified: bool,
473 #[serde(default)]
475 pub gender: Option<String>,
476 #[serde(default)]
478 pub birthdate: Option<NaiveDate>,
479 #[serde(default)]
481 pub zoneinfo: Option<String>,
482 #[serde(default)]
484 pub locale: Option<String>,
485 #[serde(default)]
487 pub phone_number: Option<String>,
488 #[serde(default)]
489 pub phone_number_verified: bool,
490 #[serde(default)]
491 pub address: Option<Address>,
492 #[serde(default)]
493 pub updated_at: Option<i64>,
494}
495
496pub enum Display {
498 Page,
499 Popup,
500 Touch,
501 Wap,
502}
503
504impl Display {
505 fn as_str(&self) -> &'static str {
506 use self::Display::*;
507 match *self {
508 Page => "page",
509 Popup => "popup",
510 Touch => "touch",
511 Wap => "wap",
512 }
513 }
514}
515
516#[derive(PartialEq, Eq, Hash)]
518pub enum Prompt {
519 None,
520 Login,
521 Consent,
522 SelectAccount,
523}
524
525impl Prompt {
526 fn as_str(&self) -> &'static str {
527 use self::Prompt::*;
528 match *self {
529 None => "none",
530 Login => "login",
531 Consent => "consent",
532 SelectAccount => "select_account",
533 }
534 }
535}
536
537#[derive(Debug, Deserialize, Serialize)]
539pub struct Address {
540 #[serde(default)]
541 pub formatted: Option<String>,
542 #[serde(default)]
543 pub street_address: Option<String>,
544 #[serde(default)]
545 pub locality: Option<String>,
546 #[serde(default)]
547 pub region: Option<String>,
548 #[serde(default)]
550 pub postal_code: Option<String>,
551 #[serde(default)]
552 pub country: Option<String>,
553}
554
555#[cfg(test)]
556mod tests {
557 use crate::issuer;
558 use crate::Client;
559 use reqwest::Url;
560
561 #[test]
562 fn default_options() {
563 let _: super::Options = Default::default();
564 }
565
566 macro_rules! test {
567 ($issuer:ident) => {
568 #[test]
569 fn $issuer() {
570 let id = "test".to_string();
571 let secret = "a secret to everybody".to_string();
572 let redirect = Url::parse("https://example.com/re").unwrap();
573 let client = Client::discover(id, secret, redirect, issuer::$issuer()).unwrap();
574 client.auth_url(&Default::default());
575 }
576 };
577 }
578
579 test!(google);
580 test!(microsoft);
581 test!(paypal);
582 test!(salesforce);
583 test!(yahoo);
584}