firebase_rs_sdk/data_connect/
transport.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
6use reqwest::Url;
7use serde::Deserialize;
8use serde_json::Value;
9
10use crate::data_connect::config::{DataConnectOptions, TransportOptions};
11use crate::data_connect::error::{
12    internal_error, operation_error, unauthorized, DataConnectErrorPathSegment,
13    DataConnectOperationFailureResponse, DataConnectOperationFailureResponseErrorInfo,
14    DataConnectResult,
15};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum CallerSdkType {
19    Base,
20    Generated,
21    TanstackReactCore,
22    GeneratedReact,
23    TanstackAngularCore,
24    GeneratedAngular,
25}
26
27impl Default for CallerSdkType {
28    fn default() -> Self {
29        CallerSdkType::Base
30    }
31}
32
33#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
34#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
35pub trait RequestTokenProvider: Send + Sync {
36    async fn auth_token(&self) -> DataConnectResult<Option<String>>;
37    async fn app_check_headers(&self) -> DataConnectResult<Option<AppCheckHeaders>>;
38}
39
40#[derive(Clone, Debug, Default)]
41pub struct AppCheckHeaders {
42    pub token: String,
43    pub heartbeat: Option<String>,
44}
45
46#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
47#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
48pub trait DataConnectTransport: Send + Sync {
49    async fn invoke_query(&self, operation: &str, variables: &Value) -> DataConnectResult<Value>;
50    async fn invoke_mutation(&self, operation: &str, variables: &Value)
51        -> DataConnectResult<Value>;
52    fn use_emulator(&self, options: TransportOptions);
53    fn set_generated_sdk(&self, enabled: bool);
54    fn set_caller_sdk_type(&self, caller: CallerSdkType);
55}
56
57pub struct RestTransport {
58    client: reqwest::Client,
59    options: DataConnectOptions,
60    api_key: Option<String>,
61    app_id: Option<String>,
62    token_provider: Arc<dyn RequestTokenProvider>,
63    state: Mutex<TransportState>,
64    generated_sdk: AtomicBool,
65    caller_sdk_type: Mutex<CallerSdkType>,
66}
67
68struct TransportState {
69    transport: TransportOptions,
70    is_emulator: bool,
71}
72
73impl RestTransport {
74    pub fn new(
75        options: DataConnectOptions,
76        api_key: Option<String>,
77        app_id: Option<String>,
78        token_provider: Arc<dyn RequestTokenProvider>,
79    ) -> DataConnectResult<Self> {
80        Ok(Self {
81            client: reqwest::Client::new(),
82            options,
83            api_key,
84            app_id,
85            token_provider,
86            state: Mutex::new(TransportState {
87                transport: TransportOptions::default(),
88                is_emulator: false,
89            }),
90            generated_sdk: AtomicBool::new(false),
91            caller_sdk_type: Mutex::new(CallerSdkType::Base),
92        })
93    }
94
95    fn endpoint_url(&self, action: &str) -> DataConnectResult<Url> {
96        let state = self.state.lock().unwrap();
97        let base = state.transport.base_url();
98        let path = format!("{base}/v1/{}:{}", self.options.resource_path(), action);
99        let mut url = Url::parse(&path).map_err(|err| internal_error(err.to_string()))?;
100        if let Some(key) = &self.api_key {
101            url.query_pairs_mut().append_pair("key", key);
102        }
103        Ok(url)
104    }
105
106    fn goog_api_client_header(&self) -> String {
107        let sdk_version = env!("CARGO_PKG_VERSION");
108        let mut header = format!("gl-rs/ fire/{sdk_version}");
109        if self.generated_sdk.load(Ordering::SeqCst) {
110            header.push_str(" rs/gen");
111        }
112        match &*self.caller_sdk_type.lock().unwrap() {
113            CallerSdkType::Base => {}
114            CallerSdkType::Generated => header.push_str(" js/gen"),
115            CallerSdkType::TanstackReactCore => header.push_str(" js/tanstack-react"),
116            CallerSdkType::GeneratedReact => header.push_str(" js/gen-react"),
117            CallerSdkType::TanstackAngularCore => header.push_str(" js/tanstack-angular"),
118            CallerSdkType::GeneratedAngular => header.push_str(" js/gen-angular"),
119        }
120        header
121    }
122
123    async fn perform_request(
124        &self,
125        action: &str,
126        operation: &str,
127        variables: &Value,
128    ) -> DataConnectResult<Value> {
129        let mut body = serde_json::Map::new();
130        body.insert(
131            "name".to_string(),
132            Value::String(format!(
133                "projects/{}/locations/{}/services/{}/connectors/{}",
134                self.options.project_id,
135                self.options.connector.location,
136                self.options.connector.service,
137                self.options.connector.connector,
138            )),
139        );
140        body.insert(
141            "operationName".to_string(),
142            Value::String(operation.to_string()),
143        );
144        body.insert("variables".to_string(), variables.clone());
145
146        let mut headers = HeaderMap::new();
147        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
148        headers.insert(
149            "X-Goog-Api-Client",
150            HeaderValue::from_str(&self.goog_api_client_header())
151                .map_err(|err| internal_error(err.to_string()))?,
152        );
153
154        if let Some(app_id) = &self.app_id {
155            if !app_id.is_empty() {
156                headers.insert(
157                    "X-Firebase-GMPID",
158                    HeaderValue::from_str(app_id).map_err(|err| internal_error(err.to_string()))?,
159                );
160            }
161        }
162
163        if let Some(token) = self.token_provider.auth_token().await? {
164            if !token.is_empty() {
165                headers.insert(
166                    "X-Firebase-Auth-Token",
167                    HeaderValue::from_str(&token).map_err(|err| internal_error(err.to_string()))?,
168                );
169            }
170        }
171
172        if let Some(app_check) = self.token_provider.app_check_headers().await? {
173            if !app_check.token.is_empty() {
174                headers.insert(
175                    "X-Firebase-AppCheck",
176                    HeaderValue::from_str(&app_check.token)
177                        .map_err(|err| internal_error(err.to_string()))?,
178                );
179            }
180            if let Some(heartbeat) = &app_check.heartbeat {
181                if !heartbeat.is_empty() {
182                    headers.insert(
183                        "X-Firebase-Client",
184                        HeaderValue::from_str(heartbeat)
185                            .map_err(|err| internal_error(err.to_string()))?,
186                    );
187                }
188            }
189        }
190
191        let url = self.endpoint_url(action)?;
192        let response = self
193            .client
194            .post(url)
195            .headers(headers)
196            .json(&body)
197            .send()
198            .await
199            .map_err(|err| internal_error(err.to_string()))?;
200
201        if response.status().as_u16() == 401 {
202            return Err(unauthorized("Request unauthorized"));
203        }
204        if !response.status().is_success() {
205            return Err(internal_error(format!(
206                "Data Connect request failed with status {}",
207                response.status()
208            )));
209        }
210
211        let graph_response: GraphQlResponse = response
212            .json()
213            .await
214            .map_err(|err| internal_error(err.to_string()))?;
215        if !graph_response.errors.is_empty() {
216            let response = DataConnectOperationFailureResponse {
217                data: graph_response.data,
218                errors: graph_response
219                    .errors
220                    .into_iter()
221                    .map(|error| DataConnectOperationFailureResponseErrorInfo {
222                        message: error
223                            .message
224                            .unwrap_or_else(|| "Unknown Data Connect error".to_string()),
225                        path: error
226                            .path
227                            .unwrap_or_default()
228                            .into_iter()
229                            .filter_map(|segment| match segment {
230                                Value::String(field) => {
231                                    Some(DataConnectErrorPathSegment::Field(field))
232                                }
233                                Value::Number(num) => num
234                                    .as_i64()
235                                    .map(|idx| DataConnectErrorPathSegment::Index(idx)),
236                                _ => None,
237                            })
238                            .collect(),
239                    })
240                    .collect(),
241            };
242            return Err(operation_error(
243                format!("Data Connect error executing {operation}"),
244                response,
245            ));
246        }
247
248        Ok(graph_response.data.unwrap_or(Value::Null))
249    }
250}
251
252#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
253#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
254impl DataConnectTransport for RestTransport {
255    async fn invoke_query(&self, operation: &str, variables: &Value) -> DataConnectResult<Value> {
256        self.perform_request("executeQuery", operation, variables)
257            .await
258    }
259
260    async fn invoke_mutation(
261        &self,
262        operation: &str,
263        variables: &Value,
264    ) -> DataConnectResult<Value> {
265        self.perform_request("executeMutation", operation, variables)
266            .await
267    }
268
269    fn use_emulator(&self, options: TransportOptions) {
270        let mut state = self.state.lock().unwrap();
271        state.transport = options;
272        state.is_emulator = true;
273    }
274
275    fn set_generated_sdk(&self, enabled: bool) {
276        self.generated_sdk.store(enabled, Ordering::SeqCst);
277    }
278
279    fn set_caller_sdk_type(&self, caller: CallerSdkType) {
280        *self.caller_sdk_type.lock().unwrap() = caller;
281    }
282}
283
284#[derive(Deserialize)]
285struct GraphQlResponse {
286    #[serde(default)]
287    data: Option<Value>,
288    #[serde(default)]
289    errors: Vec<GraphQlError>,
290}
291
292#[derive(Deserialize)]
293struct GraphQlError {
294    message: Option<String>,
295    path: Option<Vec<Value>>,
296}