1use super::scope::Scope;
7
8use std::borrow::Cow;
9use std::cmp;
10use std::collections::HashMap;
11use std::fmt;
12use std::iter::{Extend, FromIterator};
13use std::rc::Rc;
14use std::sync::{Arc, MutexGuard, RwLockWriteGuard};
15
16use argon2::{self, Config};
17use once_cell::sync::Lazy;
18use rand::{RngCore, thread_rng};
19use serde::{Deserialize, Serialize};
20use url::{Url, ParseError as ParseUrlError};
21
22pub trait Registrar {
28 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError>;
32
33 fn negotiate(&self, client: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError>;
44
45 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError>;
47}
48
49#[non_exhaustive]
56#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
57pub enum RegisteredUrl {
58 Exact(ExactUrl),
65 Semantic(Url),
67 IgnorePortOnLocalhost(IgnoreLocalPortUrl),
74}
75
76#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
96pub struct ExactUrl(String);
97
98impl<'de> Deserialize<'de> for ExactUrl {
99 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100 where
101 D: serde::Deserializer<'de>,
102 {
103 let string: &str = Deserialize::deserialize(deserializer)?;
104 core::str::FromStr::from_str(&string).map_err(serde::de::Error::custom)
105 }
106}
107
108#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal);
136
137#[derive(Clone, Debug, PartialEq, Eq)]
138enum IgnoreLocalPortUrlInternal {
139 Exact(String),
140 Local(Url),
141}
142
143impl Serialize for IgnoreLocalPortUrl {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 serializer.serialize_str(self.as_str())
149 }
150}
151
152impl<'de> Deserialize<'de> for IgnoreLocalPortUrl {
153 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
154 where
155 D: serde::Deserializer<'de>,
156 {
157 let string: &str = Deserialize::deserialize(deserializer)?;
158 Self::new(string).map_err(serde::de::Error::custom)
159 }
160}
161
162#[derive(Clone, Debug)]
170pub struct ClientUrl<'a> {
171 pub client_id: Cow<'a, str>,
173
174 pub redirect_uri: Option<Cow<'a, ExactUrl>>,
176}
177
178#[derive(Clone, Debug)]
183pub struct BoundClient<'a> {
184 pub client_id: Cow<'a, str>,
186
187 pub redirect_uri: Cow<'a, RegisteredUrl>,
189}
190
191#[derive(Clone, Debug, PartialEq, Eq)]
196pub struct PreGrant {
197 pub client_id: String,
199
200 pub redirect_uri: RegisteredUrl,
202
203 pub scope: Scope,
205}
206
207#[derive(Clone, Debug)]
209pub enum RegistrarError {
210 Unspecified,
220
221 PrimitiveError,
223}
224
225#[derive(Clone, Debug)]
235pub struct Client {
236 client_id: String,
237 redirect_uri: RegisteredUrl,
238 additional_redirect_uris: Vec<RegisteredUrl>,
239 default_scope: Scope,
240 client_type: ClientType,
241}
242
243#[derive(Clone, Debug, Serialize, Deserialize)]
248pub struct EncodedClient {
249 pub client_id: String,
252
253 pub redirect_uri: RegisteredUrl,
257
258 pub additional_redirect_uris: Vec<RegisteredUrl>,
261
262 pub default_scope: Scope,
264
265 pub encoded_client: ClientType,
267}
268
269pub struct RegisteredClient<'a> {
271 client: &'a EncodedClient,
272 policy: &'a dyn PasswordPolicy,
273}
274
275#[derive(Clone, Serialize, Deserialize)]
277pub enum ClientType {
278 Public,
280
281 Confidential {
283 passdata: Vec<u8>,
285 },
286}
287
288#[derive(Default)]
290pub struct ClientMap {
291 clients: HashMap<String, EncodedClient>,
292 password_policy: Option<Box<dyn PasswordPolicy>>,
293}
294
295impl fmt::Debug for ClientType {
296 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
297 match self {
298 ClientType::Public => write!(f, "<public>"),
299 ClientType::Confidential { .. } => write!(f, "<confidential>"),
300 }
301 }
302}
303
304impl RegisteredUrl {
305 pub fn as_str(&self) -> &str {
307 match self {
308 RegisteredUrl::Exact(exact) => &exact.0,
309 RegisteredUrl::Semantic(url) => url.as_str(),
310 RegisteredUrl::IgnorePortOnLocalhost(url) => url.as_str(),
311 }
312 }
313
314 pub fn to_url(&self) -> Url {
316 match self {
317 RegisteredUrl::Exact(exact) => exact.to_url(),
318 RegisteredUrl::Semantic(url) => url.clone(),
319 RegisteredUrl::IgnorePortOnLocalhost(url) => url.to_url(),
320 }
321 }
322
323 pub fn into_url(self) -> Url {
325 self.into()
326 }
327}
328
329impl From<Url> for RegisteredUrl {
330 fn from(url: Url) -> Self {
331 RegisteredUrl::Semantic(url)
332 }
333}
334
335impl From<ExactUrl> for RegisteredUrl {
336 fn from(url: ExactUrl) -> Self {
337 RegisteredUrl::Exact(url)
338 }
339}
340
341impl From<IgnoreLocalPortUrl> for RegisteredUrl {
342 fn from(url: IgnoreLocalPortUrl) -> Self {
343 RegisteredUrl::IgnorePortOnLocalhost(url)
344 }
345}
346
347impl From<RegisteredUrl> for Url {
348 fn from(url: RegisteredUrl) -> Self {
349 match url {
350 RegisteredUrl::Exact(exact) => exact.0.parse().expect("was validated"),
351 RegisteredUrl::Semantic(url) => url,
352 RegisteredUrl::IgnorePortOnLocalhost(url) => url.to_url(),
353 }
354 }
355}
356
357impl cmp::PartialEq<ExactUrl> for RegisteredUrl {
360 fn eq(&self, exact: &ExactUrl) -> bool {
361 match self {
362 RegisteredUrl::Exact(url) => url == exact,
363 RegisteredUrl::Semantic(url) => *url == exact.to_url(),
364 RegisteredUrl::IgnorePortOnLocalhost(url) => url == &IgnoreLocalPortUrl::from(exact),
365 }
366 }
367}
368
369impl cmp::PartialEq<IgnoreLocalPortUrl> for RegisteredUrl {
370 fn eq(&self, ign_lport: &IgnoreLocalPortUrl) -> bool {
371 match self {
372 RegisteredUrl::Exact(url) => ign_lport == &IgnoreLocalPortUrl::from(url),
373 RegisteredUrl::Semantic(url) => ign_lport == &IgnoreLocalPortUrl::from(url.clone()),
374 RegisteredUrl::IgnorePortOnLocalhost(url) => ign_lport == url,
375 }
376 }
377}
378
379impl cmp::PartialEq<Url> for RegisteredUrl {
381 fn eq(&self, semantic: &Url) -> bool {
382 self.to_url() == *semantic
383 }
384}
385
386impl fmt::Display for RegisteredUrl {
387 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
388 match self {
389 RegisteredUrl::Exact(url) => write!(f, "{}", url.to_url()),
390 RegisteredUrl::Semantic(url) => write!(f, "{}", url),
391 RegisteredUrl::IgnorePortOnLocalhost(url) => write!(f, "{}", url.to_url()),
392 }
393 }
394}
395
396impl ExactUrl {
397 pub fn new(url: String) -> Result<Self, ParseUrlError> {
399 let _: Url = url.parse()?;
400 Ok(ExactUrl(url))
401 }
402
403 pub fn as_str(&self) -> &str {
405 &self.0
406 }
407
408 pub fn to_url(&self) -> Url {
410 self.0.parse().expect("was validated")
412 }
413}
414
415impl core::str::FromStr for ExactUrl {
416 type Err = ParseUrlError;
417 fn from_str(st: &str) -> Result<Self, Self::Err> {
418 let _: Url = st.parse()?;
419 Ok(ExactUrl(st.to_string()))
420 }
421}
422
423impl IgnoreLocalPortUrl {
424 pub fn new<'a, S: Into<Cow<'a, str>>>(url: S) -> Result<Self, ParseUrlError> {
426 let url: Cow<'a, str> = url.into();
427 let mut parsed: Url = url.parse()?;
428 match parsed.host_str() {
429 Some("localhost") => {
430 let _ = parsed.set_port(None);
431 Ok(IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Local(parsed)))
432 }
433 _ => Ok(IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Exact(
434 url.into_owned(),
435 ))),
436 }
437 }
438
439 pub fn as_str(&self) -> &str {
441 match &self.0 {
442 IgnoreLocalPortUrlInternal::Exact(url) => url.as_str(),
443 IgnoreLocalPortUrlInternal::Local(url) => url.as_str(),
444 }
445 }
446
447 pub fn to_url(&self) -> Url {
449 match &self.0 {
451 IgnoreLocalPortUrlInternal::Exact(url) => url.parse().expect("was validated"),
452 IgnoreLocalPortUrlInternal::Local(url) => url.clone(),
453 }
454 }
455}
456
457impl From<ExactUrl> for IgnoreLocalPortUrl {
458 #[inline]
459 fn from(exact_url: ExactUrl) -> Self {
460 IgnoreLocalPortUrl::new(exact_url.0).expect("was validated")
461 }
462}
463
464impl<'e> From<&'e ExactUrl> for IgnoreLocalPortUrl {
465 #[inline]
466 fn from(exact_url: &'e ExactUrl) -> Self {
467 IgnoreLocalPortUrl::new(exact_url.as_str()).expect("was validated")
468 }
469}
470
471impl From<Url> for IgnoreLocalPortUrl {
472 fn from(mut url: Url) -> Self {
473 if url.host_str() == Some("localhost") {
474 let _ = url.set_port(None);
479 IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Local(url))
480 } else {
481 IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Exact(url.into()))
482 }
483 }
484}
485
486impl core::str::FromStr for IgnoreLocalPortUrl {
487 type Err = ParseUrlError;
488 #[inline]
489 fn from_str(st: &str) -> Result<Self, Self::Err> {
490 IgnoreLocalPortUrl::new(st)
491 }
492}
493
494impl Client {
495 pub fn public(client_id: &str, redirect_uri: RegisteredUrl, default_scope: Scope) -> Client {
497 Client {
498 client_id: client_id.to_string(),
499 redirect_uri,
500 additional_redirect_uris: vec![],
501 default_scope,
502 client_type: ClientType::Public,
503 }
504 }
505
506 pub fn confidential(
508 client_id: &str, redirect_uri: RegisteredUrl, default_scope: Scope, passphrase: &[u8],
509 ) -> Client {
510 Client {
511 client_id: client_id.to_string(),
512 redirect_uri,
513 additional_redirect_uris: vec![],
514 default_scope,
515 client_type: ClientType::Confidential {
516 passdata: passphrase.to_owned(),
517 },
518 }
519 }
520
521 pub fn with_additional_redirect_uris(mut self, uris: Vec<RegisteredUrl>) -> Self {
523 self.additional_redirect_uris = uris;
524 self
525 }
526
527 pub fn encode(self, policy: &dyn PasswordPolicy) -> EncodedClient {
533 let encoded_client = match self.client_type {
534 ClientType::Public => ClientType::Public,
535 ClientType::Confidential { passdata: passphrase } => ClientType::Confidential {
536 passdata: policy.store(&self.client_id, &passphrase),
537 },
538 };
539
540 EncodedClient {
541 client_id: self.client_id,
542 redirect_uri: self.redirect_uri,
543 additional_redirect_uris: self.additional_redirect_uris,
544 default_scope: self.default_scope,
545 encoded_client,
546 }
547 }
548}
549
550impl<'a> RegisteredClient<'a> {
551 pub fn new(client: &'a EncodedClient, policy: &'a dyn PasswordPolicy) -> Self {
556 RegisteredClient { client, policy }
557 }
558
559 pub fn check_authentication(&self, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
563 match (passphrase, &self.client.encoded_client) {
564 (None, &ClientType::Public) => Ok(()),
565 (Some(provided), &ClientType::Confidential { passdata: ref stored }) => {
566 self.policy.check(&self.client.client_id, provided, stored)
567 }
568 _ => Err(RegistrarError::Unspecified),
569 }
570 }
571}
572
573impl cmp::PartialOrd<Self> for PreGrant {
574 fn partial_cmp(&self, rhs: &PreGrant) -> Option<cmp::Ordering> {
576 if (&self.client_id, &self.redirect_uri) != (&rhs.client_id, &rhs.redirect_uri) {
577 None
578 } else {
579 self.scope.partial_cmp(&rhs.scope)
580 }
581 }
582}
583
584pub trait PasswordPolicy: Send + Sync {
588 fn store(&self, client_id: &str, passphrase: &[u8]) -> Vec<u8>;
590
591 fn check(&self, client_id: &str, passphrase: &[u8], stored: &[u8]) -> Result<(), RegistrarError>;
593}
594
595#[derive(Clone, Debug, Default)]
597pub struct Argon2 {
598 _private: (),
599}
600
601impl PasswordPolicy for Argon2 {
602 fn store(&self, client_id: &str, passphrase: &[u8]) -> Vec<u8> {
603 let config = Config {
604 ad: client_id.as_bytes(),
605 secret: &[],
606 ..Config::rfc9106_low_mem()
607 };
608
609 let mut salt = vec![0; 32];
610 thread_rng()
611 .try_fill_bytes(salt.as_mut_slice())
612 .expect("Failed to generate password salt");
613
614 let encoded = argon2::hash_encoded(passphrase, &salt, &config);
615 encoded.unwrap().as_bytes().to_vec()
616 }
617
618 fn check(&self, client_id: &str, passphrase: &[u8], stored: &[u8]) -> Result<(), RegistrarError> {
619 let hash = String::from_utf8(stored.to_vec()).map_err(|_| RegistrarError::PrimitiveError)?;
620 let valid = argon2::verify_encoded_ext(&hash, passphrase, &[], client_id.as_bytes())
621 .map_err(|_| RegistrarError::PrimitiveError)?;
622 match valid {
623 true => Ok(()),
624 false => Err(RegistrarError::Unspecified),
625 }
626 }
627}
628
629static DEFAULT_PASSWORD_POLICY: Lazy<Argon2> = Lazy::new(Argon2::default);
634
635impl ClientMap {
636 pub fn new() -> ClientMap {
638 ClientMap::default()
639 }
640
641 pub fn register_client(&mut self, client: Client) {
643 let password_policy = Self::current_policy(&self.password_policy);
644 self.clients
645 .insert(client.client_id.clone(), client.encode(password_policy));
646 }
647
648 pub fn set_password_policy<P: PasswordPolicy + 'static>(&mut self, new_policy: P) {
650 self.password_policy = Some(Box::new(new_policy))
651 }
652
653 fn current_policy<'a>(policy: &'a Option<Box<dyn PasswordPolicy>>) -> &'a dyn PasswordPolicy {
655 policy
656 .as_ref()
657 .map(|boxed| &**boxed)
658 .unwrap_or(&*DEFAULT_PASSWORD_POLICY)
659 }
660}
661
662impl Extend<Client> for ClientMap {
663 fn extend<I>(&mut self, iter: I)
664 where
665 I: IntoIterator<Item = Client>,
666 {
667 iter.into_iter().for_each(|client| self.register_client(client))
668 }
669}
670
671impl FromIterator<Client> for ClientMap {
672 fn from_iter<I>(iter: I) -> Self
673 where
674 I: IntoIterator<Item = Client>,
675 {
676 let mut into = ClientMap::new();
677 into.extend(iter);
678 into
679 }
680}
681
682impl<'s, R: Registrar + ?Sized> Registrar for &'s R {
683 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
684 (**self).bound_redirect(bound)
685 }
686
687 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
688 (**self).negotiate(bound, scope)
689 }
690
691 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
692 (**self).check(client_id, passphrase)
693 }
694}
695
696impl<'s, R: Registrar + ?Sized> Registrar for &'s mut R {
697 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
698 (**self).bound_redirect(bound)
699 }
700
701 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
702 (**self).negotiate(bound, scope)
703 }
704
705 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
706 (**self).check(client_id, passphrase)
707 }
708}
709
710impl<R: Registrar + ?Sized> Registrar for Box<R> {
711 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
712 (**self).bound_redirect(bound)
713 }
714
715 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
716 (**self).negotiate(bound, scope)
717 }
718
719 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
720 (**self).check(client_id, passphrase)
721 }
722}
723
724impl<R: Registrar + ?Sized> Registrar for Rc<R> {
725 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
726 (**self).bound_redirect(bound)
727 }
728
729 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
730 (**self).negotiate(bound, scope)
731 }
732
733 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
734 (**self).check(client_id, passphrase)
735 }
736}
737
738impl<R: Registrar + ?Sized> Registrar for Arc<R> {
739 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
740 (**self).bound_redirect(bound)
741 }
742
743 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
744 (**self).negotiate(bound, scope)
745 }
746
747 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
748 (**self).check(client_id, passphrase)
749 }
750}
751
752impl<'s, R: Registrar + ?Sized + 's> Registrar for MutexGuard<'s, R> {
753 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
754 (**self).bound_redirect(bound)
755 }
756
757 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
758 (**self).negotiate(bound, scope)
759 }
760
761 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
762 (**self).check(client_id, passphrase)
763 }
764}
765
766impl<'s, R: Registrar + ?Sized + 's> Registrar for RwLockWriteGuard<'s, R> {
767 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
768 (**self).bound_redirect(bound)
769 }
770
771 fn negotiate(&self, bound: BoundClient, scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
772 (**self).negotiate(bound, scope)
773 }
774
775 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
776 (**self).check(client_id, passphrase)
777 }
778}
779
780impl Registrar for ClientMap {
781 fn bound_redirect<'a>(&self, bound: ClientUrl<'a>) -> Result<BoundClient<'a>, RegistrarError> {
782 let client = match self.clients.get(bound.client_id.as_ref()) {
783 None => return Err(RegistrarError::Unspecified),
784 Some(stored) => stored,
785 };
786
787 let registered_url = match bound.redirect_uri {
789 None => client.redirect_uri.clone(),
790 Some(url) => {
791 let original = std::iter::once(&client.redirect_uri);
792 let alternatives = client.additional_redirect_uris.iter();
793 if original
794 .chain(alternatives)
795 .any(|registered| *registered == *url.as_ref())
796 {
797 RegisteredUrl::Exact((*url).clone())
798 } else {
799 return Err(RegistrarError::Unspecified);
800 }
801 }
802 };
803
804 Ok(BoundClient {
805 client_id: bound.client_id,
806 redirect_uri: Cow::Owned(registered_url),
807 })
808 }
809
810 fn negotiate(&self, bound: BoundClient, _scope: Option<Scope>) -> Result<PreGrant, RegistrarError> {
812 let client = self
813 .clients
814 .get(bound.client_id.as_ref())
815 .expect("Bound client appears to not have been constructed with this registrar");
816 Ok(PreGrant {
817 client_id: bound.client_id.into_owned(),
818 redirect_uri: bound.redirect_uri.into_owned(),
819 scope: client.default_scope.clone(),
820 })
821 }
822
823 fn check(&self, client_id: &str, passphrase: Option<&[u8]>) -> Result<(), RegistrarError> {
824 let password_policy = Self::current_policy(&self.password_policy);
825
826 self.clients
827 .get(client_id)
828 .ok_or(RegistrarError::Unspecified)
829 .and_then(|client| {
830 RegisteredClient::new(client, password_policy).check_authentication(passphrase)
831 })?;
832
833 Ok(())
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 pub fn simple_test_suite<Reg, RegFn>(registrar: &mut Reg, register: RegFn)
843 where
844 Reg: Registrar,
845 RegFn: Fn(&mut Reg, Client),
846 {
847 let public_id = "PrivateClientId";
848 let client_url = "https://example.com";
849
850 let private_id = "PublicClientId";
851 let private_passphrase = b"WOJJCcS8WyS2aGmJK6ZADg==";
852
853 let public_client = Client::public(
854 public_id,
855 client_url.parse::<Url>().unwrap().into(),
856 "default".parse().unwrap(),
857 );
858
859 register(registrar, public_client);
860
861 {
862 registrar
863 .check(public_id, None)
864 .expect("Authorization of public client has changed");
865 registrar
866 .check(public_id, Some(b""))
867 .err()
868 .expect("Authorization with password succeeded");
869 }
870
871 let private_client = Client::confidential(
872 private_id,
873 client_url.parse::<Url>().unwrap().into(),
874 "default".parse().unwrap(),
875 private_passphrase,
876 );
877
878 register(registrar, private_client);
879
880 {
881 registrar
882 .check(private_id, Some(private_passphrase))
883 .expect("Authorization with right password did not succeed");
884 registrar
885 .check(private_id, Some(b"Not the private passphrase"))
886 .err()
887 .expect("Authorization succeed with wrong password");
888 }
889 }
890
891 #[test]
892 fn public_client() {
893 let policy = Argon2::default();
894 let client = Client::public(
895 "ClientId",
896 "https://example.com".parse::<Url>().unwrap().into(),
897 "default".parse().unwrap(),
898 )
899 .encode(&policy);
900 let client = RegisteredClient::new(&client, &policy);
901
902 assert!(client.check_authentication(None).is_ok());
904 assert!(client.check_authentication(Some(b"")).is_err());
906 }
907
908 #[test]
909 fn confidential_client() {
910 let policy = Argon2::default();
911 let pass = b"AB3fAj6GJpdxmEVeNCyPoA==";
912 let client = Client::confidential(
913 "ClientId",
914 "https://example.com".parse::<Url>().unwrap().into(),
915 "default".parse().unwrap(),
916 pass,
917 )
918 .encode(&policy);
919 let client = RegisteredClient::new(&client, &policy);
920 assert!(client.check_authentication(None).is_err());
921 assert!(client.check_authentication(Some(pass)).is_ok());
922 assert!(client.check_authentication(Some(b"not the passphrase")).is_err());
923 assert!(client.check_authentication(Some(b"")).is_err());
924 }
925
926 #[test]
927 fn with_additional_redirect_uris() {
928 let client_id = "ClientId";
929 let redirect_uri: Url = "https://example.com/foo".parse().unwrap();
930 let additional_redirect_uris: Vec<RegisteredUrl> =
931 vec!["https://example.com/bar".parse::<Url>().unwrap().into()];
932 let default_scope = "default".parse().unwrap();
933 let client = Client::public(client_id, redirect_uri.into(), default_scope)
934 .with_additional_redirect_uris(additional_redirect_uris);
935 let mut client_map = ClientMap::new();
936 client_map.register_client(client);
937
938 assert_eq!(
939 client_map
940 .bound_redirect(ClientUrl {
941 client_id: Cow::from(client_id),
942 redirect_uri: Some(Cow::Borrowed(&"https://example.com/foo".parse().unwrap()))
943 })
944 .unwrap()
945 .redirect_uri,
946 Cow::<Url>::Owned("https://example.com/foo".parse().unwrap())
947 );
948
949 assert_eq!(
950 client_map
951 .bound_redirect(ClientUrl {
952 client_id: Cow::from(client_id),
953 redirect_uri: Some(Cow::Borrowed(&"https://example.com/bar".parse().unwrap()))
954 })
955 .unwrap()
956 .redirect_uri,
957 Cow::<Url>::Owned("https://example.com/bar".parse().unwrap())
958 );
959
960 assert!(client_map
961 .bound_redirect(ClientUrl {
962 client_id: Cow::from(client_id),
963 redirect_uri: Some(Cow::Borrowed(&"https://example.com/baz".parse().unwrap()))
964 })
965 .is_err());
966
967 assert!(client_map
968 .bound_redirect(ClientUrl {
969 client_id: Cow::from(client_id),
970 redirect_uri: Some(Cow::Borrowed(&"https://example.com:1234/foo".parse().unwrap()))
971 })
972 .is_err());
973 }
974
975 #[test]
976 fn localhost_redirect_uris() {
977 let client_id = "ClientId";
978 let redirect_uri: Url = "http://localhost/foo".parse().unwrap();
979 let default_scope = "default".parse().unwrap();
980 let client = Client::public(
981 client_id,
982 RegisteredUrl::IgnorePortOnLocalhost(redirect_uri.into()),
983 default_scope,
984 );
985 let mut client_map = ClientMap::new();
986 client_map.register_client(client);
987
988 for url in &["http://localhost/foo", "http://localhost:1234/foo"] {
991 assert_eq!(
992 client_map
993 .bound_redirect(ClientUrl {
994 client_id: Cow::from(client_id),
995 redirect_uri: Some(Cow::Borrowed(&url.parse().unwrap()))
996 })
997 .unwrap()
998 .redirect_uri,
999 Cow::<Url>::Owned(url.parse().unwrap())
1000 );
1001 }
1002
1003 for url in &[
1006 "http://localhost/bar",
1007 "http://localhost:1234/bar",
1008 "http://example.com/foo",
1009 "http://example.com:1234/foo",
1010 "http://127.0.0.1/foo",
1011 "http://127.0.0.1:1234/foo",
1012 ] {
1013 assert!(client_map
1014 .bound_redirect(ClientUrl {
1015 client_id: Cow::from(client_id),
1016 redirect_uri: Some(Cow::Borrowed(&url.parse().unwrap()))
1017 })
1018 .is_err());
1019 }
1020 }
1021
1022 #[test]
1023 fn client_map() {
1024 let mut client_map = ClientMap::new();
1025 simple_test_suite(&mut client_map, ClientMap::register_client);
1026 }
1027
1028 #[test]
1029 fn ignore_local_port_url_eq_local() {
1030 let url = IgnoreLocalPortUrl::new("https://localhost/cb").unwrap();
1031 let url2 = IgnoreLocalPortUrl::new("https://localhost:8000/cb").unwrap();
1032
1033 let aliases = [
1034 "https://localhost/cb",
1035 "https://localhost:1313/cb",
1036 "https://localhost:8000/cb",
1037 "https://localhost:8080/cb",
1038 "https://localhost:4343/cb",
1039 "https://localhost:08000/cb",
1040 "https://localhost:0000008000/cb",
1041 ];
1042
1043 let others = [
1044 "http://localhost/cb",
1045 "https://localhost/cb/",
1046 "https://127.0.0.1/cb",
1047 "http://127.0.0.1/cb",
1048 ];
1049
1050 for alias in aliases.iter().map(|&a| IgnoreLocalPortUrl::new(a).unwrap()) {
1051 assert_eq!(url, alias);
1052 assert_eq!(url2, alias);
1053 }
1054
1055 for other in others.iter().map(|&o| IgnoreLocalPortUrl::new(o).unwrap()) {
1056 assert_ne!(url, other);
1057 assert_ne!(url2, other);
1058 }
1059 }
1060
1061 #[test]
1062 fn ignore_local_port_url_eq_not_local() {
1063 let url1 = IgnoreLocalPortUrl::new("https://example.com/cb").unwrap();
1064 let url2 = IgnoreLocalPortUrl::new("http://example.com/cb/").unwrap();
1065
1066 let not_url1 = ["https://example.com/cb/", "https://example.com:443/cb"];
1067
1068 let not_url2 = ["https://example.com/cb", "https://example.com:80/cb/"];
1069
1070 assert_eq!(
1071 url1,
1072 IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Exact(
1073 "https://example.com/cb".to_string()
1074 ))
1075 );
1076 assert_eq!(
1077 url2,
1078 IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Exact(
1079 "http://example.com/cb/".to_string()
1080 ))
1081 );
1082
1083 for different_url in not_url1.iter().map(|&a| IgnoreLocalPortUrl::new(a).unwrap()) {
1084 assert_ne!(url1, different_url);
1085 }
1086
1087 for different_url in not_url2.iter().map(|&a| IgnoreLocalPortUrl::new(a).unwrap()) {
1088 assert_ne!(url2, different_url);
1089 }
1090 }
1091
1092 #[test]
1093 fn ignore_local_port_url_new() {
1094 let locals = &[
1095 "https://localhost/callback",
1096 "https://localhost:8080/callback",
1097 "http://localhost:8080/callback",
1098 "https://localhost:8080",
1099 ];
1100
1101 let exacts = &[
1102 "https://example.com:8000/callback",
1103 "https://example.com:8080/callback",
1104 "https://example.com:8888/callback",
1105 "https://localhost.com:8888/callback",
1106 ];
1107
1108 for &local in locals {
1109 let mut parsed: Url = local.parse().unwrap();
1110 let _ = parsed.set_port(None);
1111
1112 assert_eq!(
1113 local.parse(),
1114 Ok(IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Local(parsed)))
1115 );
1116 }
1117
1118 for &exact in exacts {
1119 assert_eq!(
1120 exact.parse(),
1121 Ok(IgnoreLocalPortUrl(IgnoreLocalPortUrlInternal::Exact(
1122 exact.into()
1123 )))
1124 );
1125 }
1126 }
1127
1128 #[test]
1129 fn roundtrip_serialization_ignore_local_port_url() {
1130 let url = "https://localhost/callback"
1131 .parse::<IgnoreLocalPortUrl>()
1132 .unwrap();
1133 let serialized = rmp_serde::to_vec(&url).unwrap();
1134 let deserialized = rmp_serde::from_slice::<IgnoreLocalPortUrl>(&serialized).unwrap();
1135 assert_eq!(url, deserialized);
1136 }
1137
1138 #[test]
1139 fn deserialize_invalid_exact_url() {
1140 let url = "/callback";
1141 let serialized = rmp_serde::to_vec(&url).unwrap();
1142 let deserialized = rmp_serde::from_slice::<ExactUrl>(&serialized);
1143 assert!(deserialized.is_err());
1144 }
1145
1146 #[test]
1147 fn roundtrip_serialization_exact_url() {
1148 let url = "https://example.com/callback".parse::<ExactUrl>().unwrap();
1149 let serialized = rmp_serde::to_vec(&url).unwrap();
1150 let deserialized = rmp_serde::from_slice::<ExactUrl>(&serialized).unwrap();
1151 assert_eq!(url, deserialized);
1152 }
1153}