1use std::{
2 convert::TryFrom,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11 Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12 Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21 ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26 pub remote_address: Option<String>,
27 pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31 type Error = Error;
32 fn try_from(value: JsValue) -> Result<Self> {
33 let remote_address_value =
34 js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35 let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36 Ok(Self {
37 remote_address: remote_address_value.as_string(),
38 local_address: local_address_value.as_string(),
39 })
40 }
41}
42
43#[derive(Debug, Default)]
44enum Reading {
45 #[default]
46 None,
47 Pending(JsFuture, ReadableStreamDefaultReader),
48 Ready(Vec<u8>),
49}
50
51#[derive(Debug, Default)]
52enum Writing {
53 Pending(JsFuture, WritableStreamDefaultWriter, usize),
54 #[default]
55 None,
56}
57
58#[derive(Debug, Default)]
59enum Closing {
60 Pending(JsFuture),
61 #[default]
62 None,
63}
64
65#[derive(Debug)]
67pub struct Socket {
68 inner: worker_sys::Socket,
69 writable: WritableStream,
70 readable: ReadableStream,
71 write: Option<Writing>,
72 read: Option<Reading>,
73 close: Option<Closing>,
74}
75
76unsafe impl Send for Socket {}
78unsafe impl Sync for Socket {}
79
80impl Socket {
81 fn new(inner: worker_sys::Socket) -> Self {
82 let writable = inner.writable().unwrap();
83 let readable = inner.readable().unwrap();
84 Socket {
85 inner,
86 writable,
87 readable,
88 read: None,
89 write: None,
90 close: None,
91 }
92 }
93
94 pub(crate) fn from_inner(inner: worker_sys::Socket) -> Self {
95 Self::new(inner)
96 }
97
98 pub async fn close(&mut self) -> Result<()> {
100 JsFuture::from(self.inner.close()?).await?;
101 Ok(())
102 }
103
104 pub async fn closed(&self) -> Result<()> {
107 JsFuture::from(self.inner.closed()?).await?;
108 Ok(())
109 }
110
111 pub async fn opened(&self) -> Result<SocketInfo> {
112 let value = JsFuture::from(self.inner.opened()?).await?;
113 value.try_into()
114 }
115
116 pub fn start_tls(self) -> Socket {
122 let inner = self.inner.start_tls().unwrap();
123 Socket::new(inner)
124 }
125
126 pub fn builder() -> ConnectionBuilder {
127 ConnectionBuilder::default()
128 }
129
130 fn handle_write_future(
131 cx: &mut Context<'_>,
132 mut fut: JsFuture,
133 writer: WritableStreamDefaultWriter,
134 len: usize,
135 ) -> (Writing, Poll<IoResult<usize>>) {
136 match fut.poll_unpin(cx) {
137 Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
138 Poll::Ready(res) => {
139 writer.release_lock();
140 match res {
141 Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
142 Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
143 }
144 }
145 }
146 }
147}
148
149fn js_value_to_std_io_error(value: JsValue) -> IoError {
150 let s = if value.is_string() {
151 value.as_string().unwrap()
152 } else if let Some(value) = value.dyn_ref::<JsError>() {
153 value.to_string().into()
154 } else {
155 format!("Error interpreting JsError: {value:?}")
156 };
157 IoError::other(s)
158}
159impl AsyncRead for Socket {
160 fn poll_read(
161 mut self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &mut ReadBuf<'_>,
164 ) -> Poll<IoResult<()>> {
165 fn handle_future(
166 cx: &mut Context<'_>,
167 buf: &mut ReadBuf<'_>,
168 mut fut: JsFuture,
169 reader: ReadableStreamDefaultReader,
170 ) -> (Reading, Poll<IoResult<()>>) {
171 match fut.poll_unpin(cx) {
172 Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
173 Poll::Ready(res) => match res {
174 Ok(value) => {
175 reader.release_lock();
176 let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
177 Ok(value) => value.into(),
178 Err(error) => {
179 let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}");
180 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
181 }
182 };
183 if done.is_truthy() {
184 (Reading::None, Poll::Ready(Ok(())))
185 } else {
186 let arr: Uint8Array = match Reflect::get(
187 &value,
188 &JsValue::from("value"),
189 ) {
190 Ok(value) => value.into(),
191 Err(error) => {
192 let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}");
193 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
194 }
195 };
196 let data = arr.to_vec();
197 handle_data(buf, data)
198 }
199 }
200 Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
201 },
202 }
203 }
204
205 let (new_reading, poll) = match self.read.take().unwrap_or_default() {
206 Reading::None => {
207 let reader: ReadableStreamDefaultReader =
208 match self.readable.get_reader().dyn_into() {
209 Ok(reader) => reader,
210 Err(error) => {
211 let msg = format!(
212 "Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}"
213 );
214 return Poll::Ready(Err(IoError::other(msg)));
215 }
216 };
217
218 handle_future(cx, buf, JsFuture::from(reader.read()), reader)
219 }
220 Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
221 Reading::Ready(data) => handle_data(buf, data),
222 };
223 self.read = Some(new_reading);
224 poll
225 }
226}
227
228impl AsyncWrite for Socket {
229 fn poll_write(
230 mut self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 buf: &[u8],
233 ) -> Poll<IoResult<usize>> {
234 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
235 Writing::None => {
236 let obj = JsValue::from(Uint8Array::from(buf));
237 let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
238 Ok(writer) => writer,
239 Err(error) => {
240 let msg = format!("Could not retrieve Writer: {error:?}");
241 return Poll::Ready(Err(IoError::other(msg)));
242 }
243 };
244 Self::handle_write_future(
245 cx,
246 JsFuture::from(writer.write_with_chunk(&obj)),
247 writer,
248 buf.len(),
249 )
250 }
251 Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
252 };
253 self.write = Some(new_writing);
254 poll
255 }
256
257 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
258 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
260 Writing::Pending(fut, writer, len) => {
261 let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
262 (writing, poll.map(|res| res.map(|_| ())))
264 }
265 writing => (writing, Poll::Ready(Ok(()))),
266 };
267 self.write = Some(new_writing);
268 poll
269 }
270
271 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
272 fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
273 match fut.poll_unpin(cx) {
274 Poll::Pending => (Closing::Pending(fut), Poll::Pending),
275 Poll::Ready(res) => match res {
276 Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
277 Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
278 },
279 }
280 }
281 let (new_closing, poll) = match self.close.take().unwrap_or_default() {
282 Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
283 Closing::Pending(fut) => handle_future(cx, fut),
284 };
285 self.close = Some(new_closing);
286 poll
287 }
288}
289
290#[derive(Debug, Clone)]
292pub enum SecureTransport {
293 Off,
295 On,
297 StartTls,
300}
301
302#[derive(Debug, Clone)]
304pub struct SocketOptions {
305 pub secure_transport: SecureTransport,
307 pub allow_half_open: bool,
312}
313
314impl Default for SocketOptions {
315 fn default() -> Self {
316 SocketOptions {
317 secure_transport: SecureTransport::Off,
318 allow_half_open: false,
319 }
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct SocketAddress {
326 pub hostname: String,
328 pub port: u16,
330}
331
332#[derive(Default, Debug, Clone)]
333pub struct ConnectionBuilder {
334 options: SocketOptions,
335}
336
337impl ConnectionBuilder {
338 pub fn new() -> Self {
340 ConnectionBuilder {
341 options: SocketOptions::default(),
342 }
343 }
344
345 pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
348 self.options.allow_half_open = allow_half_open;
349 self
350 }
351
352 pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
354 self.options.secure_transport = secure_transport;
355 self
356 }
357
358 pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
360 let address: JsValue = js_object!(
361 "hostname" => JsObject::from(JsString::from(hostname.into())),
362 "port" => JsNumber::from(port)
363 )
364 .into();
365
366 let options = socket_options_to_js_value(&self.options);
367
368 let inner = worker_sys::connect(address, options)?;
369 Ok(Socket::new(inner))
370 }
371}
372
373pub(crate) fn secure_transport_label(secure_transport: &SecureTransport) -> &'static str {
374 match secure_transport {
375 SecureTransport::On => "on",
376 SecureTransport::Off => "off",
377 SecureTransport::StartTls => "starttls",
378 }
379}
380
381pub(crate) fn socket_options_to_js_value(options: &SocketOptions) -> JsValue {
382 js_object!(
383 "allowHalfOpen" => JsBoolean::from(options.allow_half_open),
384 "secureTransport" => JsString::from(secure_transport_label(&options.secure_transport))
385 )
386 .into()
387}
388
389fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
391 let idx = buf.remaining().min(data.len());
392 let store = data.split_off(idx);
393 buf.put_slice(&data);
394 if store.is_empty() {
395 (Reading::None, Poll::Ready(Ok(())))
396 } else {
397 (Reading::Ready(store), Poll::Ready(Ok(())))
398 }
399}
400
401#[cfg(feature = "tokio-postgres")]
402pub mod postgres_tls {
406 use super::Socket;
407 use futures_util::future::{ready, Ready};
408 use std::error::Error;
409 use std::fmt::{self, Display, Formatter};
410 use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
411
412 #[derive(Debug, Clone, Default)]
423 pub struct PassthroughTls;
424
425 #[derive(Debug)]
426 pub struct PassthroughTlsError;
429
430 impl Error for PassthroughTlsError {}
431
432 impl Display for PassthroughTlsError {
433 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
434 fmt.write_str("PassthroughTlsError")
435 }
436 }
437
438 impl TlsConnect<Socket> for PassthroughTls {
439 type Stream = Socket;
440 type Error = PassthroughTlsError;
441 type Future = Ready<Result<Socket, PassthroughTlsError>>;
442
443 fn connect(self, s: Self::Stream) -> Self::Future {
444 let tls = s.start_tls();
445 ready(Ok(tls))
446 }
447 }
448
449 impl TlsStream for Socket {
450 fn channel_binding(&self) -> ChannelBinding {
451 ChannelBinding::none()
452 }
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn secure_transport_labels_match_runtime_strings() {
462 assert_eq!(secure_transport_label(&SecureTransport::On), "on");
463 assert_eq!(secure_transport_label(&SecureTransport::Off), "off");
464 assert_eq!(secure_transport_label(&SecureTransport::StartTls), "starttls");
465 }
466
467 #[test]
468 fn test_handle_data() {
469 let mut arr = vec![0u8; 32];
470 let mut buf = ReadBuf::new(&mut arr);
471 let data = vec![1u8; 32];
472 let (reading, _) = handle_data(&mut buf, data);
473
474 assert!(matches!(reading, Reading::None));
475 assert_eq!(buf.remaining(), 0);
476 assert_eq!(buf.filled().len(), 32);
477 }
478
479 #[test]
480 fn test_handle_large_data() {
481 let mut arr = vec![0u8; 32];
482 let mut buf = ReadBuf::new(&mut arr);
483 let data = vec![1u8; 64];
484 let (reading, _) = handle_data(&mut buf, data);
485
486 assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
487 assert_eq!(buf.remaining(), 0);
488 assert_eq!(buf.filled().len(), 32);
489 }
490
491 #[test]
492 fn test_handle_small_data() {
493 let mut arr = vec![0u8; 32];
494 let mut buf = ReadBuf::new(&mut arr);
495 let data = vec![1u8; 16];
496 let (reading, _) = handle_data(&mut buf, data);
497
498 assert!(matches!(reading, Reading::None));
499 assert_eq!(buf.remaining(), 16);
500 assert_eq!(buf.filled().len(), 16);
501 }
502}