1use std::io::IoSlice;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use bytes::Bytes;
27use tokio::io::{AsyncWrite, AsyncWriteExt};
28use tokio::sync::mpsc;
29use tokio::task::JoinHandle;
30
31use crate::error::{ProcwireError, Result};
32use crate::protocol::{Header, HEADER_SIZE};
33
34pub const DEFAULT_MAX_PENDING_FRAMES: usize = 1024;
36
37pub const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
39
40pub const DEFAULT_BACKPRESSURE_TIMEOUT: Duration = Duration::from_secs(5);
42
43const MAX_BATCH_SIZE: usize = 64;
45
46#[derive(Debug)]
48pub struct OutboundFrame {
49 pub header: [u8; HEADER_SIZE],
51 pub payload: Bytes,
53}
54
55impl OutboundFrame {
56 #[inline]
58 pub fn new(header: &Header, payload: Bytes) -> Self {
59 Self {
60 header: header.encode(),
61 payload,
62 }
63 }
64
65 #[inline]
67 pub fn empty(header: &Header) -> Self {
68 Self {
69 header: header.encode(),
70 payload: Bytes::new(),
71 }
72 }
73
74 #[inline]
76 pub fn size(&self) -> usize {
77 HEADER_SIZE + self.payload.len()
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct WriterConfig {
84 pub max_pending_frames: usize,
86 pub channel_capacity: usize,
88 pub backpressure_timeout: Duration,
90}
91
92impl Default for WriterConfig {
93 fn default() -> Self {
94 Self {
95 max_pending_frames: DEFAULT_MAX_PENDING_FRAMES,
96 channel_capacity: DEFAULT_CHANNEL_CAPACITY,
97 backpressure_timeout: DEFAULT_BACKPRESSURE_TIMEOUT,
98 }
99 }
100}
101
102#[derive(Clone)]
106pub struct WriterHandle {
107 tx: mpsc::Sender<OutboundFrame>,
109 pending: Arc<AtomicUsize>,
111 max_pending: usize,
113 timeout: Duration,
115}
116
117impl WriterHandle {
118 fn new(
120 tx: mpsc::Sender<OutboundFrame>,
121 pending: Arc<AtomicUsize>,
122 max_pending: usize,
123 timeout: Duration,
124 ) -> Self {
125 Self {
126 tx,
127 pending,
128 max_pending,
129 timeout,
130 }
131 }
132
133 pub async fn send(&self, frame: OutboundFrame) -> Result<()> {
138 let current = self.pending.load(Ordering::Acquire);
140 if current >= self.max_pending {
141 self.wait_for_backpressure().await?;
143 }
144
145 self.pending.fetch_add(1, Ordering::AcqRel);
147
148 self.tx.send(frame).await.map_err(|_| {
150 self.pending.fetch_sub(1, Ordering::Release);
152 ProcwireError::ConnectionClosed
153 })
154 }
155
156 async fn wait_for_backpressure(&self) -> Result<()> {
158 let start = Instant::now();
159 let check_interval = Duration::from_micros(100);
160
161 loop {
162 if self.pending.load(Ordering::Acquire) < self.max_pending {
163 return Ok(());
164 }
165
166 if start.elapsed() > self.timeout {
167 return Err(ProcwireError::BackpressureTimeout);
168 }
169
170 tokio::time::sleep(check_interval).await;
171 }
172 }
173
174 #[inline]
176 pub fn is_backpressure_active(&self) -> bool {
177 self.pending.load(Ordering::Acquire) >= self.max_pending
178 }
179
180 #[inline]
182 pub fn pending_count(&self) -> usize {
183 self.pending.load(Ordering::Acquire)
184 }
185
186 pub fn try_send(&self, frame: OutboundFrame) -> Result<()> {
190 let current = self.pending.load(Ordering::Acquire);
191 if current >= self.max_pending {
192 return Err(ProcwireError::BackpressureTimeout);
193 }
194
195 self.pending.fetch_add(1, Ordering::AcqRel);
196
197 self.tx.try_send(frame).map_err(|e| {
198 self.pending.fetch_sub(1, Ordering::Release);
199 match e {
200 mpsc::error::TrySendError::Full(_) => ProcwireError::BackpressureTimeout,
201 mpsc::error::TrySendError::Closed(_) => ProcwireError::ConnectionClosed,
202 }
203 })
204 }
205}
206
207pub fn spawn_writer_task<W>(
219 writer: W,
220 config: WriterConfig,
221) -> (WriterHandle, JoinHandle<Result<()>>)
222where
223 W: AsyncWrite + Unpin + Send + 'static,
224{
225 let (tx, rx) = mpsc::channel(config.channel_capacity);
226 let pending = Arc::new(AtomicUsize::new(0));
227
228 let handle = WriterHandle::new(
229 tx,
230 pending.clone(),
231 config.max_pending_frames,
232 config.backpressure_timeout,
233 );
234
235 let task = tokio::spawn(writer_loop(rx, writer, pending));
236
237 (handle, task)
238}
239
240pub fn spawn_writer_task_default<W>(writer: W) -> (WriterHandle, JoinHandle<Result<()>>)
242where
243 W: AsyncWrite + Unpin + Send + 'static,
244{
245 spawn_writer_task(writer, WriterConfig::default())
246}
247
248async fn writer_loop<W>(
252 mut rx: mpsc::Receiver<OutboundFrame>,
253 mut writer: W,
254 pending: Arc<AtomicUsize>,
255) -> Result<()>
256where
257 W: AsyncWrite + Unpin,
258{
259 loop {
260 let first = match rx.recv().await {
262 Some(f) => f,
263 None => {
264 return Ok(());
266 }
267 };
268
269 let mut batch = Vec::with_capacity(MAX_BATCH_SIZE);
271 batch.push(first);
272
273 while batch.len() < MAX_BATCH_SIZE {
274 match rx.try_recv() {
275 Ok(frame) => batch.push(frame),
276 Err(_) => break,
277 }
278 }
279
280 let batch_size = batch.len();
282 write_batch(&mut writer, &batch).await?;
283
284 pending.fetch_sub(batch_size, Ordering::Release);
286 }
287}
288
289async fn write_batch<W>(writer: &mut W, batch: &[OutboundFrame]) -> Result<()>
295where
296 W: AsyncWrite + Unpin,
297{
298 if batch.is_empty() {
299 return Ok(());
300 }
301
302 let mut slices: Vec<IoSlice<'_>> = Vec::with_capacity(batch.len() * 2);
305
306 for frame in batch {
307 slices.push(IoSlice::new(&frame.header));
308 if !frame.payload.is_empty() {
309 slices.push(IoSlice::new(&frame.payload));
310 }
311 }
312
313 let total_size: usize = batch.iter().map(|f| f.size()).sum();
315
316 let written = writer.write_vectored(&slices).await?;
319
320 if written == total_size {
321 writer.flush().await?;
323 return Ok(());
324 }
325
326 if written == 0 {
327 return Err(ProcwireError::Io(std::io::Error::new(
328 std::io::ErrorKind::WriteZero,
329 "write_vectored returned 0",
330 )));
331 }
332
333 let mut total_written = written;
335
336 while total_written < total_size {
337 let remaining_slices = build_remaining_slices(batch, total_written);
339 if remaining_slices.is_empty() {
340 break;
341 }
342
343 let written = writer.write_vectored(&remaining_slices).await?;
344 if written == 0 {
345 return Err(ProcwireError::Io(std::io::Error::new(
346 std::io::ErrorKind::WriteZero,
347 "write_vectored returned 0",
348 )));
349 }
350
351 total_written += written;
352 }
353
354 writer.flush().await?;
355 Ok(())
356}
357
358fn build_remaining_slices(batch: &[OutboundFrame], skip_bytes: usize) -> Vec<IoSlice<'_>> {
360 let mut slices = Vec::with_capacity(batch.len() * 2);
361 let mut skipped = 0;
362
363 for frame in batch {
364 let header_start = skipped;
366 let header_end = skipped + HEADER_SIZE;
367
368 if skip_bytes < header_end {
369 let start_in_header = skip_bytes.saturating_sub(header_start);
370 slices.push(IoSlice::new(&frame.header[start_in_header..]));
371 }
372 skipped = header_end;
373
374 if !frame.payload.is_empty() {
376 let payload_start = skipped;
377 let payload_end = skipped + frame.payload.len();
378
379 if skip_bytes < payload_end {
380 let start_in_payload = skip_bytes.saturating_sub(payload_start);
381 slices.push(IoSlice::new(&frame.payload[start_in_payload..]));
382 }
383 skipped = payload_end;
384 }
385 }
386
387 slices
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use std::io::Cursor;
394 use tokio::io::duplex;
395
396 #[test]
397 fn test_outbound_frame_creation() {
398 let header = Header::new(1, 0x03, 42, 5);
399 let payload = Bytes::from_static(b"hello");
400 let frame = OutboundFrame::new(&header, payload);
401
402 assert_eq!(frame.header.len(), HEADER_SIZE);
403 assert_eq!(frame.payload.len(), 5);
404 assert_eq!(frame.size(), HEADER_SIZE + 5);
405 }
406
407 #[test]
408 fn test_outbound_frame_empty() {
409 let header = Header::new(1, 0x23, 42, 0);
410 let frame = OutboundFrame::empty(&header);
411
412 assert!(frame.payload.is_empty());
413 assert_eq!(frame.size(), HEADER_SIZE);
414 }
415
416 #[test]
417 fn test_writer_config_default() {
418 let config = WriterConfig::default();
419 assert_eq!(config.max_pending_frames, DEFAULT_MAX_PENDING_FRAMES);
420 assert_eq!(config.channel_capacity, DEFAULT_CHANNEL_CAPACITY);
421 assert_eq!(config.backpressure_timeout, DEFAULT_BACKPRESSURE_TIMEOUT);
422 }
423
424 #[tokio::test]
425 async fn test_writer_handle_send() {
426 let (client, mut server) = duplex(4096);
427 let (handle, _task) = spawn_writer_task_default(client);
428
429 let header = Header::new(1, 0x03, 42, 5);
431 let frame = OutboundFrame::new(&header, Bytes::from_static(b"hello"));
432 handle.send(frame).await.unwrap();
433
434 tokio::time::sleep(Duration::from_millis(10)).await;
436
437 let mut buf = vec![0u8; 64];
439 let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
440 .await
441 .unwrap();
442
443 assert_eq!(n, HEADER_SIZE + 5);
444 }
445
446 #[tokio::test]
447 async fn test_writer_handle_pending_count() {
448 let (client, _server) = duplex(4096);
449 let config = WriterConfig {
450 max_pending_frames: 1000,
451 channel_capacity: 100,
452 backpressure_timeout: Duration::from_secs(1),
453 };
454 let (handle, _task) = spawn_writer_task(client, config);
455
456 assert_eq!(handle.pending_count(), 0);
457 assert!(!handle.is_backpressure_active());
458 }
459
460 #[tokio::test]
461 async fn test_writer_batching() {
462 let (client, mut server) = duplex(4096);
463 let (handle, _task) = spawn_writer_task_default(client);
464
465 for i in 0..10u32 {
467 let header = Header::new(1, 0x03, i, 4);
468 let payload = Bytes::copy_from_slice(&i.to_be_bytes());
469 let frame = OutboundFrame::new(&header, payload);
470 handle.send(frame).await.unwrap();
471 }
472
473 tokio::time::sleep(Duration::from_millis(50)).await;
475
476 let mut buf = vec![0u8; 1024];
478 let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
479 .await
480 .unwrap();
481
482 let expected_size = 10 * (HEADER_SIZE + 4);
484 assert_eq!(n, expected_size);
485 }
486
487 #[tokio::test]
488 async fn test_try_send_at_capacity() {
489 let (tx, _rx) = mpsc::channel::<OutboundFrame>(10);
490 let pending = Arc::new(AtomicUsize::new(100)); let handle = WriterHandle::new(tx, pending, 100, Duration::from_secs(1));
493
494 let header = Header::new(1, 0x03, 42, 0);
495 let frame = OutboundFrame::empty(&header);
496
497 let result = handle.try_send(frame);
498 assert!(matches!(result, Err(ProcwireError::BackpressureTimeout)));
499 }
500
501 #[test]
502 fn test_build_remaining_slices_no_skip() {
503 let header = Header::new(1, 0x03, 42, 5);
504 let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
505
506 let slices = build_remaining_slices(&batch, 0);
507 assert_eq!(slices.len(), 2); }
509
510 #[test]
511 fn test_build_remaining_slices_partial_header() {
512 let header = Header::new(1, 0x03, 42, 5);
513 let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
514
515 let slices = build_remaining_slices(&batch, 5);
516 assert_eq!(slices.len(), 2);
518 assert_eq!(slices[0].len(), HEADER_SIZE - 5);
519 assert_eq!(slices[1].len(), 5);
520 }
521
522 #[test]
523 fn test_build_remaining_slices_skip_header() {
524 let header = Header::new(1, 0x03, 42, 5);
525 let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
526
527 let slices = build_remaining_slices(&batch, HEADER_SIZE);
528 assert_eq!(slices.len(), 1);
530 assert_eq!(slices[0].len(), 5);
531 }
532
533 #[tokio::test]
534 async fn test_write_batch_single() {
535 let mut buf = Cursor::new(Vec::new());
536
537 let header = Header::new(1, 0x03, 42, 5);
538 let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
539
540 write_batch(&mut buf, &batch).await.unwrap();
541
542 let written = buf.into_inner();
543 assert_eq!(written.len(), HEADER_SIZE + 5);
544 }
545
546 #[tokio::test]
547 async fn test_write_batch_multiple() {
548 let mut buf = Cursor::new(Vec::new());
549
550 let batch: Vec<_> = (0..5)
551 .map(|i| {
552 let header = Header::new(1, 0x03, i, 3);
553 OutboundFrame::new(&header, Bytes::from_static(b"abc"))
554 })
555 .collect();
556
557 write_batch(&mut buf, &batch).await.unwrap();
558
559 let written = buf.into_inner();
560 assert_eq!(written.len(), 5 * (HEADER_SIZE + 3));
561 }
562
563 #[tokio::test]
564 async fn test_writer_shutdown_on_channel_close() {
565 let (client, _server) = duplex(4096);
566 let (handle, task) = spawn_writer_task_default(client);
567
568 drop(handle);
570
571 let result = task.await.unwrap();
573 assert!(result.is_ok());
574 }
575}