use std::{
pin::Pin,
task::{ready, Poll},
};
use futures::FutureExt;
use rhai::{Dynamic, Engine, NativeCallContext};
use tracing::{debug, debug_span, error, trace, warn};
use crate::scenario_executor::{
types::Hangup,
utils1::{ExtractHandleOrFail, HandleExt, SimpleErr},
};
use super::{
types::{
BufferFlag, BufferFlags, DatagramRead, DatagramSocket, DatagramWrite, Handle, PacketRead,
PacketReadResult, PacketWrite,
},
utils1::RhResult,
};
struct TeeWriterNode {
w: Option<DatagramWrite>,
a_write_completed: bool,
}
struct TeeWriter {
nodes: Vec<TeeWriterNode>,
fail_all_if_one_fails: bool,
writing_in_progress: bool,
}
impl PacketWrite for TeeWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
flags: BufferFlags,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if !this.writing_in_progress {
this.writing_in_progress = true;
for n in &mut this.nodes {
n.a_write_completed = false;
}
}
let mut ok_count = 0;
let mut pend_count = 0;
let mut err_count = 0;
let mut err = None;
for n in &mut this.nodes {
if n.a_write_completed {
continue;
}
let Some(ref mut w) = n.w else { continue };
match w.snk.as_mut().poll_write(cx, buf, flags) {
Poll::Ready(Ok(())) => {
ok_count += 1;
n.a_write_completed = true;
}
Poll::Ready(Err(e)) => {
err_count += 1;
if this.fail_all_if_one_fails {
return Poll::Ready(Err(e));
} else {
n.w = None;
err = Some(e);
}
}
Poll::Pending => {
pend_count += 1;
}
}
}
trace!("ok={ok_count}, pend={pend_count}, err={err_count}");
if pend_count > 0 {
Poll::Pending
} else if ok_count == 0 && err_count > 0 {
Poll::Ready(Err(err.unwrap()))
} else {
this.writing_in_progress = false;
Poll::Ready(Ok(()))
}
}
}
struct TeeReader {
nodes: Vec<Option<DatagramRead>>,
propagate_any_eof: bool,
err_on_orphaned_fragment: bool,
err_on_any_error: bool,
chosen_node: Option<usize>,
active_nodes_remains: usize,
}
impl PacketRead for TeeReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<PacketReadResult>> {
let this = self.get_mut();
if let Some(i) = this.chosen_node {
trace!("reading from chosen node {i}");
let Some(ref mut r) = this.nodes[i] else {
unreachable!()
};
match ready!(r.src.as_mut().poll_read(cx, buf)) {
Ok(prr) => {
if !prr.flags.contains(BufferFlag::NonFinalChunk) {
this.chosen_node = None;
}
if prr.flags.contains(BufferFlag::Eof)
&& !this.err_on_orphaned_fragment
&& !this.propagate_any_eof
{
warn!("Trimmed datagram coming from `tee:` overlay due to one of the nodes abruptly disconnecting");
this.chosen_node = None;
return Poll::Ready(Ok(PacketReadResult {
flags: Default::default(),
buffer_subset: 0..0,
}));
}
return Poll::Ready(Ok(prr));
}
Err(e) => {
if this.err_on_orphaned_fragment {
return Poll::Ready(Err(e));
} else {
warn!("Trimmed datagram coming from `tee:` overlay due to one of the nodes abruptly disconnecting: {e}");
this.chosen_node = None;
return Poll::Ready(Ok(PacketReadResult {
flags: Default::default(),
buffer_subset: 0..0,
}));
}
}
}
}
let mut ok_count = 0;
let mut frag_count = 0;
let mut eof_count = 0;
let mut pend_count = 0;
let mut err_count = 0;
let mut err_to_propagate = None;
let mut prr_to_propagate = None;
for (i, n) in this.nodes.iter_mut().enumerate() {
let annul;
let mut eof_occured = false;
{
let Some(r) = n.as_mut() else {
continue;
};
match r.src.as_mut().poll_read(cx, buf) {
Poll::Pending => {
pend_count += 1;
continue;
}
Poll::Ready(Ok(prr)) => {
prr_to_propagate = Some(prr.clone());
if prr.flags.contains(BufferFlag::NonFinalChunk) {
frag_count += 1;
this.chosen_node = Some(i);
break;
} else if prr.flags.contains(BufferFlag::Eof) {
eof_count += 1;
eof_occured = true;
annul = true;
} else {
ok_count += 1;
break;
}
}
Poll::Ready(Err(e)) => {
annul = true;
err_to_propagate = Some(e);
err_count += 1;
}
}
}
if annul {
*n = None;
this.active_nodes_remains -= 1;
}
if eof_occured && this.propagate_any_eof {
break;
}
}
trace!("ok={ok_count} frag={frag_count} eof={eof_count} pend={pend_count} err={err_count}");
#[allow(clippy::if_same_then_else)]
if this.err_on_any_error && err_count > 0 {
Poll::Ready(Err(err_to_propagate.unwrap()))
} else if this.propagate_any_eof && eof_count > 0 {
Poll::Ready(Ok(prr_to_propagate.unwrap()))
} else if ok_count > 0 || frag_count > 0 {
Poll::Ready(Ok(prr_to_propagate.unwrap()))
} else if eof_count > 0 && this.active_nodes_remains == 0 {
Poll::Ready(Ok(prr_to_propagate.unwrap()))
} else if err_count > 0 && this.active_nodes_remains == 0 {
Poll::Ready(Err(err_to_propagate.unwrap()))
} else if pend_count > 0 {
Poll::Pending
} else {
error!("strange tee's read state: ok={ok_count} frag={frag_count} eof={eof_count} pend={pend_count} err={err_count}");
Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
}
}
}
fn tee(
ctx: NativeCallContext,
opts: Dynamic,
sockets: Vec<Dynamic>,
) -> RhResult<Handle<DatagramSocket>> {
let span = debug_span!("tee");
#[derive(serde::Deserialize)]
struct Opts {
#[serde(default)]
write_fail_all_if_one_fails: bool,
#[serde(default)]
read_fail_all_if_one_fails: bool,
#[serde(default)]
propagate_eofs: bool,
#[serde(default)]
tolerate_torn_msgs: bool,
#[serde(default)]
use_hangups: bool,
#[serde(default)]
use_first_hangup: bool,
}
let opts: Opts = rhai::serde::from_dynamic(&opts)?;
let mut ss: Vec<DatagramSocket> = Vec::with_capacity(sockets.len());
for (i, socket) in sockets.into_iter().enumerate() {
let Some(socket): Option<Handle<DatagramSocket>> = socket.try_cast() else {
return Err(ctx.err("Non-datagram-socket element in the `tee`'s array"));
};
debug!(i, "extracted socket");
let s = ctx.lutbar(socket)?;
ss.push(s);
}
if opts.use_hangups && opts.use_first_hangup {
return Err(ctx.err("use_hangups and use_first_hangup are incompatible"));
}
if ss.is_empty() {
return Err(ctx.err("Specify at least one socket"));
}
debug!(parent: &span, "options parsed");
let mut close: Option<Hangup> = None;
if opts.use_hangups {
let mut hgtoks: Vec<Hangup> = vec![];
for s in &mut ss {
if let Some(hg) = s.close.take() {
hgtoks.push(hg);
}
}
close = if hgtoks.is_empty() {
None
} else if hgtoks.len() == 1 {
Some(hgtoks.drain(..).next().unwrap())
} else {
Some(Box::pin(futures::future::select_all(hgtoks).map(|_| ())))
};
}
if opts.use_first_hangup {
close = ss[0].close.take();
}
let mut write_nodes = Vec::with_capacity(ss.len());
let mut read_nodes = Vec::with_capacity(ss.len());
for s in ss {
if let Some(r) = s.read {
read_nodes.push(Some(r));
}
if let Some(w) = s.write {
write_nodes.push(TeeWriterNode {
w: Some(w),
a_write_completed: false,
});
}
}
let write = DatagramWrite {
snk: Box::pin(TeeWriter {
nodes: write_nodes,
fail_all_if_one_fails: opts.write_fail_all_if_one_fails,
writing_in_progress: false,
}),
};
let active_nodes_remains = read_nodes.len();
let read = DatagramRead {
src: Box::pin(TeeReader {
nodes: read_nodes,
propagate_any_eof: opts.propagate_eofs,
err_on_orphaned_fragment: !opts.tolerate_torn_msgs,
err_on_any_error: opts.read_fail_all_if_one_fails,
chosen_node: None,
active_nodes_remains,
}),
};
let s = DatagramSocket {
read: Some(read),
write: Some(write),
close,
fd: None,
};
debug!(parent: &span, ?s, "wrapped");
Ok(Some(s).wrap())
}
pub fn register(engine: &mut Engine) {
engine.register_fn("tee", tee);
}