use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::{
ConnectionActor,
ConnectionChannels,
drain::{DrainContext, QueueKind},
multi_packet::MultiPacketTerminationReason,
state::ActorState,
};
use crate::{
app::{Packet, PacketParts},
hooks::ProtocolHooks,
push::{PushConfigError, PushQueues},
};
impl Packet for u8 {
fn id(&self) -> u32 { 0 }
fn into_parts(self) -> PacketParts { PacketParts::new(0, None, vec![self]) }
fn from_parts(parts: PacketParts) -> Self {
parts.into_payload().first().copied().unwrap_or_default()
}
}
impl Packet for Vec<u8> {
fn id(&self) -> u32 { 0 }
fn into_parts(self) -> PacketParts { PacketParts::new(0, None, self) }
fn from_parts(parts: PacketParts) -> Self { parts.into_payload() }
}
pub fn create_test_actor_with_hooks(
hooks: ProtocolHooks<u8, ()>,
) -> Result<ConnectionActor<u8, ()>, PushConfigError> {
let (queues, handle) = PushQueues::<u8>::builder()
.high_capacity(4)
.low_capacity(4)
.build()?;
Ok(ConnectionActor::with_hooks(
ConnectionChannels::new(queues, handle),
None,
CancellationToken::new(),
hooks,
))
}
pub struct ActorHarness {
actor: ConnectionActor<u8, ()>,
state: ActorState,
pub out: Vec<u8>,
}
impl ActorHarness {
pub fn new_with_state(
hooks: ProtocolHooks<u8, ()>,
has_response: bool,
has_multi_packet: bool,
) -> Result<Self, PushConfigError> {
let actor = create_test_actor_with_hooks(hooks)?;
Ok(Self {
actor,
state: ActorState::new(has_response, has_multi_packet),
out: Vec::new(),
})
}
pub fn new() -> Result<Self, PushConfigError> {
Self::new_with_state(ProtocolHooks::<u8, ()>::default(), false, false)
}
#[must_use]
pub fn snapshot(&self) -> ActorStateSnapshot {
ActorStateSnapshot {
is_active: self.state.is_active(),
is_shutting_down: self.state.is_shutting_down(),
is_done: self.state.is_done(),
total_sources: self.state.total_sources(),
closed_sources: self.state.closed_sources(),
}
}
pub fn set_low_queue(&mut self, queue: Option<mpsc::Receiver<u8>>) {
self.actor.set_low_queue(queue);
}
pub fn set_multi_queue(
&mut self,
queue: Option<mpsc::Receiver<u8>>,
) -> Result<(), crate::connection::ConnectionStateError> {
self.actor.set_multi_packet(queue)
}
#[must_use]
pub fn has_low_queue(&self) -> bool { self.actor.low_rx.is_some() }
#[must_use]
pub fn has_multi_queue(&self) -> bool { self.actor.active_output.is_multi_packet() }
pub fn process_multi_packet(&mut self, res: Option<u8>) {
self.actor.process_queue(
QueueKind::Multi,
res,
DrainContext {
out: &mut self.out,
state: &mut self.state,
},
);
}
pub fn handle_multi_packet_closed(&mut self) {
self.actor.handle_multi_packet_closed(
MultiPacketTerminationReason::Drained,
&mut self.state,
&mut self.out,
);
}
pub fn start_shutdown(&mut self) { self.actor.start_shutdown(&mut self.state); }
pub fn try_drain_low(&mut self) -> bool {
let state = &mut self.state;
let out = &mut self.out;
self.actor
.try_opportunistic_drain(QueueKind::Low, DrainContext { out, state })
}
pub fn try_drain_multi(&mut self) -> bool {
let state = &mut self.state;
let out = &mut self.out;
self.actor
.try_opportunistic_drain(QueueKind::Multi, DrainContext { out, state })
}
pub fn actor_mut(&mut self) -> &mut ConnectionActor<u8, ()> { &mut self.actor }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ActorStateSnapshot {
pub is_active: bool,
pub is_shutting_down: bool,
pub is_done: bool,
pub total_sources: usize,
pub closed_sources: usize,
}
pub struct ActorStateHarness {
state: ActorState,
}
impl ActorStateHarness {
#[must_use]
pub fn new(has_response: bool, has_multi_packet: bool) -> Self {
Self {
state: ActorState::new(has_response, has_multi_packet),
}
}
pub fn mark_closed(&mut self) { self.state.mark_closed(); }
#[must_use]
pub fn snapshot(&self) -> ActorStateSnapshot {
ActorStateSnapshot {
is_active: self.state.is_active(),
is_shutting_down: self.state.is_shutting_down(),
is_done: self.state.is_done(),
total_sources: self.state.total_sources(),
closed_sources: self.state.closed_sources(),
}
}
}
pub async fn poll_queue_next(rx: Option<&mut mpsc::Receiver<u8>>) -> Option<u8> {
ConnectionActor::<u8, ()>::poll_queue(rx).await
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use tokio::sync::mpsc;
use super::*;
type TestResult<T> = Result<T, Box<dyn std::error::Error>>;
#[fixture]
fn harness() -> TestResult<ActorHarness> {
ActorHarness::new().map_err(Into::into)
}
#[rstest]
#[case::default(false, false, false)]
#[case::install(true, false, true)]
#[case::clear(true, true, false)]
fn has_multi_queue_states(
#[case] install: bool,
#[case] clear: bool,
#[case] expected: bool,
harness: TestResult<ActorHarness>,
) -> TestResult<()> {
let mut harness = harness?;
if install {
let (_tx, rx) = mpsc::channel(1);
harness.set_multi_queue(Some(rx))?;
}
if clear {
harness.set_multi_queue(None)?;
}
if harness.has_multi_queue() != expected {
return Err("multi-packet queue state mismatch".into());
}
Ok(())
}
}