jupyter_websocket_client/
websocket.rs

1use anyhow::{Context, Result};
2use async_tungstenite::{tokio::ConnectStream, tungstenite::Message, WebSocketStream};
3use futures::{Sink, SinkExt as _, Stream, StreamExt};
4
5use jupyter_protocol::{JupyterConnection, JupyterMessage};
6use std::pin::Pin;
7use std::task::{Context as TaskContext, Poll};
8
9#[derive(Debug)]
10pub struct JupyterWebSocket {
11    pub inner: WebSocketStream<ConnectStream>,
12}
13
14impl Stream for JupyterWebSocket {
15    type Item = Result<JupyterMessage>;
16
17    fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
18        match self.inner.poll_next_unpin(cx) {
19            Poll::Ready(Some(Ok(msg))) => match msg {
20                Message::Text(text) => Poll::Ready(Some(
21                    serde_json::from_str(&text)
22                        .context("Failed to parse JSON")
23                        .and_then(|value| {
24                            JupyterMessage::from_value(value)
25                                .context("Failed to create JupyterMessage")
26                        }),
27                )),
28                _ => Poll::Ready(Some(Err(anyhow::anyhow!("Received non-text message")))),
29            },
30            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
31            Poll::Ready(None) => Poll::Ready(None),
32            Poll::Pending => Poll::Pending,
33        }
34    }
35}
36
37impl Sink<JupyterMessage> for JupyterWebSocket {
38    type Error = anyhow::Error;
39
40    fn poll_ready(
41        mut self: Pin<&mut Self>,
42        cx: &mut TaskContext<'_>,
43    ) -> Poll<Result<(), Self::Error>> {
44        self.inner.poll_ready_unpin(cx).map_err(Into::into)
45    }
46
47    fn start_send(mut self: Pin<&mut Self>, item: JupyterMessage) -> Result<(), Self::Error> {
48        let message_str =
49            serde_json::to_string(&item).context("Failed to serialize JupyterMessage")?;
50        self.inner
51            .start_send_unpin(Message::Text(message_str.into()))
52            .map_err(Into::into)
53    }
54
55    fn poll_flush(
56        mut self: Pin<&mut Self>,
57        cx: &mut TaskContext<'_>,
58    ) -> Poll<Result<(), Self::Error>> {
59        self.inner.poll_flush_unpin(cx).map_err(Into::into)
60    }
61
62    fn poll_close(
63        mut self: Pin<&mut Self>,
64        cx: &mut TaskContext<'_>,
65    ) -> Poll<Result<(), Self::Error>> {
66        self.inner.poll_close_unpin(cx).map_err(Into::into)
67    }
68}
69
70impl JupyterConnection for JupyterWebSocket {}
71
72pub type JupyterWebSocketReader = futures::stream::SplitStream<JupyterWebSocket>;
73pub type JupyterWebSocketWriter = futures::stream::SplitSink<JupyterWebSocket, JupyterMessage>;