ostool 0.15.0

A tool for operating system development
Documentation
use std::sync::{
    Arc,
    atomic::{AtomicBool, Ordering},
};

use anyhow::Context as _;
use futures::{SinkExt, StreamExt};
use tokio::{
    io::{AsyncReadExt, split},
    task::JoinHandle,
    time::timeout,
};
use tokio_tungstenite::tungstenite::Message;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};

use crate::board::terminal::{
    ServerControlAction, ServerControlMessage, classify_server_control_message,
};

pub type BoxedAsyncRead = Box<dyn futures::AsyncRead + Send + Unpin>;
pub type BoxedAsyncWrite = Box<dyn futures::AsyncWrite + Send + Unpin>;

pub struct SerialStreamTasks {
    read_task: JoinHandle<anyhow::Result<()>>,
    write_task: JoinHandle<anyhow::Result<()>>,
}

pub async fn connect_serial_stream(
    ws_url: reqwest::Url,
) -> anyhow::Result<(BoxedAsyncWrite, BoxedAsyncRead, SerialStreamTasks)> {
    let (stream, _) = tokio_tungstenite::connect_async(ws_url.as_str())
        .await
        .with_context(|| format!("failed to connect serial websocket {}", ws_url))?;
    let (mut ws_sink, mut ws_stream) = stream.split();
    let locally_closed = Arc::new(AtomicBool::new(false));

    let (runner_stream, bridge_stream) = tokio::io::duplex(64 * 1024);
    let (runner_rx, runner_tx) = split(runner_stream);
    let (mut bridge_rx, mut bridge_tx) = split(bridge_stream);

    let read_task = tokio::spawn({
        let locally_closed = locally_closed.clone();
        async move {
            while let Some(message) = ws_stream.next().await {
                match message.context("serial websocket read failed")? {
                    Message::Binary(bytes) => {
                        tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, &bytes)
                            .await
                            .context("failed to write serial websocket bytes")?;
                        tokio::io::AsyncWriteExt::flush(&mut bridge_tx)
                            .await
                            .context("failed to flush serial websocket bytes")?;
                    }
                    Message::Text(text) => {
                        if let Ok(control) = serde_json::from_str::<ServerControlMessage>(&text) {
                            match classify_server_control_message(
                                &control,
                                locally_closed.load(Ordering::SeqCst),
                            ) {
                                ServerControlAction::Ignore => continue,
                                ServerControlAction::Close => break,
                                ServerControlAction::Error(err) => return Err(err),
                                ServerControlAction::Forward => {}
                            }
                        }

                        tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, text.as_bytes())
                            .await
                            .context("failed to write text serial websocket payload")?;
                        tokio::io::AsyncWriteExt::flush(&mut bridge_tx)
                            .await
                            .context("failed to flush text serial websocket payload")?;
                    }
                    Message::Close(_) => {
                        if locally_closed.load(Ordering::SeqCst) {
                            break;
                        }
                        anyhow::bail!(
                            "ostool-server closed the serial websocket; the board session may have been released"
                        );
                    }
                    Message::Ping(_) => {}
                    Message::Pong(_) | Message::Frame(_) => {}
                }
            }

            Ok(())
        }
    });

    let write_task = tokio::spawn({
        let locally_closed = locally_closed.clone();
        async move {
            let mut buffer = [0u8; 4096];
            loop {
                let read = bridge_rx
                    .read(&mut buffer)
                    .await
                    .context("failed to read runner serial bytes")?;
                if read == 0 {
                    break;
                }
                ws_sink
                    .send(Message::Binary(buffer[..read].to_vec().into()))
                    .await
                    .context("serial websocket write failed")?;
            }

            locally_closed.store(true, Ordering::SeqCst);
            let _ = ws_sink
                .send(Message::Text(r#"{"type":"close"}"#.to_string().into()))
                .await;
            let _ = ws_sink.send(Message::Close(None)).await;
            Ok(())
        }
    });

    Ok((
        Box::new(runner_tx.compat_write()),
        Box::new(runner_rx.compat()),
        SerialStreamTasks {
            read_task,
            write_task,
        },
    ))
}

impl SerialStreamTasks {
    pub async fn shutdown(self) -> anyhow::Result<()> {
        let write_result = self.write_task.await;
        let read_result = self.read_task.await;

        if let Ok(Err(err)) = write_result {
            return Err(err);
        }
        if let Err(err) = write_result
            && !err.is_cancelled()
        {
            return Err(anyhow::anyhow!("serial websocket writer join error: {err}"));
        }
        if let Ok(Err(err)) = read_result {
            return Err(err);
        }
        if let Err(err) = read_result
            && !err.is_cancelled()
        {
            return Err(anyhow::anyhow!("serial websocket reader join error: {err}"));
        }

        Ok(())
    }

    pub async fn shutdown_with_timeout(self, duration: std::time::Duration) -> anyhow::Result<()> {
        let SerialStreamTasks {
            mut read_task,
            mut write_task,
        } = self;
        let shutdown = async {
            let write_result = (&mut write_task).await;
            let read_result = (&mut read_task).await;

            if let Ok(Err(err)) = write_result {
                return Err(err);
            }
            if let Err(err) = write_result
                && !err.is_cancelled()
            {
                return Err(anyhow::anyhow!("serial websocket writer join error: {err}"));
            }
            if let Ok(Err(err)) = read_result {
                return Err(err);
            }
            if let Err(err) = read_result
                && !err.is_cancelled()
            {
                return Err(anyhow::anyhow!("serial websocket reader join error: {err}"));
            }

            Ok(())
        };

        match timeout(duration, shutdown).await {
            Ok(result) => result,
            Err(_) => {
                write_task.abort();
                read_task.abort();
                let _ = write_task.await;
                let _ = read_task.await;
                Err(anyhow::anyhow!(
                    "serial websocket shutdown timed out after {}s",
                    duration.as_secs_f64()
                ))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    };

    use tokio::{sync::Notify, task::JoinHandle};

    use super::SerialStreamTasks;

    #[tokio::test]
    async fn shutdown_waits_for_writer_before_reader() {
        let reader_released = Arc::new(Notify::new());
        let writer_finished = Arc::new(AtomicBool::new(false));

        let read_task: JoinHandle<anyhow::Result<()>> = {
            let reader_released = reader_released.clone();
            let writer_finished = writer_finished.clone();
            tokio::spawn(async move {
                while !writer_finished.load(Ordering::SeqCst) {
                    reader_released.notified().await;
                }
                Ok(())
            })
        };

        let write_task: JoinHandle<anyhow::Result<()>> = {
            let reader_released = reader_released.clone();
            let writer_finished = writer_finished.clone();
            tokio::spawn(async move {
                writer_finished.store(true, Ordering::SeqCst);
                reader_released.notify_waiters();
                Ok(())
            })
        };

        SerialStreamTasks {
            read_task,
            write_task,
        }
        .shutdown()
        .await
        .unwrap();
    }
}