1use std::{
2 fmt,
3 io::Write,
4 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
5 sync::{Arc, OnceLock},
6 time::Duration,
7};
8
9use flate2::{Compression, write::GzEncoder};
10use hickory_resolver::{
11 ConnectionProvider, Resolver,
12 config::{ConnectionConfig, NameServerConfig, ResolverConfig},
13 net::runtime::TokioRuntimeProvider,
14};
15use miette::{IntoDiagnostic, Result, WrapErr};
16use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair};
17use reqwest::Url;
18use time::{Duration as TimeDuration, OffsetDateTime};
19use tokio::sync::RwLock;
20use tracing::debug;
21
22use crate::Redacted;
23
24pub const DEFAULT_CANOPY_URL: &str = "https://meta.tamanu.app";
25
26pub const TAILSCALE_URL: &str = "https://canopy.tail53aef.ts.net";
31
32const TAILSCALE_HOST: &str = "canopy.tail53aef.ts.net";
34
35const CANOPY_HARDCODED_V4: Ipv4Addr = Ipv4Addr::new(100, 99, 98, 97);
38const CANOPY_HARDCODED_V6: Ipv6Addr =
39 Ipv6Addr::new(0xfd7a, 0x115c, 0xa1e0, 0, 0, 0, 0x9337, 0xfb52);
40
41const CERT_VALIDITY_DAYS: i64 = 6;
46
47pub const CERT_RENEW_AFTER: Duration = Duration::from_secs(5 * 24 * 60 * 60);
52
53const TAILSCALE_PROBE_TIMEOUT: Duration = Duration::from_secs(5);
55
56pub type ClientBuilderFactory = Arc<dyn Fn() -> reqwest::ClientBuilder + Send + Sync>;
64
65#[derive(Debug, Clone)]
72pub struct CanopyHttpError {
73 pub status: reqwest::StatusCode,
75 pub path: String,
77 pub body: String,
79}
80
81impl fmt::Display for CanopyHttpError {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(
84 f,
85 "canopy {} returned {}: {}",
86 self.path, self.status, self.body
87 )
88 }
89}
90
91impl std::error::Error for CanopyHttpError {}
92impl miette::Diagnostic for CanopyHttpError {}
93
94fn user_agent() -> &'static str {
102 static UA: OnceLock<String> = OnceLock::new();
103 UA.get_or_init(|| {
104 let os = sysinfo::System::long_os_version()
105 .or_else(sysinfo::System::name)
106 .unwrap_or_else(|| std::env::consts::OS.to_owned());
107 format!(
108 "bestool-canopy/{} ({os}; {})",
109 env!("CARGO_PKG_VERSION"),
110 sysinfo::System::cpu_arch(),
111 )
112 })
113}
114
115pub async fn tailscale_client(make_builder: &ClientBuilderFactory) -> Option<reqwest::Client> {
124 let tailscale_url = TAILSCALE_URL
125 .parse()
126 .expect("default tailscale URL is valid");
127 probe_tailscale(&tailscale_url, make_builder).await
128}
129
130pub struct CanopyClient {
141 base_url: Url,
144 tailscale_url: Url,
147 device_key: Option<Redacted<String>>,
148 make_builder: ClientBuilderFactory,
150 state: RwLock<State>,
151}
152
153enum State {
154 Tailscale(reqwest::Client),
155 Mtls(reqwest::Client),
156}
157
158impl State {
159 fn is_tailscale(&self) -> bool {
160 matches!(self, State::Tailscale(_))
161 }
162
163 fn http(&self) -> reqwest::Client {
164 match self {
165 State::Tailscale(http) | State::Mtls(http) => http.clone(),
166 }
167 }
168}
169
170impl fmt::Debug for CanopyClient {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("CanopyClient").finish_non_exhaustive()
173 }
174}
175
176impl CanopyClient {
177 pub async fn new(
188 device_key_pem: Option<&str>,
189 make_builder: impl Fn() -> reqwest::ClientBuilder + Send + Sync + 'static,
190 ) -> Result<Option<Self>> {
191 Self::with_urls(
192 DEFAULT_CANOPY_URL
193 .parse()
194 .expect("default canopy URL is valid"),
195 TAILSCALE_URL
196 .parse()
197 .expect("default tailscale URL is valid"),
198 device_key_pem,
199 make_builder,
200 )
201 .await
202 }
203
204 pub async fn with_urls(
211 base_url: Url,
212 tailscale_url: Url,
213 device_key_pem: Option<&str>,
214 make_builder: impl Fn() -> reqwest::ClientBuilder + Send + Sync + 'static,
215 ) -> Result<Option<Self>> {
216 let device_key = device_key_pem.map(|s| Redacted(s.to_owned()));
217 let make_builder: ClientBuilderFactory = Arc::new(make_builder);
218
219 if let Some(http) = probe_tailscale(&tailscale_url, &make_builder).await {
220 debug!("canopy: tailscale endpoint reachable, preferring it");
221 return Ok(Some(Self {
222 base_url,
223 tailscale_url,
224 device_key,
225 make_builder,
226 state: RwLock::new(State::Tailscale(http)),
227 }));
228 }
229
230 if let Some(pem) = device_key_pem {
231 debug!("canopy: tailscale unreachable, falling back to mTLS");
232 let http = build_mtls_http(&make_builder, pem)?;
233 return Ok(Some(Self {
234 base_url,
235 tailscale_url,
236 device_key,
237 make_builder,
238 state: RwLock::new(State::Mtls(http)),
239 }));
240 }
241
242 Ok(None)
243 }
244
245 pub async fn is_tailscale(&self) -> bool {
247 self.state.read().await.is_tailscale()
248 }
249
250 pub async fn refresh(&self) -> Result<()> {
254 if let Some(http) = probe_tailscale(&self.tailscale_url, &self.make_builder).await {
255 let mut state = self.state.write().await;
256 if !state.is_tailscale() {
257 debug!("canopy refresh: switching to tailscale path");
258 }
259 *state = State::Tailscale(http);
260 return Ok(());
261 }
262
263 if let Some(pem) = &self.device_key {
264 let http = build_mtls_http(&self.make_builder, &pem.0)?;
265 let mut state = self.state.write().await;
266 if state.is_tailscale() {
267 debug!("canopy refresh: tailscale dropped, falling back to mTLS");
268 }
269 *state = State::Mtls(http);
270 return Ok(());
271 }
272
273 debug!("canopy refresh: no auth path available, keeping current state");
274 Ok(())
275 }
276
277 pub async fn renew(&self) -> Result<()> {
283 let Some(pem) = &self.device_key else {
284 return Ok(());
285 };
286 let mut state = self.state.write().await;
287 if state.is_tailscale() {
288 return Ok(());
289 }
290 *state = State::Mtls(build_mtls_http(&self.make_builder, &pem.0)?);
291 Ok(())
292 }
293
294 async fn endpoint_url(&self, path: &str) -> Result<(reqwest::Client, Url)> {
299 let state = self.state.read().await;
300 let url = match &*state {
301 State::Tailscale(_) => self
302 .tailscale_url
303 .join(&format!("/public{path}"))
304 .into_diagnostic()
305 .wrap_err_with(|| format!("building tailscale /public{path} URL"))?,
306 State::Mtls(_) => self
307 .base_url
308 .join(path)
309 .into_diagnostic()
310 .wrap_err_with(|| format!("building {path} URL"))?,
311 };
312 Ok((state.http(), url))
313 }
314
315 async fn send_call<B: serde::Serialize + ?Sized>(
322 &self,
323 method: reqwest::Method,
324 path: &str,
325 body: Option<&B>,
326 ) -> Result<reqwest::Response> {
327 let (http, url) = self.endpoint_url(path).await?;
328 debug!(%url, %method, "canopy request");
329 let mut req = http.request(method, url);
330 if let Some(body) = body {
331 let raw = serde_json::to_vec(body)
332 .into_diagnostic()
333 .wrap_err_with(|| format!("serialising canopy {path} body"))?;
334 let compressed = gzip_bytes(&raw)
335 .into_diagnostic()
336 .wrap_err_with(|| format!("gzipping canopy {path} body"))?;
337 req = req
338 .header(reqwest::header::CONTENT_TYPE, "application/json")
339 .header(reqwest::header::CONTENT_ENCODING, "gzip")
340 .body(compressed);
341 }
342
343 let response = req
344 .send()
345 .await
346 .into_diagnostic()
347 .wrap_err_with(|| format!("calling canopy {path}"))?;
348
349 let status = response.status();
350 if !status.is_success() {
351 let body = response.text().await.unwrap_or_default();
352 return Err(miette::Report::new(CanopyHttpError {
353 status,
354 path: path.to_owned(),
355 body,
356 }));
357 }
358 Ok(response)
359 }
360
361 pub(crate) async fn call_json<B, R>(
363 &self,
364 method: reqwest::Method,
365 path: &str,
366 body: Option<&B>,
367 ) -> Result<R>
368 where
369 B: serde::Serialize + ?Sized,
370 R: serde::de::DeserializeOwned,
371 {
372 let response = self.send_call(method, path, body).await?;
373 response
374 .json::<R>()
375 .await
376 .into_diagnostic()
377 .wrap_err_with(|| format!("parsing canopy {path} response"))
378 }
379
380 pub(crate) async fn call_empty<B: serde::Serialize + ?Sized>(
382 &self,
383 method: reqwest::Method,
384 path: &str,
385 body: Option<&B>,
386 ) -> Result<()> {
387 self.send_call(method, path, body).await.map(drop)
388 }
389
390 #[cfg(feature = "raw-requests")]
396 pub async fn get(&self, tailscale_path: &str, mtls_path: &str) -> Result<reqwest::Response> {
397 let (http, url) = {
398 let state = self.state.read().await;
399 let url = match &*state {
400 State::Tailscale(_) => self
401 .tailscale_url
402 .join(tailscale_path)
403 .into_diagnostic()
404 .wrap_err("building tailscale GET URL")?,
405 State::Mtls(_) => self
406 .base_url
407 .join(mtls_path)
408 .into_diagnostic()
409 .wrap_err("building mTLS GET URL")?,
410 };
411 (state.http(), url)
412 };
413
414 debug!(%url, "GET via canopy");
415 http.get(url)
416 .send()
417 .await
418 .into_diagnostic()
419 .wrap_err("GET via canopy")
420 }
421
422 #[cfg(feature = "raw-requests")]
428 pub async fn request(
429 &self,
430 method: reqwest::Method,
431 path: &str,
432 ) -> Result<reqwest::RequestBuilder> {
433 let (http, url) = self.endpoint_url(path).await?;
434 debug!(%url, %method, "arbitrary canopy request");
435 Ok(http.request(method, url))
436 }
437
438 #[cfg(feature = "raw-requests")]
445 pub async fn request_json<Res: serde::de::DeserializeOwned>(
446 &self,
447 method: reqwest::Method,
448 path: &str,
449 body: Option<&(impl serde::Serialize + ?Sized)>,
450 ) -> Result<Res> {
451 self.call_json(method, path, body).await
452 }
453}
454
455async fn probe_tailscale(
472 tailscale_url: &Url,
473 make_builder: &ClientBuilderFactory,
474) -> Option<reqwest::Client> {
475 let host = tailscale_url.host_str()?;
476
477 if host != TAILSCALE_HOST {
480 return try_probe(tailscale_url, host, &[], make_builder).await;
481 }
482
483 let dns_addrs: Vec<SocketAddr> = tailscale_resolver()
484 .lookup_ip("canopy")
485 .await
486 .ok()
487 .map(|addrs| addrs.iter().map(|ip| SocketAddr::new(ip, 443)).collect())
488 .unwrap_or_default();
489 if !dns_addrs.is_empty()
490 && let Some(client) = try_probe(tailscale_url, host, &dns_addrs, make_builder).await
491 {
492 return Some(client);
493 }
494
495 let hardcoded = [
496 SocketAddr::new(IpAddr::V4(CANOPY_HARDCODED_V4), 443),
497 SocketAddr::new(IpAddr::V6(CANOPY_HARDCODED_V6), 443),
498 ];
499 debug!(
500 ?hardcoded,
501 "canopy tailscale DNS lookup empty or probe failed, trying hardcoded IPs"
502 );
503 try_probe(tailscale_url, host, &hardcoded, make_builder).await
504}
505
506async fn try_probe(
509 tailscale_url: &Url,
510 host: &str,
511 addrs: &[SocketAddr],
512 make_builder: &ClientBuilderFactory,
513) -> Option<reqwest::Client> {
514 let mut builder = make_builder()
515 .user_agent(user_agent())
516 .timeout(TAILSCALE_PROBE_TIMEOUT);
517 if !addrs.is_empty() {
518 builder = builder.resolve_to_addrs(host, addrs);
519 }
520 let client = builder.build().ok()?;
521
522 let url = tailscale_url.join("/public/servers").ok()?;
523 match client.get(url).send().await {
524 Ok(resp) if resp.status().is_success() => Some(client),
525 Ok(resp) => {
526 debug!(status = %resp.status(), ?addrs, "canopy tailscale probe: unexpected status");
527 None
528 }
529 Err(err) => {
530 debug!(?addrs, "canopy tailscale probe failed: {err}");
531 None
532 }
533 }
534}
535
536fn tailscale_resolver() -> Resolver<impl ConnectionProvider> {
537 Resolver::builder_with_config(
538 ResolverConfig::from_parts(
539 None,
540 vec!["tail53aef.ts.net.".parse().unwrap()],
541 vec![NameServerConfig::new(
542 "100.100.100.100".parse().unwrap(),
543 true,
544 vec![ConnectionConfig::udp()],
545 )],
546 ),
547 TokioRuntimeProvider::default(),
548 )
549 .build()
550 .expect("tailscale resolver config is hardcoded and cannot fail to build")
551}
552
553fn gzip_bytes(bytes: &[u8]) -> std::io::Result<Vec<u8>> {
554 let mut encoder = GzEncoder::new(Vec::with_capacity(bytes.len() / 2), Compression::default());
555 encoder.write_all(bytes)?;
556 encoder.finish()
557}
558
559pub fn device_identity(device_key_pem: &str) -> Result<reqwest::Identity> {
568 let key_pair = KeyPair::from_pem(device_key_pem)
569 .into_diagnostic()
570 .wrap_err("parsing device key PEM")?;
571
572 let mut params = CertificateParams::new(vec!["device.local".into()])
573 .into_diagnostic()
574 .wrap_err("building certificate params")?;
575 params.distinguished_name = DistinguishedName::new();
576 params
577 .distinguished_name
578 .push(DnType::CommonName, "device.local");
579
580 let now = OffsetDateTime::now_utc();
581 params.not_before = now - TimeDuration::minutes(1);
582 params.not_after = now + TimeDuration::days(CERT_VALIDITY_DAYS);
583
584 let cert = params
585 .self_signed(&key_pair)
586 .into_diagnostic()
587 .wrap_err("self-signing certificate")?;
588
589 let mut combined = cert.pem();
590 combined.push('\n');
591 combined.push_str(&key_pair.serialize_pem());
592
593 reqwest::Identity::from_pem(combined.as_bytes())
594 .into_diagnostic()
595 .wrap_err("building reqwest TLS identity")
596}
597
598fn build_mtls_http(
599 make_builder: &ClientBuilderFactory,
600 device_key_pem: &str,
601) -> Result<reqwest::Client> {
602 let identity = device_identity(device_key_pem)?;
603
604 make_builder()
605 .user_agent(user_agent())
606 .identity(identity)
607 .use_rustls_tls()
608 .timeout(Duration::from_secs(30))
609 .build()
610 .into_diagnostic()
611 .wrap_err("building canopy HTTP client")
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 const TEST_DEVICE_KEY: &str = "\
619-----BEGIN PRIVATE KEY-----
620MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgVvhzsYiidp38GYn1
621KxD5Wipc/h8lglVsy1UFZq/SZbGhRANCAAT2EsEq7xjeWVnim9XwdYXga/LBbppm
622fXLgamTYOa/w9n/Ta64fiYWmN54kEd0DgnflJDLtID321Zz6xswvK/VN
623-----END PRIVATE KEY-----";
624
625 fn test_factory() -> ClientBuilderFactory {
626 Arc::new(reqwest::Client::builder)
627 }
628
629 #[test]
630 fn build_mtls_http_from_p256_key() {
631 let result = build_mtls_http(&test_factory(), TEST_DEVICE_KEY);
633 assert!(result.is_ok(), "{:?}", result.err());
634 }
635
636 #[test]
637 fn build_mtls_http_fails_on_garbage_key() {
638 assert!(build_mtls_http(&test_factory(), "not a real PEM").is_err());
639 }
640
641 #[tokio::test]
642 async fn renew_with_mtls_state_swaps_in_fresh_client() {
643 let http = build_mtls_http(&test_factory(), TEST_DEVICE_KEY).unwrap();
645 let client = CanopyClient {
646 base_url: DEFAULT_CANOPY_URL.parse().unwrap(),
647 tailscale_url: TAILSCALE_URL.parse().unwrap(),
648 device_key: Some(Redacted(TEST_DEVICE_KEY.to_owned())),
649 make_builder: test_factory(),
650 state: RwLock::new(State::Mtls(http)),
651 };
652 client.renew().await.expect("renew should succeed");
653 assert!(!client.is_tailscale().await);
654 }
655
656 #[tokio::test]
657 async fn renew_is_noop_in_tailscale_mode() {
658 let http = reqwest::Client::new();
660 let client = CanopyClient {
661 base_url: DEFAULT_CANOPY_URL.parse().unwrap(),
662 tailscale_url: TAILSCALE_URL.parse().unwrap(),
663 device_key: None,
664 make_builder: test_factory(),
665 state: RwLock::new(State::Tailscale(http)),
666 };
667 client.renew().await.expect("renew should be a no-op");
668 assert!(client.is_tailscale().await);
669 }
670
671 fn mtls_client_against(base: &str) -> CanopyClient {
672 let http = build_mtls_http(&test_factory(), TEST_DEVICE_KEY).unwrap();
673 CanopyClient {
674 base_url: base.parse().unwrap(),
675 tailscale_url: TAILSCALE_URL.parse().unwrap(),
676 device_key: Some(Redacted(TEST_DEVICE_KEY.to_owned())),
677 make_builder: test_factory(),
678 state: RwLock::new(State::Mtls(http)),
679 }
680 }
681
682 struct Captured {
683 request_line: String,
684 headers: String,
685 body: Vec<u8>,
686 }
687
688 fn serve_once(response: &'static str) -> (String, std::thread::JoinHandle<Captured>) {
691 use std::io::{Read, Write};
692 use std::net::TcpListener;
693
694 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
695 let base = format!("http://{}", listener.local_addr().unwrap());
696 let handle = std::thread::spawn(move || {
697 let (mut stream, _) = listener.accept().unwrap();
698 let mut buf = Vec::new();
699 let mut chunk = [0u8; 1024];
700 let header_end = loop {
701 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
702 break pos + 4;
703 }
704 let n = stream.read(&mut chunk).unwrap();
705 if n == 0 {
706 panic!("connection closed before headers were complete");
707 }
708 buf.extend_from_slice(&chunk[..n]);
709 };
710
711 let head = String::from_utf8_lossy(&buf[..header_end]).into_owned();
712 let content_length = head
713 .lines()
714 .find_map(|line| {
715 let (name, value) = line.split_once(':')?;
716 name.trim()
717 .eq_ignore_ascii_case("content-length")
718 .then(|| value.trim().parse::<usize>().ok())
719 .flatten()
720 })
721 .unwrap_or(0);
722
723 let mut body = buf[header_end..].to_vec();
724 while body.len() < content_length {
725 let n = stream.read(&mut chunk).unwrap();
726 if n == 0 {
727 break;
728 }
729 body.extend_from_slice(&chunk[..n]);
730 }
731
732 stream.write_all(response.as_bytes()).unwrap();
733 stream.flush().unwrap();
734
735 let mut lines = head.lines();
736 let request_line = lines.next().unwrap_or_default().to_owned();
737 let headers = lines.collect::<Vec<_>>().join("\n");
738 Captured {
739 request_line,
740 headers,
741 body,
742 }
743 });
744 (base, handle)
745 }
746
747 #[derive(Debug, serde::Deserialize, PartialEq)]
748 struct Echo {
749 ok: bool,
750 who: String,
751 }
752
753 #[tokio::test]
754 async fn call_json_gzips_body_sets_user_agent_and_parses_response() {
755 let (base, handle) = serve_once(
756 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 26\r\n\r\n{\"ok\":true,\"who\":\"device\"}",
757 );
758 let client = mtls_client_against(&base);
759
760 let payload = serde_json::json!({ "hello": "world" });
761 let got: Echo = client
762 .call_json(reqwest::Method::POST, "/thing", Some(&payload))
763 .await
764 .expect("call_json should succeed");
765
766 assert_eq!(
767 got,
768 Echo {
769 ok: true,
770 who: "device".into()
771 }
772 );
773
774 let captured = handle.join().unwrap();
775 assert!(
776 captured.request_line.starts_with("POST /thing "),
777 "unexpected request line: {}",
778 captured.request_line
779 );
780 let headers = captured.headers.to_ascii_lowercase();
781 assert!(
782 headers.contains("user-agent: bestool-canopy/"),
783 "missing canopy user-agent in:\n{}",
784 captured.headers
785 );
786 assert!(
787 headers.contains("content-encoding: gzip"),
788 "body should be gzipped:\n{}",
789 captured.headers
790 );
791 use flate2::read::GzDecoder;
793 use std::io::Read as _;
794 let mut decoder = GzDecoder::new(&captured.body[..]);
795 let mut raw = Vec::new();
796 decoder
797 .read_to_end(&mut raw)
798 .expect("body should be valid gzip");
799 let sent: serde_json::Value = serde_json::from_slice(&raw).unwrap();
800 assert_eq!(sent, payload);
801 }
802
803 #[tokio::test]
804 async fn call_json_errors_on_non_success_with_body() {
805 let (base, handle) =
806 serve_once("HTTP/1.1 418 I'm a teapot\r\nContent-Length: 14\r\n\r\nno coffee here");
807 let client = mtls_client_against(&base);
808
809 let err = client
810 .call_json::<(), serde_json::Value>(reqwest::Method::GET, "/brew", None::<&()>)
811 .await
812 .expect_err("non-2xx should error");
813 let msg = err.to_string();
814 assert!(msg.contains("/brew"), "expected path in error: {msg}");
815 assert!(msg.contains("418"), "expected status in error: {msg}");
816 assert!(
817 msg.contains("no coffee here"),
818 "expected body text in error: {msg}"
819 );
820
821 handle.join().unwrap();
822 }
823
824 #[test]
825 fn user_agent_identifies_the_crate_with_os_comment() {
826 let ua = user_agent();
827 assert!(
828 ua.starts_with(concat!("bestool-canopy/", env!("CARGO_PKG_VERSION"), " ")),
829 "unexpected user-agent: {ua}"
830 );
831 assert!(ua.contains('('), "expected OS comment in: {ua}");
832 assert!(ua.ends_with(')'), "expected OS comment in: {ua}");
833 assert!(
834 ua.contains(sysinfo::System::cpu_arch().as_str()),
835 "expected arch in: {ua}"
836 );
837 }
838
839 #[test]
840 fn gzip_bytes_roundtrips() {
841 use flate2::read::GzDecoder;
842 use std::io::Read;
843
844 let original = br#"{"health":[{"check":"x","result":"passed"}]}"#;
845 let compressed = gzip_bytes(original).expect("gzip should succeed");
846 assert!(
847 compressed.starts_with(&[0x1f, 0x8b]),
848 "expected gzip magic bytes"
849 );
850 let mut decoder = GzDecoder::new(&compressed[..]);
851 let mut decompressed = Vec::new();
852 decoder.read_to_end(&mut decompressed).unwrap();
853 assert_eq!(decompressed, original);
854 }
855}