firebase_rs_sdk/auth/oauth/
provider.rs

1use std::collections::HashMap;
2
3use url::Url;
4
5use super::OAuthRequest;
6use crate::auth::api::Auth;
7use crate::auth::error::{AuthError, AuthResult};
8use crate::auth::model::UserCredential;
9use crate::auth::oauth::RedirectOperation;
10
11/// Builder-like representation of an OAuth identity provider.
12///
13/// The provider stores configuration (scopes, custom parameters, language hints)
14/// and creates [`OAuthRequest`] instances that can be routed through the popup
15/// or redirect handlers registered on [`Auth`].
16#[derive(Debug, Clone)]
17pub struct OAuthProvider {
18    provider_id: String,
19    authorization_endpoint: String,
20    scopes: Vec<String>,
21    custom_parameters: HashMap<String, String>,
22    display_name: Option<String>,
23    language_code: Option<String>,
24}
25
26impl OAuthProvider {
27    /// Creates a new provider with the given ID and authorization endpoint.
28    pub fn new(provider_id: impl Into<String>, authorization_endpoint: impl Into<String>) -> Self {
29        Self {
30            provider_id: provider_id.into(),
31            authorization_endpoint: authorization_endpoint.into(),
32            scopes: Vec::new(),
33            custom_parameters: HashMap::new(),
34            display_name: None,
35            language_code: None,
36        }
37    }
38
39    /// Returns the provider identifier (e.g. `google.com`).
40    pub fn provider_id(&self) -> &str {
41        &self.provider_id
42    }
43
44    /// Returns the full authorization endpoint URL.
45    pub fn authorization_endpoint(&self) -> &str {
46        &self.authorization_endpoint
47    }
48
49    /// Returns the configured OAuth scopes.
50    pub fn scopes(&self) -> &[String] {
51        &self.scopes
52    }
53
54    /// Returns any custom query parameters used when initiating flows.
55    pub fn custom_parameters(&self) -> &HashMap<String, String> {
56        &self.custom_parameters
57    }
58
59    /// Returns an optional user-facing display name for the provider.
60    pub fn display_name(&self) -> Option<&str> {
61        self.display_name.as_deref()
62    }
63
64    /// Returns the preferred language hint for provider UX.
65    pub fn language_code(&self) -> Option<&str> {
66        self.language_code.as_deref()
67    }
68
69    /// Adds a scope to the provider if it has not been added yet.
70    pub fn add_scope(&mut self, scope: impl Into<String>) {
71        let value = scope.into();
72        if !self.scopes.contains(&value) {
73            self.scopes.push(value);
74        }
75    }
76
77    /// Replaces the provider scopes with the provided list.
78    pub fn set_scopes<I, S>(&mut self, scopes: I)
79    where
80        I: IntoIterator<Item = S>,
81        S: Into<String>,
82    {
83        self.scopes.clear();
84        self.scopes.extend(scopes.into_iter().map(Into::into));
85    }
86
87    /// Overwrites the custom parameters included in authorization requests.
88    pub fn set_custom_parameters(&mut self, parameters: HashMap<String, String>) -> &mut Self {
89        self.custom_parameters = parameters;
90        self
91    }
92
93    /// Sets the user-visible display name.
94    pub fn set_display_name(&mut self, value: impl Into<String>) -> &mut Self {
95        self.display_name = Some(value.into());
96        self
97    }
98
99    /// Sets the preferred language hint passed to the provider.
100    pub fn set_language_code(&mut self, value: impl Into<String>) -> &mut Self {
101        self.language_code = Some(value.into());
102        self
103    }
104
105    /// Builds the `OAuthRequest` that will be passed to popup/redirect handlers.
106    pub fn build_request(&self, auth: &Auth) -> AuthResult<OAuthRequest> {
107        let mut url = Url::parse(&self.authorization_endpoint).map_err(|err| {
108            AuthError::InvalidCredential(format!(
109                "Invalid authorization endpoint for provider {}: {err}",
110                self.provider_id
111            ))
112        })?;
113
114        {
115            let mut pairs = url.query_pairs_mut();
116            if !self.scopes.is_empty() {
117                pairs.append_pair("scope", &self.scopes.join(" "));
118            }
119            if let Some(lang) = &self.language_code {
120                pairs.append_pair("hl", lang);
121            }
122            if let Some(auth_domain) = auth.app().options().auth_domain {
123                pairs.append_pair("auth_domain", &auth_domain);
124            }
125            if let Some(api_key) = auth.app().options().api_key {
126                pairs.append_pair("apiKey", &api_key);
127            }
128            for (key, value) in &self.custom_parameters {
129                pairs.append_pair(key, value);
130            }
131        }
132
133        let auth_url: String = url.into();
134        let mut request = OAuthRequest::new(self.provider_id.clone(), auth_url);
135        if let Some(display) = &self.display_name {
136            request = request.with_display_name(display.clone());
137        }
138        if let Some(lang) = &self.language_code {
139            request = request.with_language_code(lang.clone());
140        }
141        request = request.with_custom_parameters(self.custom_parameters.clone());
142        Ok(request)
143    }
144
145    /// Runs the configured popup handler and returns the produced credential.
146    /// Executes the sign-in flow using a popup handler.
147    pub fn sign_in_with_popup(&self, auth: &Auth) -> AuthResult<UserCredential> {
148        let handler = auth.popup_handler().ok_or(AuthError::NotImplemented(
149            "OAuth popup handler not registered",
150        ))?;
151        let request = self.build_request(auth)?;
152        let credential = handler.open_popup(request)?;
153        auth.sign_in_with_oauth_credential(credential)
154    }
155
156    /// Links the current user with this provider using a popup flow.
157    pub fn link_with_popup(&self, auth: &Auth) -> AuthResult<UserCredential> {
158        let handler = auth.popup_handler().ok_or(AuthError::NotImplemented(
159            "OAuth popup handler not registered",
160        ))?;
161        let request = self.build_request(auth)?;
162        let credential = handler.open_popup(request)?;
163        auth.link_with_oauth_credential(credential)
164    }
165
166    /// Delegates to the redirect handler to start a redirect based flow.
167    pub fn sign_in_with_redirect(&self, auth: &Auth) -> AuthResult<()> {
168        let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
169            "OAuth redirect handler not registered",
170        ))?;
171        auth.set_pending_redirect_event(&self.provider_id, RedirectOperation::SignIn)?;
172        let request = self.build_request(auth)?;
173        if let Err(err) = handler.initiate_redirect(request) {
174            auth.clear_pending_redirect_event()?;
175            return Err(err);
176        }
177        Ok(())
178    }
179
180    /// Initiates a redirect flow to link the current user with this provider.
181    pub fn link_with_redirect(&self, auth: &Auth) -> AuthResult<()> {
182        let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
183            "OAuth redirect handler not registered",
184        ))?;
185        auth.set_pending_redirect_event(&self.provider_id, RedirectOperation::Link)?;
186        let request = self.build_request(auth)?;
187        if let Err(err) = handler.initiate_redirect(request) {
188            auth.clear_pending_redirect_event()?;
189            return Err(err);
190        }
191        Ok(())
192    }
193
194    /// Completes a redirect flow using the registered redirect handler.
195    ///
196    /// The provider does not influence result parsing at this stage; the
197    /// handler is responsible for decoding whichever callback mechanism the
198    /// hosting platform uses.
199    pub fn get_redirect_result(auth: &Auth) -> AuthResult<Option<UserCredential>> {
200        let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
201            "OAuth redirect handler not registered",
202        ))?;
203        let pending = auth.take_pending_redirect_event()?;
204        if pending.is_none() {
205            return Ok(None);
206        }
207        let pending = pending.unwrap();
208
209        match handler.complete_redirect()? {
210            Some(credential) => match pending.operation {
211                RedirectOperation::Link => auth.link_with_oauth_credential(credential).map(Some),
212                RedirectOperation::SignIn => {
213                    auth.sign_in_with_oauth_credential(credential).map(Some)
214                }
215            },
216            None => Ok(None),
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::app::{FirebaseApp, FirebaseAppConfig, FirebaseOptions};
225    use crate::component::ComponentContainer;
226
227    use crate::auth::api::Auth;
228
229    use std::sync::{Arc, Mutex};
230
231    fn build_test_auth() -> Arc<Auth> {
232        let options = FirebaseOptions {
233            api_key: Some("test-key".into()),
234            auth_domain: Some("example.firebaseapp.com".into()),
235            ..Default::default()
236        };
237        let config = FirebaseAppConfig::new("test-app", false);
238        let container = ComponentContainer::new("test-app");
239        let app = FirebaseApp::new(options, config, container);
240        Auth::builder(app).build().unwrap()
241    }
242
243    #[test]
244    fn build_request_includes_scopes_and_params() {
245        let auth = build_test_auth();
246        let mut provider = OAuthProvider::new("google.com", "https://example.com/oauth");
247        provider.add_scope("profile");
248        provider.set_language_code("en");
249        provider.set_custom_parameters(
250            [("prompt".to_string(), "select_account".to_string())]
251                .into_iter()
252                .collect(),
253        );
254
255        let request = provider.build_request(&auth).unwrap();
256        assert!(request.auth_url.contains("scope=profile"));
257        assert!(request.auth_url.contains("apiKey=test-key"));
258        assert!(request
259            .auth_url
260            .contains("auth_domain=example.firebaseapp.com"));
261        assert!(request.auth_url.contains("prompt=select_account"));
262        assert_eq!(request.provider_id, "google.com");
263    }
264
265    struct RecordingRedirectHandler {
266        fail: bool,
267        initiated: Arc<Mutex<bool>>,
268    }
269
270    impl crate::auth::OAuthRedirectHandler for RecordingRedirectHandler {
271        fn initiate_redirect(&self, _request: OAuthRequest) -> AuthResult<()> {
272            *self.initiated.lock().unwrap() = true;
273            if self.fail {
274                Err(AuthError::InvalidCredential("failure".into()))
275            } else {
276                Ok(())
277            }
278        }
279
280        fn complete_redirect(&self) -> AuthResult<Option<crate::auth::oauth::AuthCredential>> {
281            Ok(None)
282        }
283    }
284
285    #[test]
286    fn link_with_redirect_sets_and_clears_event_on_success() {
287        let auth = build_test_auth();
288        let handler = Arc::new(RecordingRedirectHandler {
289            fail: false,
290            initiated: Arc::new(Mutex::new(false)),
291        });
292        auth.set_redirect_handler(handler);
293
294        let provider = OAuthProvider::new("google.com", "https://example.com");
295        provider.link_with_redirect(&auth).unwrap();
296        let event = auth.take_pending_redirect_event().unwrap();
297        assert!(event.is_some());
298        let event = event.unwrap();
299        assert_eq!(event.provider_id, "google.com");
300        assert_eq!(event.operation, RedirectOperation::Link);
301    }
302
303    #[test]
304    fn link_with_redirect_clears_on_failure() {
305        let auth = build_test_auth();
306        let handler = Arc::new(RecordingRedirectHandler {
307            fail: true,
308            initiated: Arc::new(Mutex::new(false)),
309        });
310        auth.set_redirect_handler(handler);
311
312        let provider = OAuthProvider::new("google.com", "https://example.com");
313        assert!(provider.link_with_redirect(&auth).is_err());
314        assert!(auth.take_pending_redirect_event().unwrap().is_none());
315    }
316}