1use std::{
5 collections::VecDeque,
6 io::{BufRead, Cursor},
7};
8
9use tokio::io::AsyncWriteExt;
10use tracing::trace;
11
12use crate::{
13 comms::{chunker::Chunker, message_chunk::MessageChunk, secure_channel::SecureChannel},
14 Message,
15};
16
17use opcua_types::{Error, SimpleBinaryEncodable, StatusCode};
18
19use super::{
20 sequence_number::SequenceNumberHandle,
21 tcp_types::{AcknowledgeMessage, ErrorMessage},
22};
23
24#[derive(Copy, Clone, Debug)]
25enum SendBufferState {
26 Reading(usize),
27 Writing,
28}
29
30#[derive(Debug)]
31enum PendingPayload {
32 Chunk(MessageChunk),
33 Ack(AcknowledgeMessage),
34 Error(ErrorMessage),
35}
36
37pub struct SendBuffer {
39 buffer: Cursor<Vec<u8>>,
41 chunks: VecDeque<PendingPayload>,
43 last_request_id: u32,
45 sequence_numbers: SequenceNumberHandle,
47 pub max_message_size: usize,
49 pub max_chunk_count: usize,
51 pub send_buffer_size: usize,
53
54 state: SendBufferState,
55}
56
57impl SendBuffer {
63 pub fn new(
65 buffer_size: usize,
66 max_message_size: usize,
67 max_chunk_count: usize,
68 sequence_numbers_legacy: bool,
69 ) -> Self {
70 Self {
71 buffer: Cursor::new(vec![0u8; buffer_size + 1024]),
72 chunks: VecDeque::with_capacity(max_chunk_count),
73 last_request_id: 1000,
74 sequence_numbers: SequenceNumberHandle::new(sequence_numbers_legacy),
75 max_message_size,
76 max_chunk_count,
77 send_buffer_size: buffer_size,
78 state: SendBufferState::Writing,
79 }
80 }
81
82 pub fn encode_next_chunk(&mut self, secure_channel: &SecureChannel) -> Result<(), StatusCode> {
84 if matches!(self.state, SendBufferState::Reading(_)) {
85 return Err(StatusCode::BadInvalidState);
86 }
87
88 let Some(next_chunk) = self.chunks.pop_front() else {
89 return Ok(());
90 };
91
92 let size = match next_chunk {
93 PendingPayload::Chunk(c) => secure_channel.apply_security(&c, self.buffer.get_mut())?,
94 PendingPayload::Ack(a) => {
95 a.encode(&mut self.buffer)?;
96 self.buffer.position() as usize
97 }
98 PendingPayload::Error(e) => {
99 e.encode(&mut self.buffer)?;
100 self.buffer.position() as usize
101 }
102 };
103 self.buffer.set_position(0);
104 self.state = SendBufferState::Reading(size);
105
106 Ok(())
107 }
108
109 pub fn set_sequence_number_legacy(&mut self, is_legacy: bool) {
112 self.sequence_numbers.set_is_legacy(is_legacy);
113 }
114
115 pub fn write_error(&mut self, error: ErrorMessage) {
118 self.chunks.clear();
120 self.chunks.push_back(PendingPayload::Error(error));
121 }
122
123 pub fn write_ack(&mut self, ack: AcknowledgeMessage) {
125 self.chunks.push_back(PendingPayload::Ack(ack));
126 }
127
128 pub fn write(
132 &mut self,
133 request_id: u32,
134 message: impl Message,
135 secure_channel: &SecureChannel,
136 ) -> Result<u32, Error> {
137 trace!("Writing request to buffer");
138
139 let chunks = Chunker::encode(
141 self.sequence_numbers.clone(),
142 request_id,
143 self.max_message_size,
144 self.send_buffer_size,
145 secure_channel,
146 &message,
147 )
148 .map_err(|e| e.with_context(Some(request_id), Some(message.request_handle())))?;
149
150 if self.max_chunk_count > 0 && chunks.len() > self.max_chunk_count {
151 Err(Error::new(
152 StatusCode::BadCommunicationError,
153 format!(
154 "Cannot write message since {} chunks exceeds {} chunk limit",
155 chunks.len(),
156 self.max_chunk_count
157 ),
158 )
159 .with_context(Some(request_id), Some(message.request_handle())))
160 } else {
161 self.sequence_numbers.increment(chunks.len() as u32);
163
164 self.chunks
166 .extend(chunks.into_iter().map(PendingPayload::Chunk));
167 Ok(request_id)
168 }
169 }
170
171 pub fn next_request_id(&mut self) -> u32 {
173 self.last_request_id += 1;
174 self.last_request_id
175 }
176
177 pub async fn read_into_async(
179 &mut self,
180 write: &mut (impl tokio::io::AsyncWrite + Unpin),
181 ) -> Result<(), tokio::io::Error> {
182 let end = match self.state {
184 SendBufferState::Writing => {
185 let end = self.buffer.position() as usize;
186 self.state = SendBufferState::Reading(end);
187 self.buffer.set_position(0);
188 end
189 }
190 SendBufferState::Reading(end) => end,
191 };
192
193 let pos = self.buffer.position() as usize;
194 let buf = &self.buffer.get_ref()[pos..end];
195 let written = write.write(buf).await?;
199
200 self.buffer.consume(written);
201
202 if end == self.buffer.position() as usize {
203 self.state = SendBufferState::Writing;
204 self.buffer.set_position(0);
205 }
206
207 Ok(())
208 }
209
210 pub fn should_encode_chunks(&self) -> bool {
212 !self.chunks.is_empty() && !self.can_read()
213 }
214
215 pub fn can_read(&self) -> bool {
217 matches!(self.state, SendBufferState::Reading(_)) || self.buffer.position() != 0
218 }
219
220 pub fn revise(
222 &mut self,
223 send_buffer_size: usize,
224 max_message_size: usize,
225 max_chunk_count: usize,
226 ) {
227 if self.send_buffer_size > send_buffer_size {
228 self.buffer.get_mut().shrink_to(send_buffer_size + 1024);
229 self.send_buffer_size = send_buffer_size;
230 }
231 if self.max_message_size > max_message_size && max_message_size > 0 {
232 self.max_message_size = max_message_size;
233 }
234 if self.max_chunk_count > max_chunk_count && max_chunk_count > 0 {
235 self.max_chunk_count = max_chunk_count;
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use std::io::Cursor;
243 use std::sync::Arc;
244
245 use parking_lot::RwLock;
246
247 use super::SendBuffer;
248
249 use crate::comms::secure_channel::{Role, SecureChannel};
250 use crate::RequestMessage;
251 use opcua_crypto::CertificateStore;
252 use opcua_types::StatusCode;
253 use opcua_types::{
254 DateTime, NodeId, ReadRequest, ReadValueId, RequestHeader, TimestampsToReturn,
255 };
256
257 fn get_buffer_and_channel() -> (SendBuffer, SecureChannel) {
258 let buffer = SendBuffer::new(8196, 81960, 5, true);
259 let channel = SecureChannel::new(
260 Arc::new(RwLock::new(CertificateStore::new(std::path::Path::new(
261 "./pki",
262 )))),
263 Role::Client,
264 Default::default(),
265 );
266
267 (buffer, channel)
268 }
269
270 #[tokio::test]
271 async fn test_buffer_simple() {
272 let message = ReadRequest {
274 request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
275 max_age: 0.0,
276 timestamps_to_return: TimestampsToReturn::Both,
277 nodes_to_read: Some(vec![ReadValueId {
278 node_id: (1, 1).into(),
279 attribute_id: 1,
280 ..Default::default()
281 }]),
282 };
283
284 let (mut buffer, channel) = get_buffer_and_channel();
285
286 let m: RequestMessage = message.into();
287 let request_id = buffer.write(1, m, &channel).unwrap();
288 assert_eq!(request_id, 1);
289
290 assert!(buffer.should_encode_chunks());
291 assert_eq!(buffer.chunks.len(), 1);
292 buffer.encode_next_chunk(&channel).unwrap();
293 assert!(buffer.can_read());
294
295 let mut cursor = Cursor::new(Vec::new());
296 buffer.read_into_async(&mut cursor).await.unwrap();
297 assert!(cursor.get_ref().len() > 50);
298 }
299
300 #[tokio::test]
301 async fn test_buffer_chunking() {
302 let message = ReadRequest {
304 request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
305 max_age: 0.0,
306 timestamps_to_return: TimestampsToReturn::Both,
307 nodes_to_read: Some(
308 (0..1000)
309 .map(|r| ReadValueId {
310 node_id: (1, r).into(),
311 attribute_id: 1,
312 ..Default::default()
313 })
314 .collect(),
315 ),
316 };
317
318 let (mut buffer, channel) = get_buffer_and_channel();
319
320 let m: RequestMessage = message.into();
321 let request_id = buffer.write(1, m, &channel).unwrap();
322 assert_eq!(request_id, 1);
323
324 assert_eq!(buffer.chunks.len(), 3);
325 let mut cursor = Cursor::new(Vec::new());
326
327 for _ in 0..3 {
328 assert!(buffer.should_encode_chunks());
329 buffer.encode_next_chunk(&channel).unwrap();
330 assert!(!buffer.should_encode_chunks());
331 assert!(buffer.can_read());
332
333 buffer.read_into_async(&mut cursor).await.unwrap();
334 }
335 assert!(!buffer.should_encode_chunks());
336 assert!(!buffer.can_read());
337 assert!(cursor.get_ref().len() > 8196 * 2 && cursor.get_ref().len() < 8196 * 3);
338 }
339
340 #[test]
341 fn test_buffer_too_large_message() {
342 let message = ReadRequest {
344 request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
345 max_age: 0.0,
346 timestamps_to_return: TimestampsToReturn::Both,
347 nodes_to_read: Some(
348 (0..10000)
349 .map(|r| ReadValueId {
350 node_id: (1, r).into(),
351 attribute_id: 1,
352 ..Default::default()
353 })
354 .collect(),
355 ),
356 };
357
358 let (mut buffer, channel) = get_buffer_and_channel();
359
360 let m: RequestMessage = message.into();
361 let err = buffer.write(1, m, &channel).unwrap_err();
362 assert_eq!(err.status(), StatusCode::BadRequestTooLarge);
363 }
364
365 #[test]
366 fn test_buffer_too_many_chunks() {
367 let message = ReadRequest {
369 request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
370 max_age: 0.0,
371 timestamps_to_return: TimestampsToReturn::Both,
372 nodes_to_read: Some(
373 (0..4000)
374 .map(|r| ReadValueId {
375 node_id: (1, r).into(),
376 attribute_id: 1,
377 ..Default::default()
378 })
379 .collect(),
380 ),
381 };
382
383 let (mut buffer, channel) = get_buffer_and_channel();
384
385 let m: RequestMessage = message.into();
386 let err = buffer.write(1, m, &channel).unwrap_err();
387 assert_eq!(err.status(), StatusCode::BadCommunicationError);
388 }
389
390 #[tokio::test]
391 async fn test_buffer_read_partial() {
392 let message = ReadRequest {
394 request_header: RequestHeader::new(&NodeId::null(), &DateTime::null(), 101),
395 max_age: 0.0,
396 timestamps_to_return: TimestampsToReturn::Both,
397 nodes_to_read: Some(
398 (0..1000)
399 .map(|r| ReadValueId {
400 node_id: (1, r).into(),
401 attribute_id: 1,
402 ..Default::default()
403 })
404 .collect(),
405 ),
406 };
407
408 let (mut buffer, channel) = get_buffer_and_channel();
409
410 let m: RequestMessage = message.into();
411 let request_id = buffer.write(1, m, &channel).unwrap();
412 assert_eq!(request_id, 1);
413
414 assert_eq!(buffer.chunks.len(), 3);
415 let mut buf = [0u8; 4098];
418 let mut cursor = Cursor::new(&mut buf as &mut [u8]);
420
421 for _ in 0..2 {
422 println!("Encode chunks");
423 assert!(buffer.should_encode_chunks());
424 buffer.encode_next_chunk(&channel).unwrap();
425 assert!(!buffer.should_encode_chunks());
426 assert!(buffer.can_read());
427
428 buffer.read_into_async(&mut cursor).await.unwrap();
429 assert!(buffer.can_read());
430 assert_eq!(cursor.position(), 4098);
431 cursor.set_position(0);
432 buffer.read_into_async(&mut cursor).await.unwrap();
433 assert!(!buffer.can_read());
434 assert_eq!(cursor.position(), 4098);
435 cursor.set_position(0);
436 }
437 assert!(buffer.should_encode_chunks());
438 buffer.encode_next_chunk(&channel).unwrap();
439 assert!(buffer.can_read());
440 buffer.read_into_async(&mut cursor).await.unwrap();
441 assert!(cursor.position() < 4098);
442
443 assert!(!buffer.should_encode_chunks());
444 assert!(!buffer.can_read());
445 }
446}