ts_runtime 0.2.0

tailscale runtime
Documentation
use std::{
    collections::HashMap,
    sync::Arc,
    time::{Duration, Instant},
};

use kameo::{
    actor::ActorRef,
    error::SendError,
    message::{Context, Message},
};
use tokio::{sync::watch, task::JoinSet};
use ts_control::DerpRegion;
use ts_keys::NodeKeyPair;
use ts_transport::{UnderlayTransport, UnderlayTransportId};
use ts_transport_derp::RegionId;

use crate::{
    Env, Error,
    dataplane::{DataplaneActor, NewUnderlayTransport, UnderlayFromDataplane, UnderlayToDataplane},
    derp_latency::DerpLatencyMeasurement,
};

/// Consumes derp map updates and spawns a task per region that runs an underlay transport.
/// Also consumes home derp indications (for this node) to notify the relevant task that it
/// should keep the transport awake even if there is no traffic.
///
/// Other than the home task (which is always kept alive to receive packets), the transport
/// tasks keep the connection alive as long as there is traffic sent or received, and for a
/// short grace period afterward. Connections are otherwise closed not in use.
pub struct Multiderp {
    env: Env,
    dataplane: ActorRef<DataplaneActor>,
    derps: HashMap<RegionId, RegionEntry>,
    current_home_derp: Option<RegionId>,
    tasks: JoinSet<()>,
}

struct RegionEntry {
    transport_id: UnderlayTransportId,
    home_derp: watch::Sender<bool>,
}

impl Multiderp {
    #[tracing::instrument(skip_all, fields(region_id = %id))]
    async fn ensure_region(
        &mut self,
        id: RegionId,
        region: &DerpRegion,
        mut shutdown: watch::Receiver<bool>,
    ) {
        // TODO(npry): update if region info changes

        if self.derps.contains_key(&id) {
            tracing::trace!("region already existed");
            return;
        }

        let region = region.clone();
        let keys = self.env.keys.node_keys;

        let (transport_id, mut up, down) = match self.dataplane.ask(NewUnderlayTransport).await {
            Ok(val) => val,
            Err(SendError::ActorNotRunning(..) | SendError::ActorStopped) => {
                if !*shutdown.borrow() {
                    panic!("dataplane has stopped but we're not shutting down");
                }

                return;
            }
            Err(e) => unreachable!("{}", e),
        };
        let (home_derp_tx, mut home_derp_rx) = watch::channel(false);

        self.tasks.spawn(async move {
            while !*shutdown.borrow() {
                tokio::select! {
                    _ = shutdown.changed() => {
                        break;
                    },
                    ret = run_derp_once(
                        id,
                        &region,
                        keys,
                        &down,
                        &mut up,
                        &mut home_derp_rx,
                    ) => if let Err(e) = ret {
                        tracing::error!(error = %e, region_id = %id, "running derp client");
                        tokio::time::sleep(Duration::from_millis(500)).await;
                    },
                }

                if up.is_closed() {
                    tracing::warn!(region_id = %id, "underlay up channel closed!");
                    break;
                }

                if down.is_closed() {
                    tracing::warn!(region_id = %id, "underlay down channel closed!");
                    break;
                }
            }
        });

        self.derps.insert(
            id,
            RegionEntry {
                transport_id,
                home_derp: home_derp_tx,
            },
        );
    }
}

#[kameo::messages]
impl Multiderp {
    #[message]
    pub fn transport_id_for_region(&self, id: RegionId) -> Option<UnderlayTransportId> {
        Some(self.derps.get(&id)?.transport_id)
    }
}

