plane 0.5.5

Session backend orchestrator for ambitious browser-based apps.
Documentation
use chrono::{DateTime, Utc};
use plane_common::names::{ControllerName, DroneName};
use plane_common::types::{BackendStatus, ClusterName, DronePoolName, NodeId};
use sqlx::{postgres::types::PgInterval, query, PgPool};
use std::str::FromStr;
use std::time::Duration;

use crate::heartbeat_consts::UNHEALTHY_SECONDS;

pub struct DroneDatabase<'a> {
    pool: &'a PgPool,
}

impl<'a> DroneDatabase<'a> {
    pub fn new(pool: &'a PgPool) -> Self {
        Self { pool }
    }

    pub async fn register_drone(
        &self,
        id: NodeId,
        ready: bool,
        pool: DronePoolName,
    ) -> sqlx::Result<()> {
        query!(
            r#"
            insert into drone (id, draining, ready, pool)
            values ($1, false, $2, $3)
            on conflict (id) do update set
                ready = $2
            "#,
            id.as_i32(),
            ready,
            pool.to_string(),
        )
        .execute(self.pool)
        .await?;

        Ok(())
    }

    pub async fn drain(&self, id: NodeId) -> sqlx::Result<bool> {
        let result = query!(
            r#"
            update drone
            set draining = true
            where id = $1
            returning (
                select draining
                from drone
                where id = $1
            ) as "was_draining!"
            "#,
            id.as_i32(),
        )
        .fetch_optional(self.pool)
        .await?;

        if let Some(was_draining) = result {
            Ok(!was_draining.was_draining)
        } else {
            Err(sqlx::Error::RowNotFound)
        }
    }

    pub async fn heartbeat(&self, id: NodeId, local_time: DateTime<Utc>) -> sqlx::Result<()> {
        query!(
            r#"
            update drone
            set last_heartbeat = now(), last_local_time = $2
            where id = $1
            "#,
            id.as_i32(),
            local_time,
        )
        .execute(self.pool)
        .await?;

        Ok(())
    }

    pub async fn get_drone_pool(&self, id: NodeId) -> sqlx::Result<DronePoolName> {
        let result = query!(
            r#"
            select pool
            from drone
            where id = $1
            "#,
            id.as_i32(),
        )
        .fetch_one(self.pool)
        .await?;

        Ok(result.pool.into())
    }

    /// Returns a list of drones that have been seen in the last `seen_in_last` duration.
    /// Limits the result to the most recent 100 drones.
    pub async fn get_drones_for_pool(
        &self,
        cluster: &ClusterName,
        pool: &DronePoolName,
        seen_in_last: Duration,
    ) -> sqlx::Result<Vec<DroneWithMetadata>> {
        // Converting to a PgInterval fails if the duration has nanosecond granularity.
        // So let's round up to the nearest microsecond.
        let seen_in_last = if seen_in_last.subsec_nanos() % 1000 == 0 {
            seen_in_last
        } else {
            let nanos = (seen_in_last.subsec_micros() + 1) * 1000;
            Duration::new(seen_in_last.as_secs(), nanos)
        };

        let Ok(seen_in_last) = PgInterval::try_from(seen_in_last) else {
            return Err(sqlx::Error::Protocol("invalid interval".to_string()));
        };
        let result = query!(
            r#"
            select
                drone.id as id,
                drone.ready as ready,
                drone.draining as draining,
                drone.last_heartbeat as "last_heartbeat!",
                drone.last_local_time as "last_local_time!",
                drone.pool as pool,
                node.name as name,
                node.cluster as "cluster!",
                node.plane_version as plane_version,
                node.plane_hash as plane_hash,
                node.controller as controller,
                node.last_connection_start_time as "last_connection_start_time!"
            from node
            left join drone on node.id = drone.id
            where
                cluster = $1
                and now() - drone.last_heartbeat < $2
                and pool = $3
                and last_local_time is not null
                and last_connection_start_time is not null
            order by drone.id desc
            limit 100
            "#,
            cluster.to_string(),
            seen_in_last,
            pool.to_string(),
        )
        .fetch_all(self.pool)
        .await?;

        let drones: Vec<DroneWithMetadata> = result
            .into_iter()
            .map(|r| DroneWithMetadata {
                id: NodeId::from(r.id),
                name: DroneName::try_from(r.name).expect("valid drone name"),
                ready: r.ready,
                draining: r.draining,
                last_heartbeat: r.last_heartbeat,
                last_local_time: r.last_local_time,
                pool: r.pool.into(),
                cluster: ClusterName::from_str(&r.cluster).expect("valid cluster name"),
                plane_version: r.plane_version,
                plane_hash: r.plane_hash,
                controller: r
                    .controller
                    .map(|c| ControllerName::try_from(c).expect("valid controller name")),
                last_connection_start_time: r.last_connection_start_time,
            })
            .collect();

        Ok(drones)
    }

    /// TODO: simple algorithm until we collect more metrics.
    pub async fn pick_drone_for_spawn(
        &self,
        cluster: &ClusterName,
        pool: &DronePoolName,
    ) -> sqlx::Result<Option<DroneForSpawn>> {
        let result = query!(
            r#"
            select
                drone.id,
                node.name,
                drone.last_local_time as "last_local_time!"
            from node
            left join drone
                on node.id = drone.id
            left join controller
                on node.controller = controller.id
            where
                drone.ready = true
                and controller is not null
                and cluster = $1
                and now() - drone.last_heartbeat < $2
                and now() - controller.last_heartbeat < $2
                and controller.is_online = true
                and draining = false
                and last_local_time is not null
                and pool = $3
            order by (
                select
                    count(*)
                from backend
                where drone_id = node.id
                and last_status != $4
            ) asc, random()
            limit 1
            "#,
            cluster.to_string(),
            PgInterval::try_from(Duration::from_secs(UNHEALTHY_SECONDS as _))
                .expect("valid interval"),
            pool.to_string(),
            BackendStatus::Terminated.to_string(),
        )
        .fetch_optional(self.pool)
        .await?;

        let result = match result {
            Some(result) => {
                let id = NodeId::from(result.id);
                let drone = DroneName::try_from(result.name).expect("valid drone name");
                let last_local_time = result.last_local_time;

                Some(DroneForSpawn {
                    id,
                    drone,
                    last_local_time,
                })
            }
            None => return Ok(None),
        };

        Ok(result)
    }
}

pub struct DroneForSpawn {
    pub id: NodeId,
    pub drone: DroneName,
    pub last_local_time: DateTime<Utc>,
}

pub struct DroneWithMetadata {
    pub id: NodeId,
    pub name: DroneName,
    pub ready: bool,
    pub draining: bool,
    pub last_heartbeat: DateTime<Utc>,
    pub last_local_time: DateTime<Utc>,
    pub pool: DronePoolName,
    pub cluster: ClusterName,
    pub plane_version: String,
    pub plane_hash: String,
    pub controller: Option<ControllerName>,
    pub last_connection_start_time: DateTime<Utc>,
}