apple_apns/
client.rs

1use std::time::Duration;
2
3use reqwest::tls::Version;
4#[cfg(feature = "rustls")]
5use reqwest::{Certificate, Identity};
6use reqwest_middleware::ClientWithMiddleware;
7use serde::Serialize;
8use url::Url;
9use uuid::Uuid;
10
11use crate::endpoint::Endpoint;
12use crate::header::APNS_ID;
13use crate::payload::*;
14use crate::reason::Reason;
15use crate::request::Request;
16use crate::result::{Error, Result};
17#[cfg(feature = "jwt")]
18use crate::token::TokenFactory;
19
20/// Default user agent.
21pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
22
23/// Authentication options.
24#[cfg(any(feature = "rustls", feature = "jwt"))]
25#[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "jwt"))))]
26#[derive(Debug, Clone)]
27pub enum Authentication<'a> {
28    /// If you’re using certificate-based authentication, you send your provider
29    /// certificate to APNs when setting up your TLS connection. For more
30    /// information, see [Establishing a Certificate-Based Connection to
31    /// APNs](https://developer.apple.com/documentation/usernotifications/setting_up_a_remote_notification_server/establishing_a_certificate-based_connection_to_apns).
32    #[cfg(feature = "rustls")]
33    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
34    Certificate { client_pem: &'a [u8] },
35
36    /// (Required for token-based authentication) The value of this header is
37    /// bearer <provider_token>, where <provider_token> is the encrypted token
38    /// that authorizes you to send notifications for the specified topic. APNs
39    /// ignores this header if you use certificate-based authentication. For
40    /// more information, see [Establishing a Token-Based Connection to
41    /// APNs](https://developer.apple.com/documentation/usernotifications/setting_up_a_remote_notification_server/establishing_a_token-based_connection_to_apns).
42    #[cfg(feature = "jwt")]
43    #[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
44    Token {
45        key_id: &'a str,
46        key_pem: &'a [u8],
47        team_id: &'a str,
48    },
49}
50
51/// Certificate authority options.
52#[cfg(feature = "rustls")]
53#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
54#[derive(Debug, Clone)]
55pub enum CertificateAuthority<'a> {
56    Pem(&'a [u8]),
57    Der(&'a [u8]),
58}
59
60/// [`Client`] builder.
61#[derive(Debug, Clone)]
62pub struct ClientBuilder<'a> {
63    pub endpoint: Endpoint,
64    pub user_agent: &'a str,
65
66    #[cfg(feature = "rustls")]
67    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
68    pub ca: Option<CertificateAuthority<'a>>,
69
70    #[cfg(any(feature = "rustls", feature = "jwt"))]
71    #[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "jwt"))))]
72    pub authentication: Option<Authentication<'a>>,
73}
74
75impl<'a> Default for ClientBuilder<'a> {
76    fn default() -> Self {
77        Self {
78            endpoint: Endpoint::default(),
79            user_agent: USER_AGENT,
80
81            #[cfg(feature = "rustls")]
82            ca: None,
83
84            #[cfg(any(feature = "rustls", feature = "jwt"))]
85            authentication: None,
86        }
87    }
88}
89
90impl<'a> ClientBuilder<'a> {
91    /// Creates a new [`ClientBuilder`].
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Builds a `Client`.
97    pub fn build(self) -> Result<Client> {
98        let client = self.reqwest_client_builder()?.build();
99        self.with_reqwest_middleware_client(client)
100    }
101
102    /// Builds a `Client` with middleware.
103    ///
104    /// ```rust
105    /// use reqwest_tracing::{SpanBackendWithUrl, TracingMiddleware};
106    ///
107    /// let _client = apple_apns::ClientBuilder::new().build_with_middleware(|builder| {
108    ///     Ok(builder.with(TracingMiddleware::<SpanBackendWithUrl>::new()))
109    /// }).unwrap();
110    /// ```
111    pub fn build_with_middleware<F>(self, f: F) -> Result<Client>
112    where
113        F: FnOnce(reqwest_middleware::ClientBuilder) -> Result<reqwest_middleware::ClientBuilder>,
114    {
115        let builder = self.reqwest_client_builder()?;
116        let builder = f(builder)?;
117        self.with_reqwest_middleware_client(builder.build())
118    }
119
120    fn with_reqwest_middleware_client(&self, client: ClientWithMiddleware) -> Result<Client> {
121        let base_url = self.endpoint.as_url().clone();
122
123        #[cfg(feature = "jwt")]
124        let token_factory = if let Some(Authentication::Token {
125            key_id,
126            key_pem,
127            team_id,
128        }) = self.authentication
129        {
130            Some(TokenFactory::new(key_id, key_pem, team_id)?)
131        } else {
132            None
133        };
134
135        Ok(Client {
136            base_url,
137            client,
138            #[cfg(feature = "jwt")]
139            token_factory,
140        })
141    }
142
143    fn reqwest_client_builder(&self) -> Result<reqwest_middleware::ClientBuilder> {
144        #[allow(unused_mut)]
145        let mut builder = reqwest::Client::builder()
146            .user_agent(self.user_agent)
147            .pool_idle_timeout(None)
148            .http2_keep_alive_interval(Some(Duration::from_secs(60 * 60)))
149            .http2_keep_alive_timeout(Duration::from_secs(60))
150            .http2_keep_alive_while_idle(true)
151            .min_tls_version(Version::TLS_1_2);
152
153        #[cfg(not(feature = "http1"))]
154        {
155            builder = builder.http2_prior_knowledge();
156        }
157
158        #[cfg(feature = "rustls")]
159        {
160            // Force rustls
161            builder = builder.use_rustls_tls();
162
163            // Add root certificate
164            if let Some(ca) = &self.ca {
165                let cert = match ca {
166                    CertificateAuthority::Pem(pem) => Certificate::from_pem(pem)?,
167                    CertificateAuthority::Der(der) => Certificate::from_der(der)?,
168                };
169                builder = builder.add_root_certificate(cert);
170            }
171
172            // Configure certificate authentication
173            if let Some(Authentication::Certificate { client_pem }) = self.authentication {
174                let identity = Identity::from_pem(client_pem)?;
175                builder = builder.identity(identity);
176            }
177        }
178
179        let client = builder.build()?;
180        let builder = reqwest_middleware::ClientBuilder::new(client);
181        Ok(builder)
182    }
183}
184
185/// Apple Push Notification service client.
186///
187/// The [`Client`] is safe to use from multiple threads. However, [`Client`]
188/// uses a [`std::sync::RwLock`] and is not [`Clone`]. To pass [`Client`] to
189/// multiple threads, use [`std::sync::Arc`] for OS threads, or [`std::rc::Rc`]
190/// for green threads.
191pub struct Client {
192    base_url: Url,
193    client: ClientWithMiddleware,
194
195    #[cfg(feature = "jwt")]
196    token_factory: Option<TokenFactory>,
197}
198
199impl Client {
200    /// Creates a [`ClientBuilder`].
201    pub fn builder<'a>() -> ClientBuilder<'a> {
202        ClientBuilder::new()
203    }
204
205    /// Sends a push notification and returns the APNS ID.
206    pub async fn post<T>(&self, request: Request<T>) -> Result<Uuid>
207    where
208        T: Serialize,
209    {
210        let url = self.base_url.join(&request.device_token)?;
211        let payload_size_limit = request.push_type.payload_size_limit();
212        let (headers, payload): (_, Payload<T>) = request.try_into()?;
213
214        let body = serde_json::to_vec(&payload)?;
215        if body.len() > payload_size_limit {
216            return Err(Error::PayloadTooLarge {
217                size: body.len(),
218                limit: payload_size_limit,
219            });
220        }
221
222        let mut req = self.client.post(url).body(body);
223        for (name, value) in headers {
224            if let Some(name) = name {
225                req = req.header(name, value);
226            }
227        }
228
229        #[cfg(feature = "jwt")]
230        if let Some(token_factory) = &self.token_factory {
231            let jwt = token_factory.get()?;
232            req = req.bearer_auth(jwt);
233        }
234
235        let res = req.send().await?;
236
237        if let Err(err) = res.error_for_status_ref() {
238            if let Ok(reason) = res.json::<Reason>().await {
239                Err(reason.into())
240            } else {
241                Err(err.into())
242            }
243        } else {
244            let apns_id = res
245                .headers()
246                .get(&APNS_ID)
247                .and_then(|v| v.to_str().ok())
248                .and_then(|s| s.parse().ok())
249                .unwrap_or_default();
250            Ok(apns_id)
251        }
252    }
253}