pub mod client;
pub mod server;
#[cfg(test)]
mod tests;
mod uid;
pub(crate) mod pdata;
use std::{
backtrace::Backtrace,
io::{BufRead, BufReader, Cursor, Read},
time::Duration,
};
use bytes::{Buf, BytesMut};
#[cfg(feature = "async")]
pub use client::AsyncClientAssociation;
pub use client::{ClientAssociation, ClientAssociationOptions};
#[cfg(feature = "async")]
pub use pdata::non_blocking::AsyncPDataWriter;
pub use pdata::{PDataReader, PDataWriter};
#[cfg(feature = "async")]
pub use server::AsyncServerAssociation;
pub use server::{ServerAssociation, ServerAssociationOptions};
use snafu::{ensure, ResultExt, Snafu};
use crate::{
pdu::{self, AssociationRJ, PresentationContextNegotiated, ReadPduSnafu, UserVariableItem},
write_pdu, Pdu,
};
type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum Error {
MissingAbstractSyntax { backtrace: Backtrace },
ToAddress {
source: std::io::Error,
backtrace: Backtrace,
},
Connect {
source: std::io::Error,
backtrace: Backtrace,
},
SetReadTimeout {
source: std::io::Error,
backtrace: Backtrace,
},
SetWriteTimeout {
source: std::io::Error,
backtrace: Backtrace,
},
#[snafu(display("failed to send pdu: {}", source))]
SendPdu {
#[snafu(backtrace)]
source: crate::pdu::WriteError,
},
#[snafu(display("failed to receive pdu: {}", source))]
ReceivePdu {
#[snafu(backtrace)]
source: crate::pdu::ReadError,
},
#[snafu(display("unexpected response from peer `{:?}`", pdu))]
#[non_exhaustive]
UnexpectedPdu {
pdu: Box<Pdu>,
},
#[snafu(display("unknown response from peer `{:?}`", pdu))]
#[non_exhaustive]
UnknownPdu {
pdu: Box<Pdu>,
},
#[snafu(display("protocol version mismatch: expected {}, got {}", expected, got))]
ProtocolVersionMismatch {
expected: u16,
got: u16,
backtrace: Backtrace,
},
#[snafu(display("association rejected {}", association_rj.source))]
Rejected {
association_rj: AssociationRJ,
backtrace: Backtrace,
},
Aborted { backtrace: Backtrace },
NoAcceptedPresentationContexts { backtrace: Backtrace },
#[non_exhaustive]
WireSend {
source: std::io::Error,
backtrace: Backtrace,
},
#[non_exhaustive]
WireRead {
source: std::io::Error,
backtrace: Backtrace,
},
#[non_exhaustive]
Timeout {
source: std::io::Error,
backtrace: Backtrace,
},
#[snafu(display("failed close connection: {}", source))]
Close {
source: std::io::Error,
backtrace: Backtrace,
},
#[snafu(display(
"PDU is too large ({} bytes) to be sent to the remote application entity",
length
))]
#[non_exhaustive]
SendTooLongPdu { length: usize, backtrace: Backtrace },
#[snafu(display("Connection closed by peer"))]
ConnectionClosed,
#[cfg(feature = "sync-tls")]
#[snafu(display("TLS configuration is required but not provided"))]
TlsConfigMissing { backtrace: Backtrace },
#[cfg(feature = "sync-tls")]
#[snafu(display("Invalid server name for TLS connection"))]
InvalidServerName {
source: rustls::pki_types::InvalidDnsNameError,
backtrace: Backtrace,
},
#[cfg(feature = "sync-tls")]
#[snafu(display("Failed to establish TLS connection: {:?}", source))]
TlsConnection {
source: rustls::Error,
backtrace: Backtrace,
},
}
pub(crate) struct NegotiatedOptions {
peer_max_pdu_length: u32,
user_variables: Vec<UserVariableItem>,
presentation_contexts: Vec<PresentationContextNegotiated>,
peer_ae_title: String,
}
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct SocketOptions {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
connection_timeout: Option<Duration>,
}
pub trait CloseSocket {
fn close(&mut self) -> std::io::Result<()>;
}
impl CloseSocket for std::net::TcpStream {
fn close(&mut self) -> std::io::Result<()> {
self.shutdown(std::net::Shutdown::Both)
}
}
#[cfg(feature = "sync-tls")]
impl CloseSocket for rustls::StreamOwned<rustls::ClientConnection, std::net::TcpStream> {
fn close(&mut self) -> std::io::Result<()> {
self.get_mut().shutdown(std::net::Shutdown::Both)
}
}
#[cfg(feature = "sync-tls")]
impl CloseSocket for rustls::StreamOwned<rustls::ServerConnection, std::net::TcpStream> {
fn close(&mut self) -> std::io::Result<()> {
self.get_mut().shutdown(std::net::Shutdown::Both)
}
}
pub trait Association {
fn peer_ae_title(&self) -> &str;
fn acceptor_max_pdu_length(&self) -> u32;
fn requestor_max_pdu_length(&self) -> u32;
fn local_max_pdu_length(&self) -> u32;
fn peer_max_pdu_length(&self) -> u32;
fn presentation_contexts(&self) -> &[PresentationContextNegotiated];
fn user_variables(&self) -> &[UserVariableItem];
}
mod private {
use crate::{
pdu::{AbortRQServiceProviderReason, AbortRQSource},
Pdu,
};
use snafu::ResultExt;
pub trait SyncAssociationSealed<S: std::io::Read + std::io::Write + super::CloseSocket> {
fn close(&mut self) -> std::io::Result<()>;
fn send(&mut self, pdu: &Pdu) -> super::Result<()>;
fn receive(&mut self) -> super::Result<Pdu>;
fn release(&mut self) -> super::Result<()> {
let pdu = Pdu::ReleaseRQ;
self.send(&pdu)?;
let pdu = self.receive()?;
match pdu {
Pdu::ReleaseRP => {}
pdu @ Pdu::AbortRQ { .. }
| pdu @ Pdu::AssociationAC { .. }
| pdu @ Pdu::AssociationRJ { .. }
| pdu @ Pdu::AssociationRQ { .. }
| pdu @ Pdu::PData { .. }
| pdu @ Pdu::ReleaseRQ => return super::UnexpectedPduSnafu { pdu }.fail(),
pdu @ Pdu::Unknown { .. } => return super::UnknownPduSnafu { pdu }.fail(),
}
self.close().context(super::CloseSnafu)?;
Ok(())
}
fn abort(&mut self) -> super::Result<()>
where
Self: Sized,
{
let pdu = Pdu::AbortRQ {
source: AbortRQSource::ServiceProvider(
AbortRQServiceProviderReason::ReasonNotSpecified,
),
};
let out = self.send(&pdu);
let _ = self.close();
out
}
}
#[cfg(feature = "async")]
pub trait AsyncAssociationSealed<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> {
fn close(&mut self) -> impl std::future::Future<Output = std::io::Result<()>> + Send
where
Self: Send;
fn send(
&mut self,
pdu: &Pdu,
) -> impl std::future::Future<Output = super::Result<()>> + Send
where
Self: Send;
fn receive(&mut self) -> impl std::future::Future<Output = super::Result<Pdu>> + Send
where
Self: Send;
fn release(&mut self) -> impl std::future::Future<Output = super::Result<()>> + Send
where
Self: Send,
{
async move {
let pdu = Pdu::ReleaseRQ;
self.send(&pdu).await?;
let pdu = self.receive().await?;
match pdu {
Pdu::ReleaseRP => {}
pdu @ Pdu::AbortRQ { .. }
| pdu @ Pdu::AssociationAC { .. }
| pdu @ Pdu::AssociationRJ { .. }
| pdu @ Pdu::AssociationRQ { .. }
| pdu @ Pdu::PData { .. }
| pdu @ Pdu::ReleaseRQ => return super::UnexpectedPduSnafu { pdu }.fail(),
pdu @ Pdu::Unknown { .. } => return super::UnknownPduSnafu { pdu }.fail(),
}
self.close().await.context(super::CloseSnafu)?;
Ok(())
}
}
fn abort(&mut self) -> impl std::future::Future<Output = super::Result<()>> + Send
where
Self: Sized + Send,
{
let pdu = Pdu::AbortRQ {
source: AbortRQSource::ServiceProvider(
AbortRQServiceProviderReason::ReasonNotSpecified,
),
};
async move {
let out = self.send(&pdu).await;
let _ = self.close().await;
out
}
}
}
}
pub trait SyncAssociation<S: std::io::Read + std::io::Write + CloseSocket>:
private::SyncAssociationSealed<S> + Association
{
fn inner_stream(&mut self) -> &mut S;
fn get_mut(&mut self) -> (&mut S, &mut BytesMut);
fn send(&mut self, pdu: &Pdu) -> Result<()> {
private::SyncAssociationSealed::send(self, pdu)
}
fn receive(&mut self) -> Result<Pdu> {
private::SyncAssociationSealed::receive(self)
}
fn abort(mut self) -> Result<()>
where
Self: Sized,
{
private::SyncAssociationSealed::abort(&mut self)
}
fn release(mut self) -> Result<()>
where
Self: Sized,
{
private::SyncAssociationSealed::release(&mut self)
}
fn send_pdata(&mut self, presentation_context_id: u8) -> PDataWriter<&mut S> {
let max_pdu_length = self.peer_max_pdu_length();
PDataWriter::new(self.inner_stream(), presentation_context_id, max_pdu_length)
}
fn receive_pdata(&mut self) -> PDataReader<'_, &mut S> {
let max_pdu_length = self.local_max_pdu_length();
let (socket, read_buffer) = self.get_mut();
PDataReader::new(socket, max_pdu_length, read_buffer)
}
}
#[cfg(feature = "async")]
pub trait AsyncAssociation<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>:
private::AsyncAssociationSealed<S> + Association
{
fn inner_stream(&mut self) -> &mut S;
fn get_mut(&mut self) -> (&mut S, &mut BytesMut);
fn send(&mut self, pdu: &Pdu) -> impl std::future::Future<Output = Result<()>> + Send
where
Self: Send,
{
async move { private::AsyncAssociationSealed::send(self, pdu).await }
}
fn receive(&mut self) -> impl std::future::Future<Output = Result<Pdu>> + Send
where
Self: Send,
{
async move { private::AsyncAssociationSealed::receive(self).await }
}
fn abort(mut self) -> impl std::future::Future<Output = Result<()>> + Send
where
Self: Sized + Send,
{
async move { private::AsyncAssociationSealed::abort(&mut self).await }
}
fn release(mut self) -> impl std::future::Future<Output = Result<()>> + Send
where
Self: Sized + Send,
{
async move { private::AsyncAssociationSealed::release(&mut self).await }
}
fn send_pdata(&mut self, presentation_context_id: u8) -> AsyncPDataWriter<&mut S> {
let max_pdu_length = self.peer_max_pdu_length();
AsyncPDataWriter::new(self.inner_stream(), presentation_context_id, max_pdu_length)
}
fn receive_pdata(&mut self) -> PDataReader<'_, &mut S> {
let max_pdu_length = self.local_max_pdu_length();
let (socket, read_buffer) = self.get_mut();
PDataReader::new(socket, max_pdu_length, read_buffer)
}
}
#[cfg(feature = "async")]
async fn timeout<T>(
timeout: Option<Duration>,
block: impl std::future::Future<Output = Result<T>>,
) -> Result<T> {
if let Some(timeout) = timeout {
tokio::time::timeout(timeout, block)
.await
.map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))
.context(crate::association::TimeoutSnafu)?
} else {
block.await
}
}
pub(crate) fn encode_pdu(buffer: &mut Vec<u8>, pdu: &Pdu, peer_max_pdu_length: u32) -> Result<()> {
write_pdu(buffer, pdu).context(SendPduSnafu)?;
if buffer.len() > peer_max_pdu_length as usize {
return SendTooLongPduSnafu {
length: buffer.len(),
}
.fail();
}
Ok(())
}
pub fn read_pdu_from_wire<R>(
reader: &mut R,
read_buffer: &mut BytesMut,
max_pdu_length: u32,
strict: bool,
) -> Result<Pdu>
where
R: Read,
{
let mut reader = BufReader::new(reader);
let msg = loop {
let mut buf = Cursor::new(&read_buffer[..]);
match pdu::read_pdu(&mut buf, max_pdu_length, strict).context(ReceivePduSnafu)? {
Some(pdu) => {
read_buffer.advance(buf.position() as usize);
break pdu;
}
None => {
buf.set_position(0)
}
}
let recv = reader
.fill_buf()
.context(ReadPduSnafu)
.context(ReceivePduSnafu)?;
let bytes_read = recv.len();
read_buffer.extend_from_slice(recv);
reader.consume(bytes_read);
ensure!(bytes_read != 0, ConnectionClosedSnafu);
};
Ok(msg)
}
#[cfg(feature = "async")]
pub async fn read_pdu_from_wire_async<R: tokio::io::AsyncRead + Unpin>(
reader: &mut R,
read_buffer: &mut BytesMut,
max_pdu_length: u32,
strict: bool,
) -> Result<Pdu> {
use tokio::io::AsyncReadExt;
let msg = loop {
let mut buf = Cursor::new(&read_buffer[..]);
match pdu::read_pdu(&mut buf, max_pdu_length, strict).context(ReceivePduSnafu)? {
Some(pdu) => {
read_buffer.advance(buf.position() as usize);
break pdu;
}
None => {
buf.set_position(0)
}
}
let recv = reader
.read_buf(read_buffer)
.await
.context(ReadPduSnafu)
.context(ReceivePduSnafu)?;
ensure!(recv > 0, ConnectionClosedSnafu);
};
Ok(msg)
}