1use std::time::Duration;
2
3use reqwest::tls::Version;
4#[cfg(feature = "rustls")]
5use reqwest::{Certificate, Identity};
6use reqwest_middleware::ClientWithMiddleware;
7use serde::Serialize;
8use url::Url;
9use uuid::Uuid;
10
11use crate::endpoint::Endpoint;
12use crate::header::APNS_ID;
13use crate::payload::*;
14use crate::reason::Reason;
15use crate::request::Request;
16use crate::result::{Error, Result};
17#[cfg(feature = "jwt")]
18use crate::token::TokenFactory;
19
20pub const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
22
23#[cfg(any(feature = "rustls", feature = "jwt"))]
25#[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "jwt"))))]
26#[derive(Debug, Clone)]
27pub enum Authentication<'a> {
28 #[cfg(feature = "rustls")]
33 #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
34 Certificate { client_pem: &'a [u8] },
35
36 #[cfg(feature = "jwt")]
43 #[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
44 Token {
45 key_id: &'a str,
46 key_pem: &'a [u8],
47 team_id: &'a str,
48 },
49}
50
51#[cfg(feature = "rustls")]
53#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
54#[derive(Debug, Clone)]
55pub enum CertificateAuthority<'a> {
56 Pem(&'a [u8]),
57 Der(&'a [u8]),
58}
59
60#[derive(Debug, Clone)]
62pub struct ClientBuilder<'a> {
63 pub endpoint: Endpoint,
64 pub user_agent: &'a str,
65
66 #[cfg(feature = "rustls")]
67 #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
68 pub ca: Option<CertificateAuthority<'a>>,
69
70 #[cfg(any(feature = "rustls", feature = "jwt"))]
71 #[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "jwt"))))]
72 pub authentication: Option<Authentication<'a>>,
73}
74
75impl<'a> Default for ClientBuilder<'a> {
76 fn default() -> Self {
77 Self {
78 endpoint: Endpoint::default(),
79 user_agent: USER_AGENT,
80
81 #[cfg(feature = "rustls")]
82 ca: None,
83
84 #[cfg(any(feature = "rustls", feature = "jwt"))]
85 authentication: None,
86 }
87 }
88}
89
90impl<'a> ClientBuilder<'a> {
91 pub fn new() -> Self {
93 Default::default()
94 }
95
96 pub fn build(self) -> Result<Client> {
98 let client = self.reqwest_client_builder()?.build();
99 self.with_reqwest_middleware_client(client)
100 }
101
102 pub fn build_with_middleware<F>(self, f: F) -> Result<Client>
112 where
113 F: FnOnce(reqwest_middleware::ClientBuilder) -> Result<reqwest_middleware::ClientBuilder>,
114 {
115 let builder = self.reqwest_client_builder()?;
116 let builder = f(builder)?;
117 self.with_reqwest_middleware_client(builder.build())
118 }
119
120 fn with_reqwest_middleware_client(&self, client: ClientWithMiddleware) -> Result<Client> {
121 let base_url = self.endpoint.as_url().clone();
122
123 #[cfg(feature = "jwt")]
124 let token_factory = if let Some(Authentication::Token {
125 key_id,
126 key_pem,
127 team_id,
128 }) = self.authentication
129 {
130 Some(TokenFactory::new(key_id, key_pem, team_id)?)
131 } else {
132 None
133 };
134
135 Ok(Client {
136 base_url,
137 client,
138 #[cfg(feature = "jwt")]
139 token_factory,
140 })
141 }
142
143 fn reqwest_client_builder(&self) -> Result<reqwest_middleware::ClientBuilder> {
144 #[allow(unused_mut)]
145 let mut builder = reqwest::Client::builder()
146 .user_agent(self.user_agent)
147 .pool_idle_timeout(None)
148 .http2_keep_alive_interval(Some(Duration::from_secs(60 * 60)))
149 .http2_keep_alive_timeout(Duration::from_secs(60))
150 .http2_keep_alive_while_idle(true)
151 .min_tls_version(Version::TLS_1_2);
152
153 #[cfg(not(feature = "http1"))]
154 {
155 builder = builder.http2_prior_knowledge();
156 }
157
158 #[cfg(feature = "rustls")]
159 {
160 builder = builder.use_rustls_tls();
162
163 if let Some(ca) = &self.ca {
165 let cert = match ca {
166 CertificateAuthority::Pem(pem) => Certificate::from_pem(pem)?,
167 CertificateAuthority::Der(der) => Certificate::from_der(der)?,
168 };
169 builder = builder.add_root_certificate(cert);
170 }
171
172 if let Some(Authentication::Certificate { client_pem }) = self.authentication {
174 let identity = Identity::from_pem(client_pem)?;
175 builder = builder.identity(identity);
176 }
177 }
178
179 let client = builder.build()?;
180 let builder = reqwest_middleware::ClientBuilder::new(client);
181 Ok(builder)
182 }
183}
184
185pub struct Client {
192 base_url: Url,
193 client: ClientWithMiddleware,
194
195 #[cfg(feature = "jwt")]
196 token_factory: Option<TokenFactory>,
197}
198
199impl Client {
200 pub fn builder<'a>() -> ClientBuilder<'a> {
202 ClientBuilder::new()
203 }
204
205 pub async fn post<T>(&self, request: Request<T>) -> Result<Uuid>
207 where
208 T: Serialize,
209 {
210 let url = self.base_url.join(&request.device_token)?;
211 let payload_size_limit = request.push_type.payload_size_limit();
212 let (headers, payload): (_, Payload<T>) = request.try_into()?;
213
214 let body = serde_json::to_vec(&payload)?;
215 if body.len() > payload_size_limit {
216 return Err(Error::PayloadTooLarge {
217 size: body.len(),
218 limit: payload_size_limit,
219 });
220 }
221
222 let mut req = self.client.post(url).body(body);
223 for (name, value) in headers {
224 if let Some(name) = name {
225 req = req.header(name, value);
226 }
227 }
228
229 #[cfg(feature = "jwt")]
230 if let Some(token_factory) = &self.token_factory {
231 let jwt = token_factory.get()?;
232 req = req.bearer_auth(jwt);
233 }
234
235 let res = req.send().await?;
236
237 if let Err(err) = res.error_for_status_ref() {
238 if let Ok(reason) = res.json::<Reason>().await {
239 Err(reason.into())
240 } else {
241 Err(err.into())
242 }
243 } else {
244 let apns_id = res
245 .headers()
246 .get(&APNS_ID)
247 .and_then(|v| v.to_str().ok())
248 .and_then(|s| s.parse().ok())
249 .unwrap_or_default();
250 Ok(apns_id)
251 }
252 }
253}