use std::{
fmt,
io::Write,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{Arc, OnceLock},
time::Duration,
};
use flate2::{Compression, write::GzEncoder};
use hickory_resolver::{
ConnectionProvider, Resolver,
config::{ConnectionConfig, NameServerConfig, ResolverConfig},
net::runtime::TokioRuntimeProvider,
};
use miette::{IntoDiagnostic, Result, WrapErr};
use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair};
use reqwest::Url;
use time::{Duration as TimeDuration, OffsetDateTime};
use tokio::sync::RwLock;
use tracing::debug;
use crate::Redacted;
pub const DEFAULT_CANOPY_URL: &str = "https://meta.tamanu.app";
pub const TAILSCALE_URL: &str = "https://canopy.tail53aef.ts.net";
const TAILSCALE_HOST: &str = "canopy.tail53aef.ts.net";
const CANOPY_HARDCODED_V4: Ipv4Addr = Ipv4Addr::new(100, 99, 98, 97);
const CANOPY_HARDCODED_V6: Ipv6Addr =
Ipv6Addr::new(0xfd7a, 0x115c, 0xa1e0, 0, 0, 0, 0x9337, 0xfb52);
const CERT_VALIDITY_DAYS: i64 = 6;
pub const CERT_RENEW_AFTER: Duration = Duration::from_secs(5 * 24 * 60 * 60);
const TAILSCALE_PROBE_TIMEOUT: Duration = Duration::from_secs(5);
pub type ClientBuilderFactory = Arc<dyn Fn() -> reqwest::ClientBuilder + Send + Sync>;
#[derive(Debug, Clone)]
pub struct CanopyHttpError {
pub status: reqwest::StatusCode,
pub path: String,
pub body: String,
}
impl fmt::Display for CanopyHttpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"canopy {} returned {}: {}",
self.path, self.status, self.body
)
}
}
impl std::error::Error for CanopyHttpError {}
impl miette::Diagnostic for CanopyHttpError {}
fn user_agent() -> &'static str {
static UA: OnceLock<String> = OnceLock::new();
UA.get_or_init(|| {
let os = sysinfo::System::long_os_version()
.or_else(sysinfo::System::name)
.unwrap_or_else(|| std::env::consts::OS.to_owned());
format!(
"bestool-canopy/{} ({os}; {})",
env!("CARGO_PKG_VERSION"),
sysinfo::System::cpu_arch(),
)
})
}
pub async fn tailscale_client(make_builder: &ClientBuilderFactory) -> Option<reqwest::Client> {
let tailscale_url = TAILSCALE_URL
.parse()
.expect("default tailscale URL is valid");
probe_tailscale(&tailscale_url, make_builder).await
}
pub struct CanopyClient {
base_url: Url,
tailscale_url: Url,
device_key: Option<Redacted<String>>,
make_builder: ClientBuilderFactory,
state: RwLock<State>,
}
enum State {
Tailscale(reqwest::Client),
Mtls(reqwest::Client),
}
impl State {
fn is_tailscale(&self) -> bool {
matches!(self, State::Tailscale(_))
}
fn http(&self) -> reqwest::Client {
match self {
State::Tailscale(http) | State::Mtls(http) => http.clone(),
}
}
}
impl fmt::Debug for CanopyClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CanopyClient").finish_non_exhaustive()
}
}
impl CanopyClient {
pub async fn new(
device_key_pem: Option<&str>,
make_builder: impl Fn() -> reqwest::ClientBuilder + Send + Sync + 'static,
) -> Result<Option<Self>> {
Self::with_urls(
DEFAULT_CANOPY_URL
.parse()
.expect("default canopy URL is valid"),
TAILSCALE_URL
.parse()
.expect("default tailscale URL is valid"),
device_key_pem,
make_builder,
)
.await
}
pub async fn with_urls(
base_url: Url,
tailscale_url: Url,
device_key_pem: Option<&str>,
make_builder: impl Fn() -> reqwest::ClientBuilder + Send + Sync + 'static,
) -> Result<Option<Self>> {
let device_key = device_key_pem.map(|s| Redacted(s.to_owned()));
let make_builder: ClientBuilderFactory = Arc::new(make_builder);
if let Some(http) = probe_tailscale(&tailscale_url, &make_builder).await {
debug!("canopy: tailscale endpoint reachable, preferring it");
return Ok(Some(Self {
base_url,
tailscale_url,
device_key,
make_builder,
state: RwLock::new(State::Tailscale(http)),
}));
}
if let Some(pem) = device_key_pem {
debug!("canopy: tailscale unreachable, falling back to mTLS");
let http = build_mtls_http(&make_builder, pem)?;
return Ok(Some(Self {
base_url,
tailscale_url,
device_key,
make_builder,
state: RwLock::new(State::Mtls(http)),
}));
}
Ok(None)
}
pub async fn is_tailscale(&self) -> bool {
self.state.read().await.is_tailscale()
}
pub async fn refresh(&self) -> Result<()> {
if let Some(http) = probe_tailscale(&self.tailscale_url, &self.make_builder).await {
let mut state = self.state.write().await;
if !state.is_tailscale() {
debug!("canopy refresh: switching to tailscale path");
}
*state = State::Tailscale(http);
return Ok(());
}
if let Some(pem) = &self.device_key {
let http = build_mtls_http(&self.make_builder, &pem.0)?;
let mut state = self.state.write().await;
if state.is_tailscale() {
debug!("canopy refresh: tailscale dropped, falling back to mTLS");
}
*state = State::Mtls(http);
return Ok(());
}
debug!("canopy refresh: no auth path available, keeping current state");
Ok(())
}
pub async fn renew(&self) -> Result<()> {
let Some(pem) = &self.device_key else {
return Ok(());
};
let mut state = self.state.write().await;
if state.is_tailscale() {
return Ok(());
}
*state = State::Mtls(build_mtls_http(&self.make_builder, &pem.0)?);
Ok(())
}
async fn endpoint_url(&self, path: &str) -> Result<(reqwest::Client, Url)> {
let state = self.state.read().await;
let url = match &*state {
State::Tailscale(_) => self
.tailscale_url
.join(&format!("/public{path}"))
.into_diagnostic()
.wrap_err_with(|| format!("building tailscale /public{path} URL"))?,
State::Mtls(_) => self
.base_url
.join(path)
.into_diagnostic()
.wrap_err_with(|| format!("building {path} URL"))?,
};
Ok((state.http(), url))
}
async fn send_call<B: serde::Serialize + ?Sized>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> Result<reqwest::Response> {
let (http, url) = self.endpoint_url(path).await?;
debug!(%url, %method, "canopy request");
let mut req = http.request(method, url);
if let Some(body) = body {
let raw = serde_json::to_vec(body)
.into_diagnostic()
.wrap_err_with(|| format!("serialising canopy {path} body"))?;
let compressed = gzip_bytes(&raw)
.into_diagnostic()
.wrap_err_with(|| format!("gzipping canopy {path} body"))?;
req = req
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(reqwest::header::CONTENT_ENCODING, "gzip")
.body(compressed);
}
let response = req
.send()
.await
.into_diagnostic()
.wrap_err_with(|| format!("calling canopy {path}"))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(miette::Report::new(CanopyHttpError {
status,
path: path.to_owned(),
body,
}));
}
Ok(response)
}
pub(crate) async fn call_json<B, R>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> Result<R>
where
B: serde::Serialize + ?Sized,
R: serde::de::DeserializeOwned,
{
let response = self.send_call(method, path, body).await?;
response
.json::<R>()
.await
.into_diagnostic()
.wrap_err_with(|| format!("parsing canopy {path} response"))
}
pub(crate) async fn call_empty<B: serde::Serialize + ?Sized>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> Result<()> {
self.send_call(method, path, body).await.map(drop)
}
#[cfg(feature = "raw-requests")]
pub async fn get(&self, tailscale_path: &str, mtls_path: &str) -> Result<reqwest::Response> {
let (http, url) = {
let state = self.state.read().await;
let url = match &*state {
State::Tailscale(_) => self
.tailscale_url
.join(tailscale_path)
.into_diagnostic()
.wrap_err("building tailscale GET URL")?,
State::Mtls(_) => self
.base_url
.join(mtls_path)
.into_diagnostic()
.wrap_err("building mTLS GET URL")?,
};
(state.http(), url)
};
debug!(%url, "GET via canopy");
http.get(url)
.send()
.await
.into_diagnostic()
.wrap_err("GET via canopy")
}
#[cfg(feature = "raw-requests")]
pub async fn request(
&self,
method: reqwest::Method,
path: &str,
) -> Result<reqwest::RequestBuilder> {
let (http, url) = self.endpoint_url(path).await?;
debug!(%url, %method, "arbitrary canopy request");
Ok(http.request(method, url))
}
#[cfg(feature = "raw-requests")]
pub async fn request_json<Res: serde::de::DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&(impl serde::Serialize + ?Sized)>,
) -> Result<Res> {
self.call_json(method, path, body).await
}
}
async fn probe_tailscale(
tailscale_url: &Url,
make_builder: &ClientBuilderFactory,
) -> Option<reqwest::Client> {
let host = tailscale_url.host_str()?;
if host != TAILSCALE_HOST {
return try_probe(tailscale_url, host, &[], make_builder).await;
}
let dns_addrs: Vec<SocketAddr> = tailscale_resolver()
.lookup_ip("canopy")
.await
.ok()
.map(|addrs| addrs.iter().map(|ip| SocketAddr::new(ip, 443)).collect())
.unwrap_or_default();
if !dns_addrs.is_empty()
&& let Some(client) = try_probe(tailscale_url, host, &dns_addrs, make_builder).await
{
return Some(client);
}
let hardcoded = [
SocketAddr::new(IpAddr::V4(CANOPY_HARDCODED_V4), 443),
SocketAddr::new(IpAddr::V6(CANOPY_HARDCODED_V6), 443),
];
debug!(
?hardcoded,
"canopy tailscale DNS lookup empty or probe failed, trying hardcoded IPs"
);
try_probe(tailscale_url, host, &hardcoded, make_builder).await
}
async fn try_probe(
tailscale_url: &Url,
host: &str,
addrs: &[SocketAddr],
make_builder: &ClientBuilderFactory,
) -> Option<reqwest::Client> {
let mut builder = make_builder()
.user_agent(user_agent())
.timeout(TAILSCALE_PROBE_TIMEOUT);
if !addrs.is_empty() {
builder = builder.resolve_to_addrs(host, addrs);
}
let client = builder.build().ok()?;
let url = tailscale_url.join("/public/servers").ok()?;
match client.get(url).send().await {
Ok(resp) if resp.status().is_success() => Some(client),
Ok(resp) => {
debug!(status = %resp.status(), ?addrs, "canopy tailscale probe: unexpected status");
None
}
Err(err) => {
debug!(?addrs, "canopy tailscale probe failed: {err}");
None
}
}
}
fn tailscale_resolver() -> Resolver<impl ConnectionProvider> {
Resolver::builder_with_config(
ResolverConfig::from_parts(
None,
vec!["tail53aef.ts.net.".parse().unwrap()],
vec![NameServerConfig::new(
"100.100.100.100".parse().unwrap(),
true,
vec![ConnectionConfig::udp()],
)],
),
TokioRuntimeProvider::default(),
)
.build()
.expect("tailscale resolver config is hardcoded and cannot fail to build")
}
fn gzip_bytes(bytes: &[u8]) -> std::io::Result<Vec<u8>> {
let mut encoder = GzEncoder::new(Vec::with_capacity(bytes.len() / 2), Compression::default());
encoder.write_all(bytes)?;
encoder.finish()
}
pub fn device_identity(device_key_pem: &str) -> Result<reqwest::Identity> {
let key_pair = KeyPair::from_pem(device_key_pem)
.into_diagnostic()
.wrap_err("parsing device key PEM")?;
let mut params = CertificateParams::new(vec!["device.local".into()])
.into_diagnostic()
.wrap_err("building certificate params")?;
params.distinguished_name = DistinguishedName::new();
params
.distinguished_name
.push(DnType::CommonName, "device.local");
let now = OffsetDateTime::now_utc();
params.not_before = now - TimeDuration::minutes(1);
params.not_after = now + TimeDuration::days(CERT_VALIDITY_DAYS);
let cert = params
.self_signed(&key_pair)
.into_diagnostic()
.wrap_err("self-signing certificate")?;
let mut combined = cert.pem();
combined.push('\n');
combined.push_str(&key_pair.serialize_pem());
reqwest::Identity::from_pem(combined.as_bytes())
.into_diagnostic()
.wrap_err("building reqwest TLS identity")
}
fn build_mtls_http(
make_builder: &ClientBuilderFactory,
device_key_pem: &str,
) -> Result<reqwest::Client> {
let identity = device_identity(device_key_pem)?;
make_builder()
.user_agent(user_agent())
.identity(identity)
.use_rustls_tls()
.timeout(Duration::from_secs(30))
.build()
.into_diagnostic()
.wrap_err("building canopy HTTP client")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_DEVICE_KEY: &str = "\
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgVvhzsYiidp38GYn1
KxD5Wipc/h8lglVsy1UFZq/SZbGhRANCAAT2EsEq7xjeWVnim9XwdYXga/LBbppm
fXLgamTYOa/w9n/Ta64fiYWmN54kEd0DgnflJDLtID321Zz6xswvK/VN
-----END PRIVATE KEY-----";
fn test_factory() -> ClientBuilderFactory {
Arc::new(reqwest::Client::builder)
}
#[test]
fn build_mtls_http_from_p256_key() {
let result = build_mtls_http(&test_factory(), TEST_DEVICE_KEY);
assert!(result.is_ok(), "{:?}", result.err());
}
#[test]
fn build_mtls_http_fails_on_garbage_key() {
assert!(build_mtls_http(&test_factory(), "not a real PEM").is_err());
}
#[tokio::test]
async fn renew_with_mtls_state_swaps_in_fresh_client() {
let http = build_mtls_http(&test_factory(), TEST_DEVICE_KEY).unwrap();
let client = CanopyClient {
base_url: DEFAULT_CANOPY_URL.parse().unwrap(),
tailscale_url: TAILSCALE_URL.parse().unwrap(),
device_key: Some(Redacted(TEST_DEVICE_KEY.to_owned())),
make_builder: test_factory(),
state: RwLock::new(State::Mtls(http)),
};
client.renew().await.expect("renew should succeed");
assert!(!client.is_tailscale().await);
}
#[tokio::test]
async fn renew_is_noop_in_tailscale_mode() {
let http = reqwest::Client::new();
let client = CanopyClient {
base_url: DEFAULT_CANOPY_URL.parse().unwrap(),
tailscale_url: TAILSCALE_URL.parse().unwrap(),
device_key: None,
make_builder: test_factory(),
state: RwLock::new(State::Tailscale(http)),
};
client.renew().await.expect("renew should be a no-op");
assert!(client.is_tailscale().await);
}
fn mtls_client_against(base: &str) -> CanopyClient {
let http = build_mtls_http(&test_factory(), TEST_DEVICE_KEY).unwrap();
CanopyClient {
base_url: base.parse().unwrap(),
tailscale_url: TAILSCALE_URL.parse().unwrap(),
device_key: Some(Redacted(TEST_DEVICE_KEY.to_owned())),
make_builder: test_factory(),
state: RwLock::new(State::Mtls(http)),
}
}
struct Captured {
request_line: String,
headers: String,
body: Vec<u8>,
}
fn serve_once(response: &'static str) -> (String, std::thread::JoinHandle<Captured>) {
use std::io::{Read, Write};
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let base = format!("http://{}", listener.local_addr().unwrap());
let handle = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = Vec::new();
let mut chunk = [0u8; 1024];
let header_end = loop {
if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
let n = stream.read(&mut chunk).unwrap();
if n == 0 {
panic!("connection closed before headers were complete");
}
buf.extend_from_slice(&chunk[..n]);
};
let head = String::from_utf8_lossy(&buf[..header_end]).into_owned();
let content_length = head
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.trim()
.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
let mut body = buf[header_end..].to_vec();
while body.len() < content_length {
let n = stream.read(&mut chunk).unwrap();
if n == 0 {
break;
}
body.extend_from_slice(&chunk[..n]);
}
stream.write_all(response.as_bytes()).unwrap();
stream.flush().unwrap();
let mut lines = head.lines();
let request_line = lines.next().unwrap_or_default().to_owned();
let headers = lines.collect::<Vec<_>>().join("\n");
Captured {
request_line,
headers,
body,
}
});
(base, handle)
}
#[derive(Debug, serde::Deserialize, PartialEq)]
struct Echo {
ok: bool,
who: String,
}
#[tokio::test]
async fn call_json_gzips_body_sets_user_agent_and_parses_response() {
let (base, handle) = serve_once(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 26\r\n\r\n{\"ok\":true,\"who\":\"device\"}",
);
let client = mtls_client_against(&base);
let payload = serde_json::json!({ "hello": "world" });
let got: Echo = client
.call_json(reqwest::Method::POST, "/thing", Some(&payload))
.await
.expect("call_json should succeed");
assert_eq!(
got,
Echo {
ok: true,
who: "device".into()
}
);
let captured = handle.join().unwrap();
assert!(
captured.request_line.starts_with("POST /thing "),
"unexpected request line: {}",
captured.request_line
);
let headers = captured.headers.to_ascii_lowercase();
assert!(
headers.contains("user-agent: bestool-canopy/"),
"missing canopy user-agent in:\n{}",
captured.headers
);
assert!(
headers.contains("content-encoding: gzip"),
"body should be gzipped:\n{}",
captured.headers
);
use flate2::read::GzDecoder;
use std::io::Read as _;
let mut decoder = GzDecoder::new(&captured.body[..]);
let mut raw = Vec::new();
decoder
.read_to_end(&mut raw)
.expect("body should be valid gzip");
let sent: serde_json::Value = serde_json::from_slice(&raw).unwrap();
assert_eq!(sent, payload);
}
#[tokio::test]
async fn call_json_errors_on_non_success_with_body() {
let (base, handle) =
serve_once("HTTP/1.1 418 I'm a teapot\r\nContent-Length: 14\r\n\r\nno coffee here");
let client = mtls_client_against(&base);
let err = client
.call_json::<(), serde_json::Value>(reqwest::Method::GET, "/brew", None::<&()>)
.await
.expect_err("non-2xx should error");
let msg = err.to_string();
assert!(msg.contains("/brew"), "expected path in error: {msg}");
assert!(msg.contains("418"), "expected status in error: {msg}");
assert!(
msg.contains("no coffee here"),
"expected body text in error: {msg}"
);
handle.join().unwrap();
}
#[test]
fn user_agent_identifies_the_crate_with_os_comment() {
let ua = user_agent();
assert!(
ua.starts_with(concat!("bestool-canopy/", env!("CARGO_PKG_VERSION"), " ")),
"unexpected user-agent: {ua}"
);
assert!(ua.contains('('), "expected OS comment in: {ua}");
assert!(ua.ends_with(')'), "expected OS comment in: {ua}");
assert!(
ua.contains(sysinfo::System::cpu_arch().as_str()),
"expected arch in: {ua}"
);
}
#[test]
fn gzip_bytes_roundtrips() {
use flate2::read::GzDecoder;
use std::io::Read;
let original = br#"{"health":[{"check":"x","result":"passed"}]}"#;
let compressed = gzip_bytes(original).expect("gzip should succeed");
assert!(
compressed.starts_with(&[0x1f, 0x8b]),
"expected gzip magic bytes"
);
let mut decoder = GzDecoder::new(&compressed[..]);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
}