jupyter_websocket_client/
websocket.rs1use anyhow::{Context, Result};
2use async_tungstenite::{tokio::ConnectStream, tungstenite::Message, WebSocketStream};
3use futures::{Sink, SinkExt as _, Stream, StreamExt};
4
5use jupyter_protocol::{JupyterConnection, JupyterMessage};
6use std::pin::Pin;
7use std::task::{Context as TaskContext, Poll};
8
9#[derive(Debug)]
10pub struct JupyterWebSocket {
11 pub inner: WebSocketStream<ConnectStream>,
12}
13
14impl Stream for JupyterWebSocket {
15 type Item = Result<JupyterMessage>;
16
17 fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
18 match self.inner.poll_next_unpin(cx) {
19 Poll::Ready(Some(Ok(msg))) => match msg {
20 Message::Text(text) => Poll::Ready(Some(
21 serde_json::from_str(&text)
22 .context("Failed to parse JSON")
23 .and_then(|value| {
24 JupyterMessage::from_value(value)
25 .context("Failed to create JupyterMessage")
26 }),
27 )),
28 _ => Poll::Ready(Some(Err(anyhow::anyhow!("Received non-text message")))),
29 },
30 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
31 Poll::Ready(None) => Poll::Ready(None),
32 Poll::Pending => Poll::Pending,
33 }
34 }
35}
36
37impl Sink<JupyterMessage> for JupyterWebSocket {
38 type Error = anyhow::Error;
39
40 fn poll_ready(
41 mut self: Pin<&mut Self>,
42 cx: &mut TaskContext<'_>,
43 ) -> Poll<Result<(), Self::Error>> {
44 self.inner.poll_ready_unpin(cx).map_err(Into::into)
45 }
46
47 fn start_send(mut self: Pin<&mut Self>, item: JupyterMessage) -> Result<(), Self::Error> {
48 let message_str =
49 serde_json::to_string(&item).context("Failed to serialize JupyterMessage")?;
50 self.inner
51 .start_send_unpin(Message::Text(message_str.into()))
52 .map_err(Into::into)
53 }
54
55 fn poll_flush(
56 mut self: Pin<&mut Self>,
57 cx: &mut TaskContext<'_>,
58 ) -> Poll<Result<(), Self::Error>> {
59 self.inner.poll_flush_unpin(cx).map_err(Into::into)
60 }
61
62 fn poll_close(
63 mut self: Pin<&mut Self>,
64 cx: &mut TaskContext<'_>,
65 ) -> Poll<Result<(), Self::Error>> {
66 self.inner.poll_close_unpin(cx).map_err(Into::into)
67 }
68}
69
70impl JupyterConnection for JupyterWebSocket {}
71
72pub type JupyterWebSocketReader = futures::stream::SplitStream<JupyterWebSocket>;
73pub type JupyterWebSocketWriter = futures::stream::SplitSink<JupyterWebSocket, JupyterMessage>;