async-opcua-client 0.18.0

OPC UA client API
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

use futures::future::Either;
use opcua_core::comms::sequence_number::SequenceNumberHandle;
use opcua_core::{trace_read_lock, RequestMessage, ResponseMessage};
use tracing::{debug, error, trace, warn};

use opcua_core::comms::buffer::SendBuffer;
use opcua_core::comms::message_chunk::MessageIsFinalType;
use opcua_core::comms::{
    chunker::Chunker, message_chunk::MessageChunk, message_chunk_info::ChunkInfo,
    tcp_codec::Message,
};
use opcua_types::{Error, StatusCode};

use crate::transport::state::SecureChannelState;
use crate::transport::RequestRecv;

#[derive(Debug)]
struct MessageChunkWithChunkInfo {
    header: ChunkInfo,
    data_with_header: Vec<u8>,
}

pub(crate) struct MessageState {
    callback: tokio::sync::oneshot::Sender<Result<ResponseMessage, StatusCode>>,
    chunks: Vec<MessageChunkWithChunkInfo>,
    deadline: Instant,
}

/// Internal state of a transport implementation.
pub struct TransportState {
    /// Channel for outgoing requests. Will only be polled if the number of inflight requests is below the limit.
    outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
    /// State of pending requests
    message_states: HashMap<u32, MessageState>,
    /// Secure channel
    pub channel_state: Arc<SecureChannelState>,
    /// Max pending incoming messages
    max_chunk_count: usize,
    /// Last decoded sequence number
    sequence_numbers: SequenceNumberHandle,
    /// Max size of incoming chunks
    #[allow(unused)]
    receive_buffer_size: usize,
}

#[derive(Debug, Clone, Copy)]
pub(super) enum TransportCloseState {
    Open,
    Closing(StatusCode),
    Closed(StatusCode),
}

#[derive(Debug)]
/// Result of polling a transport implementation.
/// This represents a single iteration of the transport event loop.
pub enum TransportPollResult {
    /// An outgoing message was received and enqueued.
    OutgoingMessage,
    /// An outgoing message was sent to the server.
    OutgoingMessageSent,
    /// An incoming message was received from the server.
    IncomingMessage,
    /// An error occured that is recoverable, so the transport can continue and
    /// simply fail the request.
    RecoverableError(StatusCode),
    /// The transport was closed with the given status code.
    Closed(StatusCode),
}

/// An outgoing message to be sent by the transport.
pub struct OutgoingMessage {
    /// The actual request message to send.
    pub request: RequestMessage,
    /// A callback that should be called when a response is received.
    pub callback: Option<tokio::sync::oneshot::Sender<Result<ResponseMessage, StatusCode>>>,
    /// Deadline for the request.
    pub deadline: Instant,
}

impl TransportState {
    /// Create a new transport state.
    pub fn new(
        channel_state: Arc<SecureChannelState>,
        outgoing_recv: RequestRecv,
        max_chunk_count: usize,
        receive_buffer_size: usize,
    ) -> Self {
        let legacy_sequence_numbers = channel_state
            .secure_channel()
            .read()
            .security_policy()
            .legacy_sequence_numbers();
        Self {
            channel_state,
            outgoing_recv,
            message_states: HashMap::new(),
            sequence_numbers: SequenceNumberHandle::new(legacy_sequence_numbers),
            max_chunk_count,
            receive_buffer_size,
        }
    }

    /// Wait for an outgoing message. Will also check for timed out messages.
    pub async fn wait_for_outgoing_message(
        &mut self,
        send_buffer: &mut SendBuffer,
    ) -> Option<(RequestMessage, u32)> {
        loop {
            // Check for any messages that have timed out, and get the time until the next message
            // times out
            let timeout_fut = match self.next_timeout() {
                Some(t) => Either::Left(tokio::time::sleep_until(t.into())),
                None => Either::Right(futures::future::pending::<()>()),
            };

            tokio::select! {
                    _ = timeout_fut => {
                        continue;
                    }
                    outgoing = self.outgoing_recv.recv() => {
                        let outgoing = outgoing?;
                        let request_id = send_buffer.next_request_id();
                        if let Some(callback) = outgoing.callback {
                            self.message_states.insert(request_id, MessageState {
                                callback,
                                chunks: Vec::new(),
                                deadline: outgoing.deadline,
                            });
                        }
                        break Some((outgoing.request, request_id));
                    }
            }
        }
    }

