1use openauth_oauth::oauth2::{
2 ClientAuthentication, ClientId, OAuth2Tokens, OAuth2UserInfo, OAuthError, ProviderOptions,
3};
4use serde_json::{json, Value};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10pub type GenericOAuthTokenFuture =
11 Pin<Box<dyn Future<Output = Result<OAuth2Tokens, OAuthError>> + Send>>;
12pub type GenericOAuthGetToken =
13 Arc<dyn Fn(GenericOAuthTokenRequest) -> GenericOAuthTokenFuture + Send + Sync>;
14pub type GenericOAuthUserInfoFuture =
15 Pin<Box<dyn Future<Output = Result<Option<OAuth2UserInfo>, OAuthError>> + Send>>;
16pub type GenericOAuthGetUserInfo =
17 Arc<dyn Fn(OAuth2Tokens) -> GenericOAuthUserInfoFuture + Send + Sync>;
18pub type GenericOAuthMapProfileFuture =
19 Pin<Box<dyn Future<Output = Result<OAuth2UserInfo, OAuthError>> + Send>>;
20pub type GenericOAuthMapProfileToUser =
21 Arc<dyn Fn(OAuth2UserInfo) -> GenericOAuthMapProfileFuture + Send + Sync>;
22pub type GenericOAuthParams = BTreeMap<String, String>;
23pub type GenericOAuthParamsFuture =
24 Pin<Box<dyn Future<Output = Result<GenericOAuthParams, OAuthError>> + Send>>;
25pub type GenericOAuthParamsCallback =
26 Arc<dyn Fn(GenericOAuthParamsContext) -> GenericOAuthParamsFuture + Send + Sync>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum GenericOAuthFlow {
30 SignIn,
31 Link,
32 Callback,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct GenericOAuthParamsContext {
37 pub provider_id: String,
38 pub flow: GenericOAuthFlow,
39 pub redirect_uri: String,
40}
41
42#[derive(Debug, Clone, Default, PartialEq, Eq)]
43pub struct GenericOAuthTokenRequest {
44 pub code: String,
45 pub redirect_uri: String,
46 pub code_verifier: Option<String>,
47 pub device_id: Option<String>,
48}
49
50#[derive(Clone, Default)]
51pub struct GenericOAuthOptions {
52 pub config: Vec<GenericOAuthConfig>,
53}
54
55impl std::fmt::Debug for GenericOAuthOptions {
56 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 formatter
58 .debug_struct("GenericOAuthOptions")
59 .field("config", &self.config)
60 .finish()
61 }
62}
63
64impl GenericOAuthOptions {
65 pub(crate) fn to_json(&self) -> Value {
66 json!({
67 "config": self.config.iter().map(GenericOAuthConfig::public_json).collect::<Vec<_>>(),
68 })
69 }
70
71 pub(crate) fn find(&self, provider_id: &str) -> Option<&GenericOAuthConfig> {
72 self.config
73 .iter()
74 .find(|config| config.provider_id == provider_id)
75 }
76}
77
78#[derive(Clone)]
79pub struct GenericOAuthConfig {
80 pub provider_id: String,
81 pub discovery_url: Option<String>,
82 pub issuer: Option<String>,
83 pub require_issuer_validation: bool,
84 pub authorization_url: Option<String>,
85 pub token_url: Option<String>,
86 pub user_info_url: Option<String>,
87 pub client_id: String,
88 pub client_secret: Option<String>,
89 pub scopes: Vec<String>,
90 pub redirect_uri: Option<String>,
91 pub response_type: Option<String>,
92 pub response_mode: Option<String>,
93 pub prompt: Option<String>,
94 pub pkce: bool,
95 pub access_type: Option<String>,
96 pub authorization_url_params: BTreeMap<String, String>,
97 pub token_url_params: BTreeMap<String, String>,
98 pub authorization_url_params_callback: Option<GenericOAuthParamsCallback>,
99 pub token_url_params_callback: Option<GenericOAuthParamsCallback>,
100 pub disable_implicit_sign_up: bool,
101 pub disable_sign_up: bool,
102 pub authentication: ClientAuthentication,
103 pub discovery_headers: BTreeMap<String, String>,
104 pub authorization_headers: BTreeMap<String, String>,
105 pub override_user_info: bool,
106 pub get_token: Option<GenericOAuthGetToken>,
107 pub get_user_info: Option<GenericOAuthGetUserInfo>,
108 pub map_profile_to_user: Option<GenericOAuthMapProfileToUser>,
109}
110
111impl std::fmt::Debug for GenericOAuthConfig {
112 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 formatter
114 .debug_struct("GenericOAuthConfig")
115 .field("provider_id", &self.provider_id)
116 .field("discovery_url", &self.discovery_url)
117 .field("issuer", &self.issuer)
118 .field("require_issuer_validation", &self.require_issuer_validation)
119 .field("authorization_url", &self.authorization_url)
120 .field("token_url", &self.token_url)
121 .field("user_info_url", &self.user_info_url)
122 .field("client_id", &self.client_id)
123 .field(
124 "client_secret",
125 &self.client_secret.as_ref().map(|_| "<redacted>"),
126 )
127 .field("scopes", &self.scopes)
128 .field("redirect_uri", &self.redirect_uri)
129 .field("response_type", &self.response_type)
130 .field("response_mode", &self.response_mode)
131 .field("prompt", &self.prompt)
132 .field("pkce", &self.pkce)
133 .field("access_type", &self.access_type)
134 .field("authorization_url_params", &self.authorization_url_params)
135 .field("token_url_params", &self.token_url_params)
136 .field(
137 "authorization_url_params_callback",
138 &self.authorization_url_params_callback.is_some(),
139 )
140 .field(
141 "token_url_params_callback",
142 &self.token_url_params_callback.is_some(),
143 )
144 .field("disable_implicit_sign_up", &self.disable_implicit_sign_up)
145 .field("disable_sign_up", &self.disable_sign_up)
146 .field("authentication", &self.authentication)
147 .field("discovery_headers", &self.discovery_headers)
148 .field("authorization_headers", &self.authorization_headers)
149 .field("override_user_info", &self.override_user_info)
150 .field("get_token", &self.get_token.is_some())
151 .field("get_user_info", &self.get_user_info.is_some())
152 .field("map_profile_to_user", &self.map_profile_to_user.is_some())
153 .finish()
154 }
155}
156
157impl GenericOAuthConfig {
158 pub fn new(
159 provider_id: impl Into<String>,
160 client_id: impl Into<String>,
161 client_secret: Option<impl Into<String>>,
162 authorization_url: impl Into<String>,
163 token_url: impl Into<String>,
164 ) -> Self {
165 Self {
166 provider_id: provider_id.into(),
167 client_id: client_id.into(),
168 client_secret: client_secret.map(Into::into),
169 authorization_url: Some(authorization_url.into()),
170 token_url: Some(token_url.into()),
171 ..Self::default()
172 }
173 }
174
175 pub fn discovery(
176 provider_id: impl Into<String>,
177 client_id: impl Into<String>,
178 client_secret: Option<impl Into<String>>,
179 discovery_url: impl Into<String>,
180 ) -> Self {
181 Self {
182 provider_id: provider_id.into(),
183 client_id: client_id.into(),
184 client_secret: client_secret.map(Into::into),
185 discovery_url: Some(discovery_url.into()),
186 ..Self::default()
187 }
188 }
189
190 pub(crate) fn provider_options(&self) -> ProviderOptions {
191 ProviderOptions {
192 client_id: Some(ClientId::Single(self.client_id.clone())),
193 client_secret: self.client_secret.clone(),
194 scope: self.scopes.clone(),
195 redirect_uri: self.redirect_uri.clone(),
196 authorization_endpoint: self.authorization_url.clone(),
197 disable_implicit_sign_up: self.disable_implicit_sign_up,
198 disable_sign_up: self.disable_sign_up,
199 prompt: self.prompt.clone(),
200 response_mode: self.response_mode.clone(),
201 override_user_info_on_sign_in: self.override_user_info,
202 ..ProviderOptions::default()
203 }
204 }
205
206 pub(crate) fn scopes(&self, request_scopes: Vec<String>) -> Vec<String> {
207 if request_scopes.is_empty() {
208 return self.scopes.clone();
209 }
210 let mut scopes = request_scopes;
211 scopes.extend(self.scopes.clone());
212 scopes
213 }
214
215 fn public_json(&self) -> Value {
216 json!({
217 "providerId": self.provider_id,
218 "discoveryUrl": self.discovery_url,
219 "issuer": self.issuer,
220 "requireIssuerValidation": self.require_issuer_validation,
221 "authorizationUrl": self.authorization_url,
222 "tokenUrl": self.token_url,
223 "userInfoUrl": self.user_info_url,
224 "clientId": self.client_id,
225 "scopes": self.scopes,
226 "redirectURI": self.redirect_uri,
227 "pkce": self.pkce,
228 "disableImplicitSignUp": self.disable_implicit_sign_up,
229 "disableSignUp": self.disable_sign_up,
230 "overrideUserInfo": self.override_user_info,
231 })
232 }
233}
234
235impl Default for GenericOAuthConfig {
236 fn default() -> Self {
237 Self {
238 provider_id: String::new(),
239 discovery_url: None,
240 issuer: None,
241 require_issuer_validation: false,
242 authorization_url: None,
243 token_url: None,
244 user_info_url: None,
245 client_id: String::new(),
246 client_secret: None,
247 scopes: Vec::new(),
248 redirect_uri: None,
249 response_type: None,
250 response_mode: None,
251 prompt: None,
252 pkce: false,
253 access_type: None,
254 authorization_url_params: BTreeMap::new(),
255 token_url_params: BTreeMap::new(),
256 authorization_url_params_callback: None,
257 token_url_params_callback: None,
258 disable_implicit_sign_up: false,
259 disable_sign_up: false,
260 authentication: ClientAuthentication::Post,
261 discovery_headers: BTreeMap::new(),
262 authorization_headers: BTreeMap::new(),
263 override_user_info: false,
264 get_token: None,
265 get_user_info: None,
266 map_profile_to_user: None,
267 }
268 }
269}