1#![feature(exclusive_wrapper)]
2
3use std::mem::replace;
4use std::ops::Deref;
5use std::ops::DerefMut;
6use std::sync::Arc;
7
8use axum::extract::ws::Message;
9use axum::extract::WebSocketUpgrade;
10use axum::response::Response;
11use pyo3::create_exception;
12use pyo3::exceptions::PyValueError;
13use pyo3::prelude::*;
14use tokio::sync::Mutex;
15
16create_exception!(hypermangle_py, ClosedWebSocket, pyo3::exceptions::PyException);
17create_exception!(hypermangle_py, WebSocketError, pyo3::exceptions::PyException);
18create_exception!(hypermangle_py, NotYetAccepted, pyo3::exceptions::PyException);
19create_exception!(hypermangle_py, AlreadyAccepted, pyo3::exceptions::PyException);
20
21enum WebSocketInner {
22 Pending((WebSocketUpgrade, tokio::sync::oneshot::Sender<Response>)),
23 Accepting,
24 Accepted(axum::extract::ws::WebSocket),
25}
26
27#[pyclass(frozen)]
28pub struct WebSocket {
29 inner: Arc<Mutex<WebSocketInner>>,
30}
31
32#[pyclass(frozen)]
33struct WebSocketMessage {
34 msg: Message,
35}
36
37#[pymethods]
38impl WebSocketMessage {
39 fn as_string(&self) -> Option<&str> {
40 match &self.msg {
41 Message::Text(msg) => Some(msg),
42 _ => None,
43 }
44 }
45
46 fn as_bytes(&self) -> Option<&[u8]> {
47 match &self.msg {
48 Message::Binary(msg) => Some(msg),
49 _ => None,
50 }
51 }
52}
53
54#[pymethods]
55impl WebSocket {
56 fn accept(&self) -> PyResult<()> {
57 let mut lock = self.inner.clone().blocking_lock_owned();
58
59 if let WebSocketInner::Pending(_) = lock.deref() {
60 } else {
62 return Err(AlreadyAccepted::new_err(()));
63 }
64
65 let WebSocketInner::Pending((ws, sender)) =
66 replace(lock.deref_mut(), WebSocketInner::Accepting)
67 else {
68 unreachable!()
69 };
70
71 sender
72 .send(ws.on_upgrade(move |ws| async move {
73 *lock = WebSocketInner::Accepted(ws);
74 }))
75 .expect("WebSocket Response Receiver should not have been dropped yet");
76
77 Ok(())
78 }
79
80 fn recv_msg<'a>(&self, py: Python<'a>) -> PyResult<&'a PyAny> {
81 let inner = self.inner.clone();
82
83 pyo3_asyncio::tokio::future_into_py(py, async move {
84 let mut lock = inner.lock().await;
85 let WebSocketInner::Accepted(ws) = lock.deref_mut() else {
86 return Err(NotYetAccepted::new_err(()));
87 };
88 let Some(result) = ws.recv().await else {
89 return Err(ClosedWebSocket::new_err(()));
90 };
91
92 match result {
93 Ok(msg) => Ok(WebSocketMessage { msg }),
94 Err(e) => Err(WebSocketError::new_err(e.to_string())),
95 }
96 })
97 }
98
99 fn send_msg<'a>(&self, py: Python<'a>, msg: &'a PyAny) -> PyResult<&'a PyAny> {
100 let msg = if let Ok(msg) = msg.extract::<String>() {
101 Message::Text(msg)
102 } else if let Ok(msg) = msg.extract::<Vec<u8>>() {
103 Message::Binary(msg)
104 } else {
105 return Err(PyValueError::new_err(
106 "WebSockets can only send Strings or Bytes",
107 ));
108 };
109 let inner = self.inner.clone();
110 pyo3_asyncio::tokio::future_into_py(py, async move {
111 let mut lock = inner.lock().await;
112 let WebSocketInner::Accepted(ws) = lock.deref_mut() else {
113 return Err(NotYetAccepted::new_err(()));
114 };
115 ws.send(msg)
116 .await
117 .map_err(|e| WebSocketError::new_err(e.to_string()))
118 })
119 }
120}
121
122impl WebSocket {
123 pub fn new(ws: WebSocketUpgrade) -> (Self, tokio::sync::oneshot::Receiver<Response>) {
124 let (sender, receiver) = tokio::sync::oneshot::channel();
125 (
126 Self {
127 inner: Arc::new(Mutex::new(WebSocketInner::Pending((ws, sender)))),
128 },
129 receiver,
130 )
131 }
132}
133
134#[pymodule]
135fn hypermangle_py(py: Python<'_>, m: &PyModule) -> PyResult<()> {
136 m.add("ClosedWebSocket", py.get_type::<ClosedWebSocket>())?;
137 m.add("WebSocketError", py.get_type::<WebSocketError>())?;
138 m.add("NotYetAccepted", py.get_type::<NotYetAccepted>())?;
139 m.add("AlreadyAccepted", py.get_type::<AlreadyAccepted>())?;
140 m.add_class::<WebSocket>()?;
141 m.add_class::<WebSocketMessage>()?;
142 Ok(())
143}