hi_push/fcm/
mod.rs

1mod model;
2
3
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bytecodec::{
8    bytes::{BytesEncoder, RemainingBytesDecoder},
9    Encode,
10    io::{IoDecodeExt, IoEncodeExt},
11};
12use http::Uri;
13use httpcodec::{BodyDecoder, BodyEncoder, HeaderField, HttpVersion, Method, RequestEncoder, RequestTarget, ResponseDecoder};
14use yup_oauth2 as oauth2;
15use hyper_tls::HttpsConnector;
16use hyper_socks2::SocksConnector;
17use hi_hyper_multipart as multipart;
18use hyper::client::connect::dns::GaiResolver;
19use hyper::client::HttpConnector;
20use reqwest::Proxy;
21use serde::Deserialize;
22use yup_oauth2::AccessToken;
23
24pub use model::*;
25
26pub type Config = yup_oauth2::ServiceAccountKey;
27
28type ProxyAuthClient = oauth2::authenticator::Authenticator<HttpsConnector<SocksConnector<HttpConnector<GaiResolver>>>>;
29type CommonAuthClient = oauth2::authenticator::Authenticator<HttpsConnector<HttpConnector<GaiResolver>>>;
30
31enum InnerAuthClient {
32    Common(CommonAuthClient),
33    Proxy(ProxyAuthClient),
34}
35
36pub struct Client {
37    http: reqwest::Client,
38    auth: InnerAuthClient,
39    project_id: String,
40    conf: Config,
41}
42
43pub struct ProxyConfig<'a> {
44    pub addr: &'a str,
45    pub user: Option<&'a str>,
46    pub pass: Option<&'a str>,
47}
48
49impl Client {
50    pub async fn new<'a>(
51        conf: Config,
52    ) -> Result<Self, super::Error> {
53        let project_id = conf.project_id.clone().unwrap();
54
55        let mut connector = HttpConnector::new();
56        connector.enforce_http(false);
57
58        let conn = HttpsConnector::new_with_connector(connector);
59        let http = hyper::Client::builder().build(conn);
60        let auth = oauth2::ServiceAccountAuthenticator::builder(conf.clone())
61            .hyper_client(http.clone())
62            .build()
63            .await
64            .map_err(|e| super::RetryError::Auth(e.to_string()))?;
65
66        Ok(Self {
67            http: reqwest::Client::builder().build().unwrap(),
68            auth: InnerAuthClient::Common(auth),
69            conf,
70            project_id,
71        })
72    }
73}
74
75impl Client {
76    const DEFAULT_MESSAGING_ENDPOINT: &'static str = "https://fcm.googleapis.com/v1";
77    const DEFAULT_BATCH_ENDPOINT: &'static str = "https://fcm.googleapis.com/batch";
78
79    #[inline]
80    fn build_parent(&self) -> String {
81        format!("projects/{}", self.project_id)
82    }
83
84    #[inline]
85    fn build_single_url(&self) -> String {
86        format!("{}/projects/{}/messages:send", Self::DEFAULT_MESSAGING_ENDPOINT, self.project_id)
87    }
88
89    #[inline]
90    fn build_batch_url(&self) -> String {
91        format!("{}", Self::DEFAULT_BATCH_ENDPOINT)
92    }
93
94    pub async fn with_proxy<'f>(&'f mut self, config: ProxyConfig<'f>) -> &'f mut Self {
95        let mut connector = HttpConnector::new();
96        connector.enforce_http(false);
97
98        let auth = config.user.map_or(None, |e| {
99            Some(hyper_socks2::Auth::new(e, config.pass.unwrap_or_default()))
100        });
101        let conn = SocksConnector {
102            proxy_addr: Uri::from_str(config.addr).unwrap(), // scheme is required by HttpConnector
103            auth,
104            connector,
105        }
106            .with_tls()
107            .unwrap();
108        let cli = hyper::Client::builder().build(conn);
109        let auth = oauth2::ServiceAccountAuthenticator::builder(self.conf.clone())
110            .hyper_client(cli.clone())
111            .build()
112            .await
113            .map_err(|e| super::RetryError::Auth(e.to_string())).unwrap();
114
115        let mut proxy = Proxy::all(config.addr).unwrap();
116        if let Some(user) = config.user {
117            proxy = proxy.basic_auth(user, config.pass.unwrap_or_default());
118        }
119
120        self.http = reqwest::Client::builder().proxy(proxy).build().unwrap();
121        self.auth = InnerAuthClient::Proxy(auth);
122        self
123    }
124
125    pub async fn multicast_send<'b>(&self, msg: &MulticastMessage<'b>) -> Result<BatchResponse, super::Error> {
126        let mut form = multipart::Form::new();
127
128        for (index, &token) in msg.tokens.iter().enumerate() {
129            let mut encoder = RequestEncoder::new(BodyEncoder::new(BytesEncoder::new()));
130
131            let text = serde_json::to_string(&SendMessageRequest {
132                message: Some(Message {
133                    android: msg.android.clone(),
134                    apns: msg.apns.clone(),
135                    condition: None,
136                    data: msg.data.clone(),
137                    fcm_options: None,
138                    name: None,
139                    notification: msg.notification.clone(),
140                    token: Some(token.to_string()),
141                    topic: None,
142                    webpush: msg.webpush.clone(),
143                }),
144                validate_only: None,
145            }).unwrap();
146
147            let mut buf = Vec::new();
148            let mut req = httpcodec::Request::new(
149                Method::new("POST").unwrap(),
150                RequestTarget::new(&*self.build_single_url()).unwrap(),
151                HttpVersion::V1_1, text);
152
153            req.header_mut().add_field(HeaderField::new("Content-Type", "application/json").unwrap());
154            req.header_mut().add_field(HeaderField::new("User-Agent", "").unwrap());
155
156            encoder.start_encoding(req).unwrap();
157
158            encoder.encode_all(&mut buf).unwrap();
159
160            let length = buf.len();
161
162            let mut part = multipart::Part::text(String::from_utf8(buf).unwrap());
163
164            let headers = part.headers_mut();
165
166            headers.insert("Content-Length", length.to_string().parse().unwrap());
167            headers.insert("Content-Type", "application/http".parse().unwrap());
168            headers.insert("Content-Id", (index + 1).to_string().parse().unwrap());
169            headers.insert("Content-Transfer-Encoding", "binary".parse().unwrap());
170
171            form = form.part(index.to_string(), part);
172        }
173
174        let token = self.get_token().await?;
175
176        let boundary = form.boundary().to_string();
177
178        let url = self.build_batch_url();
179
180        let mut resp = self.http.post(url)
181            .header("Content-Type", format!("multipart/mixed; boundary={}", boundary))
182            .bearer_auth(token.as_str())
183            .body(form.stream())
184            .send()
185            .await
186            .unwrap();
187        let headers = std::mem::take(resp.headers_mut());
188
189        let ct = headers.get("Content-Type").clone().map_or("", |e| e.to_str().unwrap());
190        let boundary = ct.split("=").collect::<Vec<_>>().pop().unwrap_or_default();
191
192        let mut mr = multer::Multipart::new(resp.bytes_stream(), boundary);
193
194        let mut decoder =
195            ResponseDecoder::<BodyDecoder<RemainingBytesDecoder>>::default();
196
197        let mut return_resp = BatchResponse {
198            success_count: 0,
199            failure_count: 0,
200            responses: vec![],
201        };
202        while let Some(field) = mr.next_field().await.unwrap() {
203            let index = field.index();
204            let text = field.text().await.unwrap();
205            let res = decoder.decode_exact(text.as_bytes()).unwrap();
206
207            let (res, body) = res.take_body();
208
209            let resp = serde_json::from_slice::<FcmResponse>(body.as_slice()).unwrap();
210
211            match resp {
212                FcmResponse::Ok { name } => {
213                    return_resp.success_count += 1;
214                    return_resp.responses.push(SendResponse::Ok { message_id: name });
215                }
216                FcmResponse::Error { mut error } => {
217                    return_resp.failure_count += 1;
218                    let detail = error.details.pop().unwrap();
219                    return_resp.responses.push(SendResponse::Error {
220                        token: msg.tokens.get(index).unwrap().to_string(),
221                        error: detail.error_code,
222                    });
223                }
224            }
225        }
226        Ok(return_resp)
227    }
228
229    async fn get_token(&self) -> Result<AccessToken, super::Error> {
230        match &self.auth {
231            InnerAuthClient::Common(cli) => {
232                cli.token(&["https://www.googleapis.com/auth/firebase.messaging"]).await.map_err(|e| super::RetryError::Auth(e.to_string()).into())
233            }
234            InnerAuthClient::Proxy(cli) => {
235                cli.token(&["https://www.googleapis.com/auth/firebase.messaging"]).await.map_err(|e| super::RetryError::Auth(e.to_string()).into())
236            }
237        }
238    }
239}
240
241#[derive(Deserialize)]
242pub enum SendResponse {
243    Ok {
244        message_id: String,
245    },
246    Error {
247        token: String,
248        error: String,
249    },
250}
251
252#[derive(Deserialize)]
253pub struct BatchResponse {
254    pub success_count: i64,
255    pub failure_count: i64,
256    pub responses: Vec<SendResponse>,
257}
258
259#[derive(Deserialize, Debug)]
260#[serde(untagged)]
261enum FcmResponse {
262    Ok {
263        name: String,
264    },
265    Error {
266        error: FcmError,
267    },
268}
269
270#[derive(Deserialize, Debug)]
271struct FcmErrorItem {
272    #[serde(rename = "@type")]
273    _type: String,
274    #[serde(rename = "errorCode")]
275    error_code: String,
276}
277
278#[derive(Deserialize, Debug)]
279struct FcmError {
280    details: Vec<FcmErrorItem>,
281}
282
283
284#[async_trait]
285impl<'b> super::Pusher<'b, MulticastMessage<'b>, BatchResponse> for Client {
286
287    const TOKEN_LIMIT: usize = 1000;
288
289    async fn push(&self, msg: &'b MulticastMessage) -> Result<BatchResponse, crate::Error> {
290        self.multicast_send(msg).await
291    }
292}
293
294#[cfg(test)]
295mod test {
296    use std::collections::HashMap;
297
298    #[tokio::test]
299    async fn test_fcm() {
300        use super::*;
301        let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap();
302        let client_email = std::env::var("GOOGLE_CLIENT_EMAIL").unwrap();
303        let private_id = std::env::var("GOOGLE_PRIVATE_ID").unwrap();
304        let private_key = std::env::var("GOOGLE_PRIVATE_KEY").unwrap();
305        let project_id = std::env::var("GOOGLE_PROJECT_ID").unwrap();
306
307        println!("{}", private_key);
308
309        let mut fcm = Client::new(Config {
310            key_type: "service_account".to_string().into(),
311            client_id: client_id.into(),
312            private_key_id: private_id.into(),
313            private_key,
314            token_uri: "https://oauth2.googleapis.com/token".to_string(),
315            auth_uri: "https://accounts.google.com/o/oauth2/auth".to_string().into(),
316            project_id: project_id.into(),
317            client_email,
318            auth_provider_x509_cert_url: Some("https://www.googleapis.com/oauth2/v1/certs".to_string()),
319            client_x509_cert_url: Some("https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-vle32%40avcf-7ea7e.iam.gserviceaccount.com".to_string()),
320        })
321            .await.unwrap();
322
323        fcm.with_proxy(ProxyConfig { addr: "socks5://127.0.0.1:7890", user: None, pass: None }).await;
324
325
326        let res = fcm
327            .multicast_send(&MulticastMessage {
328                tokens: vec![
329                    "fn-3aZfYyceipSqyB-iigS:APA91bEYAGIVMeDmrvZqV5T8C_5UUCrb9xlvupRuKOyHgHDJnYkuwnKfOPCoQKBIQ4IhEJdNPBlaTapVG-iBAYPZ8GegROoeQTetlvmmKPBQrH9hrVRTTaOW69qBm7ZoDy1ewPGqD5RC",
330                    "fn-3aZfYyceipSqyB-iigS:APA91bEYAGIVMeDmrvZqV5T8C_5UUCrb9xlvupRuKOyHgHDJnYkuwnKfOPCoQKBIQ4IhEJdNPBlaTapVG-iBAYPZ8GegROoeQTetlvmmKPBQrH9hrVRTTaOW69qBm7ZoDy1ewPGqD5RA",
331                ],
332                data: Some(HashMap::new()),
333                ..Default::default()
334            })
335            .await
336            .unwrap();
337    }
338}