#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeviceState {
Connecting,
Running,
NeedsLogin(url::Url),
Expired,
Failed(RegistrationError),
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum RegistrationError {
#[error("authentication rejected by control: {0}")]
AuthRejected(String),
#[error("node key expired; re-authentication required")]
KeyExpired,
#[error("interactive login required at {0}")]
NeedsLogin(url::Url),
#[error("control plane unreachable")]
NetworkUnreachable,
#[error("timed out waiting for the device to finish registering")]
Timeout,
}
impl RegistrationError {
pub fn is_permanent(&self) -> bool {
matches!(
self,
RegistrationError::AuthRejected(_) | RegistrationError::KeyExpired
)
}
}
impl From<&ts_control::Error> for RegistrationError {
fn from(e: &ts_control::Error) -> Self {
match e {
ts_control::Error::MachineNotAuthorized(u) => RegistrationError::NeedsLogin(u.clone()),
ts_control::Error::Registration(reason) => {
RegistrationError::AuthRejected(reason.clone())
}
ts_control::Error::NetworkError(_) => RegistrationError::NetworkUnreachable,
other => RegistrationError::AuthRejected(other.to_string()),
}
}
}
pub(crate) async fn wait_for_running(
mut rx: tokio::sync::watch::Receiver<DeviceState>,
timeout: Option<core::time::Duration>,
) -> Result<(), RegistrationError> {
let wait = async {
loop {
let settled = match &*rx.borrow_and_update() {
DeviceState::Running => Some(Ok(())),
DeviceState::Failed(e) => Some(Err(e.clone())),
DeviceState::Expired => Some(Err(RegistrationError::KeyExpired)),
DeviceState::NeedsLogin(u) => Some(Err(RegistrationError::NeedsLogin(u.clone()))),
DeviceState::Connecting => None,
};
if let Some(result) = settled {
return result;
}
if rx.changed().await.is_err() {
return Err(RegistrationError::NetworkUnreachable);
}
}
};
match timeout {
Some(timeout) => tokio::time::timeout(timeout, wait)
.await
.unwrap_or(Err(RegistrationError::Timeout)),
None => wait.await,
}
}
#[cfg(test)]
mod tests {
use core::time::Duration;
use tokio::sync::watch;
use super::*;
#[test]
fn permanence_classification() {
assert!(RegistrationError::AuthRejected("bad key".into()).is_permanent());
assert!(RegistrationError::KeyExpired.is_permanent());
assert!(
!RegistrationError::NeedsLogin("https://login.example/x".parse().unwrap())
.is_permanent()
);
assert!(!RegistrationError::NetworkUnreachable.is_permanent());
assert!(!RegistrationError::Timeout.is_permanent());
}
#[test]
fn maps_control_error_variants() {
let url: url::Url = "https://login.example/a".parse().unwrap();
assert_eq!(
RegistrationError::from(&ts_control::Error::MachineNotAuthorized(url.clone())),
RegistrationError::NeedsLogin(url)
);
assert_eq!(
RegistrationError::from(&ts_control::Error::Registration("bad auth key".into())),
RegistrationError::AuthRejected("bad auth key".into())
);
assert_eq!(
RegistrationError::from(&ts_control::Error::NetworkError(
ts_control::Operation::Registration
)),
RegistrationError::NetworkUnreachable
);
}
#[tokio::test]
async fn wait_resolves_when_already_running() {
let (_tx, rx) = watch::channel(DeviceState::Running);
assert_eq!(
wait_for_running(rx, Some(Duration::from_secs(1))).await,
Ok(())
);
}
#[tokio::test]
async fn wait_resolves_on_transition_to_running() {
let (tx, rx) = watch::channel(DeviceState::Connecting);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
tx.send_replace(DeviceState::Running);
});
assert_eq!(
wait_for_running(rx, Some(Duration::from_secs(1))).await,
Ok(())
);
}
#[tokio::test]
async fn wait_maps_each_settled_failure() {
for (state, expected) in [
(
DeviceState::Failed(RegistrationError::AuthRejected("bad".into())),
RegistrationError::AuthRejected("bad".into()),
),
(DeviceState::Expired, RegistrationError::KeyExpired),
(
DeviceState::NeedsLogin("https://login.example/x".parse().unwrap()),
RegistrationError::NeedsLogin("https://login.example/x".parse().unwrap()),
),
] {
let (_tx, rx) = watch::channel(state);
assert_eq!(
wait_for_running(rx, Some(Duration::from_secs(1))).await,
Err(expected)
);
}
}
#[tokio::test]
async fn wait_times_out_while_connecting() {
let (_tx, rx) = watch::channel(DeviceState::Connecting);
assert_eq!(
wait_for_running(rx, Some(Duration::from_millis(30))).await,
Err(RegistrationError::Timeout)
);
}
#[tokio::test]
async fn wait_sender_dropped_is_network_unreachable() {
let (tx, rx) = watch::channel(DeviceState::Connecting);
drop(tx);
assert_eq!(
wait_for_running(rx, Some(Duration::from_secs(1))).await,
Err(RegistrationError::NetworkUnreachable)
);
}
}