firebase_client/
client.rs

1pub mod firebase_client {
2    use google_authz::{Client, Credentials};
3    use hyper::{body, client::HttpConnector, Body, Request, Response, Uri};
4    use hyper_rustls::HttpsConnector;
5    use std::convert::TryFrom;
6
7    use crate::{error::FirebaseClientError, notification::FirebasePayload};
8
9    #[derive(Debug)]
10
11    pub struct FirebaseClient {
12        client: Client<HttpsConnector<HttpConnector>>,
13        uri: Uri,
14    }
15
16    impl FirebaseClient {
17        pub fn new<T: AsRef<str>>(
18            client: hyper::Client<HttpsConnector<HttpConnector>>,
19            credentials: Credentials,
20            project_id: T,
21        ) -> Result<FirebaseClient, FirebaseClientError> {
22            let authz_client = Client::new_with(client, credentials);
23
24            let uri = Uri::try_from(format!(
25                "https://fcm.googleapis.com/v1/projects/{}/messages:send",
26                project_id.as_ref()
27            ))?;
28
29            Ok(FirebaseClient {
30                client: authz_client,
31                uri,
32            })
33        }
34
35        pub fn new_default<T: AsRef<str>>(
36            client: hyper::Client<HttpsConnector<HttpConnector>>,
37            credentials_file_path: T,
38            project_id: T,
39        ) -> Result<FirebaseClient, FirebaseClientError> {
40            let authz_client = {
41                let credentials = Credentials::from_file(
42                    credentials_file_path.as_ref(),
43                    &["https://www.googleapis.com/auth/firebase.messaging"],
44                );
45                Client::new_with(client, credentials)
46            };
47
48            let uri = Uri::try_from(format!(
49                "https://fcm.googleapis.com/v1/projects/{}/messages:send",
50                project_id.as_ref()
51            ))?;
52
53            Ok(FirebaseClient {
54                client: authz_client,
55                uri,
56            })
57        }
58
59        pub async fn send_notification_raw(
60            &mut self,
61            notification_as_str: String,
62        ) -> Result<(), FirebaseClientError> {
63            let response = {
64                let http_request = Request::builder()
65                    .method("POST")
66                    .uri(self.uri.clone())
67                    .body(notification_as_str.into())?;
68                self.client.request(http_request).await?
69            };
70
71            if response.status() == 200 || response.status() == 204 {
72                Ok(())
73            } else {
74                let status_code = response.status();
75                let body_as_str = read_response_body(response)
76                    .await
77                    .map_err(FirebaseClientError::ReadBodyError)?;
78
79                Err(FirebaseClientError::HttpRequestError {
80                    status_code,
81                    body: body_as_str,
82                })
83            }
84        }
85        pub async fn send_notification(
86            &mut self,
87            firebase_payload: FirebasePayload,
88        ) -> Result<(), FirebaseClientError> {
89            let serialized_payload: String = serde_json::to_string(&firebase_payload)?;
90
91            self.send_notification_raw(serialized_payload).await
92        }
93    }
94    pub async fn read_response_body(res: Response<Body>) -> Result<String, hyper::Error> {
95        let bytes = body::to_bytes(res.into_body()).await?;
96        Ok(String::from_utf8(bytes.to_vec()).expect("response was not valid utf-8"))
97    }
98}
99
100#[cfg(test)]
101pub mod test {
102    use dotenv::dotenv;
103    use hyper::Body;
104    use hyper_rustls::HttpsConnector;
105    use serde_json::json;
106
107    use crate::notification::NotificationBuilder;
108
109    use super::firebase_client::FirebaseClient;
110
111    #[tokio::test]
112    pub async fn test_send_notification_serialized() {
113        dotenv().ok();
114
115        let credentials_file_path = std::env::var("CREDENTIALS_FILE_PATH").unwrap();
116        let project_id = std::env::var("PROJECT_ID").unwrap();
117        let test_token = std::env::var("TEST_TOKEN").unwrap();
118
119        let https = HttpsConnector::with_native_roots();
120        let client = hyper::Client::builder().build::<_, Body>(https);
121        let mut firebase_client =
122            FirebaseClient::new_default(client, credentials_file_path, project_id).unwrap();
123        let _result = firebase_client
124            .send_notification_raw(
125                json!(
126                {
127                  "message":
128                  {
129                    "token": test_token,
130                    "notification":
131                        {
132                            "title": "TEST_TITLE",
133                            "body": "TEST_MESSAGE"
134                        }
135                  }
136                }
137                      )
138                .to_string(),
139            )
140            .await;
141    }
142
143    #[tokio::test]
144    pub async fn test_send_notification() {
145        dotenv().ok();
146
147        let credentials_file_path = std::env::var("CREDENTIALS_FILE_PATH").unwrap();
148        let project_id = std::env::var("PROJECT_ID").unwrap();
149        let test_token = std::env::var("TEST_TOKEN").unwrap();
150
151        let https = HttpsConnector::with_native_roots();
152        let client = hyper::Client::builder().build::<_, Body>(https);
153        let mut firebase_client =
154            FirebaseClient::new_default(client, &credentials_file_path, &project_id).unwrap();
155
156        let firebase_notification = NotificationBuilder::new("TEST_TITLE", &test_token)
157            .message("TEST_MESSAGE")
158            .data(json!({
159                "url": "https://firebase.google.com/docs/cloud-messaging/migrate-v1"
160            }))
161            .android_channel_id("channel_urgent")
162            .build();
163
164        dbg!(firebase_client
165            .send_notification(firebase_notification)
166            .await
167            .unwrap());
168    }
169}