use std::{
pin::Pin,
task::{ready, Poll},
};
use base64::Engine as _;
use bytes::BytesMut;
use pin_project::pin_project;
use rhai::{Dynamic, Engine, NativeCallContext};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::debug;
use crate::scenario_executor::{
logoverlay::render_content,
types::{Handle, StreamRead},
utils1::{ExtractHandleOrFail, RhResult},
};
use super::{
scenario::ScenarioAccess,
types::{
BufferFlag, BufferFlags, DatagramRead, DatagramSocket, DatagramWrite, PacketRead,
PacketReadResult, PacketWrite, StreamSocket, StreamWrite,
},
utils1::{HandleExt, SimpleErr},
};
#[pin_project]
struct ReadChunkLimiter {
#[pin]
inner: StreamRead,
limit: usize,
}
impl AsyncRead for ReadChunkLimiter {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
let this = self.project();
buf.initialize_unfilled();
let b = buf.initialized_mut();
let limit = b.len().min(*this.limit);
let b = &mut b[0..limit];
let mut rb = ReadBuf::new(b);
ready!(tokio::io::AsyncRead::poll_read(this.inner, cx, &mut rb))?;
let read_len = rb.filled().len();
buf.advance(read_len);
Poll::Ready(Ok(()))
}
}
fn read_chunk_limiter(
ctx: NativeCallContext,
x: Handle<StreamRead>,
limit: i64,
) -> RhResult<Handle<StreamRead>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "read_chunk_limiter");
let x = StreamRead {
reader: Box::pin(ReadChunkLimiter {
inner: x,
limit: limit as usize,
}),
prefix: BytesMut::new(),
};
debug!(wrapped=?x, "read_chunk_limiter");
Ok(x.wrap())
}
struct WriteChunkLimiter {
inner: StreamWrite,
limit: usize,
}
impl AsyncWrite for WriteChunkLimiter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
if buf.len() > this.limit {
buf = &buf[..this.limit];
}
AsyncWrite::poll_write(Pin::new(&mut this.inner.writer), cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
AsyncWrite::poll_flush(Pin::new(&mut this.inner.writer), cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
AsyncWrite::poll_shutdown(Pin::new(&mut this.inner.writer), cx)
}
}
fn write_chunk_limiter(
ctx: NativeCallContext,
x: Handle<StreamWrite>,
limit: i64,
) -> RhResult<Handle<StreamWrite>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "write_chunk_limiter");
let x = StreamWrite {
writer: Box::pin(WriteChunkLimiter {
inner: x,
limit: limit as usize,
}),
};
debug!(wrapped=?x, "write_chunk_limiter");
Ok(x.wrap())
}
fn null_stream_socket() -> Handle<StreamSocket> {
Some(StreamSocket {
read: None,
write: None,
close: None,
fd: None,
})
.wrap()
}
fn null_datagram_socket() -> Handle<DatagramSocket> {
Some(DatagramSocket {
read: None,
write: None,
close: None,
fd: None,
})
.wrap()
}
fn dummy_stream_socket() -> Handle<StreamSocket> {
Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: Default::default(),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap()
}
struct DummyPkt;
impl PacketRead for DummyPkt {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
Poll::Ready(Ok(PacketReadResult {
flags: BufferFlag::Eof.into(),
buffer_subset: 0..0,
}))
}
}
impl PacketWrite for DummyPkt {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut [u8],
_flags: super::types::BufferFlags,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn dummy_datagram_socket() -> Handle<DatagramSocket> {
Some(DatagramSocket {
read: Some(DatagramRead {
src: Box::pin(DummyPkt),
}),
write: Some(DatagramWrite {
snk: Box::pin(DummyPkt),
}),
close: None,
fd: None,
})
.wrap()
}
fn write_buffer(
ctx: NativeCallContext,
inner: Handle<StreamWrite>,
capacity: i64,
) -> RhResult<Handle<StreamWrite>> {
Ok(Some(StreamWrite {
writer: Box::pin(tokio::io::BufWriter::with_capacity(
capacity as usize,
ctx.lutbar(inner)?.writer,
)),
})
.wrap())
}
fn b64str(ctx: NativeCallContext, x: &str) -> RhResult<String> {
let Ok(buf) = base64::prelude::BASE64_STANDARD.decode(x) else {
return Err(ctx.err("Failed to base64-decode the argument"));
};
let Ok(s) = String::from_utf8(buf) else {
return Err(ctx.err("Base64-encoded content is not a valid UTF-8"));
};
Ok(s)
}
fn format_str(x: Dynamic) -> String {
if x.is_blob() {
let b = x.into_blob().unwrap();
format!("b{}", render_content(&b, false))
} else {
format!("{:?}", x)
}
}
fn print_stdout(x: &str) {
print!("{x}");
}
fn print_stderr(ctx: NativeCallContext, x: &str) -> RhResult<()> {
let the_scenario = ctx.get_scenario()?;
let mut diago = the_scenario.diagnostic_output.lock().unwrap();
let _ = write!(diago, "{x}");
Ok(())
}
fn literal_socket(data: String) -> Handle<StreamSocket> {
Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: BytesMut::from(data.as_bytes()),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap()
}
fn literal_socket_base64(ctx: NativeCallContext, data: String) -> RhResult<Handle<StreamSocket>> {
let Ok(d) = base64::prelude::BASE64_STANDARD.decode(data) else {
return Err(ctx.err("Invalid base64 data"));
};
Ok(Some(StreamSocket {
read: Some(StreamRead {
reader: Box::pin(tokio::io::empty()),
prefix: BytesMut::from(&d[..]),
}),
write: Some(StreamWrite {
writer: Box::pin(tokio::io::empty()),
}),
close: None,
fd: None,
})
.wrap())
}
#[pin_project]
pub struct ReadStreamChunks(#[pin] pub StreamRead);
impl PacketRead for ReadStreamChunks {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
let sr: Pin<&mut StreamRead> = self.project().0;
let mut rb = ReadBuf::new(buf);
ready!(tokio::io::AsyncRead::poll_read(sr, cx, &mut rb))?;
let new_len = rb.filled().len();
let flags = if new_len > 0 {
BufferFlags::default()
} else {
BufferFlag::Eof.into()
};
Poll::Ready(Ok(PacketReadResult {
flags,
buffer_subset: 0..new_len,
}))
}
}
fn read_stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamRead>,
) -> RhResult<Handle<DatagramRead>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "read_stream_chunks");
let x = DatagramRead {
src: Box::pin(ReadStreamChunks(x)),
};
debug!(wrapped=?x, "read_stream_chunks");
Ok(x.wrap())
}
#[pin_project]
pub struct WriteStreamChunks {
pub w: StreamWrite,
pub debt: usize,
}
impl PacketWrite for WriteStreamChunks {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: BufferFlags,
) -> Poll<std::io::Result<()>> {
let p = self.project();
let sw: &mut StreamWrite = p.w;
loop {
assert!(buf.len() >= *p.debt);
let buf_chunk = &buf[*p.debt..];
if buf_chunk.is_empty() {
if !flags.contains(BufferFlag::NonFinalChunk) {
ready!(tokio::io::AsyncWrite::poll_flush(sw.writer.as_mut(), cx))?;
}
if flags.contains(BufferFlag::Eof) {
ready!(tokio::io::AsyncWrite::poll_shutdown(sw.writer.as_mut(), cx))?;
}
*p.debt = 0;
break;
}
let n = ready!(tokio::io::AsyncWrite::poll_write(
sw.writer.as_mut(),
cx,
buf_chunk
))?;
*p.debt += n;
}
Poll::Ready(Ok(()))
}
}
fn write_stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamWrite>,
) -> RhResult<Handle<DatagramWrite>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "write_stream_chunks");
let x = DatagramWrite {
snk: Box::pin(WriteStreamChunks { w: x, debt: 0 }),
};
debug!(wrapped=?x, "write_stream_chunks");
Ok(x.wrap())
}
fn stream_chunks(
ctx: NativeCallContext,
x: Handle<StreamSocket>,
) -> RhResult<Handle<DatagramSocket>> {
let x = ctx.lutbar(x)?;
debug!(inner=?x, "stream_chunks");
if let StreamSocket {
read: Some(r),
write: Some(w),
close,
fd,
} = x
{
let write = DatagramWrite {
snk: Box::pin(WriteStreamChunks { w, debt: 0 }),
};
let read = DatagramRead {
src: Box::pin(ReadStreamChunks(r)),
};
let x = DatagramSocket {
read: Some(read),
write: Some(write),
close,
fd,
};
debug!(wrapped=?x, "stream_chunks");
Ok(x.wrap())
} else {
Err(ctx.err(""))
}
}
pub fn register(engine: &mut Engine) {
engine.register_fn("read_chunk_limiter", read_chunk_limiter);
engine.register_fn("write_chunk_limiter", write_chunk_limiter);
engine.register_fn("null_stream_socket", null_stream_socket);
engine.register_fn("null_datagram_socket", null_datagram_socket);
engine.register_fn("dummy_stream_socket", dummy_stream_socket);
engine.register_fn("dummy_datagram_socket", dummy_datagram_socket);
engine.register_fn("write_buffer", write_buffer);
engine.register_fn("b64str", b64str);
engine.register_fn("str", format_str);
engine.register_fn("print_stderr", print_stderr);
engine.register_fn("print_stdout", print_stdout);
engine.register_fn("literal_socket", literal_socket);
engine.register_fn("literal_socket_base64", literal_socket_base64);
engine.register_fn("read_stream_chunks", read_stream_chunks);
engine.register_fn("write_stream_chunks", write_stream_chunks);
engine.register_fn("stream_chunks", stream_chunks);
}