    /// Store incoming messages in the message state.
    pub fn handle_incoming_message(&mut self, message: Message) -> Result<(), StatusCode> {
        let status = match message {
            Message::Acknowledge(ack) => {
                debug!("Reader got an unexpected ack {:?}", ack);
                StatusCode::BadUnexpectedError
            }
            Message::Chunk(chunk) => self.process_chunk(chunk).err().unwrap_or(StatusCode::Good),
            Message::Error(error) => {
                error!(
                    "Received error {} from server. Reason: {}",
                    error.error, error.reason
                );
                error.error
            }
            m => {
                error!("Expected a recognized message, got {:?}", m);
                StatusCode::BadUnexpectedError
            }
        };

        if status.is_good() {
            Ok(())
        } else {
            Err(status)
        }
    }

    /// Call this if sending a message fails. This will notify the waiting request
    /// that the message could not be sent.
    pub fn message_send_failed(&mut self, request_id: u32, err: StatusCode) {
        if let Some(message_state) = self.message_states.remove(&request_id) {
            let _ = message_state.callback.send(Err(err));
        }
    }

    fn next_timeout(&mut self) -> Option<Instant> {
        let now = Instant::now();
        let mut next_timeout = None;
        let mut timed_out = Vec::new();
        for (id, state) in &self.message_states {
            if state.deadline <= now {
                timed_out.push(*id);
            } else {
                match &next_timeout {
                    Some(t) if *t > state.deadline => next_timeout = Some(state.deadline),
                    None => next_timeout = Some(state.deadline),
                    _ => {}
                }
            }
        }
        for id in timed_out {
            if let Some(state) = self.message_states.remove(&id) {
                debug!("Message {} timed out", id);
                let _ = state.callback.send(Err(StatusCode::BadTimeout));
            }
        }
        next_timeout
    }

    fn process_chunk(&mut self, chunk: MessageChunk) -> Result<(), StatusCode> {
        let secure_channel = trace_read_lock!(self.channel_state.secure_channel());
        let chunk = secure_channel.verify_and_remove_security(chunk.data)?;

        let chunk_info = chunk.chunk_info(&secure_channel)?;
        drop(secure_channel);
        let req_id = chunk_info.sequence_header.request_id;

        self.sequence_numbers
            .validate_and_increment(chunk_info.sequence_header.sequence_number)?;

        // We do not care at all about incoming messages without a
        // corresponding request.
        let Some(message_state) = self.message_states.get_mut(&req_id) else {
            trace!(
                "Received chunk for unknown request id {}:{}. Ignoring.",
                req_id,
                chunk_info.sequence_header.sequence_number
            );

            return Ok(());
        };

        match chunk_info.message_header.is_final {
            MessageIsFinalType::Intermediate => {
                trace!(
                    "receive chunk intermediate {}:{}. Length {}",
                    chunk_info.sequence_header.request_id,
                    chunk_info.sequence_header.sequence_number,
                    chunk_info.body_length
                );
                message_state.chunks.push(MessageChunkWithChunkInfo {
                    header: chunk_info,
                    data_with_header: chunk.data,
                });
                if self.max_chunk_count > 0 && message_state.chunks.len() > self.max_chunk_count {
                    error!(
                        "Message has more than {} chunks, exceeding negotiated limits",
                        self.max_chunk_count
                    );
                    // Removing the message state means that we ignore any further chunks.
                    let message_state = self.message_states.remove(&req_id).unwrap();
                    let _ = message_state
                        .callback
                        .send(Err(StatusCode::BadEncodingLimitsExceeded));
                }
            }
            MessageIsFinalType::FinalError => {
                warn!("Discarding chunk marked in as final error");
                let message_state = self.message_states.remove(&req_id).unwrap();
                let _ = message_state
                    .callback
                    .send(Err(StatusCode::BadCommunicationError));
            }
            MessageIsFinalType::Final => {
                trace!(
                    "receive chunk final {}:{}. Length {}",
                    chunk_info.sequence_header.request_id,
                    chunk_info.sequence_header.sequence_number,
                    chunk_info.body_length
                );
                message_state.chunks.push(MessageChunkWithChunkInfo {
                    header: chunk_info,
                    data_with_header: chunk.data,
                });
                let message_state = self.message_states.remove(&req_id).unwrap();
                let in_chunks = Self::merge_chunks(message_state.chunks)?;
                let message = self.turn_received_chunks_into_message(&in_chunks)?;

                // If the message is a response to opening a secure channel, we need to update encryption keys
                // right now. If we wait, we risk new messages using the new encryption keys arriving before
                // we've updated the secure channel.
                if let ResponseMessage::OpenSecureChannel(msg) = &message {
                    self.channel_state.end_issue_or_renew_secure_channel(msg)?;
                }

                let _ = message_state.callback.send(Ok(message));
            }
        }
        Ok(())
    }

