Skip to main content

opcua_core/comms/
buffer.rs

1//! Shared implementaton of an OPC-UA buffer, handling
2//! encoding of data and the state of a communication channel.
3
4use std::{
5    collections::VecDeque,
6    io::{BufRead, Cursor},
7};
8
9use tokio::io::AsyncWriteExt;
10use tracing::trace;
11
12use crate::{
13    comms::{chunker::Chunker, message_chunk::MessageChunk, secure_channel::SecureChannel},
14    Message,
15};
16
17use opcua_types::{Error, SimpleBinaryEncodable, StatusCode};
18
19use super::{
20    sequence_number::SequenceNumberHandle,
21    tcp_types::{AcknowledgeMessage, ErrorMessage},
22};
23
24#[derive(Copy, Clone, Debug)]
25enum SendBufferState {
26    Reading(usize),
27    Writing,
28}
29
30#[derive(Debug)]
31enum PendingPayload {
32    Chunk(MessageChunk),
33    Ack(AcknowledgeMessage),
34    Error(ErrorMessage),
35}
36
37/// General implementation of a buffer of outgoing messages.
38pub struct SendBuffer {
39    /// The send buffer
40    buffer: Cursor<Vec<u8>>,
41    /// Queued chunks
42    chunks: VecDeque<PendingPayload>,
43    /// The last request id
44    last_request_id: u32,
45    /// Last sent sequence number
46    sequence_numbers: SequenceNumberHandle,
47    /// Maximum size of a message, total. Use 0 for no limit
48    pub max_message_size: usize,
49    /// Maximum number of chunks in a message.
50    pub max_chunk_count: usize,
51    /// Maximum size of each individual chunk.
52    pub send_buffer_size: usize,
53
54    state: SendBufferState,
55}
56
57// The send buffer works as follows:
58//  - `write` is called with a message that is written to the internal buffer.
59//  - `read_into_async` is called, which sets the state to `Writing`.
60//  - Once the buffer is exhausted, the state is set back to `Reading`.
61//  - `write` cannot be called while we are writing to the output.
62impl SendBuffer {
63    /// Create a new send buffer with the given initial limits.
64    pub fn new(
65        buffer_size: usize,
66        max_message_size: usize,
67        max_chunk_count: usize,
68        sequence_numbers_legacy: bool,
69    ) -> Self {
70        Self {
71            buffer: Cursor::new(vec![0u8; buffer_size + 1024]),
72            chunks: VecDeque::with_capacity(max_chunk_count),
73            last_request_id: 1000,
74            sequence_numbers: SequenceNumberHandle::new(sequence_numbers_legacy),
75            max_message_size,
76            max_chunk_count,
77            send_buffer_size: buffer_size,
78            state: SendBufferState::Writing,
79        }
80    }
81
82    /// Encode the next chunk in the queue to the out-buffer.
83    pub fn encode_next_chunk(&mut self, secure_channel: &SecureChannel) -> Result<(), StatusCode> {
84        if matches!(self.state, SendBufferState::Reading(_)) {
85            return Err(StatusCode::BadInvalidState);
86        }
87
88        let Some(next_chunk) = self.chunks.pop_front() else {
89            return Ok(());
90        };
91
92        let size = match next_chunk {
93            PendingPayload::Chunk(c) => secure_channel.apply_security(&c, self.buffer.get_mut())?,
94            PendingPayload::Ack(a) => {
95                a.encode(&mut self.buffer)?;
96                self.buffer.position() as usize
97            }
98            PendingPayload::Error(e) => {
99                e.encode(&mut self.buffer)?;
100                self.buffer.position() as usize
101            }
102        };
103        self.buffer.set_position(0);
104        self.state = SendBufferState::Reading(size);
105
106        Ok(())
107    }
108
109    /// Set whether we are using legacy sequence numbers or not.
110    /// This depends on the active security policy.
111    pub fn set_sequence_number_legacy(&mut self, is_legacy: bool) {
112        self.sequence_numbers.set_is_legacy(is_legacy);
113    }
114
115    /// Clear the list of pending messages, then
116    /// add an error.
117    pub fn write_error(&mut self, error: ErrorMessage) {
118        // Clear any pending chunks, we're erroring out
119        self.chunks.clear();
120        self.chunks.push_back(PendingPayload::Error(error));
121    }
122
123    /// Write an acknowledge message to the list of pending messages.
124    pub fn write_ack(&mut self, ack: AcknowledgeMessage) {
125        self.chunks.push_back(PendingPayload::Ack(ack));
126    }
127
128    /// Encode a message to chunks, then write them to the pending message queue.
129    ///
130    /// The messages are encrypted as they are sent.
131    pub fn write(
132        &mut self,
133        request_id: u32,
134        message: impl Message,
135        secure_channel: &SecureChannel,
136    ) -> Result<u32, Error> {
137        trace!("Writing request to buffer");
138
139        // Turn message to chunk(s)
140        let chunks = Chunker::encode(
141            self.sequence_numbers.clone(),
142            request_id,
143            self.max_message_size,
144            self.send_buffer_size,
145            secure_channel,
146            &message,
147        )
148        .map_err(|e| e.with_context(Some(request_id), Some(message.request_handle())))?;
149
150        if self.max_chunk_count > 0 && chunks.len() > self.max_chunk_count {
151            Err(Error::new(
152                StatusCode::BadCommunicationError,
153                format!(
154                    "Cannot write message since {} chunks exceeds {} chunk limit",
155                    chunks.len(),
156                    self.max_chunk_count
157                ),
158            )
159            .with_context(Some(request_id), Some(message.request_handle())))
160        } else {
161            // Sequence number monotonically increases per chunk
162            self.sequence_numbers.increment(chunks.len() as u32);
163
164            // Send chunks
165            self.chunks
166                .extend(chunks.into_iter().map(PendingPayload::Chunk));
167            Ok(request_id)
168        }
169    }
170
171    /// Get the next request ID.
172    pub fn next_request_id(&mut self) -> u32 {
173        self.last_request_id += 1;
174        self.last_request_id
175    }
176
177    /// Read the pending buffer into the given stream.
178    pub async fn read_into_async(
179        &mut self,
180        write: &mut (impl tokio::io::AsyncWrite + Unpin),
181    ) -> Result<(), tokio::io::Error> {
182        // Set the state to writing, or get the current end point
183        let end = match self.state {
184            SendBufferState::Writing => {
185                let end = self.buffer.position() as usize;
186                self.state = SendBufferState::Reading(end);
187                self.buffer.set_position(0);
188                end
189            }
190            SendBufferState::Reading(end) => end,
191        };
192
193        let pos = self.buffer.position() as usize;
194        let buf = &self.buffer.get_ref()[pos..end];
195        // Write to the stream, note that we do not actually advance the stream before
196        // after we have written. This means that since `write` is cancellation safe, our stream is
197        // cancellation safe, which is essential.
198        let written = write.write(buf).await?;
199
200        self.buffer.consume(written);
201
202        if end == self.buffer.position() as usize {
203            self.state = SendBufferState::Writing;
204            self.buffer.set_position(0);
205        }
206
207        Ok(())
208    }
209
210    /// Return `true` if we should encode a new chunk.
211    pub fn should_encode_chunks(&self) -> bool {
212        !self.chunks.is_empty() && !self.can_read()
213    }
214
215    /// Check if we can read data from the buffer into the stream.
216    pub fn can_read(&self) -> bool {
217        matches!(self.state, SendBufferState::Reading(_)) || self.buffer.position() != 0
218    }
219
220    /// Revise the limits with the result of a hello/acknowledge message.
221    pub fn revise(
222        &mut self,
223        send_buffer_size: usize,
224        max_message_size: usize,
225        max_chunk_count: usize,
226    ) {
227        if self.send_buffer_size > send_buffer_size {
228            self.buffer.get_mut().shrink_to(send_buffer_size + 1024);
229            self.send_buffer_size = send_buffer_size;
230        }
231        if self.max_message_size > max_message_size && max_message_size > 0 {
232            self.max_message_size = max_message_size;
233        }
234        if self.max_chunk_count > max_chunk_count && max_chunk_count > 0 {
235            self.max_chunk_count = max_chunk_count;
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use std::io::Cursor;
243    use std::sync::Arc;
244
245    use parking_lot::RwLock;
246
247    use super::SendBuffer;
248
249    use crate::comms::secure_channel::{Role, SecureChannel};
250    use crate::RequestMessage;
251    use opcua_crypto::CertificateStore;
252    use opcua_types::StatusCode;
253    use opcua_types::{
254        DateTime, NodeId, ReadRequest, ReadValueId, RequestHeader, TimestampsToReturn,
255    };
256
257    fn get_buffer_and_channel() -> (SendBuffer, SecureChannel) {
258        let buffer = SendBuffer::new(8196, 81960, 5, true);
259        let channel = SecureChannel::new(
260            Arc::new(RwLock::new(CertificateStore::new(std::path::Path::new(
261                "./pki",
262            )))),
263            Role::Client,
264            Default::default(),
265        );
266
267        (buffer, channel)
268    }
269
270    #[tokio::test]
271    async fn test_buffer_simple() {
272        // Write a small message to the buffer
273        let message = ReadRequest {
274            request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
275            max_age: 0.0,
276            timestamps_to_return: TimestampsToReturn::Both,
277            nodes_to_read: Some(vec![ReadValueId {
278                node_id: (1, 1).into(),
279                attribute_id: 1,
280                ..Default::default()
281            }]),
282        };
283
284        let (mut buffer, channel) = get_buffer_and_channel();
285
286        let m: RequestMessage = message.into();
287        let request_id = buffer.write(1, m, &channel).unwrap();
288        assert_eq!(request_id, 1);
289
290        assert!(buffer.should_encode_chunks());
291        assert_eq!(buffer.chunks.len(), 1);
292        buffer.encode_next_chunk(&channel).unwrap();
293        assert!(buffer.can_read());
294
295        let mut cursor = Cursor::new(Vec::new());
296        buffer.read_into_async(&mut cursor).await.unwrap();
297        assert!(cursor.get_ref().len() > 50);
298    }
299
300    #[tokio::test]
301    async fn test_buffer_chunking() {
302        // Write a large enough message that it is split into chunks.
303        let message = ReadRequest {
304            request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
305            max_age: 0.0,
306            timestamps_to_return: TimestampsToReturn::Both,
307            nodes_to_read: Some(
308                (0..1000)
309                    .map(|r| ReadValueId {
310                        node_id: (1, r).into(),
311                        attribute_id: 1,
312                        ..Default::default()
313                    })
314                    .collect(),
315            ),
316        };
317
318        let (mut buffer, channel) = get_buffer_and_channel();
319
320        let m: RequestMessage = message.into();
321        let request_id = buffer.write(1, m, &channel).unwrap();
322        assert_eq!(request_id, 1);
323
324        assert_eq!(buffer.chunks.len(), 3);
325        let mut cursor = Cursor::new(Vec::new());
326
327        for _ in 0..3 {
328            assert!(buffer.should_encode_chunks());
329            buffer.encode_next_chunk(&channel).unwrap();
330            assert!(!buffer.should_encode_chunks());
331            assert!(buffer.can_read());
332
333            buffer.read_into_async(&mut cursor).await.unwrap();
334        }
335        assert!(!buffer.should_encode_chunks());
336        assert!(!buffer.can_read());
337        assert!(cursor.get_ref().len() > 8196 * 2 && cursor.get_ref().len() < 8196 * 3);
338    }
339
340    #[test]
341    fn test_buffer_too_large_message() {
342        // Write a very large message exceeding the max message size.
343        let message = ReadRequest {
344            request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
345            max_age: 0.0,
346            timestamps_to_return: TimestampsToReturn::Both,
347            nodes_to_read: Some(
348                (0..10000)
349                    .map(|r| ReadValueId {
350                        node_id: (1, r).into(),
351                        attribute_id: 1,
352                        ..Default::default()
353                    })
354                    .collect(),
355            ),
356        };
357
358        let (mut buffer, channel) = get_buffer_and_channel();
359
360        let m: RequestMessage = message.into();
361        let err = buffer.write(1, m, &channel).unwrap_err();
362        assert_eq!(err.status(), StatusCode::BadRequestTooLarge);
363    }
364
365    #[test]
366    fn test_buffer_too_many_chunks() {
367        // Write a large enough message that we exceed the maximum chunk count.
368        let message = ReadRequest {
369            request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
370            max_age: 0.0,
371            timestamps_to_return: TimestampsToReturn::Both,
372            nodes_to_read: Some(
373                (0..4000)
374                    .map(|r| ReadValueId {
375                        node_id: (1, r).into(),
376                        attribute_id: 1,
377                        ..Default::default()
378                    })
379                    .collect(),
380            ),
381        };
382
383        let (mut buffer, channel) = get_buffer_and_channel();
384
385        let m: RequestMessage = message.into();
386        let err = buffer.write(1, m, &channel).unwrap_err();
387        assert_eq!(err.status(), StatusCode::BadCommunicationError);
388    }
389
390    #[tokio::test]
391    async fn test_buffer_read_partial() {
392        // Write a large message to the buffer.
393        let message = ReadRequest {
394            request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
395            max_age: 0.0,
396            timestamps_to_return: TimestampsToReturn::Both,
397            nodes_to_read: Some(
398                (0..1000)
399                    .map(|r| ReadValueId {
400                        node_id: (1, r).into(),
401                        attribute_id: 1,
402                        ..Default::default()
403                    })
404                    .collect(),
405            ),
406        };
407
408        let (mut buffer, channel) = get_buffer_and_channel();
409
410        let m: RequestMessage = message.into();
411        let request_id = buffer.write(1, m, &channel).unwrap();
412        assert_eq!(request_id, 1);
413
414        assert_eq!(buffer.chunks.len(), 3);
415        // Use a fixed size buffer exactly half the chunk size. This simulates a TCP connection
416        // writing data in smaller chunks than configured chunk size.
417        let mut buf = [0u8; 4098];
418        // Cursor<&mut [u8; N]> doesn't support AsyncWrite, but Cursor<&mut [u8]> does.
419        let mut cursor = Cursor::new(&mut buf as &mut [u8]);
420
421        for _ in 0..2 {
422            println!("Encode chunks");
423            assert!(buffer.should_encode_chunks());
424            buffer.encode_next_chunk(&channel).unwrap();
425            assert!(!buffer.should_encode_chunks());
426            assert!(buffer.can_read());
427
428            buffer.read_into_async(&mut cursor).await.unwrap();
429            assert!(buffer.can_read());
430            assert_eq!(cursor.position(), 4098);
431            cursor.set_position(0);
432            buffer.read_into_async(&mut cursor).await.unwrap();
433            assert!(!buffer.can_read());
434            assert_eq!(cursor.position(), 4098);
435            cursor.set_position(0);
436        }
437        assert!(buffer.should_encode_chunks());
438        buffer.encode_next_chunk(&channel).unwrap();
439        assert!(buffer.can_read());
440        buffer.read_into_async(&mut cursor).await.unwrap();
441        assert!(cursor.position() < 4098);
442
443        assert!(!buffer.should_encode_chunks());
444        assert!(!buffer.can_read());
445    }
446}