ostool-server 0.1.2

Server for managing development boards, serial sessions, and TFTP artifacts
use std::time::Duration;

use anyhow::Context;
use axum::extract::ws::{Message, WebSocket};
use base64::Engine;
use futures_util::{Sink, SinkExt, StreamExt};
use serde::Deserialize;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task::JoinHandle;
use tokio_serial::SerialPortBuilderExt;

use crate::{
    config::BoardConfig,
    power::{PowerAction, PowerActionError, execute_power_action_for_board},
    session::SessionState,
    state::AppState,
};

#[derive(Debug, Deserialize)]
struct ClientControlMessage {
    #[serde(rename = "type")]
    kind: String,
    encoding: Option<String>,
    data: Option<String>,
}

pub async fn run_serial_ws(
    socket: WebSocket,
    state: AppState,
    session: std::sync::Arc<SessionState>,
) {
    let result = run_serial_ws_inner(socket, &state, session.clone()).await;
    session.clear_serial_connected();
    if let Err(err) = result {
        log::warn!("serial websocket ended with error: {err:#}");
    }
}

async fn run_serial_ws_inner(
    socket: WebSocket,
    _state: &AppState,
    session: std::sync::Arc<SessionState>,
) -> anyhow::Result<()> {
    let session_id = session.snapshot().await.id;
    let board = session.board().clone();
    let serial = board
        .serial
        .as_ref()
        .ok_or_else(|| anyhow::anyhow!("board has no serial configuration"))?;
    let port = tokio_serial::new(&serial.port, serial.baud_rate)
        .timeout(Duration::from_millis(200))
        .open_native_async()
        .with_context(|| format!("failed to open serial port {}", serial.port))?;

    let (mut ws_sender, mut ws_receiver) = socket.split();
    let (mut serial_rx, mut serial_tx) = tokio::io::split(port);
    let mut serial_buffer = [0u8; 1024];
    let mut power_on_task = Some(spawn_power_action_task(board.clone(), PowerAction::On));
    let power_linked = true;
    let mut shutdown_rx = session.subscribe_shutdown();

    ws_sender
        .send(Message::Text(r#"{"type":"opened"}"#.to_string().into()))
        .await
        .ok();
    loop {
        if let Some(task) = power_on_task.as_mut() {
            tokio::select! {
                power_result = task => {
                    power_on_task = None;
                    match power_result {
                        Ok(Ok(_)) => {}
                        Ok(Err(err)) => {
                            let message = format!("automatic power-on failed: {err}");
                            log::warn!("session `{session_id}` {message}");
                            send_power_on_failure_and_close(&mut ws_sender, &message).await;
                            break;
                        }
                        Err(err) => {
                            let message = format!("automatic power-on task join failed: {err}");
                            log::warn!("session `{session_id}` {message}");
                            send_power_on_failure_and_close(&mut ws_sender, &message).await;
                            break;
                        }
                    }
                }
                changed = shutdown_rx.changed() => {
                    if changed.is_ok() && *shutdown_rx.borrow() {
                        let _ = ws_sender
                            .send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
                            .await;
                        break;
                    }
                }
                maybe_message = ws_receiver.next() => {
                    let Some(message) = maybe_message else {
                        break;
                    };
                    match message {
                        Ok(Message::Binary(bytes)) => {
                            write_serial_payload(&mut serial_tx, &bytes).await?;
                        }
                        Ok(Message::Text(text)) => {
                            let control: ClientControlMessage = serde_json::from_str(&text)?;
                            match control.kind.as_str() {
                                "close" => {
                                    let _ = ws_sender
                                        .send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
                                        .await;
                                    break;
                                }
                                "tx" => {
                                    let Some(data) = control.data.as_deref() else {
                                        anyhow::bail!("missing tx data");
                                    };
                                    let payload = match control.encoding.as_deref() {
                                        Some("base64") => base64::engine::general_purpose::STANDARD
                                            .decode(data)
                                            .context("invalid base64 payload")?,
                                        Some("utf8") | None => data.as_bytes().to_vec(),
                                        Some(other) => anyhow::bail!("unsupported encoding `{other}`"),
                                    };
                                    write_serial_payload(&mut serial_tx, &payload).await?;
                                }
                                other => anyhow::bail!("unsupported websocket control type `{other}`"),
                            }
                        }
                        Ok(Message::Close(_)) => break,
                        Ok(Message::Ping(payload)) => {
                            ws_sender.send(Message::Pong(payload)).await.ok();
                        }
                        Ok(Message::Pong(_)) => {}
                        Err(err) => return Err(err.into()),
                    }
                    let _ = session.heartbeat().await;
                }
                read = serial_rx.read(&mut serial_buffer) => {
                    let read = read.context("serial read failed")?;
                    if read == 0 {
                        break;
                    }
                    ws_sender
                        .send(Message::Binary(serial_buffer[..read].to_vec().into()))
                        .await
                        .context("failed to send serial output over websocket")?;
                    let _ = session.heartbeat().await;
                }
            }
        } else {
            tokio::select! {
                changed = shutdown_rx.changed() => {
                    if changed.is_ok() && *shutdown_rx.borrow() {
                        let _ = ws_sender
                            .send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
                            .await;
                        break;
                    }
                }
                maybe_message = ws_receiver.next() => {
                    let Some(message) = maybe_message else {
                        break;
                    };
                    match message {
                        Ok(Message::Binary(bytes)) => {
                            write_serial_payload(&mut serial_tx, &bytes).await?;
                        }
                        Ok(Message::Text(text)) => {
                            let control: ClientControlMessage = serde_json::from_str(&text)?;
                            match control.kind.as_str() {
                                "close" => {
                                    let _ = ws_sender
                                        .send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
                                        .await;
                                    break;
                                }
                                "tx" => {
                                    let Some(data) = control.data.as_deref() else {
                                        anyhow::bail!("missing tx data");
                                    };
                                    let payload = match control.encoding.as_deref() {
                                        Some("base64") => base64::engine::general_purpose::STANDARD
                                            .decode(data)
                                            .context("invalid base64 payload")?,
                                        Some("utf8") | None => data.as_bytes().to_vec(),
                                        Some(other) => anyhow::bail!("unsupported encoding `{other}`"),
                                    };
                                    write_serial_payload(&mut serial_tx, &payload).await?;
                                }
                                other => anyhow::bail!("unsupported websocket control type `{other}`"),
                            }
                        }
                        Ok(Message::Close(_)) => break,
                        Ok(Message::Ping(payload)) => {
                            ws_sender.send(Message::Pong(payload)).await.ok();
                        }
                        Ok(Message::Pong(_)) => {}
                        Err(err) => return Err(err.into()),
                    }
                    let _ = session.heartbeat().await;
                }
                read = serial_rx.read(&mut serial_buffer) => {
                    let read = read.context("serial read failed")?;
                    if read == 0 {
                        break;
                    }
                    ws_sender
                        .send(Message::Binary(serial_buffer[..read].to_vec().into()))
                        .await
                        .context("failed to send serial output over websocket")?;
                    let _ = session.heartbeat().await;
                }
            }
        }
    }

    cleanup_power_link(&board, power_linked, power_on_task).await;
    let _ = ws_sender.send(Message::Close(None)).await;
    Ok(())
}

fn spawn_power_action_task(
    board: BoardConfig,
    action: PowerAction,
) -> JoinHandle<Result<String, PowerActionError>> {
    tokio::spawn(async move { execute_power_action_for_board(&board, action).await })
}

async fn cleanup_power_link(
    board: &BoardConfig,
    power_linked: bool,
    power_on_task: Option<JoinHandle<Result<String, PowerActionError>>>,
) {
    if !power_linked {
        return;
    }

    if let Some(task) = power_on_task {
        match task.await {
            Ok(Ok(_)) => {}
            Ok(Err(err)) => {
                log::warn!(
                    "session `{}` power-on task ended with error: {err}",
                    board.id
                )
            }
            Err(err) => log::warn!("session `{}` power-on task join failed: {err}", board.id),
        }
    }

    match tokio::spawn({
        let board = board.clone();
        async move { execute_power_action_for_board(&board, PowerAction::Off).await }
    })
    .await
    {
        Ok(Ok(_)) => {}
        Ok(Err(err)) => log::warn!("session `{}` automatic power-off failed: {err}", board.id),
        Err(err) => log::warn!(
            "session `{}` automatic power-off task join failed: {err}",
            board.id
        ),
    }
}

async fn send_power_on_failure_and_close<S>(ws_sender: &mut S, message: &str)
where
    S: Sink<Message> + Unpin,
{
    let payload = serde_json::json!({
        "type": "error",
        "message": message,
    })
    .to_string();
    let _ = ws_sender.send(Message::Text(payload.into())).await;
    let _ = ws_sender
        .send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
        .await;
    let _ = ws_sender.send(Message::Close(None)).await;
}

async fn write_serial_payload(
    port: &mut tokio::io::WriteHalf<tokio_serial::SerialStream>,
    payload: &[u8],
) -> anyhow::Result<()> {
    port.write_all(payload).await?;
    port.flush().await?;
    Ok(())
}

#[cfg(test)]
mod tests {
    use std::{
        fs,
        pin::Pin,
        task::{Context, Poll},
        time::Duration,
    };

    use axum::extract::ws::Message;
    use futures_util::Sink;
    use tempfile::tempdir;

    use super::{ClientControlMessage, cleanup_power_link, send_power_on_failure_and_close};
    use crate::{
        config::{
            BoardConfig, BootConfig, CustomPowerManagement, PowerManagementConfig, PxeProfile,
        },
        power::PowerActionError,
    };

    #[derive(Default)]
    struct VecSink {
        messages: Vec<Message>,
    }

    impl Sink<Message> for VecSink {
        type Error = ();

        fn poll_ready(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
        ) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
            self.get_mut().messages.push(item);
            Ok(())
        }

        fn poll_flush(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
        ) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn poll_close(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
        ) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }
    }

    #[test]
    fn control_message_parses_close_type() {
        let message: ClientControlMessage = serde_json::from_str(r#"{"type":"close"}"#).unwrap();
        assert_eq!(message.kind, "close");
    }

    #[tokio::test]
    async fn cleanup_waits_for_power_on_task_before_power_off() {
        let dir = tempdir().unwrap();
        let output_path = dir.path().join("power.log");
        let board = BoardConfig {
            id: "demo".into(),
            board_type: "demo".into(),
            tags: vec![],
            serial: None,
            power_management: PowerManagementConfig::Custom(CustomPowerManagement {
                power_on_cmd: String::new(),
                power_off_cmd: format!("printf 'off\\n' >> {}", output_path.display()),
            }),
            boot: BootConfig::Pxe(PxeProfile::default()),
            notes: None,
            disabled: false,
        };

        let power_on_task = tokio::spawn(async move {
            tokio::time::sleep(Duration::from_millis(20)).await;
            fs::write(&output_path, "on\n").unwrap();
            Ok::<String, PowerActionError>("executed".into())
        });

        cleanup_power_link(&board, true, Some(power_on_task)).await;

        let content = fs::read_to_string(dir.path().join("power.log")).unwrap();
        assert_eq!(content, "on\noff\n");
    }

    #[tokio::test]
    async fn power_on_failure_sends_error_then_close_messages() {
        let mut sender = VecSink::default();
        send_power_on_failure_and_close(&mut sender, "automatic power-on failed").await;
        let mut messages = sender.messages.into_iter();
        let first = messages.next().unwrap();
        let second = messages.next().unwrap();
        let third = messages.next().unwrap();

        match first {
            Message::Text(text) => assert!(text.contains(r#""type":"error""#)),
            other => panic!("unexpected first message: {other:?}"),
        }
        match second {
            Message::Text(text) => assert_eq!(text, r#"{"type":"closed"}"#),
            other => panic!("unexpected second message: {other:?}"),
        }
        assert!(matches!(third, Message::Close(_)));
    }
}