    fn turn_received_chunks_into_message(
        &mut self,
        chunks: &[MessageChunk],
    ) -> Result<ResponseMessage, Error> {
        // Validate that all chunks have incrementing sequence numbers and valid chunk types
        let secure_channel = trace_read_lock!(self.channel_state.secure_channel());
        Chunker::validate_chunks(&secure_channel, chunks)?;
        // Now decode
        Chunker::decode(chunks, &secure_channel, None)
    }

    fn merge_chunks(
        mut chunks: Vec<MessageChunkWithChunkInfo>,
    ) -> Result<Vec<MessageChunk>, StatusCode> {
        if chunks.len() == 1 {
            return Ok(vec![MessageChunk {
                data: chunks.pop().unwrap().data_with_header,
            }]);
        }
        chunks.sort_by(|a, b| {
            a.header
                .sequence_header
                .sequence_number
                .cmp(&b.header.sequence_header.sequence_number)
        });
        let mut ret = Vec::with_capacity(chunks.len());
        let mut expect_sequence_number = chunks
            .first()
            .unwrap()
            .header
            .sequence_header
            .sequence_number;
        for c in chunks {
            if c.header.sequence_header.sequence_number != expect_sequence_number {
                warn!(
                    "receive wrong chunk expect seq={} got={}",
                    expect_sequence_number, c.header.sequence_header.sequence_number
                );
                continue; //may be duplicate chunk
            }
            expect_sequence_number += 1;
            ret.push(MessageChunk {
                data: c.data_with_header,
            });
        }
        Ok(ret)
    }

    /// Close the transport, aborting any pending requests.
    /// If `status` is good, the pending requests will be terminated with
    /// `BadConnectionClosed`.
    pub async fn close(&mut self, status: StatusCode) -> StatusCode {
        // If the status is good, we still want to send a bad status code
        // to the pending requests. They didn't succeed, after all.
        let request_status = if status.is_good() {
            StatusCode::BadConnectionClosed
        } else {
            status
        };

        for (_, pending) in self.message_states.drain() {
            let _ = pending.callback.send(Err(request_status));
        }

        // Make sure we also send a bad status for any remaining messages in the queue
        // Close the channel first.
        self.outgoing_recv.close();

        // recv is no longer blocking.
        while let Some(msg) = self.outgoing_recv.recv().await {
            if let Some(cb) = msg.callback {
                let _ = cb.send(Err(request_status));
            }
        }

        status
    }
}