firebase_rs_sdk/data_connect/
transport.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
6use reqwest::Url;
7use serde::Deserialize;
8use serde_json::Value;
9
10use crate::data_connect::config::{DataConnectOptions, TransportOptions};
11use crate::data_connect::error::{
12 internal_error, operation_error, unauthorized, DataConnectErrorPathSegment,
13 DataConnectOperationFailureResponse, DataConnectOperationFailureResponseErrorInfo,
14 DataConnectResult,
15};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum CallerSdkType {
19 Base,
20 Generated,
21 TanstackReactCore,
22 GeneratedReact,
23 TanstackAngularCore,
24 GeneratedAngular,
25}
26
27impl Default for CallerSdkType {
28 fn default() -> Self {
29 CallerSdkType::Base
30 }
31}
32
33#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
34#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
35pub trait RequestTokenProvider: Send + Sync {
36 async fn auth_token(&self) -> DataConnectResult<Option<String>>;
37 async fn app_check_headers(&self) -> DataConnectResult<Option<AppCheckHeaders>>;
38}
39
40#[derive(Clone, Debug, Default)]
41pub struct AppCheckHeaders {
42 pub token: String,
43 pub heartbeat: Option<String>,
44}
45
46#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
47#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
48pub trait DataConnectTransport: Send + Sync {
49 async fn invoke_query(&self, operation: &str, variables: &Value) -> DataConnectResult<Value>;
50 async fn invoke_mutation(&self, operation: &str, variables: &Value)
51 -> DataConnectResult<Value>;
52 fn use_emulator(&self, options: TransportOptions);
53 fn set_generated_sdk(&self, enabled: bool);
54 fn set_caller_sdk_type(&self, caller: CallerSdkType);
55}
56
57pub struct RestTransport {
58 client: reqwest::Client,
59 options: DataConnectOptions,
60 api_key: Option<String>,
61 app_id: Option<String>,
62 token_provider: Arc<dyn RequestTokenProvider>,
63 state: Mutex<TransportState>,
64 generated_sdk: AtomicBool,
65 caller_sdk_type: Mutex<CallerSdkType>,
66}
67
68struct TransportState {
69 transport: TransportOptions,
70 is_emulator: bool,
71}
72
73impl RestTransport {
74 pub fn new(
75 options: DataConnectOptions,
76 api_key: Option<String>,
77 app_id: Option<String>,
78 token_provider: Arc<dyn RequestTokenProvider>,
79 ) -> DataConnectResult<Self> {
80 Ok(Self {
81 client: reqwest::Client::new(),
82 options,
83 api_key,
84 app_id,
85 token_provider,
86 state: Mutex::new(TransportState {
87 transport: TransportOptions::default(),
88 is_emulator: false,
89 }),
90 generated_sdk: AtomicBool::new(false),
91 caller_sdk_type: Mutex::new(CallerSdkType::Base),
92 })
93 }
94
95 fn endpoint_url(&self, action: &str) -> DataConnectResult<Url> {
96 let state = self.state.lock().unwrap();
97 let base = state.transport.base_url();
98 let path = format!("{base}/v1/{}:{}", self.options.resource_path(), action);
99 let mut url = Url::parse(&path).map_err(|err| internal_error(err.to_string()))?;
100 if let Some(key) = &self.api_key {
101 url.query_pairs_mut().append_pair("key", key);
102 }
103 Ok(url)
104 }
105
106 fn goog_api_client_header(&self) -> String {
107 let sdk_version = env!("CARGO_PKG_VERSION");
108 let mut header = format!("gl-rs/ fire/{sdk_version}");
109 if self.generated_sdk.load(Ordering::SeqCst) {
110 header.push_str(" rs/gen");
111 }
112 match &*self.caller_sdk_type.lock().unwrap() {
113 CallerSdkType::Base => {}
114 CallerSdkType::Generated => header.push_str(" js/gen"),
115 CallerSdkType::TanstackReactCore => header.push_str(" js/tanstack-react"),
116 CallerSdkType::GeneratedReact => header.push_str(" js/gen-react"),
117 CallerSdkType::TanstackAngularCore => header.push_str(" js/tanstack-angular"),
118 CallerSdkType::GeneratedAngular => header.push_str(" js/gen-angular"),
119 }
120 header
121 }
122
123 async fn perform_request(
124 &self,
125 action: &str,
126 operation: &str,
127 variables: &Value,
128 ) -> DataConnectResult<Value> {
129 let mut body = serde_json::Map::new();
130 body.insert(
131 "name".to_string(),
132 Value::String(format!(
133 "projects/{}/locations/{}/services/{}/connectors/{}",
134 self.options.project_id,
135 self.options.connector.location,
136 self.options.connector.service,
137 self.options.connector.connector,
138 )),
139 );
140 body.insert(
141 "operationName".to_string(),
142 Value::String(operation.to_string()),
143 );
144 body.insert("variables".to_string(), variables.clone());
145
146 let mut headers = HeaderMap::new();
147 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
148 headers.insert(
149 "X-Goog-Api-Client",
150 HeaderValue::from_str(&self.goog_api_client_header())
151 .map_err(|err| internal_error(err.to_string()))?,
152 );
153
154 if let Some(app_id) = &self.app_id {
155 if !app_id.is_empty() {
156 headers.insert(
157 "X-Firebase-GMPID",
158 HeaderValue::from_str(app_id).map_err(|err| internal_error(err.to_string()))?,
159 );
160 }
161 }
162
163 if let Some(token) = self.token_provider.auth_token().await? {
164 if !token.is_empty() {
165 headers.insert(
166 "X-Firebase-Auth-Token",
167 HeaderValue::from_str(&token).map_err(|err| internal_error(err.to_string()))?,
168 );
169 }
170 }
171
172 if let Some(app_check) = self.token_provider.app_check_headers().await? {
173 if !app_check.token.is_empty() {
174 headers.insert(
175 "X-Firebase-AppCheck",
176 HeaderValue::from_str(&app_check.token)
177 .map_err(|err| internal_error(err.to_string()))?,
178 );
179 }
180 if let Some(heartbeat) = &app_check.heartbeat {
181 if !heartbeat.is_empty() {
182 headers.insert(
183 "X-Firebase-Client",
184 HeaderValue::from_str(heartbeat)
185 .map_err(|err| internal_error(err.to_string()))?,
186 );
187 }
188 }
189 }
190
191 let url = self.endpoint_url(action)?;
192 let response = self
193 .client
194 .post(url)
195 .headers(headers)
196 .json(&body)
197 .send()
198 .await
199 .map_err(|err| internal_error(err.to_string()))?;
200
201 if response.status().as_u16() == 401 {
202 return Err(unauthorized("Request unauthorized"));
203 }
204 if !response.status().is_success() {
205 return Err(internal_error(format!(
206 "Data Connect request failed with status {}",
207 response.status()
208 )));
209 }
210
211 let graph_response: GraphQlResponse = response
212 .json()
213 .await
214 .map_err(|err| internal_error(err.to_string()))?;
215 if !graph_response.errors.is_empty() {
216 let response = DataConnectOperationFailureResponse {
217 data: graph_response.data,
218 errors: graph_response
219 .errors
220 .into_iter()
221 .map(|error| DataConnectOperationFailureResponseErrorInfo {
222 message: error
223 .message
224 .unwrap_or_else(|| "Unknown Data Connect error".to_string()),
225 path: error
226 .path
227 .unwrap_or_default()
228 .into_iter()
229 .filter_map(|segment| match segment {
230 Value::String(field) => {
231 Some(DataConnectErrorPathSegment::Field(field))
232 }
233 Value::Number(num) => num
234 .as_i64()
235 .map(|idx| DataConnectErrorPathSegment::Index(idx)),
236 _ => None,
237 })
238 .collect(),
239 })
240 .collect(),
241 };
242 return Err(operation_error(
243 format!("Data Connect error executing {operation}"),
244 response,
245 ));
246 }
247
248 Ok(graph_response.data.unwrap_or(Value::Null))
249 }
250}
251
252#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
253#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
254impl DataConnectTransport for RestTransport {
255 async fn invoke_query(&self, operation: &str, variables: &Value) -> DataConnectResult<Value> {
256 self.perform_request("executeQuery", operation, variables)
257 .await
258 }
259
260 async fn invoke_mutation(
261 &self,
262 operation: &str,
263 variables: &Value,
264 ) -> DataConnectResult<Value> {
265 self.perform_request("executeMutation", operation, variables)
266 .await
267 }
268
269 fn use_emulator(&self, options: TransportOptions) {
270 let mut state = self.state.lock().unwrap();
271 state.transport = options;
272 state.is_emulator = true;
273 }
274
275 fn set_generated_sdk(&self, enabled: bool) {
276 self.generated_sdk.store(enabled, Ordering::SeqCst);
277 }
278
279 fn set_caller_sdk_type(&self, caller: CallerSdkType) {
280 *self.caller_sdk_type.lock().unwrap() = caller;
281 }
282}
283
284#[derive(Deserialize)]
285struct GraphQlResponse {
286 #[serde(default)]
287 data: Option<Value>,
288 #[serde(default)]
289 errors: Vec<GraphQlError>,
290}
291
292#[derive(Deserialize)]
293struct GraphQlError {
294 message: Option<String>,
295 path: Option<Vec<Value>>,
296}