ostool 0.15.0

A tool for operating system development
Documentation
use std::sync::{
    Arc,
    atomic::{AtomicBool, Ordering},
};

use anyhow::{Context as _, anyhow};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;

use crate::sterm::{AsyncTerminal, TerminalConfig};

#[derive(Debug, Deserialize)]
pub(crate) struct ServerControlMessage {
    #[serde(rename = "type")]
    pub(crate) kind: String,
    pub(crate) message: Option<String>,
}

pub(crate) enum ServerControlAction {
    Ignore,
    Close,
    Error(anyhow::Error),
    Forward,
}

pub(crate) fn classify_server_control_message(
    control: &ServerControlMessage,
    locally_closed: bool,
) -> ServerControlAction {
    match control.kind.as_str() {
        "opened" => ServerControlAction::Ignore,
        "closed" => {
            if locally_closed {
                ServerControlAction::Close
            } else {
                ServerControlAction::Error(anyhow!(
                    "ostool-server closed the serial websocket; the board session may have been released"
                ))
            }
        }
        "error" => {
            let message = control
                .message
                .clone()
                .unwrap_or_else(|| "serial websocket error".to_string());
            ServerControlAction::Error(anyhow!("ostool-server serial websocket error: {message}"))
        }
        _ => ServerControlAction::Forward,
    }
}

pub async fn run_serial_terminal(ws_url: reqwest::Url) -> anyhow::Result<()> {
    let (stream, _) = tokio_tungstenite::connect_async(ws_url.as_str())
        .await
        .with_context(|| format!("failed to connect serial websocket {}", ws_url))?;
    let (mut sink, mut stream) = stream.split();
    let locally_closed = Arc::new(AtomicBool::new(false));

    let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::<Vec<u8>>();
    let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::<Vec<u8>>();

    let read_task = tokio::spawn({
        let locally_closed = locally_closed.clone();
        async move {
            while let Some(message) = stream.next().await {
                match message.context("serial websocket read failed")? {
                    Message::Binary(bytes) => {
                        if inbound_tx.send(bytes.to_vec()).is_err() {
                            break;
                        }
                    }
                    Message::Text(text) => {
                        if let Ok(control) = serde_json::from_str::<ServerControlMessage>(&text) {
                            match classify_server_control_message(
                                &control,
                                locally_closed.load(Ordering::SeqCst),
                            ) {
                                ServerControlAction::Ignore => continue,
                                ServerControlAction::Close => break,
                                ServerControlAction::Error(err) => {
                                    let _ = inbound_tx
                                        .send(format!("\n[ostool-server] {err}\n").into_bytes());
                                    return Err(err);
                                }
                                ServerControlAction::Forward => {}
                            }
                        }
                        if inbound_tx.send(text.bytes().collect()).is_err() {
                            break;
                        }
                    }
                    Message::Close(_) => {
                        if locally_closed.load(Ordering::SeqCst) {
                            break;
                        }
                        return Err(anyhow!(
                            "ostool-server closed the serial websocket; the board session may have been released"
                        ));
                    }
                    Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {}
                }
            }

            Ok::<(), anyhow::Error>(())
        }
    });

    let write_task = tokio::spawn({
        let locally_closed = locally_closed.clone();
        async move {
            while let Some(bytes) = outbound_rx.recv().await {
                sink.send(Message::Binary(bytes.into()))
                    .await
                    .context("serial websocket write failed")?;
            }

            locally_closed.store(true, Ordering::SeqCst);
            let _ = sink
                .send(Message::Text(r#"{"type":"close"}"#.to_string().into()))
                .await;
            let _ = sink.send(Message::Close(None)).await;
            Ok::<(), anyhow::Error>(())
        }
    });

    let terminal = AsyncTerminal::new(TerminalConfig {
        intercept_exit_sequence: true,
        timeout: None,
        timeout_label: "remote serial terminal".to_string(),
    });
    let run_result = terminal
        .run(inbound_rx, outbound_tx, |_handle, _byte| {})
        .await;

    let mut write_task = write_task;
    let write_result =
        tokio::time::timeout(std::time::Duration::from_secs(1), &mut write_task).await;
    let mut read_task = read_task;
    let read_result =
        tokio::time::timeout(std::time::Duration::from_millis(300), &mut read_task).await;

    let write_error = match write_result {
        Ok(Ok(Ok(()))) => None,
        Ok(Ok(Err(err))) => Some(err),
        Ok(Err(err)) if !err.is_cancelled() => {
            Some(anyhow!("serial websocket writer join error: {err}"))
        }
        Ok(Err(_)) => None,
        Err(_) => {
            write_task.abort();
            let _ = write_task.await;
            Some(anyhow!("serial websocket writer shutdown timed out"))
        }
    };
    let read_error = match read_result {
        Ok(Ok(Ok(()))) => None,
        Ok(Ok(Err(err))) => Some(err),
        Ok(Err(err)) if !err.is_cancelled() => {
            Some(anyhow!("serial websocket reader join error: {err}"))
        }
        Ok(Err(_)) => None,
        Err(_) => {
            read_task.abort();
            let _ = read_task.await;
            Some(anyhow!("serial websocket reader shutdown timed out"))
        }
    };

    if let Some(err) = write_error.or(read_error) {
        if run_result.is_ok() {
            return Err(err);
        }
        log::warn!("remote serial terminal shutdown failed: {err:#}");
    }

    run_result
}

#[cfg(test)]
mod tests {
    use super::{ServerControlAction, ServerControlMessage, classify_server_control_message};

    #[test]
    fn parse_server_control_message() {
        let opened: ServerControlMessage = serde_json::from_str(r#"{"type":"opened"}"#).unwrap();
        assert_eq!(opened.kind, "opened");
    }

    #[test]
    fn parse_server_error_control_message() {
        let error: ServerControlMessage =
            serde_json::from_str(r#"{"type":"error","message":"power failed"}"#).unwrap();
        assert_eq!(error.kind, "error");
        assert_eq!(error.message.as_deref(), Some("power failed"));
    }

    #[test]
    fn closed_control_message_becomes_error_when_not_locally_closed() {
        let control: ServerControlMessage = serde_json::from_str(r#"{"type":"closed"}"#).unwrap();
        match classify_server_control_message(&control, false) {
            ServerControlAction::Error(err) => {
                assert!(err.to_string().contains("may have been released"));
            }
            _ => panic!("expected error action"),
        }
    }

    #[test]
    fn closed_control_message_is_normal_when_locally_closed() {
        let control: ServerControlMessage = serde_json::from_str(r#"{"type":"closed"}"#).unwrap();
        assert!(matches!(
            classify_server_control_message(&control, true),
            ServerControlAction::Close
        ));
    }
}