use std::{
collections::VecDeque,
io::{BufRead, BufReader, Cursor, Read, Write},
};
use bytes::{Buf, BytesMut};
use tracing::warn;
use crate::{
pdu::{LARGE_PDU_SIZE, PDU_HEADER_SIZE, PDV_HEADER_SIZE},
read_pdu, Pdu,
};
const PDU_PDV_HEADER_SIZE: usize = (PDU_HEADER_SIZE + PDV_HEADER_SIZE) as usize;
fn setup_pdata_header(buffer: &mut [u8], is_last: bool) {
let data_len = (buffer.len() - PDU_PDV_HEADER_SIZE) as u32;
let pdu_len = data_len + 4 + 2;
let pdu_len_bytes = pdu_len.to_be_bytes();
buffer[2] = pdu_len_bytes[0];
buffer[3] = pdu_len_bytes[1];
buffer[4] = pdu_len_bytes[2];
buffer[5] = pdu_len_bytes[3];
let pdv_data_len = data_len + 2;
let data_len_bytes = pdv_data_len.to_be_bytes();
buffer[6] = data_len_bytes[0];
buffer[7] = data_len_bytes[1];
buffer[8] = data_len_bytes[2];
buffer[9] = data_len_bytes[3];
buffer[11] = if is_last { 0x02 } else { 0x00 };
}
#[must_use]
pub struct PDataWriter<W: Write> {
buffer: Vec<u8>,
stream: W,
max_data_len: u32,
}
impl<W> PDataWriter<W>
where
W: Write,
{
pub(crate) fn new(stream: W, presentation_context_id: u8, max_pdu_length: u32) -> Self {
let max_data_length = calculate_max_data_len_single(max_pdu_length);
let mut buffer =
Vec::with_capacity((max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize);
buffer.extend([
0x04,
0x00,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
presentation_context_id,
0xFF,
]);
PDataWriter {
stream,
max_data_len: max_data_length,
buffer,
}
}
pub fn finish(mut self) -> std::io::Result<()> {
self.finish_impl()?;
Ok(())
}
fn finish_impl(&mut self) -> std::io::Result<()> {
if !self.buffer.is_empty() {
setup_pdata_header(&mut self.buffer, true);
self.stream.write_all(&self.buffer[..])?;
self.buffer.clear();
}
Ok(())
}
fn dispatch_pdu(&mut self) -> std::io::Result<()> {
debug_assert!(self.buffer.len() >= PDU_PDV_HEADER_SIZE);
setup_pdata_header(&mut self.buffer, false);
self.stream.write_all(&self.buffer)?;
self.buffer.truncate(PDU_PDV_HEADER_SIZE);
Ok(())
}
}
impl<W> Write for PDataWriter<W>
where
W: Write,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let total_len = self.max_data_len as usize + PDU_PDV_HEADER_SIZE;
if self.buffer.len() + buf.len() <= total_len {
self.buffer.extend(buf);
Ok(buf.len())
} else {
let buf = &buf[..total_len - self.buffer.len()];
self.buffer.extend(buf);
debug_assert_eq!(self.buffer.len(), total_len);
self.dispatch_pdu()?;
Ok(buf.len())
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<W> Drop for PDataWriter<W>
where
W: Write,
{
fn drop(&mut self) {
let _ = self.finish_impl();
}
}
#[must_use]
pub struct PDataReader<'a, R> {
buffer: VecDeque<u8>,
stream: R,
presentation_context_id: Option<u8>,
max_data_length: u32,
last_pdu: bool,
read_buffer: &'a mut BytesMut,
}
impl<'a, R> PDataReader<'a, R> {
pub fn new(stream: R, max_data_length: u32, remaining: &'a mut BytesMut) -> Self {
PDataReader {
buffer: VecDeque::with_capacity(max_data_length.min(LARGE_PDU_SIZE) as usize),
stream,
presentation_context_id: None,
max_data_length,
last_pdu: false,
read_buffer: remaining,
}
}
pub fn stop_receiving(&mut self) -> std::io::Result<()> {
self.last_pdu = true;
Ok(())
}
}
impl<R> Read for PDataReader<'_, R>
where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.buffer.is_empty() {
if self.last_pdu {
return Ok(0);
}
let mut reader = BufReader::new(&mut self.stream);
let msg = loop {
let mut buf = Cursor::new(&self.read_buffer[..]);
match read_pdu(&mut buf, self.max_data_length, false)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
{
Some(pdu) => {
self.read_buffer.advance(buf.position() as usize);
break pdu;
}
None => {
buf.set_position(0)
}
}
let recv = reader.fill_buf()?.to_vec();
reader.consume(recv.len());
self.read_buffer.extend_from_slice(&recv);
if recv.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Connection closed by peer",
));
}
};
match msg {
Pdu::PData { data } => {
for pdata_value in data {
self.presentation_context_id = match self.presentation_context_id {
None => Some(pdata_value.presentation_context_id),
Some(cid) if cid == pdata_value.presentation_context_id => Some(cid),
Some(cid) => {
warn!("Received PData value of presentation context {}, but should be {}", pdata_value.presentation_context_id, cid);
Some(cid)
}
};
self.buffer.extend(pdata_value.data);
self.last_pdu = pdata_value.is_last;
}
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected PDU type",
))
}
}
}
Read::read(&mut self.buffer, buf)
}
}
#[inline]
fn calculate_max_data_len_single(pdu_len: u32) -> u32 {
pdu_len - PDV_HEADER_SIZE
}
#[cfg(feature = "async")]
pub mod non_blocking {
use std::{
io::Cursor,
pin::Pin,
task::{ready, Context, Poll},
};
use bytes::{Buf, BufMut};
use tokio::io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf,
};
use tracing::warn;
use crate::{
pdu::{PDU_HEADER_SIZE, PDV_HEADER_SIZE},
read_pdu, Pdu,
};
pub use super::PDataReader;
use super::{calculate_max_data_len_single, setup_pdata_header};
const PDU_PDV_HEADER_SIZE: usize = (PDU_HEADER_SIZE + PDV_HEADER_SIZE) as usize;
enum WriteState {
Ready,
Writing(usize),
}
#[must_use]
pub struct AsyncPDataWriter<W: AsyncWrite + Unpin> {
buffer: Vec<u8>,
stream: W,
max_data_len: u32,
state: WriteState,
}
#[cfg(feature = "async")]
impl<W> AsyncPDataWriter<W>
where
W: AsyncWrite + Unpin,
{
pub(crate) fn new(stream: W, presentation_context_id: u8, max_pdu_length: u32) -> Self {
use crate::pdu::LARGE_PDU_SIZE;
let max_data_length = calculate_max_data_len_single(max_pdu_length);
let mut buffer =
Vec::with_capacity((max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize);
buffer.extend([
0x04,
0x00,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
0xFF,
presentation_context_id,
0xFF,
]);
AsyncPDataWriter {
stream,
max_data_len: max_data_length,
buffer,
state: WriteState::Ready,
}
}
pub async fn finish(mut self) -> std::io::Result<()> {
self.finish_impl().await?;
Ok(())
}
async fn finish_impl(&mut self) -> std::io::Result<()> {
if !self.buffer.is_empty() {
setup_pdata_header(&mut self.buffer, true);
if let Err(e) = self.stream.write_all(&self.buffer[..]).await {
println!("Error: {e:?}");
}
self.buffer.clear();
}
Ok(())
}
}
#[cfg(feature = "async")]
impl<W> AsyncWrite for AsyncPDataWriter<W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.state {
WriteState::Ready => {
let total_len = self.max_data_len as usize + PDU_PDV_HEADER_SIZE;
if self.buffer.len() + buf.len() <= total_len {
self.buffer.extend(buf);
Poll::Ready(Ok(buf.len()))
} else {
let slice = &buf[..total_len - self.buffer.len()];
self.buffer.extend(slice);
debug_assert_eq!(self.buffer.len(), total_len);
setup_pdata_header(&mut self.buffer, false);
let this = self.get_mut();
match Pin::new(&mut this.stream).poll_write(cx, &this.buffer) {
Poll::Ready(Ok(n)) => {
if n == this.buffer.len() {
this.buffer.truncate(PDU_PDV_HEADER_SIZE);
Poll::Ready(Ok(slice.len()))
} else {
this.state = WriteState::Writing(n);
Poll::Pending
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
this.state = WriteState::Writing(0);
Poll::Pending
}
}
}
}
WriteState::Writing(pos) => {
let buflen = self.buffer.len();
let this = self.get_mut();
match Pin::new(&mut this.stream).poll_write(cx, &this.buffer[pos..]) {
Poll::Ready(Ok(n)) => {
if (n + pos) == this.buffer.len() {
this.buffer.truncate(PDU_PDV_HEADER_SIZE);
this.state = WriteState::Ready;
Poll::Ready(Ok(buflen - PDU_PDV_HEADER_SIZE))
} else {
this.state = WriteState::Writing(n + pos);
Poll::Pending
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
impl<W> Drop for AsyncPDataWriter<W>
where
W: AsyncWrite + Unpin,
{
fn drop(&mut self) {
tokio::task::block_in_place(move || {
tokio::runtime::Handle::current().block_on(async move {
let _ = self.finish_impl().await;
})
})
}
}
impl<R> AsyncRead for PDataReader<'_, R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
if self.buffer.is_empty() {
if self.last_pdu {
return Poll::Ready(Ok(()));
}
let Self {
ref mut stream,
ref mut read_buffer,
ref max_data_length,
..
} = &mut *self;
let mut reader = BufReader::new(stream);
let msg = loop {
let mut buf = Cursor::new(&read_buffer[..]);
match read_pdu(&mut buf, *max_data_length + PDV_HEADER_SIZE, false)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
{
Some(pdu) => {
read_buffer.advance(buf.position() as usize);
break pdu;
}
None => {
buf.set_position(0)
}
}
let recv = ready!(Pin::new(&mut reader).poll_fill_buf(cx))?.to_vec();
reader.consume(recv.len());
read_buffer.extend_from_slice(&recv);
if recv.is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Connection closed by peer",
)));
}
};
match msg {
Pdu::PData { data } => {
for pdata_value in data {
self.presentation_context_id = match self.presentation_context_id {
None => Some(pdata_value.presentation_context_id),
Some(cid) if cid == pdata_value.presentation_context_id => {
Some(cid)
}
Some(cid) => {
warn!("Received PData value of presentation context {}, but should be {}", pdata_value.presentation_context_id, cid);
Some(cid)
}
};
self.buffer.extend(pdata_value.data);
self.last_pdu = pdata_value.is_last;
}
}
_ => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected PDU type",
)))
}
}
}
let len = std::cmp::min(self.buffer.len(), buf.remaining());
for _ in 0..len {
buf.put_u8(self.buffer.pop_front().unwrap());
}
Poll::Ready(Ok(()))
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use crate::association::pdata::PDataWriter;
use crate::pdu::{read_pdu, Pdu, MINIMUM_PDU_SIZE, PDV_HEADER_SIZE};
use crate::pdu::{PDataValue, PDataValueType};
use crate::write_pdu;
use super::PDataReader;
use bytes::BytesMut;
#[cfg(feature = "async")]
use tokio::io::AsyncWriteExt;
#[cfg(feature = "async")]
use crate::association::pdata::non_blocking::AsyncPDataWriter;
#[test]
fn test_write_pdata_and_finish() {
let presentation_context_id = 12;
let mut buf = Vec::new();
{
let mut writer = PDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
writer.write_all(&(0..64).collect::<Vec<u8>>()).unwrap();
writer.finish().unwrap();
}
let mut cursor = &buf[..];
let same_pdu = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
match same_pdu {
Some(Pdu::PData { data: data_1 }) => {
let data_1 = &data_1[0];
assert_eq!(data_1.value_type, PDataValueType::Data);
assert_eq!(data_1.presentation_context_id, presentation_context_id);
assert_eq!(data_1.data.len(), 64);
assert_eq!(data_1.data, (0..64).collect::<Vec<u8>>());
}
pdu => panic!("Expected PData, got {:?}", pdu),
}
assert_eq!(cursor.len(), 0);
}
#[cfg(feature = "async")]
#[tokio::test(flavor = "multi_thread")]
async fn test_async_write_pdata_and_finish() {
let presentation_context_id = 12;
let mut buf = Vec::new();
{
let mut writer =
AsyncPDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
writer
.write_all(&(0..64).collect::<Vec<u8>>())
.await
.unwrap();
writer.finish().await.unwrap();
}
let mut cursor = &buf[..];
let same_pdu = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
match same_pdu {
Some(Pdu::PData { data: data_1 }) => {
let data_1 = &data_1[0];
assert_eq!(data_1.value_type, PDataValueType::Data);
assert_eq!(data_1.presentation_context_id, presentation_context_id);
assert_eq!(data_1.data.len(), 64);
assert_eq!(data_1.data, (0..64).collect::<Vec<u8>>());
}
pdu => panic!("Expected PData, got {:?}", pdu),
}
assert_eq!(cursor.len(), 0);
}
#[test]
fn test_write_large_pdata_and_finish() {
let presentation_context_id = 32;
let my_data: Vec<_> = (0..2500).map(|x: u32| x as u8).collect();
let mut buf = Vec::new();
{
let mut writer = PDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
writer.write_all(&my_data).unwrap();
writer.finish().unwrap();
}
let mut cursor = &buf[..];
let pdu_1 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
let pdu_2 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
let pdu_3 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
match (pdu_1, pdu_2, pdu_3) {
(
Some(Pdu::PData { data: data_1 }),
Some(Pdu::PData { data: data_2 }),
Some(Pdu::PData { data: data_3 }),
) => {
assert_eq!(data_1.len(), 1);
let data_1 = &data_1[0];
assert_eq!(data_2.len(), 1);
let data_2 = &data_2[0];
assert_eq!(data_3.len(), 1);
let data_3 = &data_3[0];
assert_eq!(data_1.value_type, PDataValueType::Data);
assert_eq!(data_2.value_type, PDataValueType::Data);
assert_eq!(data_1.presentation_context_id, presentation_context_id);
assert_eq!(data_2.presentation_context_id, presentation_context_id);
assert_eq!(
data_1.data.len(),
(MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize
);
assert_eq!(
data_2.data.len(),
(MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize
);
assert_eq!(
data_3.data.len(),
2500 - (MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize * 2
);
assert_eq!(
&data_1.data[..],
(0..MINIMUM_PDU_SIZE - PDV_HEADER_SIZE)
.map(|x| x as u8)
.collect::<Vec<_>>()
);
assert_eq!(
data_1.data.len() + data_2.data.len() + data_3.data.len(),
2500
);
let data_1 = &data_1.data;
let data_2 = &data_2.data;
let data_3 = &data_3.data;
let mut all_data: Vec<u8> = Vec::new();
all_data.extend(data_1);
all_data.extend(data_2);
all_data.extend(data_3);
assert_eq!(all_data, my_data);
}
x => panic!("Expected 3 PDatas, got {:?}", x),
}
assert_eq!(cursor.len(), 0);
}
#[cfg(feature = "async")]
#[tokio::test(flavor = "multi_thread")]
async fn test_async_write_large_pdata_and_finish() {
let presentation_context_id = 32;
let my_data: Vec<_> = (0..2500).map(|x: u32| x as u8).collect();
let mut buf = Vec::new();
{
let mut writer =
AsyncPDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE);
writer.write_all(&my_data).await.unwrap();
writer.finish().await.unwrap();
}
let mut cursor = &buf[..];
let pdu_1 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
let pdu_2 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
let pdu_3 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap();
match (pdu_1, pdu_2, pdu_3) {
(
Some(Pdu::PData { data: data_1 }),
Some(Pdu::PData { data: data_2 }),
Some(Pdu::PData { data: data_3 }),
) => {
assert_eq!(data_1.len(), 1);
let data_1 = &data_1[0];
assert_eq!(data_2.len(), 1);
let data_2 = &data_2[0];
assert_eq!(data_3.len(), 1);
let data_3 = &data_3[0];
assert_eq!(data_1.value_type, PDataValueType::Data);
assert_eq!(data_2.value_type, PDataValueType::Data);
assert_eq!(data_1.presentation_context_id, presentation_context_id);
assert_eq!(data_2.presentation_context_id, presentation_context_id);
assert_eq!(
data_1.data.len(),
(MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize
);
assert_eq!(
data_2.data.len(),
(MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize
);
assert_eq!(
data_3.data.len(),
2500 - (MINIMUM_PDU_SIZE - PDV_HEADER_SIZE) as usize * 2
);
assert_eq!(
&data_1.data[..],
(0..MINIMUM_PDU_SIZE - PDV_HEADER_SIZE)
.map(|x| x as u8)
.collect::<Vec<_>>()
);
assert_eq!(
data_1.data.len() + data_2.data.len() + data_3.data.len(),
2500
);
let data_1 = &data_1.data;
let data_2 = &data_2.data;
let data_3 = &data_3.data;
let mut all_data: Vec<u8> = Vec::new();
all_data.extend(data_1);
all_data.extend(data_2);
all_data.extend(data_3);
assert_eq!(all_data, my_data);
}
x => panic!("Expected 3 PDatas, got {:?}", x),
}
assert_eq!(cursor.len(), 0);
}
#[test]
fn test_read_large_pdata_and_finish() {
use std::collections::VecDeque;
let presentation_context_id = 32;
let my_data: Vec<_> = (0..2500).map(|x: u32| x as u8).collect();
let pdata_1 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[0..1018].to_owned(),
presentation_context_id,
is_last: false,
}];
let pdata_2 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[1018..2036].to_owned(),
presentation_context_id,
is_last: false,
}];
let pdata_3 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[2036..].to_owned(),
presentation_context_id,
is_last: true,
}];
let mut pdu_stream = VecDeque::new();
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_1 }).unwrap();
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_2 }).unwrap();
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_3 }).unwrap();
let mut buf = Vec::new();
{
let mut read_buf = BytesMut::new();
let mut reader = PDataReader::new(&mut pdu_stream, MINIMUM_PDU_SIZE, &mut read_buf);
reader.read_to_end(&mut buf).unwrap();
}
assert_eq!(buf, my_data);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_read_large_pdata_and_finish() {
use tokio::io::AsyncReadExt;
let presentation_context_id = 32;
let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect();
let pdata_1 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[0..3000].to_owned(),
presentation_context_id,
is_last: false,
}];
let pdata_2 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[3000..6000].to_owned(),
presentation_context_id,
is_last: false,
}];
let pdata_3 = vec![PDataValue {
value_type: PDataValueType::Data,
data: my_data[6000..].to_owned(),
presentation_context_id,
is_last: true,
}];
let mut pdu_stream = std::io::Cursor::new(Vec::new());
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_1 }).unwrap();
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_2 }).unwrap();
write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_3 }).unwrap();
let mut buf = Vec::new();
let inner = pdu_stream.into_inner();
let mut stream = tokio::io::BufReader::new(inner.as_slice());
{
let mut read_buf = BytesMut::new();
let mut reader = PDataReader::new(&mut stream, MINIMUM_PDU_SIZE, &mut read_buf);
reader.read_to_end(&mut buf).await.unwrap();
}
assert_eq!(buf, my_data);
}
}