use bytes::Bytes;
use futures::SinkExt;
use log::warn;
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use super::{
combined_codec::ConnectionCodec,
envelope::Envelope,
fragmentation_state::FragmentationState,
};
use crate::{
codec::FrameCodec,
fragment::{FragmentationConfig, FragmentationError},
message::EncodeWith,
serializer::Serializer,
};
pub(crate) struct FramePipeline {
fragmentation: Option<FragmentationState>,
out: Vec<Envelope>,
}
impl FramePipeline {
pub(crate) fn new(fragmentation: Option<FragmentationConfig>) -> Self {
Self {
fragmentation: fragmentation.map(FragmentationState::new),
out: Vec::new(),
}
}
pub(crate) fn process(&mut self, envelope: Envelope) -> io::Result<()> {
let id = envelope.id;
let correlation_id = envelope.correlation_id;
let frames = self.fragment_envelope(envelope).map_err(|err| {
warn!(
"failed to fragment outbound envelope: id={id}, \
correlation_id={correlation_id:?}, error={err:?}"
);
crate::metrics::inc_handler_errors();
io::Error::other(err)
})?;
for frame in frames {
self.push_frame(frame);
}
Ok(())
}
fn fragment_envelope(
&mut self,
envelope: Envelope,
) -> Result<Vec<Envelope>, FragmentationError> {
match self.fragmentation.as_mut() {
Some(state) => state.fragment(envelope),
None => Ok(vec![envelope]),
}
}
pub(crate) fn purge_expired(&mut self) {
if let Some(state) = self.fragmentation.as_mut() {
state.purge_expired();
}
}
pub(crate) fn drain_output(&mut self) -> Vec<Envelope> { std::mem::take(&mut self.out) }
pub(crate) fn fragmentation_mut(&mut self) -> Option<&mut FragmentationState> {
self.fragmentation.as_mut()
}
#[cfg(test)]
pub(crate) fn has_fragmentation(&self) -> bool { self.fragmentation.is_some() }
fn push_frame(&mut self, envelope: Envelope) {
self.out.push(envelope);
crate::metrics::inc_frames(crate::metrics::Direction::Outbound);
}
}
pub(super) async fn send_envelope<S, W, F>(
serializer: &S,
codec: &F,
framed: &mut Framed<W, ConnectionCodec<F>>,
envelope: &Envelope,
) -> io::Result<()>
where
S: Serializer + Send + Sync,
W: AsyncRead + AsyncWrite + Unpin,
F: FrameCodec,
Envelope: EncodeWith<S>,
{
let bytes = serializer.serialize(envelope).map_err(|e| {
let id = envelope.id;
let correlation_id = envelope.correlation_id;
warn!(
"failed to serialize outbound envelope: id={id}, correlation_id={correlation_id:?}, \
error={e:?}"
);
crate::metrics::inc_handler_errors();
io::Error::other(e)
})?;
let frame = codec.wrap_payload(Bytes::from(bytes));
framed.send(frame).await.map_err(|e| {
let id = envelope.id;
let correlation_id = envelope.correlation_id;
warn!(
"failed to send outbound frame: id={id}, correlation_id={correlation_id:?}, \
error={e:?}"
);
crate::metrics::inc_handler_errors();
io::Error::other(e)
})
}
pub(super) async fn flush_pipeline_output<S, W, F>(
serializer: &S,
codec: &F,
framed: &mut Framed<W, ConnectionCodec<F>>,
envelopes: &mut Vec<Envelope>,
) -> io::Result<()>
where
S: Serializer + Send + Sync,
W: AsyncRead + AsyncWrite + Unpin,
F: FrameCodec,
Envelope: EncodeWith<S>,
{
for envelope in envelopes.drain(..) {
send_envelope(serializer, codec, framed, &envelope).await?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use super::*;
#[fixture]
fn pipeline() -> FramePipeline {
let config = None;
FramePipeline::new(config)
}
#[rstest]
fn process_single_envelope_emits_one_frame(mut pipeline: FramePipeline) {
let env = Envelope::new(1, Some(42), vec![1, 2, 3]);
pipeline
.process(env)
.expect("processing should succeed without fragmentation");
let mut output = pipeline.drain_output();
assert_eq!(output.len(), 1);
let first = output
.pop()
.expect("pipeline should emit exactly one envelope");
assert_eq!(first.id, 1);
assert_eq!(first.correlation_id, Some(42));
assert_eq!(first.payload, vec![1, 2, 3]);
}
#[rstest]
fn drain_clears_buffer(mut pipeline: FramePipeline) {
pipeline
.process(Envelope::new(1, None, vec![]))
.expect("processing should succeed without fragmentation");
let first = pipeline.drain_output();
assert_eq!(first.len(), 1);
let second = pipeline.drain_output();
assert!(second.is_empty());
}
#[rstest]
fn pipeline_without_fragmentation(pipeline: FramePipeline) {
assert!(!pipeline.has_fragmentation());
}
}