1use openauth_oauth::oauth2::{
2 ClientAuthentication, ClientId, OAuth2Tokens, OAuth2UserInfo, OAuthError, ProviderOptions,
3 SocialIdTokenRequest,
4};
5use serde_json::{json, Value};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11pub type GenericOAuthTokenFuture =
12 Pin<Box<dyn Future<Output = Result<OAuth2Tokens, OAuthError>> + Send>>;
13pub type GenericOAuthGetToken =
14 Arc<dyn Fn(GenericOAuthTokenRequest) -> GenericOAuthTokenFuture + Send + Sync>;
15pub type GenericOAuthUserInfoFuture =
16 Pin<Box<dyn Future<Output = Result<Option<OAuth2UserInfo>, OAuthError>> + Send>>;
17pub type GenericOAuthGetUserInfo =
18 Arc<dyn Fn(OAuth2Tokens) -> GenericOAuthUserInfoFuture + Send + Sync>;
19pub type GenericOAuthMapProfileFuture =
20 Pin<Box<dyn Future<Output = Result<OAuth2UserInfo, OAuthError>> + Send>>;
21pub type GenericOAuthMapProfileToUser =
22 Arc<dyn Fn(OAuth2UserInfo) -> GenericOAuthMapProfileFuture + Send + Sync>;
23pub type GenericOAuthRefreshAccessToken =
24 Arc<dyn Fn(String) -> GenericOAuthTokenFuture + Send + Sync>;
25pub type GenericOAuthVerifyIdTokenFuture =
26 Pin<Box<dyn Future<Output = Result<bool, OAuthError>> + Send>>;
27pub type GenericOAuthVerifyIdToken =
28 Arc<dyn Fn(SocialIdTokenRequest) -> GenericOAuthVerifyIdTokenFuture + Send + Sync>;
29pub type GenericOAuthRevokeTokenFuture =
30 Pin<Box<dyn Future<Output = Result<(), OAuthError>> + Send>>;
31pub type GenericOAuthRevokeToken =
32 Arc<dyn Fn(String) -> GenericOAuthRevokeTokenFuture + Send + Sync>;
33pub type GenericOAuthParams = BTreeMap<String, String>;
34pub type GenericOAuthParamsFuture =
35 Pin<Box<dyn Future<Output = Result<GenericOAuthParams, OAuthError>> + Send>>;
36pub type GenericOAuthParamsCallback =
37 Arc<dyn Fn(GenericOAuthParamsContext) -> GenericOAuthParamsFuture + Send + Sync>;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum GenericOAuthFlow {
41 SignIn,
42 Link,
43 Callback,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct GenericOAuthParamsContext {
48 pub provider_id: String,
49 pub flow: GenericOAuthFlow,
50 pub redirect_uri: String,
51}
52
53#[derive(Debug, Clone, Default, PartialEq, Eq)]
54pub struct GenericOAuthTokenRequest {
55 pub code: String,
56 pub redirect_uri: String,
57 pub code_verifier: Option<String>,
58 pub device_id: Option<String>,
59}
60
61#[derive(Clone, Default)]
62pub struct GenericOAuthOptions {
63 pub config: Vec<GenericOAuthConfig>,
64}
65
66impl std::fmt::Debug for GenericOAuthOptions {
67 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 formatter
69 .debug_struct("GenericOAuthOptions")
70 .field("config", &self.config)
71 .finish()
72 }
73}
74
75impl GenericOAuthOptions {
76 pub(crate) fn to_json(&self) -> Value {
77 json!({
78 "config": self.config.iter().map(GenericOAuthConfig::public_json).collect::<Vec<_>>(),
79 })
80 }
81
82 pub(crate) fn find(&self, provider_id: &str) -> Option<&GenericOAuthConfig> {
83 self.config
84 .iter()
85 .find(|config| config.provider_id == provider_id)
86 }
87}
88
89#[derive(Clone)]
90pub struct GenericOAuthConfig {
91 pub provider_id: String,
92 pub discovery_url: Option<String>,
93 pub issuer: Option<String>,
94 pub require_issuer_validation: bool,
95 pub authorization_url: Option<String>,
96 pub token_url: Option<String>,
97 pub user_info_url: Option<String>,
98 pub client_id: String,
99 pub client_secret: Option<String>,
100 pub scopes: Vec<String>,
101 pub redirect_uri: Option<String>,
102 pub response_type: Option<String>,
103 pub response_mode: Option<String>,
104 pub prompt: Option<String>,
105 pub pkce: bool,
106 pub access_type: Option<String>,
107 pub authorization_url_params: BTreeMap<String, String>,
108 pub token_url_params: BTreeMap<String, String>,
109 pub authorization_url_params_callback: Option<GenericOAuthParamsCallback>,
110 pub token_url_params_callback: Option<GenericOAuthParamsCallback>,
111 pub disable_implicit_sign_up: bool,
112 pub disable_sign_up: bool,
113 pub authentication: ClientAuthentication,
114 pub discovery_headers: BTreeMap<String, String>,
115 pub authorization_headers: BTreeMap<String, String>,
116 pub override_user_info: bool,
117 pub get_token: Option<GenericOAuthGetToken>,
118 pub get_user_info: Option<GenericOAuthGetUserInfo>,
119 pub map_profile_to_user: Option<GenericOAuthMapProfileToUser>,
120 pub refresh_access_token: Option<GenericOAuthRefreshAccessToken>,
121 pub verify_id_token: Option<GenericOAuthVerifyIdToken>,
122 pub revoke_token: Option<GenericOAuthRevokeToken>,
123}
124
125impl std::fmt::Debug for GenericOAuthConfig {
126 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 formatter
128 .debug_struct("GenericOAuthConfig")
129 .field("provider_id", &self.provider_id)
130 .field("discovery_url", &self.discovery_url)
131 .field("issuer", &self.issuer)
132 .field("require_issuer_validation", &self.require_issuer_validation)
133 .field("authorization_url", &self.authorization_url)
134 .field("token_url", &self.token_url)
135 .field("user_info_url", &self.user_info_url)
136 .field("client_id", &self.client_id)
137 .field(
138 "client_secret",
139 &self.client_secret.as_ref().map(|_| "<redacted>"),
140 )
141 .field("scopes", &self.scopes)
142 .field("redirect_uri", &self.redirect_uri)
143 .field("response_type", &self.response_type)
144 .field("response_mode", &self.response_mode)
145 .field("prompt", &self.prompt)
146 .field("pkce", &self.pkce)
147 .field("access_type", &self.access_type)
148 .field("authorization_url_params", &self.authorization_url_params)
149 .field("token_url_params", &self.token_url_params)
150 .field(
151 "authorization_url_params_callback",
152 &self.authorization_url_params_callback.is_some(),
153 )
154 .field(
155 "token_url_params_callback",
156 &self.token_url_params_callback.is_some(),
157 )
158 .field("disable_implicit_sign_up", &self.disable_implicit_sign_up)
159 .field("disable_sign_up", &self.disable_sign_up)
160 .field("authentication", &self.authentication)
161 .field("discovery_headers", &self.discovery_headers)
162 .field("authorization_headers", &self.authorization_headers)
163 .field("override_user_info", &self.override_user_info)
164 .field("get_token", &self.get_token.is_some())
165 .field("get_user_info", &self.get_user_info.is_some())
166 .field("map_profile_to_user", &self.map_profile_to_user.is_some())
167 .field("refresh_access_token", &self.refresh_access_token.is_some())
168 .field("verify_id_token", &self.verify_id_token.is_some())
169 .field("revoke_token", &self.revoke_token.is_some())
170 .finish()
171 }
172}
173
174impl GenericOAuthConfig {
175 pub fn new(
176 provider_id: impl Into<String>,
177 client_id: impl Into<String>,
178 client_secret: Option<impl Into<String>>,
179 authorization_url: impl Into<String>,
180 token_url: impl Into<String>,
181 ) -> Self {
182 Self {
183 provider_id: provider_id.into(),
184 client_id: client_id.into(),
185 client_secret: client_secret.map(Into::into),
186 authorization_url: Some(authorization_url.into()),
187 token_url: Some(token_url.into()),
188 ..Self::default()
189 }
190 }
191
192 pub fn discovery(
193 provider_id: impl Into<String>,
194 client_id: impl Into<String>,
195 client_secret: Option<impl Into<String>>,
196 discovery_url: impl Into<String>,
197 ) -> Self {
198 Self {
199 provider_id: provider_id.into(),
200 client_id: client_id.into(),
201 client_secret: client_secret.map(Into::into),
202 discovery_url: Some(discovery_url.into()),
203 ..Self::default()
204 }
205 }
206
207 pub(crate) fn provider_options(&self) -> ProviderOptions {
208 ProviderOptions {
209 client_id: Some(ClientId::Single(self.client_id.clone())),
210 client_secret: self.client_secret.clone(),
211 scope: self.scopes.clone(),
212 redirect_uri: self.redirect_uri.clone(),
213 authorization_endpoint: self.authorization_url.clone(),
214 disable_implicit_sign_up: self.disable_implicit_sign_up,
215 disable_sign_up: self.disable_sign_up,
216 prompt: self.prompt.clone(),
217 response_mode: self.response_mode.clone(),
218 override_user_info_on_sign_in: self.override_user_info,
219 ..ProviderOptions::default()
220 }
221 }
222
223 pub(crate) fn scopes(&self, request_scopes: Vec<String>) -> Vec<String> {
224 if request_scopes.is_empty() {
225 return self.scopes.clone();
226 }
227 let mut scopes = request_scopes;
228 scopes.extend(self.scopes.clone());
229 scopes
230 }
231
232 fn public_json(&self) -> Value {
233 json!({
234 "providerId": self.provider_id,
235 "discoveryUrl": self.discovery_url,
236 "issuer": self.issuer,
237 "requireIssuerValidation": self.require_issuer_validation,
238 "authorizationUrl": self.authorization_url,
239 "tokenUrl": self.token_url,
240 "userInfoUrl": self.user_info_url,
241 "clientId": self.client_id,
242 "scopes": self.scopes,
243 "redirectURI": self.redirect_uri,
244 "pkce": self.pkce,
245 "disableImplicitSignUp": self.disable_implicit_sign_up,
246 "disableSignUp": self.disable_sign_up,
247 "overrideUserInfo": self.override_user_info,
248 })
249 }
250}
251
252impl Default for GenericOAuthConfig {
253 fn default() -> Self {
254 Self {
255 provider_id: String::new(),
256 discovery_url: None,
257 issuer: None,
258 require_issuer_validation: false,
259 authorization_url: None,
260 token_url: None,
261 user_info_url: None,
262 client_id: String::new(),
263 client_secret: None,
264 scopes: Vec::new(),
265 redirect_uri: None,
266 response_type: None,
267 response_mode: None,
268 prompt: None,
269 pkce: false,
270 access_type: None,
271 authorization_url_params: BTreeMap::new(),
272 token_url_params: BTreeMap::new(),
273 authorization_url_params_callback: None,
274 token_url_params_callback: None,
275 disable_implicit_sign_up: false,
276 disable_sign_up: false,
277 authentication: ClientAuthentication::Post,
278 discovery_headers: BTreeMap::new(),
279 authorization_headers: BTreeMap::new(),
280 override_user_info: false,
281 get_token: None,
282 get_user_info: None,
283 map_profile_to_user: None,
284 refresh_access_token: None,
285 verify_id_token: None,
286 revoke_token: None,
287 }
288 }
289}