oauth2_broker/provider/descriptor/
builder.rs1use std::iter::IntoIterator;
3use crate::{
5 _prelude::*,
6 auth::ProviderId,
7 provider::{
8 ClientAuthMethod, GrantType, ProviderDescriptor, ProviderEndpoints, ProviderQuirks,
9 SupportedGrants,
10 },
11};
12
13#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, ThisError)]
15pub enum ProviderDescriptorError {
16 #[error("Missing authorization endpoint.")]
18 MissingAuthorizationEndpoint,
19 #[error("Missing token endpoint.")]
21 MissingTokenEndpoint,
22 #[error("Descriptor must enable at least one grant type.")]
24 NoSupportedGrants,
25 #[error("The `pkce_required` flag requires enabling the authorization_code grant.")]
27 PkceRequiredWithoutAuthorizationCode,
28 #[error("The {endpoint} endpoint must use HTTPS: {url}.")]
30 InsecureEndpoint {
31 endpoint: &'static str,
33 url: String,
35 },
36 #[error("Scope delimiter must be a printable character.")]
38 InvalidScopeDelimiter {
39 delimiter: char,
41 },
42}
43
44#[derive(Debug)]
46pub struct ProviderDescriptorBuilder {
47 pub id: ProviderId,
49 pub authorization_endpoint: Option<Url>,
51 pub token_endpoint: Option<Url>,
53 pub revocation_endpoint: Option<Url>,
55 pub supported_grants: SupportedGrants,
57 pub preferred_client_auth_method: ClientAuthMethod,
59 pub quirks: ProviderQuirks,
61}
62impl ProviderDescriptorBuilder {
63 pub fn new(id: ProviderId) -> Self {
65 Self {
66 id,
67 authorization_endpoint: None,
68 token_endpoint: None,
69 revocation_endpoint: None,
70 supported_grants: SupportedGrants::default(),
71 preferred_client_auth_method: ClientAuthMethod::default(),
72 quirks: ProviderQuirks::default(),
73 }
74 }
75
76 pub fn authorization_endpoint(mut self, url: Url) -> Self {
78 self.authorization_endpoint = Some(url);
79
80 self
81 }
82
83 pub fn token_endpoint(mut self, url: Url) -> Self {
85 self.token_endpoint = Some(url);
86
87 self
88 }
89
90 pub fn revocation_endpoint(mut self, url: Url) -> Self {
92 self.revocation_endpoint = Some(url);
93
94 self
95 }
96
97 pub fn support_grant(mut self, grant: GrantType) -> Self {
99 self.supported_grants = self.supported_grants.enable(grant);
100
101 self
102 }
103
104 pub fn support_grants<I>(mut self, grants: I) -> Self
106 where
107 I: IntoIterator<Item = GrantType>,
108 {
109 for grant in grants.into_iter() {
110 self.supported_grants = self.supported_grants.enable(grant);
111 }
112
113 self
114 }
115
116 pub fn preferred_client_auth_method(mut self, method: ClientAuthMethod) -> Self {
118 self.preferred_client_auth_method = method;
119
120 self
121 }
122
123 pub fn quirks(mut self, quirks: ProviderQuirks) -> Self {
125 self.quirks = quirks;
126
127 self
128 }
129
130 pub fn build(self) -> Result<ProviderDescriptor, ProviderDescriptorError> {
132 let authorization = self
133 .authorization_endpoint
134 .ok_or(ProviderDescriptorError::MissingAuthorizationEndpoint)?;
135 let token = self.token_endpoint.ok_or(ProviderDescriptorError::MissingTokenEndpoint)?;
136 let endpoints =
137 ProviderEndpoints { authorization, token, revocation: self.revocation_endpoint };
138 let descriptor = ProviderDescriptor {
139 id: self.id,
140 endpoints,
141 supported_grants: self.supported_grants,
142 preferred_client_auth_method: self.preferred_client_auth_method,
143 quirks: self.quirks,
144 };
145
146 descriptor.validate()?;
147
148 Ok(descriptor)
149 }
150}
151
152impl ProviderDescriptor {
153 fn validate(&self) -> Result<(), ProviderDescriptorError> {
155 if self.supported_grants.is_empty() {
156 return Err(ProviderDescriptorError::NoSupportedGrants);
157 }
158 if self.quirks.pkce_required && !self.supports(GrantType::AuthorizationCode) {
159 return Err(ProviderDescriptorError::PkceRequiredWithoutAuthorizationCode);
160 }
161
162 validate_endpoint("authorization", &self.endpoints.authorization)?;
163 validate_endpoint("token", &self.endpoints.token)?;
164
165 if let Some(revocation) = self.endpoints.revocation.as_ref() {
166 validate_endpoint("revocation", revocation)?;
167 }
168
169 validate_scope_delimiter(self.quirks.scope_delimiter)?;
170
171 Ok(())
172 }
173}
174
175fn validate_endpoint(name: &'static str, url: &Url) -> Result<(), ProviderDescriptorError> {
176 if url.scheme() != "https" {
177 Err(ProviderDescriptorError::InsecureEndpoint { endpoint: name, url: url.to_string() })
178 } else {
179 Ok(())
180 }
181}
182
183fn validate_scope_delimiter(delimiter: char) -> Result<(), ProviderDescriptorError> {
184 if delimiter.is_control() {
185 Err(ProviderDescriptorError::InvalidScopeDelimiter { delimiter })
186 } else {
187 Ok(())
188 }
189}