hypermangle_py/
lib.rs

1#![feature(exclusive_wrapper)]
2
3use std::mem::replace;
4use std::ops::Deref;
5use std::ops::DerefMut;
6use std::sync::Arc;
7
8use axum::extract::ws::Message;
9use axum::extract::WebSocketUpgrade;
10use axum::response::Response;
11use pyo3::create_exception;
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use tokio::sync::Mutex;
15
16create_exception!(hypermangle_py, ClosedWebSocket, pyo3::exceptions::PyException);
17create_exception!(hypermangle_py, WebSocketError, pyo3::exceptions::PyException);
18create_exception!(hypermangle_py, NotYetAccepted, pyo3::exceptions::PyException);
19create_exception!(hypermangle_py, AlreadyAccepted, pyo3::exceptions::PyException);
20
21enum WebSocketInner {
22    Pending((WebSocketUpgrade, tokio::sync::oneshot::Sender<Response>)),
23    Accepting,
24    Accepted(axum::extract::ws::WebSocket),
25}
26
27#[pyclass(frozen)]
28pub struct WebSocket {
29    inner: Arc<Mutex<WebSocketInner>>,
30}
31
32#[pyclass(frozen)]
33struct WebSocketMessage {
34    msg: Message,
35}
36
37#[pymethods]
38impl WebSocketMessage {
39    fn as_string(&self) -> Option<&str> {
40        match &self.msg {
41            Message::Text(msg) => Some(msg),
42            _ => None,
43        }
44    }
45
46    fn as_bytes(&self) -> Option<&[u8]> {
47        match &self.msg {
48            Message::Binary(msg) => Some(msg),
49            _ => None,
50        }
51    }
52}
53
54#[pymethods]
55impl WebSocket {
56    fn accept(&self) -> PyResult<()> {
57        let mut lock = self.inner.clone().blocking_lock_owned();
58
59        if let WebSocketInner::Pending(_) = lock.deref() {
60            // Should be in this state
61        } else {
62            return Err(AlreadyAccepted::new_err(()));
63        }
64
65        let WebSocketInner::Pending((ws, sender)) =
66            replace(lock.deref_mut(), WebSocketInner::Accepting)
67        else {
68            unreachable!()
69        };
70
71        sender
72            .send(ws.on_upgrade(move |ws| async move {
73                *lock = WebSocketInner::Accepted(ws);
74            }))
75            .expect("WebSocket Response Receiver should not have been dropped yet");
76
77        Ok(())
78    }
79
80    fn recv_msg<'a>(&self, py: Python<'a>) -> PyResult<&'a PyAny> {
81        let inner = self.inner.clone();
82
83        pyo3_asyncio::tokio::future_into_py(py, async move {
84            let mut lock = inner.lock().await;
85            let WebSocketInner::Accepted(ws) = lock.deref_mut() else {
86                return Err(NotYetAccepted::new_err(()));
87            };
88            let Some(result) = ws.recv().await else {
89                return Err(ClosedWebSocket::new_err(()));
90            };
91
92            match result {
93                Ok(msg) => Ok(WebSocketMessage { msg }),
94                Err(e) => Err(WebSocketError::new_err(e.to_string())),
95            }
96        })
97    }
98
99    fn send_msg<'a>(&self, py: Python<'a>, msg: &'a PyAny) -> PyResult<&'a PyAny> {
100        let msg = if let Ok(msg) = msg.extract::<String>() {
101            Message::Text(msg)
102        } else if let Ok(msg) = msg.extract::<Vec<u8>>() {
103            Message::Binary(msg)
104        } else {
105            return Err(PyValueError::new_err(
106                "WebSockets can only send Strings or Bytes",
107            ));
108        };
109        let inner = self.inner.clone();
110        pyo3_asyncio::tokio::future_into_py(py, async move {
111            let mut lock = inner.lock().await;
112            let WebSocketInner::Accepted(ws) = lock.deref_mut() else {
113                return Err(NotYetAccepted::new_err(()));
114            };
115            ws.send(msg)
116                .await
117                .map_err(|e| WebSocketError::new_err(e.to_string()))
118        })
119    }
120}
121
122impl WebSocket {
123    pub fn new(ws: WebSocketUpgrade) -> (Self, tokio::sync::oneshot::Receiver<Response>) {
124        let (sender, receiver) = tokio::sync::oneshot::channel();
125        (
126            Self {
127                inner: Arc::new(Mutex::new(WebSocketInner::Pending((ws, sender)))),
128            },
129            receiver,
130        )
131    }
132}
133
134#[pymodule]
135fn hypermangle_py(py: Python<'_>, m: &PyModule) -> PyResult<()> {
136    m.add("ClosedWebSocket", py.get_type::<ClosedWebSocket>())?;
137    m.add("WebSocketError", py.get_type::<WebSocketError>())?;
138    m.add("NotYetAccepted", py.get_type::<NotYetAccepted>())?;
139    m.add("AlreadyAccepted", py.get_type::<AlreadyAccepted>())?;
140    m.add_class::<WebSocket>()?;
141    m.add_class::<WebSocketMessage>()?;
142    Ok(())
143}