Skip to main content

firebase_admin_sdk/messaging/
mod.rs

1//! Firebase Cloud Messaging (FCM) module.
2//!
3//! This module provides functionality for sending messages via FCM (single, batch, multicast)
4//! and managing topic subscriptions.
5//!
6//! # Examples
7//!
8//! ```rust,no_run
9//! use firebase_admin_sdk::messaging::models::{Message, Notification};
10//! # use firebase_admin_sdk::FirebaseApp;
11//! # async fn run(app: FirebaseApp) {
12//! let messaging = app.messaging();
13//!
14//! let message = Message {
15//!     token: Some("device_token".to_string()),
16//!     notification: Some(Notification {
17//!         title: Some("Title".to_string()),
18//!         body: Some("Body".to_string()),
19//!         ..Default::default()
20//!     }),
21//!     ..Default::default()
22//! };
23//!
24//! let result = messaging.send(&message, false).await;
25//! # }
26//! ```
27
28use reqwest::{Client, header};
29use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
30use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
31use crate::core::middleware::AuthMiddleware;
32use crate::core::parse_error_response;
33use crate::messaging::models::{Message, MulticastMessage, TopicManagementResponse, TopicManagementError, BatchResponse, SendResponse, SendResponseInternal};
34use thiserror::Error;
35use serde::{Deserialize, Serialize};
36
37pub mod models;
38
39#[cfg(test)]
40mod tests;
41
42/// Errors that can occur during Messaging operations.
43#[derive(Error, Debug)]
44pub enum MessagingError {
45    /// Wrapper for `reqwest::Error`.
46    #[error("HTTP Request failed: {0}")]
47    RequestError(#[from] reqwest::Error),
48    /// Wrapper for `reqwest_middleware::Error`.
49    #[error("Middleware error: {0}")]
50    MiddlewareError(#[from] reqwest_middleware::Error),
51    /// Errors returned by the FCM API.
52    #[error("API error: {0}")]
53    ApiError(String),
54    /// Wrapper for `serde_json::Error`.
55    #[error("Serialization error: {0}")]
56    SerializationError(#[from] serde_json::Error),
57    /// Error parsing multipart responses for batch requests.
58    #[error("Multipart response parsing error: {0}")]
59    MultipartError(String),
60}
61
62/// Client for interacting with Firebase Cloud Messaging.
63#[derive(Clone)]
64pub struct FirebaseMessaging {
65    client: ClientWithMiddleware,
66    project_id: String,
67    base_url: String,
68    batch_url: String,
69    iid_base_url: String,
70}
71
72// Wrapper for the request body required by FCM v1 API
73#[derive(Serialize)]
74#[serde(rename_all = "camelCase")]
75struct SendRequest<'a> {
76    validate_only: bool,
77    message: &'a Message,
78}
79
80#[derive(Serialize)]
81struct TopicManagementRequest<'a> {
82    to: String,
83    registration_tokens: &'a [&'a str],
84}
85
86#[derive(Deserialize)]
87struct TopicManagementApiResponse {
88    results: Option<Vec<TopicManagementApiResult>>,
89}
90
91#[derive(Deserialize)]
92struct TopicManagementApiResult {
93    error: Option<String>,
94}
95
96impl FirebaseMessaging {
97    /// Creates a new `FirebaseMessaging` instance.
98    ///
99    /// This is typically called via `FirebaseApp::messaging()`.
100    pub fn new(middleware: AuthMiddleware) -> Self {
101        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
102
103        let client = ClientBuilder::new(Client::new())
104            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
105            .with(middleware.clone())
106            .build();
107
108        let project_id = middleware.key.project_id.clone().unwrap_or_default();
109        let base_url = format!("https://fcm.googleapis.com/v1/projects/{}/messages:send", project_id);
110        let batch_url = "https://fcm.googleapis.com/batch".to_string();
111        let iid_base_url = "https://iid.googleapis.com".to_string();
112
113        Self {
114            client,
115            project_id,
116            base_url,
117            batch_url,
118            iid_base_url,
119        }
120    }
121
122    #[cfg(test)]
123    pub(crate) fn new_with_url(middleware: AuthMiddleware, base_url: String, batch_url: String, iid_base_url: String) -> Self {
124        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
125        let client = ClientBuilder::new(Client::new())
126            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
127            .with(middleware.clone())
128            .build();
129        let project_id = middleware.key.project_id.clone().unwrap_or_default();
130        Self { client, project_id, base_url, batch_url, iid_base_url }
131    }
132
133    /// Sends a message to a specific target (token, topic, or condition).
134    ///
135    /// # Arguments
136    ///
137    /// * `message` - The `Message` struct defining the payload and target.
138    /// * `dry_run` - If true, the message will be validated but not sent.
139    pub async fn send(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
140        self.validate_message(message)?;
141        self.send_request(message, dry_run).await
142    }
143
144    /// Validates that the message has exactly one target.
145    fn validate_message(&self, message: &Message) -> Result<(), MessagingError> {
146        let num_targets = [
147            message.token.is_some(),
148            message.topic.is_some(),
149            message.condition.is_some(),
150        ]
151        .iter()
152        .filter(|&&t| t)
153        .count();
154
155        if num_targets != 1 {
156            return Err(MessagingError::ApiError(
157                "Message must have exactly one of token, topic, or condition.".to_string(),
158            ));
159        }
160
161        Ok(())
162    }
163
164    /// Internal method to send the HTTP request.
165    async fn send_request(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
166        let request = SendRequest {
167            validate_only: dry_run,
168            message,
169        };
170
171        let response = self.client
172            .post(&self.base_url)
173            .header(header::CONTENT_TYPE, "application/json")
174            .body(serde_json::to_vec(&request)?)
175            .send()
176            .await?;
177
178        if !response.status().is_success() {
179            return Err(MessagingError::ApiError(parse_error_response(response, "FCM send failed").await));
180        }
181
182        let result: SendResponseInternal = response.json().await?;
183        Ok(result.name)
184    }
185
186    /// Sends a batch of messages.
187    ///
188    /// This uses the FCM batch endpoint to send up to 500 messages in a single HTTP request.
189    ///
190    /// # Arguments
191    ///
192    /// * `messages` - A slice of `Message` structs.
193    /// * `dry_run` - If true, the messages will be validated but not sent.
194    pub async fn send_each(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
195        for message in messages {
196            self.validate_message(message)?;
197        }
198        self.send_each_request(messages, dry_run).await
199    }
200
201    async fn send_each_request(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
202        if messages.is_empty() {
203            return Ok(BatchResponse::default());
204        }
205
206        if messages.len() > 500 {
207            return Err(MessagingError::ApiError("Cannot send more than 500 messages in a single batch.".to_string()));
208        }
209
210        let url = self.batch_url.clone();
211        let boundary = format!("batch_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
212
213        let body = self.build_multipart_body(messages, dry_run, &boundary)?;
214
215        let content_type = format!("multipart/mixed; boundary={}", boundary);
216
217        let response = self.client
218            .post(&url)
219            .header(header::CONTENT_TYPE, content_type)
220            .body(body)
221            .send()
222            .await?;
223
224        if !response.status().is_success() {
225            return Err(MessagingError::ApiError(parse_error_response(response, "FCM batch send failed").await));
226        }
227
228        let multipart_boundary = response
229            .headers()
230            .get(header::CONTENT_TYPE)
231            .and_then(|ct| ct.to_str().ok())
232            .and_then(|ct| ct.split("boundary=").nth(1))
233            .map(|s| s.to_string())
234            .ok_or_else(|| MessagingError::MultipartError("Multipart boundary not found in response".to_string()))?;
235
236        let text = response.text().await?;
237        let responses = self.parse_multipart_response(&text, &multipart_boundary)?;
238
239        let success_count = responses.iter().filter(|r| r.success).count();
240        let failure_count = responses.len() - success_count;
241
242        Ok(BatchResponse {
243            success_count,
244            failure_count,
245            responses,
246        })
247    }
248
249    fn build_multipart_body(&self, messages: &[Message], dry_run: bool, boundary: &str) -> Result<Vec<u8>, MessagingError> {
250        let mut body = Vec::new();
251
252        for message in messages {
253            let send_request = SendRequest {
254                validate_only: dry_run,
255                message,
256            };
257
258            let post_url = format!("/v1/projects/{}/messages:send", self.project_id);
259            let request_body = serde_json::to_string(&send_request)?;
260
261            body.extend_from_slice(b"--");
262            body.extend_from_slice(boundary.as_bytes());
263            body.extend_from_slice(b"\r\n");
264            body.extend_from_slice(b"Content-Type: application/http\r\n");
265            body.extend_from_slice(b"Content-Transfer-Encoding: binary\r\n\r\n");
266            body.extend_from_slice(b"POST ");
267            body.extend_from_slice(post_url.as_bytes());
268            body.extend_from_slice(b"\r\n");
269            body.extend_from_slice(b"Content-Type: application/json\r\n");
270            body.extend_from_slice(b"\r\n");
271            body.extend_from_slice(request_body.as_bytes());
272            body.extend_from_slice(b"\r\n");
273        }
274
275        body.extend_from_slice(b"--");
276        body.extend_from_slice(boundary.as_bytes());
277        body.extend_from_slice(b"--\r\n");
278
279        Ok(body)
280    }
281
282    fn parse_multipart_response(&self, body: &str, boundary: &str) -> Result<Vec<SendResponse>, MessagingError> {
283        let boundary = format!("--{}", boundary);
284        let parts: Vec<&str> = body.split(&boundary)
285            .filter(|p| !p.trim().is_empty() && p.trim() != "--")
286            .collect();
287        let mut responses = Vec::new();
288
289        for part in parts {
290            let http_part = part.trim();
291
292            if let Some(inner_response_start) = http_part.find("\r\n\r\n") {
293                let inner_response = &http_part[inner_response_start + 4..];
294
295                if let Some(json_start) = inner_response.find("\r\n\r\n") {
296                    let json_body = inner_response[json_start + 4..].trim();
297
298                    if json_body.is_empty() {
299                        return Err(MessagingError::MultipartError("Empty JSON body in response part".to_string()));
300                    }
301
302                    let status_line = inner_response.lines().next().unwrap_or("");
303                    if status_line.contains("200 OK") {
304                        match serde_json::from_str::<SendResponseInternal>(json_body) {
305                            Ok(send_response) => responses.push(SendResponse {
306                                success: true,
307                                message_id: Some(send_response.name),
308                                error: None,
309                            }),
310                            Err(_) => return Err(MessagingError::MultipartError("Failed to parse successful response part".to_string())),
311                        }
312                    } else { // It's an error response
313                         match serde_json::from_str::<serde_json::Value>(json_body) {
314                            Ok(error_response) => responses.push(SendResponse {
315                                success: false,
316                                message_id: None,
317                                error: Some(error_response.to_string()),
318                            }),
319                            Err(_) => return Err(MessagingError::MultipartError("Failed to parse error response part".to_string())),
320                        }
321                    }
322                } else {
323                     return Err(MessagingError::MultipartError("Invalid inner HTTP response format".to_string()));
324                }
325            } else {
326                return Err(MessagingError::MultipartError("Invalid multipart part format".to_string()));
327            }
328        }
329
330        Ok(responses)
331    }
332
333    /// Sends a multicast message to all specified tokens.
334    ///
335    /// This is a wrapper around `send_each` that constructs individual messages for each token.
336    ///
337    /// # Arguments
338    ///
339    /// * `message` - The `MulticastMessage` containing tokens and payload.
340    /// * `dry_run` - If true, the messages will be validated but not sent.
341    pub async fn send_each_for_multicast(&self, message: &MulticastMessage, dry_run: bool) -> Result<BatchResponse, MessagingError> {
342        let messages: Vec<Message> = message.tokens.iter().map(|token| {
343            Message {
344                token: Some(token.clone()),
345                data: message.data.clone(),
346                notification: message.notification.clone(),
347                android: message.android.clone(),
348                webpush: message.webpush.clone(),
349                apns: message.apns.clone(),
350                fcm_options: message.fcm_options.clone(),
351                ..Default::default()
352            }
353        }).collect();
354
355        self.send_each(&messages, dry_run).await
356    }
357
358    /// Subscribes a list of tokens to a topic.
359    ///
360    /// # Arguments
361    ///
362    /// * `tokens` - A list of device registration tokens.
363    /// * `topic` - The name of the topic.
364    pub async fn subscribe_to_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
365        self.manage_topic(topic, tokens, true).await
366    }
367
368    /// Unsubscribes a list of tokens from a topic.
369    ///
370    /// # Arguments
371    ///
372    /// * `tokens` - A list of device registration tokens.
373    /// * `topic` - The name of the topic.
374    pub async fn unsubscribe_from_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
375        self.manage_topic(topic, tokens, false).await
376    }
377
378    async fn manage_topic(&self, topic: &str, tokens: &[&str], subscribe: bool) -> Result<TopicManagementResponse, MessagingError> {
379        let topic_path = if topic.starts_with("/topics/") {
380            topic.to_string()
381        } else {
382            format!("/topics/{}", topic)
383        };
384
385        let url = if subscribe {
386            format!("{}/iid/v1:batchAdd", self.iid_base_url)
387        } else {
388            format!("{}/iid/v1:batchRemove", self.iid_base_url)
389        };
390
391        let mut response_summary = TopicManagementResponse::default();
392
393        for (batch_idx, chunk) in tokens.chunks(1000).enumerate() {
394            let request = TopicManagementRequest {
395                to: topic_path.clone(),
396                registration_tokens: chunk,
397            };
398
399            let response = self.client
400                .post(&url)
401                .header(header::CONTENT_TYPE, "application/json")
402                // Use access_token_header from AuthMiddleware, but the IID API also requires the standard header.
403                // The AuthMiddleware adds it automatically.
404                .header("access_token_auth", "true") // Some docs suggest this for IID, but standard Bearer should work.
405                .body(serde_json::to_vec(&request)?)
406                .send()
407                .await?;
408
409            if !response.status().is_success() {
410                return Err(MessagingError::ApiError(parse_error_response(response, "Topic management failed").await));
411            }
412
413            let api_response: TopicManagementApiResponse = response.json().await?;
414
415            if let Some(results) = api_response.results {
416                for (i, result) in results.iter().enumerate() {
417                     if let Some(error) = &result.error {
418                         response_summary.failure_count += 1;
419                         response_summary.errors.push(TopicManagementError {
420                             index: batch_idx * 1000 + i,
421                             reason: error.clone(),
422                         });
423                     } else {
424                         response_summary.success_count += 1;
425                     }
426                }
427            }
428        }
429
430        Ok(response_summary)
431    }
432}