firebase_rs_sdk/auth/api/
token.rs1use reqwest::blocking::Client;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4
5use crate::auth::error::{AuthError, AuthResult};
6
7pub(crate) const DEFAULT_SECURE_TOKEN_ENDPOINT: &str =
8 "https://securetoken.googleapis.com/v1/token";
9
10#[derive(Debug, Serialize)]
11struct RefreshTokenRequest<'a> {
12 grant_type: &'static str,
13 refresh_token: &'a str,
14}
15
16#[derive(Debug, Deserialize)]
17pub struct RefreshTokenResponse {
18 #[serde(rename = "access_token")]
19 pub access_token: String,
20 #[serde(rename = "refresh_token")]
21 pub refresh_token: String,
22 #[serde(rename = "id_token")]
23 pub id_token: String,
24 #[serde(rename = "expires_in")]
25 pub expires_in: String,
26 #[serde(rename = "user_id")]
27 pub user_id: String,
28}
29
30#[derive(Debug, Deserialize)]
31struct ErrorResponse {
32 error: Option<ErrorBody>,
33}
34
35#[derive(Debug, Deserialize)]
36struct ErrorBody {
37 message: Option<String>,
38}
39
40pub fn refresh_id_token(
42 client: &Client,
43 api_key: &str,
44 refresh_token: &str,
45) -> AuthResult<RefreshTokenResponse> {
46 refresh_id_token_with_endpoint(
47 client,
48 DEFAULT_SECURE_TOKEN_ENDPOINT,
49 api_key,
50 refresh_token,
51 )
52}
53
54pub(crate) fn refresh_id_token_with_endpoint(
55 client: &Client,
56 endpoint: &str,
57 api_key: &str,
58 refresh_token: &str,
59) -> AuthResult<RefreshTokenResponse> {
60 let url = format!("{endpoint}?key={api_key}");
61 let request = RefreshTokenRequest {
62 grant_type: "refresh_token",
63 refresh_token,
64 };
65
66 let response = client
67 .post(url)
68 .form(&request)
69 .send()
70 .map_err(|err| AuthError::Network(err.to_string()))?;
71
72 if response.status().is_success() {
73 response
74 .json()
75 .map_err(|err| AuthError::Network(err.to_string()))
76 } else {
77 let status = response.status();
78 let body = response.text().unwrap_or_else(|_| "{}".to_string());
79 Err(map_refresh_error(status, &body))
80 }
81}
82
83fn map_refresh_error(status: StatusCode, body: &str) -> AuthError {
84 if let Ok(parsed) = serde_json::from_str::<ErrorResponse>(body) {
85 if let Some(error) = parsed.error {
86 if let Some(message) = error.message {
87 return AuthError::InvalidCredential(message);
88 }
89 }
90 }
91
92 AuthError::Network(format!("Token refresh failed with status {status}"))
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::test_support::start_mock_server;
99 use httpmock::prelude::*;
100 use serde_json::json;
101
102 fn make_client() -> Client {
103 Client::new()
104 }
105
106 #[test]
107 fn refresh_id_token_succeeds_with_custom_endpoint() {
108 let server = start_mock_server();
109 let client = make_client();
110
111 let mock = server.mock(|when, then| {
112 when.method(POST)
113 .path("/token")
114 .query_param("key", "test-key")
115 .header("content-type", "application/x-www-form-urlencoded")
116 .body_contains("grant_type=refresh_token")
117 .body_contains("refresh_token=test-refresh");
118 then.status(200).json_body(json!({
119 "access_token": "access",
120 "refresh_token": "new-refresh",
121 "id_token": "id",
122 "expires_in": "3600",
123 "user_id": "uid"
124 }));
125 });
126
127 let response = refresh_id_token_with_endpoint(
128 &client,
129 &server.url("/token"),
130 "test-key",
131 "test-refresh",
132 )
133 .expect("refresh should succeed");
134
135 mock.assert();
136 assert_eq!(response.access_token, "access");
137 assert_eq!(response.refresh_token, "new-refresh");
138 assert_eq!(response.id_token, "id");
139 }
140
141 #[test]
142 fn refresh_id_token_maps_error_message() {
143 let server = start_mock_server();
144 let client = make_client();
145
146 let mock = server.mock(|when, then| {
147 when.method(POST)
148 .path("/token")
149 .query_param("key", "test-key");
150 then.status(400)
151 .body("{\"error\":{\"message\":\"TOKEN_EXPIRED\"}}");
152 });
153
154 let result = refresh_id_token_with_endpoint(
155 &client,
156 &server.url("/token"),
157 "test-key",
158 "test-refresh",
159 );
160
161 mock.assert();
162 assert!(matches!(
163 result,
164 Err(AuthError::InvalidCredential(message)) if message == "TOKEN_EXPIRED"
165 ));
166 }
167}