#[tracing::instrument(skip_all, fields(region_id = %id), name = "derp packet transport")]
async fn run_derp_once(
    id: RegionId,
    region: &DerpRegion,
    keys: NodeKeyPair,
    to_dataplane: &UnderlayToDataplane,
    from_dataplane: &mut UnderlayFromDataplane,
    home_derp_rx: &mut watch::Receiver<bool>,
) -> Result<(), ts_transport_derp::Error> {
    const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(10);

    loop {
        let mut pending = None;

        tracing::trace!("waiting for packet activity or for this to become home derp");

        while !*home_derp_rx.borrow_and_update() {
            tokio::select! {
                _ = home_derp_rx.changed() => {
                    tracing::trace!(is_home_derp = *home_derp_rx.borrow());
                },

                from_net = from_dataplane.recv() => {
                    tracing::trace!("received packet to send");
                    pending = from_net;
                    break;
                }
            }
        }

        tracing::trace!("establishing derp connection");

        let client = ts_transport_derp::DefaultClient::connect(&region.servers, &keys).await?;

        if let Some(pending) = pending {
            tracing::trace!("sending queued packet");
            client.send([pending]).await?;
        }

        let mut last_activity = Instant::now();

        loop {
            let span = tracing::trace_span!("derp_loop");

            let inactivity_timeout =
                (!*home_derp_rx.borrow()).then(|| last_activity + INACTIVITY_TIMEOUT);

            tokio::select! {
                from_derp = client.recv_one() => {
                    last_activity = Instant::now();

                    let (peer, pkt) = from_derp?;
                    tracing::trace!(parent: &span, %peer, len = pkt.len(), "packet from derp server");

                    let Ok(()) = to_dataplane.send((peer, vec![pkt])) else {
                        tracing::error!(parent: &span, "underlay receive channel closed");
                        break;
                    };
                },

                from_net = from_dataplane.recv() => {
                    last_activity = Instant::now();

                    let Some(from_net) = from_net else {
                        tracing::warn!(parent: &span, "transport queue closed");
                        break;
                    };

                    tracing::trace!(parent: &span, peer = %from_net.0, packets = from_net.1.len(), "packets to derp server");

                    client.send([from_net]).await?;
                },

                _ = option_timeout(inactivity_timeout) => {
                    if !*home_derp_rx.borrow_and_update() {
                        tracing::trace!(parent: &span, "timed out and not home derp, closing derp conn");
                        break;
                    }
                },

                _ = home_derp_rx.changed() => {
                    tracing::trace!(is_home_derp = *home_derp_rx.borrow());
                },
            }
        }
    }
}

async fn option_timeout(duration: Option<Instant>) {
    match duration {
        Some(dur) => tokio::time::sleep_until(dur.into()).await,
        None => core::future::pending().await,
    }
}

impl kameo::Actor for Multiderp {
    type Args = (Env, ActorRef<DataplaneActor>);
    type Error = Error;

    async fn on_start(
        (env, dataplane): Self::Args,
        slf: ActorRef<Self>,
    ) -> Result<Self, Self::Error> {
        env.subscribe::<Arc<ts_control::StateUpdate>>(&slf).await?;
        env.subscribe::<DerpLatencyMeasurement>(&slf).await?;

        Ok(Self {
            env,
            dataplane,
            derps: Default::default(),
            tasks: JoinSet::new(),
            current_home_derp: None,
        })
    }
}

impl Message<Arc<ts_control::StateUpdate>> for Multiderp {
    type Reply = ();

    #[tracing::instrument(skip_all, name = "multiderp map update")]
    async fn handle(
        &mut self,
        msg: Arc<ts_control::StateUpdate>,
        _ctx: &mut Context<Self, Self::Reply>,
    ) {
        let Some(derp_map) = &msg.derp else {
            return;
        };

        for (id, region) in derp_map {
            self.ensure_region(*id, region, self.env.shutdown.clone())
                .await;

            // If this is the home region and it was just started, it needs to be notified that it's
            // the home region.
            if let Some(home_derp) = self.current_home_derp
                && *id == home_derp
            {
                self.derps
                    .get_mut(&home_derp)
                    .unwrap()
                    .home_derp
                    .send_replace(true);
            }
        }
    }
}

impl Message<DerpLatencyMeasurement> for Multiderp {
    type Reply = ();

    async fn handle(&mut self, msg: DerpLatencyMeasurement, _ctx: &mut Context<Self, Self::Reply>) {
        let Some(result) = msg.measurement.as_ref().first() else {
            tracing::trace!("received home derp measurement message but none was set");
            return;
        };

        if let Some(home_derp) = self.current_home_derp {
            self.derps
                .get_mut(&home_derp)
                .unwrap()
                .home_derp
                .send_replace(false);
        }

        if self.current_home_derp.is_none_or(|id| id != result.id) {
            self.current_home_derp = Some(result.id);
            if let Some(derp) = self.derps.get_mut(&result.id) {
                derp.home_derp.send_replace(true);
            }

            tracing::info!(
                region_id = %result.id,
                latency_ms = result.latency.as_secs_f32() * 1000.,
                "new home derp region selected"
            );
        }
    }
}