use std::{
pin::Pin,
sync::{Arc, Mutex},
task::{ready, Poll},
};
use base64::Engine as _;
use bytes::BytesMut;
use pin_project::pin_project;
use rhai::{Dynamic, Engine, NativeCallContext};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::{debug, debug_span, trace, warn};
use crate::scenario_executor::{
logoverlay::render_content,
types::{Handle, StreamRead},
utils1::{ExtractHandleOrFail, RhResult},
};
use super::{
scenario::ScenarioAccess,
types::{
BufferFlag, BufferFlags, DatagramRead, DatagramSocket, DatagramWrite, PacketRead,
PacketReadResult, PacketWrite, StreamSocket, StreamWrite,
},
utils1::{HandleExt, SimpleErr},
};
#[pin_project]
struct ReadChunkLimiter {
#[pin]
inner: StreamRead,
limit: usize,
}
impl AsyncRead for ReadChunkLimiter {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
let this = self.project();
buf.initialize_unfilled();
let b = buf.initialized_mut();
let limit = b.len().min(*this.limit);
let b = &mut b[0..limit];
let mut rb = ReadBuf::new(b);
ready!(tokio::io::AsyncRead::poll_read(this.inner, cx, &mut rb))?;
let read_len = rb.filled().len();
buf.advance(read_len);
Poll::Ready(Ok(()))
}
}
fn read_chunk_limiter(
ctx: NativeCallContext,
x: Handle<StreamRead>,
limit: i64,
) -> RhResult<Handle<StreamRead>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "read_chunk_limiter");
let x = StreamRead {
reader: Box::pin(ReadChunkLimiter {
inner: x,
limit: limit as usize,
}),
prefix: BytesMut::new(),
};
debug!(wrapped=?x, "read_chunk_limiter");
Ok(x.wrap())
}
struct WriteChunkLimiter {
inner: StreamWrite,
limit: usize,
}
impl AsyncWrite for WriteChunkLimiter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
if buf.len() > this.limit {
buf = &buf[..this.limit];
}
AsyncWrite::poll_write(Pin::new(&mut this.inner.writer), cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
AsyncWrite::poll_flush(Pin::new(&mut this.inner.writer), cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
AsyncWrite::poll_shutdown(Pin::new(&mut this.inner.writer), cx)
}
}
fn write_chunk_limiter(
ctx: NativeCallContext,
x: Handle<StreamWrite>,
limit: i64,
) -> RhResult<Handle<StreamWrite>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "write_chunk_limiter");
let x = StreamWrite {
writer: Box::pin(WriteChunkLimiter {
inner: x,
limit: limit as usize,
}),
};
debug!(wrapped=?x, "write_chunk_limiter");
Ok(x.wrap())
}
fn null_stream_socket() -> Handle<StreamSocket> {
Some(StreamSocket {
read: None,
write: None,
close: None,
fd: None,
})
.wrap()
}
fn null_datagram_socket() -> Handle<DatagramSocket> {
Some(DatagramSocket {
read: None,
write: None,
close: None,
fd: None,
})
.wrap()
}
fn dummy_stream_socket() -> Handle<StreamSocket> {
Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: Default::default(),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap()
}
struct DummyPkt;
impl PacketRead for DummyPkt {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
Poll::Ready(Ok(PacketReadResult {
flags: BufferFlag::Eof.into(),
buffer_subset: 0..0,
}))
}
}
impl PacketWrite for DummyPkt {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut [u8],
_flags: super::types::BufferFlags,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn dummy_datagram_socket() -> Handle<DatagramSocket> {
Some(DatagramSocket {
read: Some(DatagramRead {
src: Box::pin(DummyPkt),
}),
write: Some(DatagramWrite {
snk: Box::pin(DummyPkt),
}),
close: None,
fd: None,
})
.wrap()
}
fn write_buffer(
ctx: NativeCallContext,
inner: Handle<StreamWrite>,
capacity: i64,
) -> RhResult<Handle<StreamWrite>> {
Ok(Some(StreamWrite {
writer: Box::pin(tokio::io::BufWriter::with_capacity(
capacity as usize,
ctx.lutbar(inner)?.writer,
)),
})
.wrap())
}
fn b64str(ctx: NativeCallContext, x: &str) -> RhResult<String> {
let Ok(buf) = base64::prelude::BASE64_STANDARD.decode(x) else {
return Err(ctx.err("Failed to base64-decode the argument"));
};
let Ok(s) = String::from_utf8(buf) else {
return Err(ctx.err("Base64-encoded content is not a valid UTF-8"));
};
Ok(s)
}
fn format_str(x: Dynamic) -> String {
if x.is_blob() {
let b = x.into_blob().unwrap();
format!("b{}", render_content(&b, false))
} else {
format!("{x:?}")
}
}
fn print_stdout(x: &str) {
print!("{x}");
}
fn print_stderr(ctx: NativeCallContext, x: &str) -> RhResult<()> {
let the_scenario = ctx.get_scenario()?;
let mut diago = the_scenario.diagnostic_output.lock().unwrap();
let _ = write!(diago, "{x}");
Ok(())
}
fn literal_socket(data: String) -> Handle<StreamSocket> {
Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: BytesMut::from(data.as_bytes()),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap()
}
fn literal_socket_base64(ctx: NativeCallContext, data: String) -> RhResult<Handle<StreamSocket>> {
let Ok(d) = base64::prelude::BASE64_STANDARD.decode(data) else {
return Err(ctx.err("Invalid base64 data"));
};
Ok(Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: BytesMut::from(&d[..]),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap())
}
#[pin_project]
pub struct ReadStreamChunks(#[pin] pub StreamRead);
impl PacketRead for ReadStreamChunks {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
let sr: Pin<&mut StreamRead> = self.project().0;
let mut rb = ReadBuf::new(buf);
ready!(tokio::io::AsyncRead::poll_read(sr, cx, &mut rb))?;
let new_len = rb.filled().len();
let flags = if new_len > 0 {
BufferFlags::default()
} else {
BufferFlag::Eof.into()
};
Poll::Ready(Ok(PacketReadResult {
flags,
buffer_subset: 0..new_len,
}))
}
}
fn read_stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamRead>,
) -> RhResult<Handle<DatagramRead>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "read_stream_chunks");
let x = DatagramRead {
src: Box::pin(ReadStreamChunks(x)),
};
debug!(wrapped=?x, "read_stream_chunks");
Ok(x.wrap())
}
#[pin_project]
pub struct WriteStreamChunks {
pub w: StreamWrite,
pub debt: usize,
}
impl PacketWrite for WriteStreamChunks {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: BufferFlags,
) -> Poll<std::io::Result<()>> {
let p = self.project();
let sw: &mut StreamWrite = p.w;
loop {
assert!(buf.len() >= *p.debt);
let buf_chunk = &buf[*p.debt..];
if buf_chunk.is_empty() {
if !flags.contains(BufferFlag::NonFinalChunk) {
ready!(tokio::io::AsyncWrite::poll_flush(sw.writer.as_mut(), cx))?;
}
if flags.contains(BufferFlag::Eof) {
ready!(tokio::io::AsyncWrite::poll_shutdown(sw.writer.as_mut(), cx))?;
}
*p.debt = 0;
break;
}
let n = ready!(tokio::io::AsyncWrite::poll_write(
sw.writer.as_mut(),
cx,
buf_chunk
))?;
*p.debt += n;
}
Poll::Ready(Ok(()))
}
}
fn write_stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamWrite>,
) -> RhResult<Handle<DatagramWrite>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "write_stream_chunks");
let x = DatagramWrite {
snk: Box::pin(WriteStreamChunks { w: x, debt: 0 }),
};
debug!(wrapped=?x, "write_stream_chunks");
Ok(x.wrap())
}
fn stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamSocket>,
) -> RhResult<Handle<DatagramSocket>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "stream_chunks");
if let StreamSocket {
read: Some(r),
write: Some(w),
close,
fd,
} = x
{
let write = DatagramWrite {
snk: Box::pin(WriteStreamChunks { w, debt: 0 }),
};
let read = DatagramRead {
src: Box::pin(ReadStreamChunks(r)),
};
let x = DatagramSocket {
read: Some(read),
write: Some(write),
close,
fd,
};
debug!(wrapped=?x, "stream_chunks");
Ok(x.wrap())
} else {
Err(ctx.err(""))
}
}
fn bytemirror_socket(opts: Dynamic) -> RhResult<Handle<StreamSocket>> {
let span = debug_span!("bytemirror_socket");
debug!(parent: &span, "node created");
#[derive(serde::Deserialize)]
struct Opts {
max_buf_size: usize,
}
let opts: Opts = rhai::serde::from_dynamic(&opts)?;
debug!(parent: &span, "options parsed");
let max_buf_size = opts.max_buf_size;
let (r, w) = tokio::io::simplex(max_buf_size);
let s = StreamSocket {
read: Some(StreamRead {
reader: Box::pin(r),
prefix: Default::default(),
}),
write: Some(StreamWrite {
writer: Box::pin(w),
}),
close: None,
fd: None,
};
debug!(parent: &span, ?s, "socket created");
Ok(Some(s).wrap())
}
#[derive(Debug, Default)]
struct PacketMirror {
buf: BytesMut,
flags: BufferFlags,
packet_ready: bool,
read_waker: Option<std::task::Waker>,
write_waker: Option<std::task::Waker>,
}
#[derive(Clone)]
pub struct PacketMirrorHandle(Arc<Mutex<PacketMirror>>);
impl PacketMirrorHandle {
pub fn new() -> PacketMirrorHandle {
PacketMirrorHandle(Arc::new(Mutex::new(PacketMirror::default())))
}
}
impl PacketRead for PacketMirrorHandle {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
let mut this = self.get_mut().0.lock().unwrap();
if this.packet_ready {
trace!("packet mirror's packet is ready");
let l = this.buf.len();
if l > buf.len() {
warn!("packet fragment is too large for packet mirror's reader buffer");
return Poll::Ready(Err(std::io::ErrorKind::InvalidInput.into()));
}
buf[0..l].copy_from_slice(&this.buf);
this.buf.clear();
this.packet_ready = false;
if let Some(w) = this.write_waker.take() {
w.wake();
}
Poll::Ready(Ok(PacketReadResult {
flags: this.flags,
buffer_subset: 0..l,
}))
} else {
trace!("packet mirror's packet is not ready");
let w = cx.waker().to_owned();
this.read_waker = Some(w);
Poll::Pending
}
}
}
impl PacketWrite for PacketMirrorHandle {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: BufferFlags,
) -> Poll<std::io::Result<()>> {
let mut this = self.get_mut().0.lock().unwrap();
if this.packet_ready {
trace!("packet mirror's packet slot is busy");
let w = cx.waker().to_owned();
this.write_waker = Some(w);
Poll::Pending
} else {
this.buf.extend_from_slice(buf);
this.flags = flags;
this.packet_ready = true;
if let Some(w) = this.read_waker.take() {
w.wake();
}
Poll::Ready(Ok(()))
}
}
}
fn packetmirror_socket(opts: Dynamic) -> RhResult<Handle<DatagramSocket>> {
let span = debug_span!("packetmirror_socket");
debug!(parent: &span, "node created");
#[derive(serde::Deserialize)]
struct Opts {}
let _opts: Opts = rhai::serde::from_dynamic(&opts)?;
debug!(parent: &span, "options parsed");
let r = PacketMirrorHandle::new();
let w = r.clone();
let s = DatagramSocket {
read: Some(DatagramRead { src: Box::pin(r) }),
write: Some(DatagramWrite { snk: Box::pin(w) }),
close: None,
fd: None,
};
debug!(parent: &span, ?s, "socket created");
Ok(Some(s).wrap())
}
pub fn register(engine: &mut Engine) {
engine.register_fn("read_chunk_limiter", read_chunk_limiter);
engine.register_fn("write_chunk_limiter", write_chunk_limiter);
engine.register_fn("null_stream_socket", null_stream_socket);
engine.register_fn("null_datagram_socket", null_datagram_socket);
engine.register_fn("dummy_stream_socket", dummy_stream_socket);
engine.register_fn("dummy_datagram_socket", dummy_datagram_socket);
engine.register_fn("write_buffer", write_buffer);
engine.register_fn("b64str", b64str);
engine.register_fn("str", format_str);
engine.register_fn("print_stderr", print_stderr);
engine.register_fn("print_stdout", print_stdout);
engine.register_fn("literal_socket", literal_socket);
engine.register_fn("literal_socket_base64", literal_socket_base64);
engine.register_fn("read_stream_chunks", read_stream_chunks);
engine.register_fn("write_stream_chunks", write_stream_chunks);
engine.register_fn("stream_chunks", stream_chunks);
engine.register_fn("bytemirror_socket", bytemirror_socket);
engine.register_fn("packetmirror_socket", packetmirror_socket);
}