1use crate::error::Error;
4use crate::error::Error::ResponseError;
5use crate::signer::Signer;
6use tokio::time::timeout;
7
8use crate::request::payload::PayloadLike;
9use crate::response::Response;
10use http::header::{AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE};
11use http_body_util::combinators::BoxBody;
12use http_body_util::{BodyExt, Full};
13use hyper::body::Bytes;
14use hyper::{self, StatusCode};
15use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
16use hyper_util::client::legacy::Client as HttpClient;
17use hyper_util::client::legacy::connect::HttpConnector;
18use hyper_util::rt::{TokioExecutor, TokioTimer};
19use std::convert::Infallible;
20use std::io::Read;
21use std::sync::Arc;
22use std::time::Duration;
23use std::{fmt, io};
24
25const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(20);
26
27type HyperConnector = HttpsConnector<HttpConnector>;
28
29#[derive(Debug, Clone)]
31pub enum Endpoint {
32 Production,
34 Sandbox,
36}
37
38impl fmt::Display for Endpoint {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 let host = match self {
41 Endpoint::Production => "api.push.apple.com",
42 Endpoint::Sandbox => "api.sandbox.push.apple.com",
43 };
44
45 write!(f, "{}", host)
46 }
47}
48
49#[derive(Debug, Clone)]
58pub struct Client {
59 options: ConnectionOptions,
60 http_client: HttpClient<HyperConnector, BoxBody<Bytes, Infallible>>,
61}
62
63#[derive(Debug, Clone)]
64pub struct ClientConfig {
67 pub endpoint: Endpoint,
69 pub request_timeout: Option<Duration>,
71 pub pool_idle_timeout: Option<Duration>,
73 pub http2_keep_alive_interval: Option<Duration>,
74 pub http2_keep_alive_while_idle: bool,
75}
76
77impl Default for ClientConfig {
78 fn default() -> Self {
79 Self {
80 endpoint: Endpoint::Production,
81 request_timeout: Some(DEFAULT_REQUEST_TIMEOUT),
82 pool_idle_timeout: None,
83 http2_keep_alive_interval: Some(Duration::from_secs(60 * 60)),
86 http2_keep_alive_while_idle: true,
87 }
88 }
89}
90
91impl ClientConfig {
92 pub fn new(endpoint: Endpoint) -> Self {
93 ClientConfig {
94 endpoint,
95 ..Default::default()
96 }
97 }
98}
99
100#[derive(Debug, Clone, Default)]
101struct ClientBuilder {
102 config: ClientConfig,
103 signer: Option<Signer>,
104 connector: Option<HyperConnector>,
105}
106
107impl ClientBuilder {
108 fn connector(mut self, connector: HyperConnector) -> Self {
109 self.connector = Some(connector);
110 self
111 }
112
113 fn signer(mut self, signer: Signer) -> Self {
114 self.signer = Some(signer);
115 self
116 }
117
118 fn config(mut self, config: ClientConfig) -> Self {
119 self.config = config;
120 self
121 }
122
123 fn build(self) -> Result<Client, Error> {
124 let ClientBuilder {
125 config:
126 ClientConfig {
127 endpoint,
128 request_timeout,
129 pool_idle_timeout,
130 http2_keep_alive_interval,
131 http2_keep_alive_while_idle,
132 },
133 signer,
134 connector,
135 } = self;
136
137 let connector = if let Some(connector) = connector {
138 connector
139 } else {
140 default_connector()?
141 };
142
143 let http_client = HttpClient::builder(TokioExecutor::new())
144 .pool_idle_timeout(pool_idle_timeout)
145 .http2_only(true)
146 .http2_keep_alive_interval(http2_keep_alive_interval)
147 .http2_keep_alive_while_idle(http2_keep_alive_while_idle)
148 .timer(TokioTimer::new())
149 .build(connector);
150
151 Ok(Client {
152 http_client,
153 options: ConnectionOptions::new(endpoint, signer, request_timeout),
154 })
155 }
156}
157
158#[derive(Debug, Clone)]
159struct ConnectionOptions {
160 endpoint: Endpoint,
161 request_timeout: Duration,
162 signer: Option<Signer>,
163}
164
165impl ConnectionOptions {
166 fn new(endpoint: Endpoint, signer: Option<Signer>, request_timeout: Option<Duration>) -> Self {
167 let request_timeout = request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT);
168 Self {
169 endpoint,
170 request_timeout,
171 signer,
172 }
173 }
174}
175
176impl Client {
177 fn builder() -> ClientBuilder {
180 ClientBuilder::default()
181 }
182
183 pub fn certificate<R>(certificate: &mut R, password: &str, config: ClientConfig) -> Result<Client, Error>
185 where
186 R: Read,
187 {
188 #[cfg(feature = "aws-lc-rs")]
189 fn create_connector(
190 certificate_bytes: &[u8],
191 password: &str,
192 ) -> Result<HttpsConnector<HttpConnector>, Error> {
193 let (cert_pem, key_pem) = crate::pkcs12::parse_pkcs12(certificate_bytes, password)?;
195 client_cert_connector(&cert_pem, &key_pem)
198 }
199
200 #[cfg(all(not(feature = "aws-lc-rs"), feature = "openssl"))]
201 fn create_connector(
202 certificate_bytes: &[u8],
203 password: &str,
204 ) -> Result<HttpsConnector<HttpConnector>, Error> {
205 let pkcs = openssl::pkcs12::Pkcs12::from_der(certificate_bytes)?.parse2(password)?;
206 let Some((cert, pkey)) = pkcs.cert.zip(pkcs.pkey) else {
207 return Err(Error::InvalidCertificate);
208 };
209 client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)
210 }
211
212 let certificate_bytes = {
214 let mut data = Vec::<u8>::new();
215 certificate.read_to_end(&mut data)?;
216 data
217 };
218
219 let connector = create_connector(certificate_bytes.as_ref(), password)?;
220 Self::builder().connector(connector).config(config).build()
221 }
222
223 pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], config: ClientConfig) -> Result<Client, Error> {
227 let connector = client_cert_connector(cert_pem, key_pem)?;
228
229 Self::builder().config(config).connector(connector).build()
230 }
231
232 pub fn token<S, T, R>(pkcs8_pem: R, key_id: S, team_id: T, config: ClientConfig) -> Result<Client, Error>
237 where
238 S: Into<String>,
239 T: Into<String>,
240 R: Read,
241 {
242 let signature_ttl = Duration::from_secs(60 * 55);
243 let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?;
244
245 Self::builder().config(config).signer(signer).build()
246 }
247
248 #[cfg_attr(feature = "tracing", ::tracing::instrument)]
252 pub async fn send<T: PayloadLike>(&self, payload: T) -> Result<Response, Error> {
253 let request = self.build_request(payload)?;
254 let requesting = self.http_client.request(request);
255
256 let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else {
257 return Err(Error::RequestTimeout(self.options.request_timeout.as_secs()));
258 };
259
260 let response = response_result?;
261
262 let header_map = response.headers();
263
264 fn get_header_key_opt(header_map: &http::HeaderMap, key: &'static str) -> Option<String> {
265 header_map
266 .get(key)
267 .and_then(|s| s.to_str().ok())
268 .map(String::from)
269 }
270
271 let apns_id = get_header_key_opt(header_map, "apns-id");
272
273 let apns_unique_id = if matches!(self.options.endpoint, Endpoint::Sandbox) {
274 get_header_key_opt(header_map, "apns-unique-id")
275 } else {
276 None
277 };
278
279 match response.status() {
280 StatusCode::OK => Ok(Response {
281 apns_id,
282 apns_unique_id,
283 error: None,
284 code: response.status().as_u16(),
285 }),
286 status => {
287 let body = response.into_body().collect().await?;
288
289 Err(ResponseError(Response {
290 apns_id,
291 apns_unique_id,
292 error: serde_json::from_slice(&body.to_bytes()).ok(),
293 code: status.as_u16(),
294 }))
295 }
296 }
297 }
298
299 fn build_request<T: PayloadLike>(
300 &self,
301 payload: T,
302 ) -> Result<hyper::Request<BoxBody<Bytes, Infallible>>, Error> {
303 let path = format!(
304 "https://{}/3/device/{}",
305 self.options.endpoint,
306 payload.get_device_token()
307 );
308
309 let mut builder = hyper::Request::builder()
310 .uri(&path)
311 .method("POST")
312 .header(CONTENT_TYPE, "application/json");
313
314 let options = payload.get_options();
315 if let Some(ref apns_priority) = options.apns_priority {
316 builder = builder.header("apns-priority", apns_priority.to_string().as_bytes());
317 }
318 if let Some(apns_id) = options.apns_id {
319 builder = builder.header("apns-id", apns_id.as_bytes());
320 }
321 if let Some(apns_push_type) = options.apns_push_type.as_ref() {
322 builder = builder.header("apns-push-type", apns_push_type.to_string().as_bytes());
323 }
324 if let Some(ref apns_expiration) = options.apns_expiration {
325 builder = builder.header("apns-expiration", apns_expiration.to_string().as_bytes());
326 }
327 if let Some(ref apns_collapse_id) = options.apns_collapse_id {
328 builder = builder.header("apns-collapse-id", apns_collapse_id.value.as_bytes());
329 }
330 if let Some(apns_topic) = options.apns_topic {
331 builder = builder.header("apns-topic", apns_topic.as_bytes());
332 }
333 if let Some(ref signer) = self.options.signer {
334 let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?;
335
336 builder = builder.header(AUTHORIZATION, auth.as_bytes());
337 }
338
339 let payload_json = payload.to_json_string()?;
340 builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes());
341
342 let request_body = Full::from(payload_json.into_bytes()).boxed();
343 builder.body(request_body).map_err(Error::BuildRequestError)
344 }
345}
346
347#[cfg(feature = "aws-lc-rs")]
348fn default_crypto_provider() -> Arc<rustls::crypto::CryptoProvider> {
349 Arc::new(rustls::crypto::aws_lc_rs::default_provider())
350}
351
352#[cfg(all(not(feature = "aws-lc-rs"), feature = "openssl"))]
353fn default_crypto_provider() -> Arc<rustls::crypto::CryptoProvider> {
354 Arc::new(rustls_openssl::default_provider())
355}
356
357#[cfg(all(not(feature = "aws-lc-rs"), not(feature = "openssl")))]
358fn default_crypto_provider() -> Arc<rustls::crypto::CryptoProvider> {
359 panic!("No provider set");
360}
361
362fn client_config_builder()
365-> Result<rustls::ConfigBuilder<rustls::ClientConfig, rustls::client::WantsClientCert>, Error> {
366 use hyper_rustls::ConfigBuilderExt as _;
367 let provider = rustls::crypto::CryptoProvider::get_default()
370 .cloned()
371 .unwrap_or_else(default_crypto_provider);
372
373 Ok(rustls::client::ClientConfig::builder_with_provider(provider)
374 .with_safe_default_protocol_versions()?
375 .try_with_platform_verifier()?)
376}
377
378fn default_connector() -> Result<HyperConnector, Error> {
380 let config = client_config_builder()?.with_no_client_auth();
381
382 Ok(HttpsConnectorBuilder::new()
383 .with_tls_config(config)
384 .https_only()
385 .enable_http2()
386 .build())
387}
388
389fn client_cert_connector(cert_pem: &[u8], key_pem: &[u8]) -> Result<HyperConnector, Error> {
390 use rustls_pki_types::{CertificateDer, PrivatePkcs8KeyDer, pem::PemObject};
391
392 let cert_error_fn = |e: rustls_pki_types::pem::Error| io::Error::new(io::ErrorKind::InvalidData, e);
393
394 let key = PrivatePkcs8KeyDer::from_pem_slice(key_pem).map_err(cert_error_fn)?;
395
396 let cert_chain = CertificateDer::pem_slice_iter(cert_pem)
397 .collect::<Result<Vec<_>, _>>()
398 .map_err(cert_error_fn)?;
399
400 let config = client_config_builder()?.with_client_auth_cert(cert_chain, key.into())?;
401
402 Ok(HttpsConnectorBuilder::new()
403 .with_tls_config(config)
404 .https_only()
405 .enable_http2()
406 .build())
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::PushType;
413 use crate::request::notification::DefaultNotificationBuilder;
414 use crate::request::notification::NotificationBuilder;
415 use crate::request::notification::{CollapseId, NotificationOptions, Priority};
416 use crate::signer::Signer;
417 use http::header::{AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE};
418 use hyper::Method;
419
420 const PRIVATE_KEY: &str = "-----BEGIN PRIVATE KEY-----
421MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg8g/n6j9roKvnUkwu
422lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3
423jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
424-----END PRIVATE KEY-----";
425
426 #[test]
427 fn test_production_request_uri() {
428 let builder = DefaultNotificationBuilder::new();
429 let payload = builder.build("a_test_id", Default::default());
430 let client = Client::builder().build().unwrap();
431 let request = client.build_request(payload).unwrap();
432 let uri = format!("{}", request.uri());
433
434 assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri);
435 }
436
437 #[test]
438 fn test_sandbox_request_uri() {
439 let builder = DefaultNotificationBuilder::new();
440 let payload = builder.build("a_test_id", Default::default());
441 let client = Client::builder()
442 .config(ClientConfig {
443 endpoint: Endpoint::Sandbox,
444 ..Default::default()
445 })
446 .build()
447 .unwrap();
448 let request = client.build_request(payload).unwrap();
449 let uri = format!("{}", request.uri());
450
451 assert_eq!("https://api.sandbox.push.apple.com/3/device/a_test_id", &uri);
452 }
453
454 #[test]
455 fn test_request_method() {
456 let builder = DefaultNotificationBuilder::new();
457 let payload = builder.build("a_test_id", Default::default());
458 let client = Client::builder().build().unwrap();
459 let request = client.build_request(payload).unwrap();
460
461 assert_eq!(&Method::POST, request.method());
462 }
463
464 #[test]
465 fn test_request_invalid() {
466 let builder = DefaultNotificationBuilder::new();
467 let payload = builder.build("\r\n", Default::default());
468 let client = Client::builder().build().unwrap();
469 let request = client.build_request(payload);
470
471 assert!(matches!(request, Err(Error::BuildRequestError(_))));
472 }
473
474 #[test]
475 fn test_request_content_type() {
476 let builder = DefaultNotificationBuilder::new();
477 let payload = builder.build("a_test_id", Default::default());
478 let client = Client::builder().build().unwrap();
479 let request = client.build_request(payload).unwrap();
480
481 assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
482 }
483
484 #[test]
485 fn test_request_content_length() {
486 let builder = DefaultNotificationBuilder::new();
487 let payload = builder.build("a_test_id", Default::default());
488 let client = Client::builder().build().unwrap();
489 let request = client.build_request(payload.clone()).unwrap();
490 let payload_json = payload.to_json_string().unwrap();
491 let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
492
493 assert_eq!(&format!("{}", payload_json.len()), content_length);
494 }
495
496 #[test]
497 fn test_request_authorization_with_no_signer() {
498 let builder = DefaultNotificationBuilder::new();
499 let payload = builder.build("a_test_id", Default::default());
500 let client = Client::builder().build().unwrap();
501 let request = client.build_request(payload).unwrap();
502
503 assert_eq!(None, request.headers().get(AUTHORIZATION));
504 }
505
506 #[test]
507 fn test_request_authorization_with_a_signer() {
508 let signer = Signer::new(
509 PRIVATE_KEY.as_bytes(),
510 "89AFRD1X22",
511 "ASDFQWERTY",
512 Duration::from_secs(100),
513 )
514 .unwrap();
515
516 let builder = DefaultNotificationBuilder::new();
517 let payload = builder.build("a_test_id", Default::default());
518 let client = Client::builder().signer(signer).build().unwrap();
519 let request = client.build_request(payload).unwrap();
520
521 assert_ne!(None, request.headers().get(AUTHORIZATION));
522 }
523
524 #[test]
525 fn test_request_with_background_type() {
526 let builder = DefaultNotificationBuilder::new();
527 let options = NotificationOptions {
528 apns_push_type: Some(PushType::Background),
529 ..Default::default()
530 };
531 let payload = builder.build("a_test_id", options);
532 let client = Client::builder().build().unwrap();
533 let request = client.build_request(payload).unwrap();
534 let apns_push_type = request.headers().get("apns-push-type").unwrap();
535
536 assert_eq!("background", apns_push_type);
537 }
538
539 #[test]
540 fn test_request_with_default_priority() {
541 let builder = DefaultNotificationBuilder::new();
542 let payload = builder.build("a_test_id", Default::default());
543 let client = Client::builder().build().unwrap();
544 let request = client.build_request(payload).unwrap();
545 let apns_priority = request.headers().get("apns-priority");
546
547 assert_eq!(None, apns_priority);
548 }
549
550 #[test]
551 fn test_request_with_normal_priority() {
552 let builder = DefaultNotificationBuilder::new();
553
554 let payload = builder.build(
555 "a_test_id",
556 NotificationOptions {
557 apns_priority: Some(Priority::Normal),
558 ..Default::default()
559 },
560 );
561
562 let client = Client::builder().build().unwrap();
563 let request = client.build_request(payload).unwrap();
564 let apns_priority = request.headers().get("apns-priority").unwrap();
565
566 assert_eq!("5", apns_priority);
567 }
568
569 #[test]
570 fn test_request_with_high_priority() {
571 let builder = DefaultNotificationBuilder::new();
572
573 let payload = builder.build(
574 "a_test_id",
575 NotificationOptions {
576 apns_priority: Some(Priority::High),
577 ..Default::default()
578 },
579 );
580
581 let client = Client::builder().build().unwrap();
582 let request = client.build_request(payload).unwrap();
583 let apns_priority = request.headers().get("apns-priority").unwrap();
584
585 assert_eq!("10", apns_priority);
586 }
587
588 #[test]
589 fn test_request_with_default_apns_id() {
590 let builder = DefaultNotificationBuilder::new();
591
592 let payload = builder.build("a_test_id", Default::default());
593
594 let client = Client::builder().build().unwrap();
595 let request = client.build_request(payload).unwrap();
596 let apns_id = request.headers().get("apns-id");
597
598 assert_eq!(None, apns_id);
599 }
600
601 #[test]
602 fn test_request_with_an_apns_id() {
603 let builder = DefaultNotificationBuilder::new();
604
605 let payload = builder.build(
606 "a_test_id",
607 NotificationOptions {
608 apns_id: Some("a-test-apns-id"),
609 ..Default::default()
610 },
611 );
612
613 let client = Client::builder().build().unwrap();
614 let request = client.build_request(payload).unwrap();
615 let apns_id = request.headers().get("apns-id").unwrap();
616
617 assert_eq!("a-test-apns-id", apns_id);
618 }
619
620 #[test]
621 fn test_request_with_default_apns_expiration() {
622 let builder = DefaultNotificationBuilder::new();
623
624 let payload = builder.build("a_test_id", Default::default());
625
626 let client = Client::builder().build().unwrap();
627 let request = client.build_request(payload).unwrap();
628 let apns_expiration = request.headers().get("apns-expiration");
629
630 assert_eq!(None, apns_expiration);
631 }
632
633 #[test]
634 fn test_request_with_an_apns_expiration() {
635 let builder = DefaultNotificationBuilder::new();
636
637 let payload = builder.build(
638 "a_test_id",
639 NotificationOptions {
640 apns_expiration: Some(420),
641 ..Default::default()
642 },
643 );
644
645 let client = Client::builder().build().unwrap();
646 let request = client.build_request(payload).unwrap();
647 let apns_expiration = request.headers().get("apns-expiration").unwrap();
648
649 assert_eq!("420", apns_expiration);
650 }
651
652 #[test]
653 fn test_request_with_default_apns_collapse_id() {
654 let builder = DefaultNotificationBuilder::new();
655
656 let payload = builder.build("a_test_id", Default::default());
657
658 let client = Client::builder().build().unwrap();
659 let request = client.build_request(payload).unwrap();
660 let apns_collapse_id = request.headers().get("apns-collapse-id");
661
662 assert_eq!(None, apns_collapse_id);
663 }
664
665 #[test]
666 fn test_request_with_an_apns_collapse_id() {
667 let builder = DefaultNotificationBuilder::new();
668
669 let payload = builder.build(
670 "a_test_id",
671 NotificationOptions {
672 apns_collapse_id: Some(CollapseId::new("a_collapse_id").unwrap()),
673 ..Default::default()
674 },
675 );
676
677 let client = Client::builder().build().unwrap();
678 let request = client.build_request(payload).unwrap();
679 let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();
680
681 assert_eq!("a_collapse_id", apns_collapse_id);
682 }
683
684 #[test]
685 fn test_request_with_default_apns_topic() {
686 let builder = DefaultNotificationBuilder::new();
687
688 let payload = builder.build("a_test_id", Default::default());
689
690 let client = Client::builder().build().unwrap();
691 let request = client.build_request(payload).unwrap();
692 let apns_topic = request.headers().get("apns-topic");
693
694 assert_eq!(None, apns_topic);
695 }
696
697 #[test]
698 fn test_request_with_an_apns_topic() {
699 let builder = DefaultNotificationBuilder::new();
700
701 let payload = builder.build(
702 "a_test_id",
703 NotificationOptions {
704 apns_topic: Some("a_topic"),
705 ..Default::default()
706 },
707 );
708
709 let client = Client::builder().build().unwrap();
710 let request = client.build_request(payload).unwrap();
711 let apns_topic = request.headers().get("apns-topic").unwrap();
712
713 assert_eq!("a_topic", apns_topic);
714 }
715
716 #[tokio::test]
717 async fn test_request_body() {
718 let builder = DefaultNotificationBuilder::new();
719 let payload = builder.build("a_test_id", Default::default());
720 let client = Client::builder().build().unwrap();
721 let request = client.build_request(payload.clone()).unwrap();
722
723 let body = request.into_body().collect().await.unwrap().to_bytes();
724 let body_str = String::from_utf8(body.to_vec()).unwrap();
725
726 assert_eq!(payload.to_json_string().unwrap(), body_str,);
727 }
728
729 #[tokio::test]
730 async fn test_cert_parts() -> Result<(), Error> {
734 let key: Vec<u8> = include_str!("../test_cert/test.key").bytes().collect();
735 let cert: Vec<u8> = include_str!("../test_cert/test.crt").bytes().collect();
736
737 let c = Client::certificate_parts(&cert, &key, ClientConfig::default())?;
738 assert!(c.options.signer.is_none());
739 Ok(())
740 }
741}