firebase_rs_sdk/auth/oauth/
provider.rs1use std::collections::HashMap;
2
3use url::Url;
4
5use super::OAuthRequest;
6use crate::auth::api::Auth;
7use crate::auth::error::{AuthError, AuthResult};
8use crate::auth::model::UserCredential;
9use crate::auth::oauth::RedirectOperation;
10
11#[derive(Debug, Clone)]
17pub struct OAuthProvider {
18 provider_id: String,
19 authorization_endpoint: String,
20 scopes: Vec<String>,
21 custom_parameters: HashMap<String, String>,
22 display_name: Option<String>,
23 language_code: Option<String>,
24}
25
26impl OAuthProvider {
27 pub fn new(provider_id: impl Into<String>, authorization_endpoint: impl Into<String>) -> Self {
29 Self {
30 provider_id: provider_id.into(),
31 authorization_endpoint: authorization_endpoint.into(),
32 scopes: Vec::new(),
33 custom_parameters: HashMap::new(),
34 display_name: None,
35 language_code: None,
36 }
37 }
38
39 pub fn provider_id(&self) -> &str {
41 &self.provider_id
42 }
43
44 pub fn authorization_endpoint(&self) -> &str {
46 &self.authorization_endpoint
47 }
48
49 pub fn scopes(&self) -> &[String] {
51 &self.scopes
52 }
53
54 pub fn custom_parameters(&self) -> &HashMap<String, String> {
56 &self.custom_parameters
57 }
58
59 pub fn display_name(&self) -> Option<&str> {
61 self.display_name.as_deref()
62 }
63
64 pub fn language_code(&self) -> Option<&str> {
66 self.language_code.as_deref()
67 }
68
69 pub fn add_scope(&mut self, scope: impl Into<String>) {
71 let value = scope.into();
72 if !self.scopes.contains(&value) {
73 self.scopes.push(value);
74 }
75 }
76
77 pub fn set_scopes<I, S>(&mut self, scopes: I)
79 where
80 I: IntoIterator<Item = S>,
81 S: Into<String>,
82 {
83 self.scopes.clear();
84 self.scopes.extend(scopes.into_iter().map(Into::into));
85 }
86
87 pub fn set_custom_parameters(&mut self, parameters: HashMap<String, String>) -> &mut Self {
89 self.custom_parameters = parameters;
90 self
91 }
92
93 pub fn set_display_name(&mut self, value: impl Into<String>) -> &mut Self {
95 self.display_name = Some(value.into());
96 self
97 }
98
99 pub fn set_language_code(&mut self, value: impl Into<String>) -> &mut Self {
101 self.language_code = Some(value.into());
102 self
103 }
104
105 pub fn build_request(&self, auth: &Auth) -> AuthResult<OAuthRequest> {
107 let mut url = Url::parse(&self.authorization_endpoint).map_err(|err| {
108 AuthError::InvalidCredential(format!(
109 "Invalid authorization endpoint for provider {}: {err}",
110 self.provider_id
111 ))
112 })?;
113
114 {
115 let mut pairs = url.query_pairs_mut();
116 if !self.scopes.is_empty() {
117 pairs.append_pair("scope", &self.scopes.join(" "));
118 }
119 if let Some(lang) = &self.language_code {
120 pairs.append_pair("hl", lang);
121 }
122 if let Some(auth_domain) = auth.app().options().auth_domain {
123 pairs.append_pair("auth_domain", &auth_domain);
124 }
125 if let Some(api_key) = auth.app().options().api_key {
126 pairs.append_pair("apiKey", &api_key);
127 }
128 for (key, value) in &self.custom_parameters {
129 pairs.append_pair(key, value);
130 }
131 }
132
133 let auth_url: String = url.into();
134 let mut request = OAuthRequest::new(self.provider_id.clone(), auth_url);
135 if let Some(display) = &self.display_name {
136 request = request.with_display_name(display.clone());
137 }
138 if let Some(lang) = &self.language_code {
139 request = request.with_language_code(lang.clone());
140 }
141 request = request.with_custom_parameters(self.custom_parameters.clone());
142 Ok(request)
143 }
144
145 pub fn sign_in_with_popup(&self, auth: &Auth) -> AuthResult<UserCredential> {
148 let handler = auth.popup_handler().ok_or(AuthError::NotImplemented(
149 "OAuth popup handler not registered",
150 ))?;
151 let request = self.build_request(auth)?;
152 let credential = handler.open_popup(request)?;
153 auth.sign_in_with_oauth_credential(credential)
154 }
155
156 pub fn link_with_popup(&self, auth: &Auth) -> AuthResult<UserCredential> {
158 let handler = auth.popup_handler().ok_or(AuthError::NotImplemented(
159 "OAuth popup handler not registered",
160 ))?;
161 let request = self.build_request(auth)?;
162 let credential = handler.open_popup(request)?;
163 auth.link_with_oauth_credential(credential)
164 }
165
166 pub fn sign_in_with_redirect(&self, auth: &Auth) -> AuthResult<()> {
168 let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
169 "OAuth redirect handler not registered",
170 ))?;
171 auth.set_pending_redirect_event(&self.provider_id, RedirectOperation::SignIn)?;
172 let request = self.build_request(auth)?;
173 if let Err(err) = handler.initiate_redirect(request) {
174 auth.clear_pending_redirect_event()?;
175 return Err(err);
176 }
177 Ok(())
178 }
179
180 pub fn link_with_redirect(&self, auth: &Auth) -> AuthResult<()> {
182 let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
183 "OAuth redirect handler not registered",
184 ))?;
185 auth.set_pending_redirect_event(&self.provider_id, RedirectOperation::Link)?;
186 let request = self.build_request(auth)?;
187 if let Err(err) = handler.initiate_redirect(request) {
188 auth.clear_pending_redirect_event()?;
189 return Err(err);
190 }
191 Ok(())
192 }
193
194 pub fn get_redirect_result(auth: &Auth) -> AuthResult<Option<UserCredential>> {
200 let handler = auth.redirect_handler().ok_or(AuthError::NotImplemented(
201 "OAuth redirect handler not registered",
202 ))?;
203 let pending = auth.take_pending_redirect_event()?;
204 if pending.is_none() {
205 return Ok(None);
206 }
207 let pending = pending.unwrap();
208
209 match handler.complete_redirect()? {
210 Some(credential) => match pending.operation {
211 RedirectOperation::Link => auth.link_with_oauth_credential(credential).map(Some),
212 RedirectOperation::SignIn => {
213 auth.sign_in_with_oauth_credential(credential).map(Some)
214 }
215 },
216 None => Ok(None),
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::app::{FirebaseApp, FirebaseAppConfig, FirebaseOptions};
225 use crate::component::ComponentContainer;
226
227 use crate::auth::api::Auth;
228
229 use std::sync::{Arc, Mutex};
230
231 fn build_test_auth() -> Arc<Auth> {
232 let options = FirebaseOptions {
233 api_key: Some("test-key".into()),
234 auth_domain: Some("example.firebaseapp.com".into()),
235 ..Default::default()
236 };
237 let config = FirebaseAppConfig::new("test-app", false);
238 let container = ComponentContainer::new("test-app");
239 let app = FirebaseApp::new(options, config, container);
240 Auth::builder(app).build().unwrap()
241 }
242
243 #[test]
244 fn build_request_includes_scopes_and_params() {
245 let auth = build_test_auth();
246 let mut provider = OAuthProvider::new("google.com", "https://example.com/oauth");
247 provider.add_scope("profile");
248 provider.set_language_code("en");
249 provider.set_custom_parameters(
250 [("prompt".to_string(), "select_account".to_string())]
251 .into_iter()
252 .collect(),
253 );
254
255 let request = provider.build_request(&auth).unwrap();
256 assert!(request.auth_url.contains("scope=profile"));
257 assert!(request.auth_url.contains("apiKey=test-key"));
258 assert!(request
259 .auth_url
260 .contains("auth_domain=example.firebaseapp.com"));
261 assert!(request.auth_url.contains("prompt=select_account"));
262 assert_eq!(request.provider_id, "google.com");
263 }
264
265 struct RecordingRedirectHandler {
266 fail: bool,
267 initiated: Arc<Mutex<bool>>,
268 }
269
270 impl crate::auth::OAuthRedirectHandler for RecordingRedirectHandler {
271 fn initiate_redirect(&self, _request: OAuthRequest) -> AuthResult<()> {
272 *self.initiated.lock().unwrap() = true;
273 if self.fail {
274 Err(AuthError::InvalidCredential("failure".into()))
275 } else {
276 Ok(())
277 }
278 }
279
280 fn complete_redirect(&self) -> AuthResult<Option<crate::auth::oauth::AuthCredential>> {
281 Ok(None)
282 }
283 }
284
285 #[test]
286 fn link_with_redirect_sets_and_clears_event_on_success() {
287 let auth = build_test_auth();
288 let handler = Arc::new(RecordingRedirectHandler {
289 fail: false,
290 initiated: Arc::new(Mutex::new(false)),
291 });
292 auth.set_redirect_handler(handler);
293
294 let provider = OAuthProvider::new("google.com", "https://example.com");
295 provider.link_with_redirect(&auth).unwrap();
296 let event = auth.take_pending_redirect_event().unwrap();
297 assert!(event.is_some());
298 let event = event.unwrap();
299 assert_eq!(event.provider_id, "google.com");
300 assert_eq!(event.operation, RedirectOperation::Link);
301 }
302
303 #[test]
304 fn link_with_redirect_clears_on_failure() {
305 let auth = build_test_auth();
306 let handler = Arc::new(RecordingRedirectHandler {
307 fail: true,
308 initiated: Arc::new(Mutex::new(false)),
309 });
310 auth.set_redirect_handler(handler);
311
312 let provider = OAuthProvider::new("google.com", "https://example.com");
313 assert!(provider.link_with_redirect(&auth).is_err());
314 assert!(auth.take_pending_redirect_event().unwrap().is_none());
315 }
316}