use bytes::BytesMut;
use std::{
borrow::Cow,
convert::TryInto,
net::{TcpStream, ToSocketAddrs},
time::Duration,
};
#[cfg(feature = "async")]
use crate::association::AsyncAssociation;
use crate::{
association::{
encode_pdu, private::SyncAssociationSealed, read_pdu_from_wire, Association,
NegotiatedOptions, SocketOptions, SyncAssociation,
},
pdu::{
write_pdu, AbortRQSource, AssociationAC, AssociationRQ, Pdu, PresentationContextNegotiated,
PresentationContextProposed, PresentationContextResultReason, UserIdentity,
UserIdentityType, UserVariableItem, DEFAULT_MAX_PDU, LARGE_PDU_SIZE, PDU_HEADER_SIZE,
},
AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME,
};
use snafu::{ensure, ResultExt};
use super::{uid::trim_uid, Result};
#[deprecated(since = "0.9.1")]
pub mod non_blocking {}
#[cfg(feature = "sync-tls")]
pub type TlsStream = rustls::StreamOwned<rustls::ClientConnection, std::net::TcpStream>;
#[cfg(feature = "async-tls")]
pub type AsyncTlsStream = tokio_rustls::client::TlsStream<tokio::net::TcpStream>;
pub use crate::association::CloseSocket;
fn tcp_connection<T>(ae_address: &AeAddr<T>, opts: &SocketOptions) -> Result<TcpStream>
where
T: ToSocketAddrs,
{
let conn_result: Result<TcpStream> = if let Some(timeout) = opts.connection_timeout {
let addresses = ae_address
.to_socket_addrs()
.context(super::ToAddressSnafu)?;
let mut result = Result::Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable));
for address in addresses {
result = TcpStream::connect_timeout(&address, timeout);
if result.is_ok() {
break;
}
}
result.context(super::ConnectSnafu)
} else {
TcpStream::connect(ae_address).context(super::ConnectSnafu)
};
let socket = conn_result?;
socket
.set_read_timeout(opts.read_timeout)
.context(super::SetReadTimeoutSnafu)?;
socket
.set_write_timeout(opts.write_timeout)
.context(super::SetWriteTimeoutSnafu)?;
Ok(socket)
}
#[cfg(feature = "sync-tls")]
fn tls_connection<T>(
ae_address: &AeAddr<T>,
server_name: &str,
opts: &SocketOptions,
tls_config: std::sync::Arc<rustls::ClientConfig>,
) -> Result<TlsStream>
where
T: ToSocketAddrs,
{
use std::convert::TryFrom;
let socket = tcp_connection(ae_address, opts)?;
let server_name = rustls::pki_types::ServerName::try_from(server_name.to_string())
.context(super::InvalidServerNameSnafu)?;
let conn = rustls::ClientConnection::new(tls_config.clone(), server_name)
.context(super::TlsConnectionSnafu)?;
Ok(rustls::StreamOwned::new(conn, socket))
}
#[derive(Debug, Clone)]
pub struct ClientAssociationOptions<'a> {
calling_ae_title: Cow<'a, str>,
called_ae_title: Option<Cow<'a, str>>,
application_context_name: Cow<'a, str>,
presentation_contexts: Vec<(Cow<'a, str>, Vec<Cow<'a, str>>)>,
protocol_version: u16,
max_pdu_length: u32,
strict: bool,
username: Option<Cow<'a, str>>,
password: Option<Cow<'a, str>>,
kerberos_service_ticket: Option<Cow<'a, str>>,
saml_assertion: Option<Cow<'a, str>>,
jwt: Option<Cow<'a, str>>,
socket_options: SocketOptions,
#[cfg(feature = "sync-tls")]
tls_config: Option<std::sync::Arc<rustls::ClientConfig>>,
#[cfg(feature = "sync-tls")]
server_name: Option<String>,
}
impl Default for ClientAssociationOptions<'_> {
fn default() -> Self {
ClientAssociationOptions {
calling_ae_title: "THIS-SCU".into(),
called_ae_title: None,
application_context_name: "1.2.840.10008.3.1.1.1".into(),
presentation_contexts: Vec::new(),
protocol_version: 1,
max_pdu_length: DEFAULT_MAX_PDU,
strict: true,
username: None,
password: None,
kerberos_service_ticket: None,
saml_assertion: None,
jwt: None,
socket_options: SocketOptions {
read_timeout: None,
write_timeout: None,
connection_timeout: None,
},
#[cfg(feature = "sync-tls")]
tls_config: None,
#[cfg(feature = "sync-tls")]
server_name: None,
}
}
}
impl<'a> ClientAssociationOptions<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn calling_ae_title<T>(mut self, calling_ae_title: T) -> Self
where
T: Into<Cow<'a, str>>,
{
self.calling_ae_title = calling_ae_title.into();
self
}
pub fn called_ae_title<T>(mut self, called_ae_title: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let cae = called_ae_title.into();
if cae.is_empty() {
self.called_ae_title = None;
} else {
self.called_ae_title = Some(cae);
}
self
}
pub fn with_presentation_context<T>(
mut self,
abstract_syntax_uid: T,
transfer_syntax_uids: Vec<T>,
) -> Self
where
T: Into<Cow<'a, str>>,
{
let transfer_syntaxes: Vec<Cow<'a, str>> = transfer_syntax_uids
.into_iter()
.map(|t| trim_uid(t.into()))
.collect();
self.presentation_contexts
.push((trim_uid(abstract_syntax_uid.into()), transfer_syntaxes));
self
}
pub fn with_abstract_syntax<T>(self, abstract_syntax_uid: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let default_transfer_syntaxes: Vec<Cow<'a, str>> =
vec!["1.2.840.10008.1.2.1".into(), "1.2.840.10008.1.2".into()];
self.with_presentation_context(abstract_syntax_uid.into(), default_transfer_syntaxes)
}
pub fn max_pdu_length(mut self, value: u32) -> Self {
self.max_pdu_length = value;
self
}
pub fn strict(mut self, strict: bool) -> Self {
self.strict = strict;
self
}
pub fn username<T>(mut self, username: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let username = username.into();
if username.is_empty() {
self.username = None;
} else {
self.username = Some(username);
self.saml_assertion = None;
self.jwt = None;
self.kerberos_service_ticket = None;
}
self
}
pub fn password<T>(mut self, password: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let password = password.into();
if password.is_empty() {
self.password = None;
} else {
self.password = Some(password);
self.saml_assertion = None;
self.jwt = None;
self.kerberos_service_ticket = None;
}
self
}
pub fn username_password<T, U>(mut self, username: T, password: U) -> Self
where
T: Into<Cow<'a, str>>,
U: Into<Cow<'a, str>>,
{
let username = username.into();
let password = password.into();
if username.is_empty() {
self.username = None;
self.password = None;
} else {
self.username = Some(username);
self.password = Some(password);
self.saml_assertion = None;
self.jwt = None;
self.kerberos_service_ticket = None;
}
self
}
pub fn kerberos_service_ticket<T>(mut self, kerberos_service_ticket: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let kerberos_service_ticket = kerberos_service_ticket.into();
if kerberos_service_ticket.is_empty() {
self.kerberos_service_ticket = None;
} else {
self.kerberos_service_ticket = Some(kerberos_service_ticket);
self.username = None;
self.password = None;
self.saml_assertion = None;
self.jwt = None;
}
self
}
pub fn saml_assertion<T>(mut self, saml_assertion: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let saml_assertion = saml_assertion.into();
if saml_assertion.is_empty() {
self.saml_assertion = None;
} else {
self.saml_assertion = Some(saml_assertion);
self.username = None;
self.password = None;
self.jwt = None;
self.kerberos_service_ticket = None;
}
self
}
pub fn jwt<T>(mut self, jwt: T) -> Self
where
T: Into<Cow<'a, str>>,
{
let jwt = jwt.into();
if jwt.is_empty() {
self.jwt = None;
} else {
self.jwt = Some(jwt);
self.username = None;
self.password = None;
self.saml_assertion = None;
self.kerberos_service_ticket = None;
}
self
}
#[cfg(feature = "sync-tls")]
pub fn tls_config(mut self, config: impl Into<std::sync::Arc<rustls::ClientConfig>>) -> Self {
self.tls_config = Some(config.into());
self
}
#[cfg(feature = "sync-tls")]
pub fn server_name(mut self, server_name: &str) -> Self {
self.server_name = Some(server_name.to_string());
self
}
pub fn establish<A: ToSocketAddrs>(
self,
address: A,
) -> Result<ClientAssociation<std::net::TcpStream>> {
let addr = AeAddr::new_socket_addr(address);
let socket = tcp_connection(&addr, &self.socket_options)?;
self.establish_impl(addr, socket)
}
#[cfg(feature = "sync-tls")]
pub fn establish_tls<A: ToSocketAddrs>(
self,
address: A,
) -> Result<ClientAssociation<TlsStream>> {
match (&self.tls_config, &self.server_name) {
(Some(tls_config), Some(server_name)) => {
let addr = AeAddr::new_socket_addr(address);
let socket =
tls_connection(&addr, server_name, &self.socket_options, tls_config.clone())?;
self.establish_impl(addr, socket)
}
_ => super::TlsConfigMissingSnafu.fail()?,
}
}
#[allow(unreachable_patterns)]
pub fn establish_with(self, ae_address: &str) -> Result<ClientAssociation<TcpStream>> {
match ae_address.try_into() {
Ok(ae_address) => {
let socket = tcp_connection(&ae_address, &self.socket_options)?;
self.establish_impl(ae_address, socket)
}
Err(_) => {
let addr = AeAddr::new_socket_addr(ae_address);
let socket = tcp_connection(&addr, &self.socket_options)?;
self.establish_impl(addr, socket)
}
}
}
#[allow(unreachable_patterns)]
#[cfg(feature = "sync-tls")]
pub fn establish_with_tls(self, ae_address: &str) -> Result<ClientAssociation<TlsStream>> {
match (&self.tls_config, &self.server_name) {
(Some(tls_config), Some(server_name)) => match ae_address.try_into() {
Ok(ae_address) => {
let socket = tls_connection(
&ae_address,
server_name,
&self.socket_options,
tls_config.clone(),
)?;
self.establish_impl(ae_address, socket)
}
Err(_) => {
let addr = AeAddr::new_socket_addr(ae_address);
let socket = tls_connection(
&addr,
server_name,
&self.socket_options,
tls_config.clone(),
)?;
self.establish_impl(addr, socket)
}
},
_ => super::TlsConfigMissingSnafu.fail()?,
}
}
pub fn read_timeout(self, timeout: Duration) -> Self {
Self {
socket_options: SocketOptions {
read_timeout: Some(timeout),
write_timeout: self.socket_options.write_timeout,
connection_timeout: self.socket_options.connection_timeout,
},
..self
}
}
pub fn write_timeout(self, timeout: Duration) -> Self {
Self {
socket_options: SocketOptions {
read_timeout: self.socket_options.read_timeout,
write_timeout: Some(timeout),
connection_timeout: self.socket_options.connection_timeout,
},
..self
}
}
pub fn connection_timeout(self, timeout: Duration) -> Self {
Self {
socket_options: SocketOptions {
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
connection_timeout: Some(timeout),
},
..self
}
}
fn create_a_associate_req(
&'a self,
ae_title: Option<&str>,
) -> Result<(Vec<PresentationContextProposed>, Pdu)> {
let ClientAssociationOptions {
calling_ae_title,
called_ae_title,
application_context_name,
presentation_contexts,
protocol_version,
max_pdu_length,
username,
password,
kerberos_service_ticket,
saml_assertion,
jwt,
..
} = self;
ensure!(
!presentation_contexts.is_empty(),
crate::association::MissingAbstractSyntaxSnafu
);
let called_ae_title: &str = match (&called_ae_title, ae_title) {
(Some(aec), Some(aet)) => {
if aec != aet {
tracing::warn!(
"Option `called_ae_title` overrides the AE title from `{aet}` to `{aec}`"
);
}
aec
}
(Some(aec), None) => aec,
(None, Some(aec)) => aec,
(None, None) => "ANY-SCP",
};
let presentation_contexts_proposed: Vec<_> = presentation_contexts
.iter()
.enumerate()
.map(|(i, presentation_context)| PresentationContextProposed {
id: (2 * i + 1) as u8,
abstract_syntax: presentation_context.0.to_string(),
transfer_syntaxes: presentation_context
.1
.iter()
.map(|uid| uid.to_string())
.collect(),
})
.collect();
let mut user_variables = vec![
UserVariableItem::MaxLength(*max_pdu_length),
UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()),
UserVariableItem::ImplementationVersionName(IMPLEMENTATION_VERSION_NAME.to_string()),
];
if let Some(user_identity) = Self::determine_user_identity(
username.as_deref(),
password.as_deref(),
kerberos_service_ticket.as_deref(),
saml_assertion.as_deref(),
jwt.as_deref(),
) {
user_variables.push(UserVariableItem::UserIdentityItem(user_identity));
}
Ok((
presentation_contexts_proposed.clone(),
Pdu::AssociationRQ(AssociationRQ {
protocol_version: *protocol_version,
calling_ae_title: calling_ae_title.to_string(),
called_ae_title: called_ae_title.to_string(),
application_context_name: application_context_name.to_string(),
presentation_contexts: presentation_contexts_proposed,
user_variables,
}),
))
}
fn process_a_association_resp(
&self,
msg: Pdu,
presentation_contexts_proposed: &[PresentationContextProposed],
) -> Result<NegotiatedOptions> {
match msg {
Pdu::AssociationAC(AssociationAC {
protocol_version: protocol_version_scp,
application_context_name: _,
presentation_contexts: presentation_contexts_scp,
calling_ae_title: _,
called_ae_title,
user_variables,
}) => {
ensure!(
self.protocol_version == protocol_version_scp,
crate::association::ProtocolVersionMismatchSnafu {
expected: self.protocol_version,
got: protocol_version_scp,
}
);
let acceptor_max_pdu_length = user_variables
.iter()
.find_map(|item| match item {
UserVariableItem::MaxLength(len) => Some(*len),
_ => None,
})
.unwrap_or(DEFAULT_MAX_PDU);
let acceptor_max_pdu_length = if acceptor_max_pdu_length == 0 {
u32::MAX
} else {
acceptor_max_pdu_length
};
let presentation_contexts: Vec<_> = presentation_contexts_scp
.into_iter()
.filter(|c| {
c.reason == PresentationContextResultReason::Acceptance
&& presentation_contexts_proposed.iter().any(|p| p.id == c.id)
})
.map(|c| {
let pcp = presentation_contexts_proposed
.iter()
.find(|pc| pc.id == c.id)
.unwrap();
PresentationContextNegotiated {
id: c.id,
reason: c.reason,
transfer_syntax: c.transfer_syntax,
abstract_syntax: pcp.abstract_syntax.clone(),
}
})
.collect();
if presentation_contexts.is_empty() {
return crate::association::NoAcceptedPresentationContextsSnafu.fail();
}
Ok(NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length: acceptor_max_pdu_length,
user_variables,
peer_ae_title: called_ae_title,
})
}
Pdu::AssociationRJ(association_rj) => {
crate::association::RejectedSnafu { association_rj }.fail()
}
pdu @ Pdu::AbortRQ { .. }
| pdu @ Pdu::ReleaseRQ
| pdu @ Pdu::AssociationRQ { .. }
| pdu @ Pdu::PData { .. }
| pdu @ Pdu::ReleaseRP => crate::association::UnexpectedPduSnafu { pdu }.fail(),
pdu @ Pdu::Unknown { .. } => crate::association::UnknownPduSnafu { pdu }.fail(),
}
}
fn establish_impl<T, S>(
self,
ae_address: AeAddr<T>,
mut socket: S,
) -> Result<ClientAssociation<S>>
where
T: ToSocketAddrs,
S: CloseSocket + std::io::Read + std::io::Write,
{
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut buffer: Vec<u8> = Vec::with_capacity(self.max_pdu_length as usize);
write_pdu(&mut buffer, &a_associate).context(super::SendPduSnafu)?;
socket.write_all(&buffer).context(super::WireSendSnafu)?;
buffer.clear();
let mut buf = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp = read_pdu_from_wire(&mut socket, &mut buf, self.max_pdu_length, self.strict)?;
let negotiated_options = self.process_a_association_resp(resp, &pc_proposed);
match negotiated_options {
Err(e) => {
let _ = write_pdu(
&mut buffer,
&Pdu::AbortRQ {
source: AbortRQSource::ServiceUser,
},
);
let _ = socket.write_all(&buffer);
buffer.clear();
Err(e)
}
Ok(NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
}) => {
Ok(ClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer: buffer,
strict: self.strict,
read_buffer: buf,
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
}
}
fn determine_user_identity<T>(
username: Option<T>,
password: Option<T>,
kerberos_service_ticket: Option<T>,
saml_assertion: Option<T>,
jwt: Option<T>,
) -> Option<UserIdentity>
where
T: Into<Cow<'a, str>>,
{
if let Some(username) = username {
if let Some(password) = password {
return Some(UserIdentity::new(
false,
UserIdentityType::UsernamePassword,
username.into().as_bytes().to_vec(),
password.into().as_bytes().to_vec(),
));
} else {
return Some(UserIdentity::new(
false,
UserIdentityType::Username,
username.into().as_bytes().to_vec(),
vec![],
));
}
}
if let Some(kerberos_service_ticket) = kerberos_service_ticket {
return Some(UserIdentity::new(
false,
UserIdentityType::KerberosServiceTicket,
kerberos_service_ticket.into().as_bytes().to_vec(),
vec![],
));
}
if let Some(saml_assertion) = saml_assertion {
return Some(UserIdentity::new(
false,
UserIdentityType::SamlAssertion,
saml_assertion.into().as_bytes().to_vec(),
vec![],
));
}
if let Some(jwt) = jwt {
return Some(UserIdentity::new(
false,
UserIdentityType::Jwt,
jwt.into().as_bytes().to_vec(),
vec![],
));
}
None
}
}
#[derive(Debug)]
pub struct ClientAssociation<S> {
presentation_contexts: Vec<PresentationContextNegotiated>,
requestor_max_pdu_length: u32,
acceptor_max_pdu_length: u32,
socket: S,
write_buffer: Vec<u8>,
strict: bool,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
read_buffer: BytesMut,
user_variables: Vec<UserVariableItem>,
peer_ae_title: String,
}
impl<S> Association for ClientAssociation<S>
where
S: CloseSocket + std::io::Read + std::io::Write,
{
fn peer_ae_title(&self) -> &str {
&self.peer_ae_title
}
fn acceptor_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
fn requestor_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
fn local_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
fn peer_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
fn presentation_contexts(&self) -> &[PresentationContextNegotiated] {
&self.presentation_contexts
}
fn user_variables(&self) -> &[UserVariableItem] {
&self.user_variables
}
}
impl<S> ClientAssociation<S>
where
S: CloseSocket + std::io::Read + std::io::Write,
{
pub fn read_timeout(&self) -> Option<Duration> {
self.read_timeout
}
pub fn write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn acceptor_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
pub fn requestor_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
pub fn user_variables(&self) -> &[UserVariableItem] {
&self.user_variables
}
pub fn presentation_contexts(&self) -> &[PresentationContextNegotiated] {
&self.presentation_contexts
}
}
impl<S> ClientAssociation<S>
where
S: CloseSocket + std::io::Read + std::io::Write,
{
pub fn send(&mut self, pdu: &Pdu) -> Result<()> {
SyncAssociation::send(self, pdu)
}
pub fn receive(&mut self) -> Result<Pdu> {
SyncAssociation::receive(self)
}
pub fn send_pdata(
&mut self,
presentation_context_id: u8,
) -> crate::association::pdata::PDataWriter<&mut S> {
SyncAssociation::send_pdata(self, presentation_context_id)
}
pub fn release(self) -> Result<()> {
SyncAssociation::release(self)
}
pub fn abort(self) -> Result<()> {
SyncAssociation::abort(self)
}
pub fn receive_pdata(&mut self) -> crate::association::pdata::PDataReader<'_, &mut S> {
SyncAssociation::receive_pdata(self)
}
pub fn inner_stream(&mut self) -> &mut S {
SyncAssociation::inner_stream(self)
}
}
impl<S> SyncAssociationSealed<S> for ClientAssociation<S>
where
S: CloseSocket + std::io::Read + std::io::Write,
{
fn send(&mut self, pdu: &Pdu) -> Result<()> {
self.write_buffer.clear();
encode_pdu(
&mut self.write_buffer,
pdu,
self.acceptor_max_pdu_length + PDU_HEADER_SIZE,
)?;
self.socket
.write_all(&self.write_buffer)
.context(super::WireSendSnafu)
}
fn receive(&mut self) -> Result<Pdu> {
read_pdu_from_wire(
&mut self.socket,
&mut self.read_buffer,
self.requestor_max_pdu_length,
self.strict,
)
}
fn close(&mut self) -> std::io::Result<()> {
self.socket.close()
}
}
impl<S> SyncAssociation<S> for ClientAssociation<S>
where
S: CloseSocket + std::io::Read + std::io::Write,
{
fn inner_stream(&mut self) -> &mut S {
&mut self.socket
}
fn get_mut(&mut self) -> (&mut S, &mut BytesMut) {
let Self {
socket,
read_buffer,
..
} = self;
(socket, read_buffer)
}
}
#[deprecated(since = "0.9.1", note = "Call `SyncAssociation::release` instead")]
pub trait Release {
#[deprecated(since = "0.9.1", note = "Call `SyncAssociation::release` instead")]
fn release(&mut self) -> Result<()>;
}
#[allow(deprecated)]
impl Release for ClientAssociation<std::net::TcpStream> {
fn release(&mut self) -> Result<()> {
SyncAssociationSealed::release(self)
}
}
#[cfg(feature = "async")]
pub(crate) async fn async_connection<T>(
ae_address: &AeAddr<T>,
opts: &SocketOptions,
) -> Result<tokio::net::TcpStream>
where
T: tokio::net::ToSocketAddrs,
{
super::timeout(opts.connection_timeout, async {
tokio::net::TcpStream::connect(ae_address.socket_addr())
.await
.context(crate::association::ConnectSnafu)
})
.await
}
#[cfg(feature = "async-tls")]
pub(crate) async fn async_tls_connection<T>(
ae_address: &AeAddr<T>,
server_name: &str,
opts: &SocketOptions,
tls_config: std::sync::Arc<rustls::ClientConfig>,
) -> Result<AsyncTlsStream>
where
T: tokio::net::ToSocketAddrs,
{
use rustls::pki_types::ServerName;
use std::convert::TryFrom;
let tcp_stream = async_connection(ae_address, opts).await?;
let connector = tokio_rustls::TlsConnector::from(tls_config);
let domain = ServerName::try_from(server_name.to_string())
.context(crate::association::InvalidServerNameSnafu)?;
let tls_stream = connector
.connect(domain, tcp_stream)
.await
.context(crate::association::ConnectSnafu)?;
Ok(tls_stream)
}
#[cfg(feature = "async")]
#[derive(Debug)]
pub struct AsyncClientAssociation<S> {
presentation_contexts: Vec<PresentationContextNegotiated>,
requestor_max_pdu_length: u32,
acceptor_max_pdu_length: u32,
socket: S,
write_buffer: Vec<u8>,
strict: bool,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
read_buffer: BytesMut,
user_variables: Vec<UserVariableItem>,
peer_ae_title: String,
}
#[cfg(feature = "async")]
impl<'a> ClientAssociationOptions<'a> {
async fn establish_impl_async<T, S>(
self,
ae_address: AeAddr<T>,
mut socket: S,
) -> Result<AsyncClientAssociation<S>>
where
T: tokio::net::ToSocketAddrs,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
{
use tokio::io::AsyncWriteExt;
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut write_buffer: Vec<u8> = Vec::with_capacity(DEFAULT_MAX_PDU as usize);
write_pdu(&mut write_buffer, &a_associate).context(crate::association::SendPduSnafu)?;
super::timeout(self.socket_options.write_timeout, async {
socket
.write_all(&write_buffer)
.await
.context(crate::association::WireSendSnafu)?;
Ok(())
})
.await?;
write_buffer.clear();
let mut read_buffer = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp = super::timeout(self.socket_options.read_timeout, async {
super::read_pdu_from_wire_async(
&mut socket,
&mut read_buffer,
self.max_pdu_length,
self.strict,
)
.await
})
.await?;
let negotiated_options = self.process_a_association_resp(resp, &pc_proposed);
match negotiated_options {
Err(e) => {
let _ = write_pdu(
&mut write_buffer,
&Pdu::AbortRQ {
source: AbortRQSource::ServiceUser,
},
);
socket
.write_all(&write_buffer)
.await
.context(crate::association::WireSendSnafu)?;
write_buffer.clear();
Err(e)
}
Ok(NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
}) => {
Ok(AsyncClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer,
strict: self.strict,
read_buffer,
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
}
}
pub async fn establish_async<A: tokio::net::ToSocketAddrs>(
self,
address: A,
) -> Result<AsyncClientAssociation<tokio::net::TcpStream>> {
let addr = AeAddr::new_socket_addr(address);
let socket = async_connection(&addr, &self.socket_options).await?;
self.establish_impl_async(addr, socket).await
}
#[cfg(feature = "async-tls")]
pub async fn establish_tls_async<A: tokio::net::ToSocketAddrs>(
self,
address: A,
) -> Result<AsyncClientAssociation<AsyncTlsStream>> {
match (&self.tls_config, &self.server_name) {
(Some(tls_config), Some(server_name)) => {
let addr = AeAddr::new_socket_addr(address);
let socket = async_tls_connection(
&addr,
server_name,
&self.socket_options,
tls_config.clone(),
)
.await?;
self.establish_impl_async(addr, socket).await
}
_ => crate::association::TlsConfigMissingSnafu.fail()?,
}
}
#[allow(unreachable_patterns)]
pub async fn establish_with_async(
self,
ae_address: &str,
) -> Result<AsyncClientAssociation<tokio::net::TcpStream>> {
match ae_address.try_into() {
Ok(ae_address) => {
let socket = async_connection(&ae_address, &self.socket_options).await?;
self.establish_impl_async(ae_address, socket).await
}
Err(_) => {
let addr = AeAddr::new_socket_addr(ae_address);
let socket = async_connection(&addr, &self.socket_options).await?;
self.establish_impl_async(addr, socket).await
}
}
}
#[cfg(feature = "async-tls")]
#[allow(unreachable_patterns)]
pub async fn establish_with_async_tls(
self,
ae_address: &str,
) -> Result<AsyncClientAssociation<AsyncTlsStream>> {
match (&self.tls_config, &self.server_name) {
(Some(tls_config), Some(server_name)) => match ae_address.try_into() {
Ok(ae_address) => {
let socket = async_tls_connection(
&ae_address,
server_name,
&self.socket_options,
tls_config.clone(),
)
.await?;
self.establish_impl_async(ae_address, socket).await
}
Err(_) => {
let addr = AeAddr::new_socket_addr(ae_address);
let socket = async_tls_connection(
&addr,
server_name,
&self.socket_options,
tls_config.clone(),
)
.await?;
self.establish_impl_async(addr, socket).await
}
},
_ => crate::association::TlsConfigMissingSnafu.fail()?,
}
}
}
#[cfg(feature = "async")]
impl<S> Association for AsyncClientAssociation<S> {
fn peer_ae_title(&self) -> &str {
&self.peer_ae_title
}
fn acceptor_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
fn requestor_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
fn local_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
fn peer_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
fn presentation_contexts(&self) -> &[PresentationContextNegotiated] {
&self.presentation_contexts
}
fn user_variables(&self) -> &[UserVariableItem] {
&self.user_variables
}
}
#[cfg(feature = "async")]
impl<S> AsyncClientAssociation<S> {
pub fn read_timeout(&self) -> Option<Duration> {
self.read_timeout
}
pub fn write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn acceptor_max_pdu_length(&self) -> u32 {
self.acceptor_max_pdu_length
}
pub fn requestor_max_pdu_length(&self) -> u32 {
self.requestor_max_pdu_length
}
pub fn user_variables(&self) -> &[UserVariableItem] {
&self.user_variables
}
pub fn presentation_contexts(&self) -> &[PresentationContextNegotiated] {
&self.presentation_contexts
}
}
#[cfg(feature = "async")]
impl<S> AsyncClientAssociation<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
{
pub fn inner_stream(&mut self) -> &mut S {
AsyncAssociation::inner_stream(self)
}
pub async fn send(&mut self, msg: &Pdu) -> Result<()> {
AsyncAssociation::send(self, msg).await
}
pub async fn receive(&mut self) -> Result<Pdu> {
AsyncAssociation::receive(self).await
}
pub async fn release(self) -> Result<()> {
AsyncAssociation::release(self).await
}
pub async fn abort(self) -> Result<()> {
AsyncAssociation::abort(self).await
}
pub fn send_pdata(
&mut self,
presentation_context_id: u8,
) -> crate::association::pdata::non_blocking::AsyncPDataWriter<&mut S> {
AsyncAssociation::send_pdata(self, presentation_context_id)
}
pub fn receive_pdata(&mut self) -> crate::association::pdata::PDataReader<'_, &mut S> {
AsyncAssociation::receive_pdata(self)
}
}
#[cfg(feature = "async")]
impl<S> super::private::AsyncAssociationSealed<S> for AsyncClientAssociation<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
{
async fn send(&mut self, msg: &Pdu) -> Result<()> {
use tokio::io::AsyncWriteExt;
self.write_buffer.clear();
encode_pdu(
&mut self.write_buffer,
msg,
self.acceptor_max_pdu_length + PDU_HEADER_SIZE,
)?;
super::timeout(self.write_timeout, async {
self.socket
.write_all(&self.write_buffer)
.await
.context(crate::association::WireSendSnafu)
})
.await
}
async fn receive(&mut self) -> Result<Pdu> {
use crate::association::read_pdu_from_wire_async;
super::timeout(self.read_timeout, async {
read_pdu_from_wire_async(
&mut self.socket,
&mut self.read_buffer,
self.requestor_max_pdu_length,
self.strict,
)
.await
})
.await
}
async fn close(&mut self) -> std::io::Result<()> {
use tokio::io::AsyncWriteExt;
self.socket.shutdown().await
}
}
#[cfg(feature = "async")]
impl<S> AsyncAssociation<S> for AsyncClientAssociation<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
{
fn inner_stream(&mut self) -> &mut S {
&mut self.socket
}
fn get_mut(&mut self) -> (&mut S, &mut BytesMut) {
let Self {
socket,
read_buffer,
..
} = self;
(socket, read_buffer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "async")]
use crate::association::read_pdu_from_wire_async;
use std::io::Write;
impl<'a> ClientAssociationOptions<'a> {
pub(crate) fn establish_with_extra_pdus<T>(
&self,
ae_address: AeAddr<T>,
extra_pdus: Vec<Pdu>,
) -> Result<ClientAssociation<std::net::TcpStream>>
where
T: ToSocketAddrs,
{
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut socket = tcp_connection(&ae_address, &self.socket_options)?;
let mut write_buffer: Vec<u8> = Vec::with_capacity(DEFAULT_MAX_PDU as usize);
write_pdu(&mut write_buffer, &a_associate).context(crate::association::SendPduSnafu)?;
for pdu in extra_pdus {
write_pdu(&mut write_buffer, &pdu).context(crate::association::SendPduSnafu)?;
}
socket
.write_all(&write_buffer)
.context(crate::association::WireSendSnafu)?;
write_buffer.clear();
let mut read_buffer = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp = read_pdu_from_wire(
&mut socket,
&mut read_buffer,
self.max_pdu_length,
self.strict,
)?;
let NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
} = self
.process_a_association_resp(resp, &pc_proposed)
.expect("Failed to process a associate response");
Ok(ClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer,
strict: self.strict,
read_buffer,
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
#[cfg(feature = "async")]
pub(crate) async fn establish_with_extra_pdus_async<T>(
&self,
ae_address: AeAddr<T>,
extra_pdus: Vec<Pdu>,
) -> Result<AsyncClientAssociation<tokio::net::TcpStream>>
where
T: tokio::net::ToSocketAddrs,
{
use tokio::io::AsyncWriteExt;
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut socket = async_connection(&ae_address, &self.socket_options).await?;
let mut buffer: Vec<u8> = Vec::with_capacity(DEFAULT_MAX_PDU as usize);
write_pdu(&mut buffer, &a_associate).context(crate::association::SendPduSnafu)?;
for pdu in extra_pdus {
write_pdu(&mut buffer, &pdu).context(crate::association::SendPduSnafu)?;
}
socket
.write_all(&buffer)
.await
.context(crate::association::WireSendSnafu)?;
buffer.clear();
let mut buf = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp =
read_pdu_from_wire_async(&mut socket, &mut buf, self.max_pdu_length, self.strict)
.await?;
let NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
} = self
.process_a_association_resp(resp, &pc_proposed)
.expect("Failed to process a associate response");
Ok(AsyncClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer: buffer,
strict: self.strict,
read_buffer: buf,
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
pub fn broken_establish<T>(
&self,
ae_address: AeAddr<T>,
) -> Result<ClientAssociation<std::net::TcpStream>>
where
T: ToSocketAddrs,
{
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut socket = tcp_connection(&ae_address, &self.socket_options)?;
let mut buffer: Vec<u8> = Vec::with_capacity(DEFAULT_MAX_PDU as usize);
write_pdu(&mut buffer, &a_associate).context(crate::association::SendPduSnafu)?;
socket
.write_all(&buffer)
.context(crate::association::WireSendSnafu)?;
buffer.clear();
let mut buf = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp = read_pdu_from_wire(&mut socket, &mut buf, self.max_pdu_length, self.strict)?;
let NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
} = self
.process_a_association_resp(resp, &pc_proposed)
.expect("Failed to process a associate response");
Ok(ClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer: buffer,
strict: self.strict,
read_buffer: BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
),
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
#[cfg(feature = "async")]
pub async fn broken_establish_async<T>(
&self,
ae_address: AeAddr<T>,
) -> Result<AsyncClientAssociation<tokio::net::TcpStream>>
where
T: tokio::net::ToSocketAddrs,
{
use tokio::io::AsyncWriteExt;
let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?;
let mut socket = async_connection(&ae_address, &self.socket_options).await?;
let mut buffer: Vec<u8> = Vec::with_capacity(DEFAULT_MAX_PDU as usize);
write_pdu(&mut buffer, &a_associate).context(crate::association::SendPduSnafu)?;
socket
.write_all(&buffer)
.await
.context(crate::association::WireSendSnafu)?;
buffer.clear();
let mut buf = BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
);
let resp =
read_pdu_from_wire_async(&mut socket, &mut buf, self.max_pdu_length, self.strict)
.await?;
let NegotiatedOptions {
presentation_contexts,
peer_max_pdu_length,
user_variables,
peer_ae_title,
} = self
.process_a_association_resp(resp, &pc_proposed)
.expect("Failed to process a associate response");
Ok(AsyncClientAssociation {
presentation_contexts,
requestor_max_pdu_length: self.max_pdu_length,
acceptor_max_pdu_length: peer_max_pdu_length,
socket,
write_buffer: buffer,
strict: self.strict,
read_buffer: BytesMut::with_capacity(
(self.max_pdu_length.min(LARGE_PDU_SIZE) + PDU_HEADER_SIZE) as usize,
),
read_timeout: self.socket_options.read_timeout,
write_timeout: self.socket_options.write_timeout,
user_variables,
peer_ae_title,
})
}
}
}