use bytes::Bytes;
use thiserror::Error;
use crate::{connection::send_buffer::SendBuffer, frame, VarInt};
#[derive(Debug)]
pub(super) struct Send {
pub(super) max_data: u64,
pub(super) state: SendState,
pub(super) pending: SendBuffer,
pub(super) priority: i32,
pub(super) fin_pending: bool,
pub(super) connection_blocked: bool,
pub(super) stop_reason: Option<VarInt>,
}
impl Send {
pub(super) fn new(max_data: VarInt) -> Self {
Self {
max_data: max_data.into(),
state: SendState::Ready,
pending: SendBuffer::new(),
priority: 0,
fin_pending: false,
connection_blocked: false,
stop_reason: None,
}
}
pub(super) fn is_reset(&self) -> bool {
matches!(self.state, SendState::ResetSent { .. })
}
pub(super) fn finish(&mut self) -> Result<(), FinishError> {
if let Some(error_code) = self.stop_reason {
Err(FinishError::Stopped(error_code))
} else if self.state == SendState::Ready {
self.state = SendState::DataSent {
finish_acked: false,
};
self.fin_pending = true;
Ok(())
} else {
Err(FinishError::UnknownStream)
}
}
pub(super) fn write<S: BytesSource>(
&mut self,
source: &mut S,
limit: u64,
) -> Result<Written, WriteError> {
if !self.is_writable() {
return Err(WriteError::UnknownStream);
}
if let Some(error_code) = self.stop_reason {
return Err(WriteError::Stopped(error_code));
}
let budget = self.max_data - self.pending.offset();
if budget == 0 {
return Err(WriteError::Blocked);
}
let mut limit = limit.min(budget) as usize;
let mut result = Written::default();
loop {
let (chunk, chunks_consumed) = source.pop_chunk(limit);
result.chunks += chunks_consumed;
result.bytes += chunk.len();
if chunk.is_empty() {
break;
}
limit -= chunk.len();
self.pending.write(chunk);
}
Ok(result)
}
pub(super) fn reset(&mut self) {
use SendState::*;
if let DataSent { .. } | Ready = self.state {
self.state = ResetSent;
}
}
pub(super) fn stop(&mut self, error_code: VarInt) {
self.stop_reason = Some(error_code);
}
pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool {
self.pending.ack(frame.offsets);
match self.state {
SendState::DataSent {
ref mut finish_acked,
} => {
*finish_acked |= frame.fin;
*finish_acked && self.pending.is_fully_acked()
}
_ => false,
}
}
pub(super) fn increase_max_data(&mut self, offset: u64) -> bool {
if offset <= self.max_data || self.state != SendState::Ready {
return false;
}
let was_blocked = self.pending.offset() == self.max_data;
self.max_data = offset;
was_blocked
}
pub(super) fn offset(&self) -> u64 {
self.pending.offset()
}
pub(super) fn is_pending(&self) -> bool {
self.pending.has_unsent_data() || self.fin_pending
}
pub(super) fn is_writable(&self) -> bool {
matches!(self.state, SendState::Ready)
}
}
pub struct BytesArray<'a> {
chunks: &'a mut [Bytes],
consumed: usize,
}
impl<'a> BytesArray<'a> {
pub fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
Self {
chunks,
consumed: 0,
}
}
}
impl<'a> BytesSource for BytesArray<'a> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let mut chunks_consumed = 0;
while self.consumed < self.chunks.len() {
let chunk = &mut self.chunks[self.consumed];
if chunk.len() <= limit {
let chunk = std::mem::take(chunk);
self.consumed += 1;
chunks_consumed += 1;
if chunk.is_empty() {
continue;
}
return (chunk, chunks_consumed);
} else if limit > 0 {
let chunk = chunk.split_to(limit);
return (chunk, chunks_consumed);
} else {
break;
}
}
(Bytes::new(), chunks_consumed)
}
}
pub struct ByteSlice<'a> {
data: &'a [u8],
}
impl<'a> ByteSlice<'a> {
pub fn from_slice(data: &'a [u8]) -> Self {
Self { data }
}
}
impl<'a> BytesSource for ByteSlice<'a> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let limit = limit.min(self.data.len());
if limit == 0 {
return (Bytes::new(), 0);
}
let chunk = Bytes::from(self.data[..limit].to_owned());
self.data = &self.data[chunk.len()..];
let chunks_consumed = if self.data.is_empty() { 1 } else { 0 };
(chunk, chunks_consumed)
}
}
pub trait BytesSource {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Written {
pub bytes: usize,
pub chunks: usize,
}
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum WriteError {
#[error("unable to accept further writes")]
Blocked,
#[error("stopped by peer: code {0}")]
Stopped(VarInt),
#[error("unknown stream")]
UnknownStream,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) enum SendState {
Ready,
DataSent { finish_acked: bool },
ResetSent,
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum FinishError {
#[error("stopped by peer: code {0}")]
Stopped(VarInt),
#[error("unknown stream")]
UnknownStream,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bytes_array() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut chunks = [
Bytes::from_static(b""),
Bytes::from_static(b"Hello "),
Bytes::from_static(b"Wo"),
Bytes::from_static(b""),
Bytes::from_static(b"r"),
Bytes::from_static(b"ld"),
Bytes::from_static(b""),
Bytes::from_static(b" 12345678"),
Bytes::from_static(b"9 ABCDE"),
Bytes::from_static(b"F"),
Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
];
let num_chunks = chunks.len();
let last_chunk_len = chunks[chunks.len() - 1].len();
let mut array = BytesArray::from_chunks(&mut chunks);
let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;
if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}
assert_eq!(&buf[..], &full[..limit]);
if limit == full.len() {
assert_eq!(chunks_consumed, num_chunks);
assert_eq!(chunks_consumed, chunks_popped + 3);
} else if limit > full.len() - last_chunk_len {
assert_eq!(chunks_consumed, num_chunks - 1);
assert_eq!(chunks_consumed, chunks_popped + 2);
}
}
}
#[test]
fn byte_slice() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut array = ByteSlice::from_slice(&full[..]);
let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;
if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}
assert_eq!(&buf[..], &full[..limit]);
if limit != 0 {
assert_eq!(chunks_popped, 1);
} else {
assert_eq!(chunks_popped, 0);
}
if limit == full.len() {
assert_eq!(chunks_consumed, 1);
} else {
assert_eq!(chunks_consumed, 0);
}
}
}
}