Skip to main content

openauth_plugins/generic_oauth/
config.rs

1use openauth_oauth::oauth2::{
2    ClientAuthentication, ClientId, OAuth2Tokens, OAuth2UserInfo, OAuthError, ProviderOptions,
3};
4use serde_json::{json, Value};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10pub type GenericOAuthTokenFuture =
11    Pin<Box<dyn Future<Output = Result<OAuth2Tokens, OAuthError>> + Send>>;
12pub type GenericOAuthGetToken =
13    Arc<dyn Fn(GenericOAuthTokenRequest) -> GenericOAuthTokenFuture + Send + Sync>;
14pub type GenericOAuthUserInfoFuture =
15    Pin<Box<dyn Future<Output = Result<Option<OAuth2UserInfo>, OAuthError>> + Send>>;
16pub type GenericOAuthGetUserInfo =
17    Arc<dyn Fn(OAuth2Tokens) -> GenericOAuthUserInfoFuture + Send + Sync>;
18pub type GenericOAuthMapProfileFuture =
19    Pin<Box<dyn Future<Output = Result<OAuth2UserInfo, OAuthError>> + Send>>;
20pub type GenericOAuthMapProfileToUser =
21    Arc<dyn Fn(OAuth2UserInfo) -> GenericOAuthMapProfileFuture + Send + Sync>;
22pub type GenericOAuthParams = BTreeMap<String, String>;
23pub type GenericOAuthParamsFuture =
24    Pin<Box<dyn Future<Output = Result<GenericOAuthParams, OAuthError>> + Send>>;
25pub type GenericOAuthParamsCallback =
26    Arc<dyn Fn(GenericOAuthParamsContext) -> GenericOAuthParamsFuture + Send + Sync>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum GenericOAuthFlow {
30    SignIn,
31    Link,
32    Callback,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct GenericOAuthParamsContext {
37    pub provider_id: String,
38    pub flow: GenericOAuthFlow,
39    pub redirect_uri: String,
40}
41
42#[derive(Debug, Clone, Default, PartialEq, Eq)]
43pub struct GenericOAuthTokenRequest {
44    pub code: String,
45    pub redirect_uri: String,
46    pub code_verifier: Option<String>,
47    pub device_id: Option<String>,
48}
49
50#[derive(Clone, Default)]
51pub struct GenericOAuthOptions {
52    pub config: Vec<GenericOAuthConfig>,
53}
54
55impl std::fmt::Debug for GenericOAuthOptions {
56    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        formatter
58            .debug_struct("GenericOAuthOptions")
59            .field("config", &self.config)
60            .finish()
61    }
62}
63
64impl GenericOAuthOptions {
65    pub(crate) fn to_json(&self) -> Value {
66        json!({
67            "config": self.config.iter().map(GenericOAuthConfig::public_json).collect::<Vec<_>>(),
68        })
69    }
70
71    pub(crate) fn find(&self, provider_id: &str) -> Option<&GenericOAuthConfig> {
72        self.config
73            .iter()
74            .find(|config| config.provider_id == provider_id)
75    }
76}
77
78#[derive(Clone)]
79pub struct GenericOAuthConfig {
80    pub provider_id: String,
81    pub discovery_url: Option<String>,
82    pub issuer: Option<String>,
83    pub require_issuer_validation: bool,
84    pub authorization_url: Option<String>,
85    pub token_url: Option<String>,
86    pub user_info_url: Option<String>,
87    pub client_id: String,
88    pub client_secret: Option<String>,
89    pub scopes: Vec<String>,
90    pub redirect_uri: Option<String>,
91    pub response_type: Option<String>,
92    pub response_mode: Option<String>,
93    pub prompt: Option<String>,
94    pub pkce: bool,
95    pub access_type: Option<String>,
96    pub authorization_url_params: BTreeMap<String, String>,
97    pub token_url_params: BTreeMap<String, String>,
98    pub authorization_url_params_callback: Option<GenericOAuthParamsCallback>,
99    pub token_url_params_callback: Option<GenericOAuthParamsCallback>,
100    pub disable_implicit_sign_up: bool,
101    pub disable_sign_up: bool,
102    pub authentication: ClientAuthentication,
103    pub discovery_headers: BTreeMap<String, String>,
104    pub authorization_headers: BTreeMap<String, String>,
105    pub override_user_info: bool,
106    pub get_token: Option<GenericOAuthGetToken>,
107    pub get_user_info: Option<GenericOAuthGetUserInfo>,
108    pub map_profile_to_user: Option<GenericOAuthMapProfileToUser>,
109}
110
111impl std::fmt::Debug for GenericOAuthConfig {
112    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        formatter
114            .debug_struct("GenericOAuthConfig")
115            .field("provider_id", &self.provider_id)
116            .field("discovery_url", &self.discovery_url)
117            .field("issuer", &self.issuer)
118            .field("require_issuer_validation", &self.require_issuer_validation)
119            .field("authorization_url", &self.authorization_url)
120            .field("token_url", &self.token_url)
121            .field("user_info_url", &self.user_info_url)
122            .field("client_id", &self.client_id)
123            .field(
124                "client_secret",
125                &self.client_secret.as_ref().map(|_| "<redacted>"),
126            )
127            .field("scopes", &self.scopes)
128            .field("redirect_uri", &self.redirect_uri)
129            .field("response_type", &self.response_type)
130            .field("response_mode", &self.response_mode)
131            .field("prompt", &self.prompt)
132            .field("pkce", &self.pkce)
133            .field("access_type", &self.access_type)
134            .field("authorization_url_params", &self.authorization_url_params)
135            .field("token_url_params", &self.token_url_params)
136            .field(
137                "authorization_url_params_callback",
138                &self.authorization_url_params_callback.is_some(),
139            )
140            .field(
141                "token_url_params_callback",
142                &self.token_url_params_callback.is_some(),
143            )
144            .field("disable_implicit_sign_up", &self.disable_implicit_sign_up)
145            .field("disable_sign_up", &self.disable_sign_up)
146            .field("authentication", &self.authentication)
147            .field("discovery_headers", &self.discovery_headers)
148            .field("authorization_headers", &self.authorization_headers)
149            .field("override_user_info", &self.override_user_info)
150            .field("get_token", &self.get_token.is_some())
151            .field("get_user_info", &self.get_user_info.is_some())
152            .field("map_profile_to_user", &self.map_profile_to_user.is_some())
153            .finish()
154    }
155}
156
157impl GenericOAuthConfig {
158    pub fn new(
159        provider_id: impl Into<String>,
160        client_id: impl Into<String>,
161        client_secret: Option<impl Into<String>>,
162        authorization_url: impl Into<String>,
163        token_url: impl Into<String>,
164    ) -> Self {
165        Self {
166            provider_id: provider_id.into(),
167            client_id: client_id.into(),
168            client_secret: client_secret.map(Into::into),
169            authorization_url: Some(authorization_url.into()),
170            token_url: Some(token_url.into()),
171            ..Self::default()
172        }
173    }
174
175    pub fn discovery(
176        provider_id: impl Into<String>,
177        client_id: impl Into<String>,
178        client_secret: Option<impl Into<String>>,
179        discovery_url: impl Into<String>,
180    ) -> Self {
181        Self {
182            provider_id: provider_id.into(),
183            client_id: client_id.into(),
184            client_secret: client_secret.map(Into::into),
185            discovery_url: Some(discovery_url.into()),
186            ..Self::default()
187        }
188    }
189
190    pub(crate) fn provider_options(&self) -> ProviderOptions {
191        ProviderOptions {
192            client_id: Some(ClientId::Single(self.client_id.clone())),
193            client_secret: self.client_secret.clone(),
194            scope: self.scopes.clone(),
195            redirect_uri: self.redirect_uri.clone(),
196            authorization_endpoint: self.authorization_url.clone(),
197            disable_implicit_sign_up: self.disable_implicit_sign_up,
198            disable_sign_up: self.disable_sign_up,
199            prompt: self.prompt.clone(),
200            response_mode: self.response_mode.clone(),
201            override_user_info_on_sign_in: self.override_user_info,
202            ..ProviderOptions::default()
203        }
204    }
205
206    pub(crate) fn scopes(&self, request_scopes: Vec<String>) -> Vec<String> {
207        if request_scopes.is_empty() {
208            return self.scopes.clone();
209        }
210        let mut scopes = request_scopes;
211        scopes.extend(self.scopes.clone());
212        scopes
213    }
214
215    fn public_json(&self) -> Value {
216        json!({
217            "providerId": self.provider_id,
218            "discoveryUrl": self.discovery_url,
219            "issuer": self.issuer,
220            "requireIssuerValidation": self.require_issuer_validation,
221            "authorizationUrl": self.authorization_url,
222            "tokenUrl": self.token_url,
223            "userInfoUrl": self.user_info_url,
224            "clientId": self.client_id,
225            "scopes": self.scopes,
226            "redirectURI": self.redirect_uri,
227            "pkce": self.pkce,
228            "disableImplicitSignUp": self.disable_implicit_sign_up,
229            "disableSignUp": self.disable_sign_up,
230            "overrideUserInfo": self.override_user_info,
231        })
232    }
233}
234
235impl Default for GenericOAuthConfig {
236    fn default() -> Self {
237        Self {
238            provider_id: String::new(),
239            discovery_url: None,
240            issuer: None,
241            require_issuer_validation: false,
242            authorization_url: None,
243            token_url: None,
244            user_info_url: None,
245            client_id: String::new(),
246            client_secret: None,
247            scopes: Vec::new(),
248            redirect_uri: None,
249            response_type: None,
250            response_mode: None,
251            prompt: None,
252            pkce: false,
253            access_type: None,
254            authorization_url_params: BTreeMap::new(),
255            token_url_params: BTreeMap::new(),
256            authorization_url_params_callback: None,
257            token_url_params_callback: None,
258            disable_implicit_sign_up: false,
259            disable_sign_up: false,
260            authentication: ClientAuthentication::Post,
261            discovery_headers: BTreeMap::new(),
262            authorization_headers: BTreeMap::new(),
263            override_user_info: false,
264            get_token: None,
265            get_user_info: None,
266            map_profile_to_user: None,
267        }
268    }
269}