use std::{future::Future, sync::Arc};
use futures::StreamExt;
use opcua_core::{
comms::{
buffer::SendBuffer,
secure_channel::SecureChannel,
tcp_codec::{Message, TcpCodec},
tcp_types::{AcknowledgeMessage, HelloMessage, ReverseHelloMessage},
},
sync::RwLock,
trace_read_lock, RequestMessage,
};
use opcua_crypto::SecurityPolicy;
use opcua_types::{DecodingOptions, Error, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::FramedRead;
use tracing::{debug, error};
use crate::transport::{
core::{TransportCloseState, TransportState},
state::SecureChannelState,
tcp::TransportConfiguration,
Connector, OutgoingMessage, Transport, TransportPollResult,
};
pub struct StreamConnection<R, W> {
reader: FramedRead<R, TcpCodec>,
writer: W,
endpoint_url: String,
}
impl<R, W> StreamConnection<R, W> {
pub fn new(reader: FramedRead<R, TcpCodec>, writer: W, endpoint_url: String) -> Self {
Self {
reader,
writer,
endpoint_url,
}
}
}
pub struct StreamConnector<R, W, C, F> {
connector: C,
default_endpoint_url: String,
_f: std::marker::PhantomData<fn() -> F>,
_r: std::marker::PhantomData<fn() -> R>,
_w: std::marker::PhantomData<fn() -> W>,
}
impl<R, W, C, F> StreamConnector<R, W, C, F>
where
R: AsyncRead + Unpin + Send + Sync + 'static,
W: AsyncWrite + Unpin + Send + Sync + 'static,
C: Fn(String, DecodingOptions) -> F + Send + Sync,
F: Future<Output = Result<StreamConnection<R, W>, Error>> + Send + Sync,
{
pub fn new(connector: C, default_endpoint_url: String) -> Self {
Self {
connector,
default_endpoint_url,
_f: std::marker::PhantomData,
_r: std::marker::PhantomData,
_w: std::marker::PhantomData,
}
}
async fn hello_exchange(
reader: &mut FramedRead<R, TcpCodec>,
writer: &mut W,
endpoint_url: &str,
config: &TransportConfiguration,
) -> Result<AcknowledgeMessage, Error> {
let hello = HelloMessage::new(
endpoint_url,
config.send_buffer_size,
config.recv_buffer_size,
config.max_message_size,
config.max_chunk_count,
);
tracing::trace!("Send hello message: {hello:?}");
writer
.write_all(&opcua_types::SimpleBinaryEncodable::encode_to_vec(&hello))
.await
.map_err(|err| {
error!("Cannot send hello to server, err = {}", err);
Error::new(
StatusCode::BadCommunicationError,
format!("Cannot send hello to server, err = {}", err),
)
})?;
writer.flush().await.map_err(|err| {
Error::new(
StatusCode::BadCommunicationError,
format!("Cannot send hello to server, err = {}", err),
)
})?;
match reader.next().await {
Some(Ok(Message::Acknowledge(ack))) => {
if ack.send_buffer_size > hello.receive_buffer_size {
tracing::warn!("Acknowledged send buffer size is greater than receive buffer size in hello message!")
}
if ack.receive_buffer_size > hello.send_buffer_size {
tracing::warn!("Acknowledged receive buffer size is greater than send buffer size in hello message!")
}
tracing::trace!("Received acknowledgement: {:?}", ack);
Ok(ack)
}
other => {
error!(
"Unexpected error while waiting for server ACK. Expected ACK, got {:?}",
other
);
Err(Error::new(
StatusCode::BadConnectionClosed,
format!(
"Unexpected error while waiting for server ACK. Expected ACK, got {:?}",
other
),
))
}
}
}
async fn connect_inner(
&self,
secure_channel: &RwLock<SecureChannel>,
config: &TransportConfiguration,
) -> Result<(StreamConnection<R, W>, AcknowledgeMessage, SecurityPolicy), Error> {
let (decoding_options, policy) = {
let secure_channel = trace_read_lock!(secure_channel);
(
secure_channel.decoding_options(),
secure_channel.security_policy(),
)
};
let mut connection =
(self.connector)(self.default_endpoint_url.clone(), decoding_options).await?;
let ack = Self::hello_exchange(
&mut connection.reader,
&mut connection.writer,
&connection.endpoint_url,
config,
)
.await?;
Ok((connection, ack, policy))
}
}
impl<R, W, C, F> Connector for StreamConnector<R, W, C, F>
where
R: AsyncRead + Unpin + Send + Sync + 'static,
W: AsyncWrite + Unpin + Send + Sync + 'static,
C: Fn(String, DecodingOptions) -> F + Send + Sync,
F: Future<Output = Result<StreamConnection<R, W>, Error>> + Send + Sync,
{
type Transport = StreamTransport<R, W>;
async fn connect(
&self,
channel: Arc<SecureChannelState>,
outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
config: TransportConfiguration,
) -> Result<StreamTransport<R, W>, StatusCode> {
let (connection, ack, policy) = self
.connect_inner(channel.secure_channel(), &config)
.await
.map_err(|e| e.status())?;
let mut buffer = SendBuffer::new(
config.send_buffer_size,
config.max_message_size,
config.max_chunk_count,
policy.legacy_sequence_numbers(),
);
buffer.revise(
ack.receive_buffer_size as usize,
ack.max_message_size as usize,
ack.max_chunk_count as usize,
);
Ok(StreamTransport {
state: TransportState::new(
channel,
outgoing_recv,
config.max_chunk_count,
ack.send_buffer_size.min(config.recv_buffer_size as u32) as usize,
),
read: connection.reader,
write: connection.writer,
send_buffer: buffer,
should_close: false,
closed: TransportCloseState::Open,
connected_url: connection.endpoint_url,
})
}
fn default_endpoint(&self) -> opcua_types::EndpointDescription {
opcua_types::EndpointDescription::from(self.default_endpoint_url.as_str())
}
}
pub struct StreamTransport<R, W> {
state: TransportState,
read: FramedRead<R, TcpCodec>,
write: W,
send_buffer: SendBuffer,
should_close: bool,
closed: TransportCloseState,
connected_url: String,
}
impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> StreamTransport<R, W> {
fn handle_incoming_message(
&mut self,
incoming: Option<Result<Message, std::io::Error>>,
) -> TransportPollResult {
let Some(incoming) = incoming else {
return TransportPollResult::Closed(StatusCode::BadCommunicationError);
};
match incoming {
Ok(message) => {
if let Err(e) = self.state.handle_incoming_message(message) {
TransportPollResult::Closed(e)
} else {
TransportPollResult::IncomingMessage
}
}
Err(err) => {
error!("Error reading from stream {}", err);
TransportPollResult::Closed(StatusCode::BadConnectionClosed)
}
}
}
async fn poll_inner(&mut self) -> TransportPollResult {
if self.send_buffer.should_encode_chunks() {
let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel());
if let Err(e) = self.send_buffer.encode_next_chunk(&secure_channel) {
return TransportPollResult::Closed(e);
}
}
if self.send_buffer.can_read() {
tokio::select! {
r = self.send_buffer.read_into_async(&mut self.write) => {
if let Err(e) = r {
error!("write bytes task failed: {}", e);
return TransportPollResult::Closed(StatusCode::BadCommunicationError);
}
TransportPollResult::OutgoingMessageSent
}
incoming = self.read.next() => {
self.handle_incoming_message(incoming)
}
}
} else {
if self.should_close {
debug!("Writer is setting the connection state to finished(good)");
return TransportPollResult::Closed(StatusCode::Good);
}
tokio::select! {
outgoing = self.state.wait_for_outgoing_message(&mut self.send_buffer) => {
let Some((outgoing, request_id)) = outgoing else {
return TransportPollResult::Closed(StatusCode::Good);
};
let close_connection =
matches!(outgoing, RequestMessage::CloseSecureChannel(_));
if close_connection {
self.should_close = true;
debug!("Writer is about to send a CloseSecureChannelRequest which means it should close in a moment");
}
let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel());
if let Err(e) = self.send_buffer.write(request_id, outgoing, &secure_channel) {
drop(secure_channel);
if let Some((request_id, request_handle)) = e.full_context() {
error!("Failed to send message with request handle {}: {}", request_handle, e);
self.state.message_send_failed(request_id, e.status());
TransportPollResult::RecoverableError(e.status())
} else {
TransportPollResult::Closed(e.status())
}
} else {
TransportPollResult::OutgoingMessage
}
}
incoming = self.read.next() => {
self.handle_incoming_message(incoming)
}
}
}
}
}
impl<R, W> Transport for StreamTransport<R, W>
where
R: AsyncRead + Unpin + Send + Sync + 'static,
W: AsyncWrite + Unpin + Send + Sync + 'static,
{
async fn poll(&mut self) -> TransportPollResult {
match self.closed {
TransportCloseState::Open => {}
TransportCloseState::Closing(c) => {
let r = self.state.close(c).await;
self.closed = TransportCloseState::Closed(c);
return TransportPollResult::Closed(r);
}
TransportCloseState::Closed(c) => {
return TransportPollResult::Closed(c);
}
}
let r = self.poll_inner().await;
if let TransportPollResult::Closed(status) = &r {
self.closed = TransportCloseState::Closing(*status);
let r = self.state.close(*status).await;
self.closed = TransportCloseState::Closed(r);
}
r
}
fn connected_url(&self) -> &str {
&self.connected_url
}
}
pub async fn wait_for_reverse_hello<R: AsyncRead + Unpin>(
framed_read: &mut FramedRead<R, TcpCodec>,
) -> Result<ReverseHelloMessage, Error> {
match framed_read.next().await {
Some(Ok(Message::ReverseHello(rev_hello))) => {
tracing::trace!("Received ReverseHello message: {:?}", rev_hello);
Ok(rev_hello)
}
Some(Ok(_)) => Err(Error::new(
StatusCode::BadConnectionClosed,
"Unexpected message while waiting for ReverseHello",
)),
Some(Err(err)) => Err(Error::new(
StatusCode::BadConnectionClosed,
format!("Error while waiting for ReverseHello: {}", err),
)),
None => Err(Error::new(
StatusCode::BadConnectionClosed,
"Connection closed while waiting for ReverseHello",
)),
}
}