firebase_admin_sdk/messaging/
mod.rs1use reqwest::{Client, header};
29use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
30use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
31use crate::core::middleware::AuthMiddleware;
32use crate::core::parse_error_response;
33use crate::messaging::models::{Message, MulticastMessage, TopicManagementResponse, TopicManagementError, BatchResponse, SendResponse, SendResponseInternal};
34use thiserror::Error;
35use serde::{Deserialize, Serialize};
36
37pub mod models;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Error, Debug)]
44pub enum MessagingError {
45 #[error("HTTP Request failed: {0}")]
47 RequestError(#[from] reqwest::Error),
48 #[error("Middleware error: {0}")]
50 MiddlewareError(#[from] reqwest_middleware::Error),
51 #[error("API error: {0}")]
53 ApiError(String),
54 #[error("Serialization error: {0}")]
56 SerializationError(#[from] serde_json::Error),
57 #[error("Multipart response parsing error: {0}")]
59 MultipartError(String),
60}
61
62#[derive(Clone)]
64pub struct FirebaseMessaging {
65 client: ClientWithMiddleware,
66 project_id: String,
67 base_url: String,
68 batch_url: String,
69 iid_base_url: String,
70}
71
72#[derive(Serialize)]
74#[serde(rename_all = "camelCase")]
75struct SendRequest<'a> {
76 validate_only: bool,
77 message: &'a Message,
78}
79
80#[derive(Serialize)]
81struct TopicManagementRequest<'a> {
82 to: String,
83 registration_tokens: &'a [&'a str],
84}
85
86#[derive(Deserialize)]
87struct TopicManagementApiResponse {
88 results: Option<Vec<TopicManagementApiResult>>,
89}
90
91#[derive(Deserialize)]
92struct TopicManagementApiResult {
93 error: Option<String>,
94}
95
96impl FirebaseMessaging {
97 pub fn new(middleware: AuthMiddleware) -> Self {
101 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
102
103 let client = ClientBuilder::new(Client::new())
104 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
105 .with(middleware.clone())
106 .build();
107
108 let project_id = middleware.key.project_id.clone().unwrap_or_default();
109 let base_url = format!("https://fcm.googleapis.com/v1/projects/{}/messages:send", project_id);
110 let batch_url = "https://fcm.googleapis.com/batch".to_string();
111 let iid_base_url = "https://iid.googleapis.com".to_string();
112
113 Self {
114 client,
115 project_id,
116 base_url,
117 batch_url,
118 iid_base_url,
119 }
120 }
121
122 #[cfg(test)]
123 pub(crate) fn new_with_url(middleware: AuthMiddleware, base_url: String, batch_url: String, iid_base_url: String) -> Self {
124 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
125 let client = ClientBuilder::new(Client::new())
126 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
127 .with(middleware.clone())
128 .build();
129 let project_id = middleware.key.project_id.clone().unwrap_or_default();
130 Self { client, project_id, base_url, batch_url, iid_base_url }
131 }
132
133 pub async fn send(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
140 self.validate_message(message)?;
141 self.send_request(message, dry_run).await
142 }
143
144 fn validate_message(&self, message: &Message) -> Result<(), MessagingError> {
146 let num_targets = [
147 message.token.is_some(),
148 message.topic.is_some(),
149 message.condition.is_some(),
150 ]
151 .iter()
152 .filter(|&&t| t)
153 .count();
154
155 if num_targets != 1 {
156 return Err(MessagingError::ApiError(
157 "Message must have exactly one of token, topic, or condition.".to_string(),
158 ));
159 }
160
161 Ok(())
162 }
163
164 async fn send_request(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
166 let request = SendRequest {
167 validate_only: dry_run,
168 message,
169 };
170
171 let response = self.client
172 .post(&self.base_url)
173 .header(header::CONTENT_TYPE, "application/json")
174 .body(serde_json::to_vec(&request)?)
175 .send()
176 .await?;
177
178 if !response.status().is_success() {
179 return Err(MessagingError::ApiError(parse_error_response(response, "FCM send failed").await));
180 }
181
182 let result: SendResponseInternal = response.json().await?;
183 Ok(result.name)
184 }
185
186 pub async fn send_each(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
195 for message in messages {
196 self.validate_message(message)?;
197 }
198 self.send_each_request(messages, dry_run).await
199 }
200
201 async fn send_each_request(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
202 if messages.is_empty() {
203 return Ok(BatchResponse::default());
204 }
205
206 if messages.len() > 500 {
207 return Err(MessagingError::ApiError("Cannot send more than 500 messages in a single batch.".to_string()));
208 }
209
210 let url = self.batch_url.clone();
211 let boundary = format!("batch_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
212
213 let body = self.build_multipart_body(messages, dry_run, &boundary)?;
214
215 let content_type = format!("multipart/mixed; boundary={}", boundary);
216
217 let response = self.client
218 .post(&url)
219 .header(header::CONTENT_TYPE, content_type)
220 .body(body)
221 .send()
222 .await?;
223
224 if !response.status().is_success() {
225 return Err(MessagingError::ApiError(parse_error_response(response, "FCM batch send failed").await));
226 }
227
228 let multipart_boundary = response
229 .headers()
230 .get(header::CONTENT_TYPE)
231 .and_then(|ct| ct.to_str().ok())
232 .and_then(|ct| ct.split("boundary=").nth(1))
233 .map(|s| s.to_string())
234 .ok_or_else(|| MessagingError::MultipartError("Multipart boundary not found in response".to_string()))?;
235
236 let text = response.text().await?;
237 let responses = self.parse_multipart_response(&text, &multipart_boundary)?;
238
239 let success_count = responses.iter().filter(|r| r.success).count();
240 let failure_count = responses.len() - success_count;
241
242 Ok(BatchResponse {
243 success_count,
244 failure_count,
245 responses,
246 })
247 }
248
249 fn build_multipart_body(&self, messages: &[Message], dry_run: bool, boundary: &str) -> Result<Vec<u8>, MessagingError> {
250 let mut body = Vec::new();
251
252 for message in messages {
253 let send_request = SendRequest {
254 validate_only: dry_run,
255 message,
256 };
257
258 let post_url = format!("/v1/projects/{}/messages:send", self.project_id);
259 let request_body = serde_json::to_string(&send_request)?;
260
261 body.extend_from_slice(b"--");
262 body.extend_from_slice(boundary.as_bytes());
263 body.extend_from_slice(b"\r\n");
264 body.extend_from_slice(b"Content-Type: application/http\r\n");
265 body.extend_from_slice(b"Content-Transfer-Encoding: binary\r\n\r\n");
266 body.extend_from_slice(b"POST ");
267 body.extend_from_slice(post_url.as_bytes());
268 body.extend_from_slice(b"\r\n");
269 body.extend_from_slice(b"Content-Type: application/json\r\n");
270 body.extend_from_slice(b"\r\n");
271 body.extend_from_slice(request_body.as_bytes());
272 body.extend_from_slice(b"\r\n");
273 }
274
275 body.extend_from_slice(b"--");
276 body.extend_from_slice(boundary.as_bytes());
277 body.extend_from_slice(b"--\r\n");
278
279 Ok(body)
280 }
281
282 fn parse_multipart_response(&self, body: &str, boundary: &str) -> Result<Vec<SendResponse>, MessagingError> {
283 let boundary = format!("--{}", boundary);
284 let parts: Vec<&str> = body.split(&boundary)
285 .filter(|p| !p.trim().is_empty() && p.trim() != "--")
286 .collect();
287 let mut responses = Vec::new();
288
289 for part in parts {
290 let http_part = part.trim();
291
292 if let Some(inner_response_start) = http_part.find("\r\n\r\n") {
293 let inner_response = &http_part[inner_response_start + 4..];
294
295 if let Some(json_start) = inner_response.find("\r\n\r\n") {
296 let json_body = inner_response[json_start + 4..].trim();
297
298 if json_body.is_empty() {
299 return Err(MessagingError::MultipartError("Empty JSON body in response part".to_string()));
300 }
301
302 let status_line = inner_response.lines().next().unwrap_or("");
303 if status_line.contains("200 OK") {
304 match serde_json::from_str::<SendResponseInternal>(json_body) {
305 Ok(send_response) => responses.push(SendResponse {
306 success: true,
307 message_id: Some(send_response.name),
308 error: None,
309 }),
310 Err(_) => return Err(MessagingError::MultipartError("Failed to parse successful response part".to_string())),
311 }
312 } else { match serde_json::from_str::<serde_json::Value>(json_body) {
314 Ok(error_response) => responses.push(SendResponse {
315 success: false,
316 message_id: None,
317 error: Some(error_response.to_string()),
318 }),
319 Err(_) => return Err(MessagingError::MultipartError("Failed to parse error response part".to_string())),
320 }
321 }
322 } else {
323 return Err(MessagingError::MultipartError("Invalid inner HTTP response format".to_string()));
324 }
325 } else {
326 return Err(MessagingError::MultipartError("Invalid multipart part format".to_string()));
327 }
328 }
329
330 Ok(responses)
331 }
332
333 pub async fn send_each_for_multicast(&self, message: &MulticastMessage, dry_run: bool) -> Result<BatchResponse, MessagingError> {
342 let messages: Vec<Message> = message.tokens.iter().map(|token| {
343 Message {
344 token: Some(token.clone()),
345 data: message.data.clone(),
346 notification: message.notification.clone(),
347 android: message.android.clone(),
348 webpush: message.webpush.clone(),
349 apns: message.apns.clone(),
350 fcm_options: message.fcm_options.clone(),
351 ..Default::default()
352 }
353 }).collect();
354
355 self.send_each(&messages, dry_run).await
356 }
357
358 pub async fn subscribe_to_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
365 self.manage_topic(topic, tokens, true).await
366 }
367
368 pub async fn unsubscribe_from_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
375 self.manage_topic(topic, tokens, false).await
376 }
377
378 async fn manage_topic(&self, topic: &str, tokens: &[&str], subscribe: bool) -> Result<TopicManagementResponse, MessagingError> {
379 let topic_path = if topic.starts_with("/topics/") {
380 topic.to_string()
381 } else {
382 format!("/topics/{}", topic)
383 };
384
385 let url = if subscribe {
386 format!("{}/iid/v1:batchAdd", self.iid_base_url)
387 } else {
388 format!("{}/iid/v1:batchRemove", self.iid_base_url)
389 };
390
391 let mut response_summary = TopicManagementResponse::default();
392
393 for (batch_idx, chunk) in tokens.chunks(1000).enumerate() {
394 let request = TopicManagementRequest {
395 to: topic_path.clone(),
396 registration_tokens: chunk,
397 };
398
399 let response = self.client
400 .post(&url)
401 .header(header::CONTENT_TYPE, "application/json")
402 .header("access_token_auth", "true") .body(serde_json::to_vec(&request)?)
406 .send()
407 .await?;
408
409 if !response.status().is_success() {
410 return Err(MessagingError::ApiError(parse_error_response(response, "Topic management failed").await));
411 }
412
413 let api_response: TopicManagementApiResponse = response.json().await?;
414
415 if let Some(results) = api_response.results {
416 for (i, result) in results.iter().enumerate() {
417 if let Some(error) = &result.error {
418 response_summary.failure_count += 1;
419 response_summary.errors.push(TopicManagementError {
420 index: batch_idx * 1000 + i,
421 reason: error.clone(),
422 });
423 } else {
424 response_summary.success_count += 1;
425 }
426 }
427 }
428 }
429
430 Ok(response_summary)
431 }
432}