use std::{collections::HashSet, future::Future, time::Duration};
use anyhow::{anyhow, bail, Context, Result};
use base64::prelude::{Engine as _, BASE64_STANDARD};
use futures::future::join_all;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
StatusCode, Url,
};
use serde::{Deserialize, Serialize};
use tokio::{
task::{JoinError, JoinHandle},
time::interval,
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace, warn};
use uuid::Uuid;
use wireguard_keys::Pubkey;
use crate::{wireguard::WgPeer, CONSUL_TTL};
#[must_use]
pub struct TaskCancellator {
join_handle: Option<JoinHandle<()>>,
token: CancellationToken,
}
impl TaskCancellator {
pub fn new(join_handle: JoinHandle<()>, token: CancellationToken) -> Self {
Self {
join_handle: Some(join_handle),
token,
}
}
#[tracing::instrument(skip(self))]
pub async fn cancel(mut self) -> Result<(), JoinError> {
self.token.cancel();
if let Some(join_handle) = self.join_handle.take() {
return join_handle.await;
}
Ok(())
}
}
impl Drop for TaskCancellator {
fn drop(&mut self) {
self.token.cancel();
}
}
#[derive(Clone, Debug)]
pub struct ConsulClient {
pub http_client: reqwest::Client,
api_base_url: Url,
pub kv_api_base_url: Url,
}
#[derive(Debug, Eq, PartialEq, Hash, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct ConsulKvGet {
pub create_index: u64,
pub flags: u64,
pub key: String,
pub lock_index: u64,
pub modify_index: u64,
pub value: String,
}
#[derive(Serialize)]
#[serde(rename_all = "lowercase")]
enum SessionInvalidationBehavior {
Delete,
}
#[derive(Copy, Clone)]
enum SessionDuration {
Seconds(u32),
}
impl TryFrom<Duration> for SessionDuration {
type Error = anyhow::Error;
fn try_from(value: Duration) -> Result<Self> {
let secs = value.as_secs();
if secs > 86400 {
bail!("Tried to convert a duration longer than 24 hours into SessionDuration");
}
Ok(SessionDuration::Seconds(secs as u32))
}
}
impl Serialize for SessionDuration {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&match self {
Self::Seconds(s) => format!("{s}s"),
})
}
}
#[derive(Serialize)]
#[serde(rename_all = "PascalCase")]
struct CreateSession {
name: String,
behavior: SessionInvalidationBehavior,
#[serde(rename = "TTL")]
ttl: SessionDuration,
}
#[derive(Deserialize)]
struct CreateSessionResponse {
#[serde(rename = "ID")]
id: Uuid,
}
impl ConsulClient {
pub fn new(
consul_address: Url,
consul_prefix: &str,
consul_token: Option<&str>,
) -> Result<ConsulClient> {
let consul_prefix = if consul_prefix.ends_with('/') {
consul_prefix.to_string()
} else {
format!("{}/", consul_prefix)
};
let kv_api_base_url = consul_address
.join("v1/")?
.join("kv/")?
.join(&consul_prefix)?;
let client_builder = reqwest::Client::builder();
let client_builder = if let Some(secret_token) = consul_token {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("X-Consul-Token"),
HeaderValue::from_str(secret_token)?,
);
client_builder.default_headers(headers)
} else {
client_builder
};
let client = client_builder.build()?;
Ok(ConsulClient {
http_client: client,
api_base_url: consul_address,
kv_api_base_url,
})
}
#[tracing::instrument(skip(self))]
pub async fn get_peers(&self) -> Result<HashSet<WgPeer>> {
let dcs = self
.http_client
.get(self.api_base_url.join("v1/catalog/datacenters")?)
.send()
.await?
.error_for_status()?
.json::<Vec<String>>()
.await?;
let mut peers = HashSet::new();
for dc_peers in join_all(dcs.iter().map(|dc| self.get_peers_for_dc(dc))).await {
let dc_peers = dc_peers?;
peers.extend(dc_peers);
}
Ok(peers)
}
#[tracing::instrument(skip(self))]
async fn get_peers_for_dc(&self, dc: &str) -> Result<HashSet<WgPeer>> {
let mut peers_url = self.kv_api_base_url.join("peers/")?;
peers_url
.query_pairs_mut()
.append_pair("recurse", "true")
.append_pair("dc", dc)
.append_pair("stale", "1");
let resp = self
.http_client
.get(peers_url)
.send()
.await?
.error_for_status();
match resp {
Ok(resp) => {
let kv_get: HashSet<ConsulKvGet> = resp.json().await?;
let wgpeers: HashSet<_> = kv_get
.into_iter()
.map(|x| {
let decoded = &BASE64_STANDARD
.decode(x.value)
.expect("Can't decode base64");
serde_json::from_slice(decoded)
.expect("Can't interpret JSON out of decoded base64")
})
.collect();
Ok(wgpeers)
}
Err(resp) => {
if resp.status() == Some(StatusCode::NOT_FOUND) {
return Ok(HashSet::new());
}
Err(anyhow!(resp))
}
}
}
#[tracing::instrument(skip(self, parent_token))]
pub async fn create_session(
&self,
public_key: Pubkey,
parent_token: CancellationToken,
) -> Result<ConsulSession> {
let url = self.api_base_url.join("v1/session/create")?;
let res = self
.http_client
.put(url)
.json(&CreateSession {
name: format!("wiresmith-{}", public_key.to_base64_urlsafe()),
behavior: SessionInvalidationBehavior::Delete,
ttl: CONSUL_TTL.try_into()?,
})
.send()
.await?
.error_for_status()?
.json::<CreateSessionResponse>()
.await?;
let session_token = CancellationToken::new();
let join_handle = tokio::spawn(
session_handler(self.clone(), session_token.clone(), parent_token, res.id)
.context("failed to create Consul session handler")?,
);
trace!("Created Consul session with id {}", res.id);
Ok(ConsulSession {
client: self.clone(),
id: res.id,
cancellator: TaskCancellator::new(join_handle, session_token),
})
}
}
fn session_handler(
client: ConsulClient,
session_token: CancellationToken,
parent_token: CancellationToken,
session_id: Uuid,
) -> Result<impl Future<Output = ()> + Send> {
let session_id = session_id.to_string();
let renewal_url = client
.api_base_url
.join("v1/session/renew/")
.context("failed to build session renewal URL")?
.join(&session_id)
.context("failed to build session renewal URL")?;
let destroy_url = client
.api_base_url
.join("v1/session/destroy/")
.context("failed to build session destroy URL")?
.join(&session_id)
.context("failed to build session destroy URL")?;
Ok(async move {
let mut interval = interval(CONSUL_TTL / 2);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = session_token.cancelled() => {
trace!("Consul session handler was cancelled");
break;
},
_ = interval.tick() => {},
};
trace!("Renewing Consul session");
let res = client
.http_client
.put(renewal_url.clone())
.send()
.await
.and_then(|res| res.error_for_status());
if let Err(err) = res {
error!("Renewing Consul session failed, aborting: {err:?}");
parent_token.cancel();
return;
}
}
trace!("Destroying Consul session");
let res = client
.http_client
.put(destroy_url)
.send()
.await
.and_then(|res| res.error_for_status());
if let Err(err) = res {
warn!("Destroying Consul session failed: {err:?}");
}
})
}
pub struct ConsulSession {
client: ConsulClient,
id: Uuid,
cancellator: TaskCancellator,
}
impl ConsulSession {
#[tracing::instrument(skip(self))]
pub async fn cancel(self) -> Result<(), JoinError> {
self.cancellator.cancel().await
}
#[tracing::instrument(skip(self, wgpeer, parent_token))]
pub async fn put_config(
&self,
wgpeer: &WgPeer,
parent_token: CancellationToken,
) -> Result<TaskCancellator> {
let peer_url = self
.client
.kv_api_base_url
.join("peers/")?
.join(&wgpeer.public_key.to_base64_urlsafe())?;
let mut put_url = peer_url.clone();
put_url
.query_pairs_mut()
.append_pair("acquire", &self.id.to_string());
let got_lock = self
.client
.http_client
.put(put_url)
.json(wgpeer)
.send()
.await?
.error_for_status()
.context("failed to put node config into Consul")?
.json::<bool>()
.await
.context("Failed to parse Consul KV put response")?;
if !got_lock {
bail!("Did not get Consul lock for node config");
}
info!("Wrote node config into Consul");
let client = self.client.clone();
let config_token = CancellationToken::new();
let join_handle = tokio::spawn(config_handler(
client,
self.id,
peer_url,
config_token.clone(),
parent_token,
));
Ok(TaskCancellator::new(join_handle, config_token))
}
}
async fn config_handler(
client: ConsulClient,
session_id: Uuid,
peer_url: Url,
config_token: CancellationToken,
parent_token: CancellationToken,
) {
tokio::time::sleep(Duration::from_millis(50)).await;
let mut failed_fetches = 0;
let mut index = None;
let mut interval = interval(Duration::from_secs(1));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = config_token.cancelled() => {
trace!("Consul config handler was cancelled");
break;
},
_ = interval.tick() => {},
};
let res = tokio::select! {
_ = config_token.cancelled() => {
trace!("Consul config handler was cancelled");
break;
},
res = ensure_config_exists(&client, peer_url.clone(), &mut index) => res,
};
match res {
Ok(owner_id) => {
failed_fetches = 0;
if owner_id != session_id {
error!(
"Consul key is locked by {owner_id}, expected it to be us ({session_id})"
);
parent_token.cancel();
break;
}
}
Err(err) => {
failed_fetches += 1;
if failed_fetches >= 5 {
error!("Failed to fetch own node config {failed_fetches} times, cancelling");
parent_token.cancel();
break;
}
error!("Could not get own node config from Consul ({failed_fetches} failed fetches): {err:?}");
continue;
}
};
trace!("Successfully fetched own node config from Consul");
}
}
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ReadKeyResponse {
session: Option<Uuid>,
}
#[tracing::instrument(skip(client))]
async fn ensure_config_exists(
client: &ConsulClient,
peer_url: Url,
index: &mut Option<String>,
) -> Result<Uuid> {
let query: &[_] = if let Some(index) = index {
&[("index", index)]
} else {
&[]
};
let res = client
.http_client
.get(peer_url)
.query(query)
.send()
.await?
.error_for_status()?;
if let Some(new_index) = res.headers().get("X-Consul-Index") {
let new_index = new_index
.to_str()
.context("Failed to convert new Consul index to String")?
.to_string();
index.replace(new_index);
};
let res = res
.json::<Vec<ReadKeyResponse>>()
.await
.context("Failed to parse KV response")?;
res.first()
.context("Consul unexpectedly returned an empty array")?
.session
.ok_or_else(|| anyhow!("Key was not locked by any session"))
}