firebase_admin_sdk/messaging/
mod.rs

1//! Firebase Cloud Messaging (FCM) module.
2//!
3//! This module provides functionality for sending messages via FCM (single, batch, multicast)
4//! and managing topic subscriptions.
5//!
6//! # Examples
7//!
8//! ```rust,ignore
9//! use firebase_admin_sdk::messaging::models::{Message, Notification};
10//! # use firebase_admin_sdk::FirebaseApp;
11//! # async fn run(app: FirebaseApp) {
12//! let messaging = app.messaging();
13//!
14//! let message = Message {
15//!     token: Some("device_token".to_string()),
16//!     notification: Some(Notification {
17//!         title: Some("Title".to_string()),
18//!         body: Some("Body".to_string()),
19//!         ..Default::default()
20//!     }),
21//!     ..Default::default()
22//! };
23//!
24//! let result = messaging.send(&message, false).await;
25//! # }
26//! ```
27
28use reqwest::{Client, header};
29use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
30use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
31use crate::core::middleware::AuthMiddleware;
32use crate::core::parse_error_response;
33use crate::messaging::models::{Message, MulticastMessage, TopicManagementResponse, TopicManagementError, BatchResponse, SendResponse, SendResponseInternal};
34use thiserror::Error;
35use serde::{Deserialize, Serialize};
36
37pub mod models;
38
39#[cfg(test)]
40mod tests;
41
42/// Errors that can occur during Messaging operations.
43#[derive(Error, Debug)]
44pub enum MessagingError {
45    /// Wrapper for `reqwest::Error`.
46    #[error("HTTP Request failed: {0}")]
47    RequestError(#[from] reqwest::Error),
48    /// Wrapper for `reqwest_middleware::Error`.
49    #[error("Middleware error: {0}")]
50    MiddlewareError(#[from] reqwest_middleware::Error),
51    /// Errors returned by the FCM API.
52    #[error("API error: {0}")]
53    ApiError(String),
54    /// Wrapper for `serde_json::Error`.
55    #[error("Serialization error: {0}")]
56    SerializationError(#[from] serde_json::Error),
57    /// Error parsing multipart responses for batch requests.
58    #[error("Multipart response parsing error: {0}")]
59    MultipartError(String),
60}
61
62/// Client for interacting with Firebase Cloud Messaging.
63#[derive(Clone)]
64pub struct FirebaseMessaging {
65    client: ClientWithMiddleware,
66    project_id: String,
67    base_url: String,
68}
69
70// Wrapper for the request body required by FCM v1 API
71#[derive(Serialize)]
72#[serde(rename_all = "camelCase")]
73struct SendRequest<'a> {
74    validate_only: bool,
75    message: &'a Message,
76}
77
78#[derive(Serialize)]
79struct TopicManagementRequest<'a> {
80    to: String,
81    registration_tokens: &'a [&'a str],
82}
83
84#[derive(Deserialize)]
85struct TopicManagementApiResponse {
86    results: Option<Vec<TopicManagementApiResult>>,
87}
88
89#[derive(Deserialize)]
90struct TopicManagementApiResult {
91    error: Option<String>,
92}
93
94impl FirebaseMessaging {
95    /// Creates a new `FirebaseMessaging` instance.
96    ///
97    /// This is typically called via `FirebaseApp::messaging()`.
98    pub fn new(middleware: AuthMiddleware) -> Self {
99        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
100
101        let client = ClientBuilder::new(Client::new())
102            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
103            .with(middleware.clone())
104            .build();
105
106        let project_id = middleware.key.project_id.clone().unwrap_or_default();
107        let base_url = format!("https://fcm.googleapis.com/v1/projects/{}/messages:send", project_id);
108
109        Self {
110            client,
111            project_id,
112            base_url,
113        }
114    }
115
116    #[cfg(test)]
117    pub(crate) fn new_with_url(middleware: AuthMiddleware, base_url: String) -> Self {
118        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
119        let client = ClientBuilder::new(Client::new())
120            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
121            .with(middleware.clone())
122            .build();
123        let project_id = middleware.key.project_id.clone().unwrap_or_default();
124        Self { client, project_id, base_url }
125    }
126
127    /// Sends a message to a specific target (token, topic, or condition).
128    ///
129    /// # Arguments
130    ///
131    /// * `message` - The `Message` struct defining the payload and target.
132    /// * `dry_run` - If true, the message will be validated but not sent.
133    pub async fn send(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
134        self.validate_message(message)?;
135        self.send_request(message, dry_run).await
136    }
137
138    /// Validates that the message has exactly one target.
139    fn validate_message(&self, message: &Message) -> Result<(), MessagingError> {
140        let num_targets = [
141            message.token.is_some(),
142            message.topic.is_some(),
143            message.condition.is_some(),
144        ]
145        .iter()
146        .filter(|&&t| t)
147        .count();
148
149        if num_targets != 1 {
150            return Err(MessagingError::ApiError(
151                "Message must have exactly one of token, topic, or condition.".to_string(),
152            ));
153        }
154
155        Ok(())
156    }
157
158    /// Internal method to send the HTTP request.
159    async fn send_request(&self, message: &Message, dry_run: bool) -> Result<String, MessagingError> {
160        let request = SendRequest {
161            validate_only: dry_run,
162            message,
163        };
164
165        let response = self.client
166            .post(&self.base_url)
167            .header(header::CONTENT_TYPE, "application/json")
168            .body(serde_json::to_vec(&request)?)
169            .send()
170            .await?;
171
172        if !response.status().is_success() {
173            return Err(MessagingError::ApiError(parse_error_response(response, "FCM send failed").await));
174        }
175
176        let result: SendResponseInternal = response.json().await?;
177        Ok(result.name)
178    }
179
180    /// Sends a batch of messages.
181    ///
182    /// This uses the FCM batch endpoint to send up to 500 messages in a single HTTP request.
183    ///
184    /// # Arguments
185    ///
186    /// * `messages` - A slice of `Message` structs.
187    /// * `dry_run` - If true, the messages will be validated but not sent.
188    pub async fn send_each(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
189        for message in messages {
190            self.validate_message(message)?;
191        }
192        self.send_each_request(messages, dry_run).await
193    }
194
195    async fn send_each_request(&self, messages: &[Message], dry_run: bool) -> Result<BatchResponse, MessagingError> {
196        if messages.is_empty() {
197            return Ok(BatchResponse::default());
198        }
199
200        if messages.len() > 500 {
201            return Err(MessagingError::ApiError("Cannot send more than 500 messages in a single batch.".to_string()));
202        }
203
204        let url = format!("https://fcm.googleapis.com/batch");
205        let boundary = format!("batch_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
206
207        let body = self.build_multipart_body(messages, dry_run, &boundary)?;
208
209        let content_type = format!("multipart/mixed; boundary={}", boundary);
210
211        let response = self.client
212            .post(&url)
213            .header(header::CONTENT_TYPE, content_type)
214            .body(body)
215            .send()
216            .await?;
217
218        if !response.status().is_success() {
219            return Err(MessagingError::ApiError(parse_error_response(response, "FCM batch send failed").await));
220        }
221
222        let multipart_boundary = response
223            .headers()
224            .get(header::CONTENT_TYPE)
225            .and_then(|ct| ct.to_str().ok())
226            .and_then(|ct| ct.split("boundary=").nth(1))
227            .map(|s| s.to_string())
228            .ok_or_else(|| MessagingError::MultipartError("Multipart boundary not found in response".to_string()))?;
229
230        let text = response.text().await?;
231        let responses = self.parse_multipart_response(&text, &multipart_boundary)?;
232
233        let success_count = responses.iter().filter(|r| r.success).count();
234        let failure_count = responses.len() - success_count;
235
236        Ok(BatchResponse {
237            success_count,
238            failure_count,
239            responses,
240        })
241    }
242
243    fn build_multipart_body(&self, messages: &[Message], dry_run: bool, boundary: &str) -> Result<Vec<u8>, MessagingError> {
244        let mut body = Vec::new();
245
246        for message in messages {
247            let send_request = SendRequest {
248                validate_only: dry_run,
249                message,
250            };
251
252            let post_url = format!("/v1/projects/{}/messages:send", self.project_id);
253            let request_body = serde_json::to_string(&send_request)?;
254
255            body.extend_from_slice(b"--");
256            body.extend_from_slice(boundary.as_bytes());
257            body.extend_from_slice(b"\r\n");
258            body.extend_from_slice(b"Content-Type: application/http\r\n");
259            body.extend_from_slice(b"Content-Transfer-Encoding: binary\r\n\r\n");
260            body.extend_from_slice(b"POST ");
261            body.extend_from_slice(post_url.as_bytes());
262            body.extend_from_slice(b"\r\n");
263            body.extend_from_slice(b"Content-Type: application/json\r\n");
264            body.extend_from_slice(b"\r\n");
265            body.extend_from_slice(request_body.as_bytes());
266            body.extend_from_slice(b"\r\n");
267        }
268
269        body.extend_from_slice(b"--");
270        body.extend_from_slice(boundary.as_bytes());
271        body.extend_from_slice(b"--\r\n");
272
273        Ok(body)
274    }
275
276    fn parse_multipart_response(&self, body: &str, boundary: &str) -> Result<Vec<SendResponse>, MessagingError> {
277        let boundary = format!("--{}", boundary);
278        let parts: Vec<&str> = body.split(&boundary)
279            .filter(|p| !p.trim().is_empty() && p.trim() != "--")
280            .collect();
281        let mut responses = Vec::new();
282
283        for part in parts {
284            let http_part = part.trim();
285
286            if let Some(inner_response_start) = http_part.find("\r\n\r\n") {
287                let inner_response = &http_part[inner_response_start + 4..];
288
289                if let Some(json_start) = inner_response.find("\r\n\r\n") {
290                    let json_body = inner_response[json_start + 4..].trim();
291
292                    if json_body.is_empty() {
293                        return Err(MessagingError::MultipartError("Empty JSON body in response part".to_string()));
294                    }
295
296                    let status_line = inner_response.lines().next().unwrap_or("");
297                    if status_line.contains("200 OK") {
298                        match serde_json::from_str::<SendResponseInternal>(json_body) {
299                            Ok(send_response) => responses.push(SendResponse {
300                                success: true,
301                                message_id: Some(send_response.name),
302                                error: None,
303                            }),
304                            Err(_) => return Err(MessagingError::MultipartError("Failed to parse successful response part".to_string())),
305                        }
306                    } else { // It's an error response
307                         match serde_json::from_str::<serde_json::Value>(json_body) {
308                            Ok(error_response) => responses.push(SendResponse {
309                                success: false,
310                                message_id: None,
311                                error: Some(error_response.to_string()),
312                            }),
313                            Err(_) => return Err(MessagingError::MultipartError("Failed to parse error response part".to_string())),
314                        }
315                    }
316                } else {
317                     return Err(MessagingError::MultipartError("Invalid inner HTTP response format".to_string()));
318                }
319            } else {
320                return Err(MessagingError::MultipartError("Invalid multipart part format".to_string()));
321            }
322        }
323
324        Ok(responses)
325    }
326
327    /// Sends a multicast message to all specified tokens.
328    ///
329    /// This is a wrapper around `send_each` that constructs individual messages for each token.
330    ///
331    /// # Arguments
332    ///
333    /// * `message` - The `MulticastMessage` containing tokens and payload.
334    /// * `dry_run` - If true, the messages will be validated but not sent.
335    pub async fn send_each_for_multicast(&self, message: &MulticastMessage, dry_run: bool) -> Result<BatchResponse, MessagingError> {
336        let messages: Vec<Message> = message.tokens.iter().map(|token| {
337            Message {
338                token: Some(token.clone()),
339                data: message.data.clone(),
340                notification: message.notification.clone(),
341                android: message.android.clone(),
342                webpush: message.webpush.clone(),
343                apns: message.apns.clone(),
344                fcm_options: message.fcm_options.clone(),
345                ..Default::default()
346            }
347        }).collect();
348
349        self.send_each(&messages, dry_run).await
350    }
351
352    /// Subscribes a list of tokens to a topic.
353    ///
354    /// # Arguments
355    ///
356    /// * `tokens` - A list of device registration tokens.
357    /// * `topic` - The name of the topic.
358    pub async fn subscribe_to_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
359        self.manage_topic(topic, tokens, true).await
360    }
361
362    /// Unsubscribes a list of tokens from a topic.
363    ///
364    /// # Arguments
365    ///
366    /// * `tokens` - A list of device registration tokens.
367    /// * `topic` - The name of the topic.
368    pub async fn unsubscribe_from_topic(&self, tokens: &[&str], topic: &str) -> Result<TopicManagementResponse, MessagingError> {
369        self.manage_topic(topic, tokens, false).await
370    }
371
372    async fn manage_topic(&self, topic: &str, tokens: &[&str], subscribe: bool) -> Result<TopicManagementResponse, MessagingError> {
373        let topic_path = if topic.starts_with("/topics/") {
374            topic.to_string()
375        } else {
376            format!("/topics/{}", topic)
377        };
378
379        let url = if subscribe {
380            "https://iid.googleapis.com/iid/v1:batchAdd"
381        } else {
382            "https://iid.googleapis.com/iid/v1:batchRemove"
383        };
384
385        let mut response_summary = TopicManagementResponse::default();
386
387        for (batch_idx, chunk) in tokens.chunks(1000).enumerate() {
388            let request = TopicManagementRequest {
389                to: topic_path.clone(),
390                registration_tokens: chunk,
391            };
392
393            let response = self.client
394                .post(url)
395                .header(header::CONTENT_TYPE, "application/json")
396                // Use access_token_header from AuthMiddleware, but the IID API also requires the standard header.
397                // The AuthMiddleware adds it automatically.
398                .header("access_token_auth", "true") // Some docs suggest this for IID, but standard Bearer should work.
399                .body(serde_json::to_vec(&request)?)
400                .send()
401                .await?;
402
403            if !response.status().is_success() {
404                return Err(MessagingError::ApiError(parse_error_response(response, "Topic management failed").await));
405            }
406
407            let api_response: TopicManagementApiResponse = response.json().await?;
408
409            if let Some(results) = api_response.results {
410                for (i, result) in results.iter().enumerate() {
411                     if let Some(error) = &result.error {
412                         response_summary.failure_count += 1;
413                         response_summary.errors.push(TopicManagementError {
414                             index: batch_idx * 1000 + i,
415                             reason: error.clone(),
416                         });
417                     } else {
418                         response_summary.success_count += 1;
419                     }
420                }
421            }
422        }
423
424        Ok(response_summary)
425    }
426}