Skip to main content

procwire_client/
writer.rs

1//! Dedicated writer task for high-throughput frame sending.
2//!
3//! This module replaces the `Arc<Mutex<BoxedWriter>>` pattern with a dedicated
4//! writer task that receives frames via an mpsc channel. This eliminates lock
5//! contention and enables batching multiple frames into single syscalls.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Handler 1 ─┐
11//! Handler 2 ─┼─► mpsc::Sender<OutboundFrame> ─► Writer Task ─► Pipe
12//! Handler N ─┘
13//! ```
14//!
15//! # Benefits
16//!
17//! - **No lock contention**: Channel-based, not mutex-based
18//! - **Batching**: Multiple frames can be written in a single syscall via writev
19//! - **Backpressure**: Built-in pending count tracking with configurable limits
20
21use std::io::IoSlice;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use bytes::Bytes;
27use tokio::io::{AsyncWrite, AsyncWriteExt};
28use tokio::sync::mpsc;
29use tokio::task::JoinHandle;
30
31use crate::error::{ProcwireError, Result};
32use crate::protocol::{Header, HEADER_SIZE};
33
34/// Default maximum pending frames before backpressure kicks in.
35pub const DEFAULT_MAX_PENDING_FRAMES: usize = 1024;
36
37/// Default channel capacity.
38pub const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
39
40/// Default backpressure timeout.
41pub const DEFAULT_BACKPRESSURE_TIMEOUT: Duration = Duration::from_secs(5);
42
43/// Maximum frames to batch in a single write operation.
44const MAX_BATCH_SIZE: usize = 64;
45
46/// A frame ready to be written to the pipe.
47#[derive(Debug)]
48pub struct OutboundFrame {
49    /// Pre-encoded header (11 bytes).
50    pub header: [u8; HEADER_SIZE],
51    /// Payload bytes (can be empty for ACK, STREAM_END, etc.).
52    pub payload: Bytes,
53}
54
55impl OutboundFrame {
56    /// Create a new outbound frame.
57    #[inline]
58    pub fn new(header: &Header, payload: Bytes) -> Self {
59        Self {
60            header: header.encode(),
61            payload,
62        }
63    }
64
65    /// Create a new outbound frame with empty payload.
66    #[inline]
67    pub fn empty(header: &Header) -> Self {
68        Self {
69            header: header.encode(),
70            payload: Bytes::new(),
71        }
72    }
73
74    /// Total size of this frame (header + payload).
75    #[inline]
76    pub fn size(&self) -> usize {
77        HEADER_SIZE + self.payload.len()
78    }
79}
80
81/// Configuration for the writer task.
82#[derive(Debug, Clone)]
83pub struct WriterConfig {
84    /// Maximum pending frames before backpressure kicks in.
85    pub max_pending_frames: usize,
86    /// Channel capacity for frame queue.
87    pub channel_capacity: usize,
88    /// Timeout when waiting for backpressure to clear.
89    pub backpressure_timeout: Duration,
90}
91
92impl Default for WriterConfig {
93    fn default() -> Self {
94        Self {
95            max_pending_frames: DEFAULT_MAX_PENDING_FRAMES,
96            channel_capacity: DEFAULT_CHANNEL_CAPACITY,
97            backpressure_timeout: DEFAULT_BACKPRESSURE_TIMEOUT,
98        }
99    }
100}
101
102/// Handle for sending frames to the writer task.
103///
104/// This is cheaply cloneable and can be shared across multiple handlers.
105#[derive(Clone)]
106pub struct WriterHandle {
107    /// Channel sender for frames.
108    tx: mpsc::Sender<OutboundFrame>,
109    /// Pending frame count (for backpressure).
110    pending: Arc<AtomicUsize>,
111    /// Maximum pending frames.
112    max_pending: usize,
113    /// Backpressure timeout.
114    timeout: Duration,
115}
116
117impl WriterHandle {
118    /// Create a new writer handle.
119    fn new(
120        tx: mpsc::Sender<OutboundFrame>,
121        pending: Arc<AtomicUsize>,
122        max_pending: usize,
123        timeout: Duration,
124    ) -> Self {
125        Self {
126            tx,
127            pending,
128            max_pending,
129            timeout,
130        }
131    }
132
133    /// Send a frame to the writer task.
134    ///
135    /// This method will wait if backpressure is active, timing out after
136    /// the configured duration.
137    pub async fn send(&self, frame: OutboundFrame) -> Result<()> {
138        // Check backpressure
139        let current = self.pending.load(Ordering::Acquire);
140        if current >= self.max_pending {
141            // Wait with timeout for backpressure to clear
142            self.wait_for_backpressure().await?;
143        }
144
145        // Increment pending count BEFORE sending
146        self.pending.fetch_add(1, Ordering::AcqRel);
147
148        // Send to channel
149        self.tx.send(frame).await.map_err(|_| {
150            // Decrement on failure
151            self.pending.fetch_sub(1, Ordering::Release);
152            ProcwireError::ConnectionClosed
153        })
154    }
155
156    /// Wait for backpressure to clear with timeout.
157    async fn wait_for_backpressure(&self) -> Result<()> {
158        let start = Instant::now();
159        let check_interval = Duration::from_micros(100);
160
161        loop {
162            if self.pending.load(Ordering::Acquire) < self.max_pending {
163                return Ok(());
164            }
165
166            if start.elapsed() > self.timeout {
167                return Err(ProcwireError::BackpressureTimeout);
168            }
169
170            tokio::time::sleep(check_interval).await;
171        }
172    }
173
174    /// Check if backpressure is currently active.
175    #[inline]
176    pub fn is_backpressure_active(&self) -> bool {
177        self.pending.load(Ordering::Acquire) >= self.max_pending
178    }
179
180    /// Get current pending frame count.
181    #[inline]
182    pub fn pending_count(&self) -> usize {
183        self.pending.load(Ordering::Acquire)
184    }
185
186    /// Try to send a frame without waiting for backpressure.
187    ///
188    /// Returns `Err(BackpressureTimeout)` immediately if at capacity.
189    pub fn try_send(&self, frame: OutboundFrame) -> Result<()> {
190        let current = self.pending.load(Ordering::Acquire);
191        if current >= self.max_pending {
192            return Err(ProcwireError::BackpressureTimeout);
193        }
194
195        self.pending.fetch_add(1, Ordering::AcqRel);
196
197        self.tx.try_send(frame).map_err(|e| {
198            self.pending.fetch_sub(1, Ordering::Release);
199            match e {
200                mpsc::error::TrySendError::Full(_) => ProcwireError::BackpressureTimeout,
201                mpsc::error::TrySendError::Closed(_) => ProcwireError::ConnectionClosed,
202            }
203        })
204    }
205}
206
207/// Spawn the writer task and return a handle for sending frames.
208///
209/// # Arguments
210///
211/// * `writer` - The async writer (pipe write half)
212/// * `config` - Writer configuration
213///
214/// # Returns
215///
216/// A tuple of `(WriterHandle, JoinHandle)` where the JoinHandle can be used
217/// to wait for the writer task to complete.
218pub fn spawn_writer_task<W>(
219    writer: W,
220    config: WriterConfig,
221) -> (WriterHandle, JoinHandle<Result<()>>)
222where
223    W: AsyncWrite + Unpin + Send + 'static,
224{
225    let (tx, rx) = mpsc::channel(config.channel_capacity);
226    let pending = Arc::new(AtomicUsize::new(0));
227
228    let handle = WriterHandle::new(
229        tx,
230        pending.clone(),
231        config.max_pending_frames,
232        config.backpressure_timeout,
233    );
234
235    let task = tokio::spawn(writer_loop(rx, writer, pending));
236
237    (handle, task)
238}
239
240/// Spawn the writer task with default configuration.
241pub fn spawn_writer_task_default<W>(writer: W) -> (WriterHandle, JoinHandle<Result<()>>)
242where
243    W: AsyncWrite + Unpin + Send + 'static,
244{
245    spawn_writer_task(writer, WriterConfig::default())
246}
247
248/// Main writer loop - receives frames and writes them to the pipe.
249///
250/// Uses batching and scatter/gather I/O (writev) for efficiency.
251async fn writer_loop<W>(
252    mut rx: mpsc::Receiver<OutboundFrame>,
253    mut writer: W,
254    pending: Arc<AtomicUsize>,
255) -> Result<()>
256where
257    W: AsyncWrite + Unpin,
258{
259    loop {
260        // Wait for first frame
261        let first = match rx.recv().await {
262            Some(f) => f,
263            None => {
264                // Channel closed, clean shutdown
265                return Ok(());
266            }
267        };
268
269        // Collect additional ready frames (non-blocking)
270        let mut batch = Vec::with_capacity(MAX_BATCH_SIZE);
271        batch.push(first);
272
273        while batch.len() < MAX_BATCH_SIZE {
274            match rx.try_recv() {
275                Ok(frame) => batch.push(frame),
276                Err(_) => break,
277            }
278        }
279
280        // Write the batch
281        let batch_size = batch.len();
282        write_batch(&mut writer, &batch).await?;
283
284        // Update pending count
285        pending.fetch_sub(batch_size, Ordering::Release);
286    }
287}
288
289/// Write a batch of frames using scatter/gather I/O (write_vectored).
290///
291/// Always uses write_vectored for both single and multiple frames to minimize
292/// syscalls. For a single frame with payload, this reduces from 2-3 syscalls
293/// (header write, payload write, flush) to 1-2 syscalls (vectored write, flush).
294async fn write_batch<W>(writer: &mut W, batch: &[OutboundFrame]) -> Result<()>
295where
296    W: AsyncWrite + Unpin,
297{
298    if batch.is_empty() {
299        return Ok(());
300    }
301
302    // Build IoSlice array: each frame contributes 1-2 slices (header, optionally payload)
303    // Using write_vectored even for single frame to minimize syscalls
304    let mut slices: Vec<IoSlice<'_>> = Vec::with_capacity(batch.len() * 2);
305
306    for frame in batch {
307        slices.push(IoSlice::new(&frame.header));
308        if !frame.payload.is_empty() {
309            slices.push(IoSlice::new(&frame.payload));
310        }
311    }
312
313    // Calculate total size
314    let total_size: usize = batch.iter().map(|f| f.size()).sum();
315
316    // Fast path: try single write_vectored call first
317    // This is the common case when kernel buffer has enough space
318    let written = writer.write_vectored(&slices).await?;
319
320    if written == total_size {
321        // All data written in one syscall - optimal case
322        writer.flush().await?;
323        return Ok(());
324    }
325
326    if written == 0 {
327        return Err(ProcwireError::Io(std::io::Error::new(
328            std::io::ErrorKind::WriteZero,
329            "write_vectored returned 0",
330        )));
331    }
332
333    // Slow path: partial write, need to continue with remaining data
334    let mut total_written = written;
335
336    while total_written < total_size {
337        // Rebuild slices for remaining data
338        let remaining_slices = build_remaining_slices(batch, total_written);
339        if remaining_slices.is_empty() {
340            break;
341        }
342
343        let written = writer.write_vectored(&remaining_slices).await?;
344        if written == 0 {
345            return Err(ProcwireError::Io(std::io::Error::new(
346                std::io::ErrorKind::WriteZero,
347                "write_vectored returned 0",
348            )));
349        }
350
351        total_written += written;
352    }
353
354    writer.flush().await?;
355    Ok(())
356}
357
358/// Build IoSlice array for remaining data after partial write.
359fn build_remaining_slices(batch: &[OutboundFrame], skip_bytes: usize) -> Vec<IoSlice<'_>> {
360    let mut slices = Vec::with_capacity(batch.len() * 2);
361    let mut skipped = 0;
362
363    for frame in batch {
364        // Handle header
365        let header_start = skipped;
366        let header_end = skipped + HEADER_SIZE;
367
368        if skip_bytes < header_end {
369            let start_in_header = skip_bytes.saturating_sub(header_start);
370            slices.push(IoSlice::new(&frame.header[start_in_header..]));
371        }
372        skipped = header_end;
373
374        // Handle payload
375        if !frame.payload.is_empty() {
376            let payload_start = skipped;
377            let payload_end = skipped + frame.payload.len();
378
379            if skip_bytes < payload_end {
380                let start_in_payload = skip_bytes.saturating_sub(payload_start);
381                slices.push(IoSlice::new(&frame.payload[start_in_payload..]));
382            }
383            skipped = payload_end;
384        }
385    }
386
387    slices
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use std::io::Cursor;
394    use tokio::io::duplex;
395
396    #[test]
397    fn test_outbound_frame_creation() {
398        let header = Header::new(1, 0x03, 42, 5);
399        let payload = Bytes::from_static(b"hello");
400        let frame = OutboundFrame::new(&header, payload);
401
402        assert_eq!(frame.header.len(), HEADER_SIZE);
403        assert_eq!(frame.payload.len(), 5);
404        assert_eq!(frame.size(), HEADER_SIZE + 5);
405    }
406
407    #[test]
408    fn test_outbound_frame_empty() {
409        let header = Header::new(1, 0x23, 42, 0);
410        let frame = OutboundFrame::empty(&header);
411
412        assert!(frame.payload.is_empty());
413        assert_eq!(frame.size(), HEADER_SIZE);
414    }
415
416    #[test]
417    fn test_writer_config_default() {
418        let config = WriterConfig::default();
419        assert_eq!(config.max_pending_frames, DEFAULT_MAX_PENDING_FRAMES);
420        assert_eq!(config.channel_capacity, DEFAULT_CHANNEL_CAPACITY);
421        assert_eq!(config.backpressure_timeout, DEFAULT_BACKPRESSURE_TIMEOUT);
422    }
423
424    #[tokio::test]
425    async fn test_writer_handle_send() {
426        let (client, mut server) = duplex(4096);
427        let (handle, _task) = spawn_writer_task_default(client);
428
429        // Send a frame
430        let header = Header::new(1, 0x03, 42, 5);
431        let frame = OutboundFrame::new(&header, Bytes::from_static(b"hello"));
432        handle.send(frame).await.unwrap();
433
434        // Small delay for writer task to process
435        tokio::time::sleep(Duration::from_millis(10)).await;
436
437        // Read from server side
438        let mut buf = vec![0u8; 64];
439        let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
440            .await
441            .unwrap();
442
443        assert_eq!(n, HEADER_SIZE + 5);
444    }
445
446    #[tokio::test]
447    async fn test_writer_handle_pending_count() {
448        let (client, _server) = duplex(4096);
449        let config = WriterConfig {
450            max_pending_frames: 1000,
451            channel_capacity: 100,
452            backpressure_timeout: Duration::from_secs(1),
453        };
454        let (handle, _task) = spawn_writer_task(client, config);
455
456        assert_eq!(handle.pending_count(), 0);
457        assert!(!handle.is_backpressure_active());
458    }
459
460    #[tokio::test]
461    async fn test_writer_batching() {
462        let (client, mut server) = duplex(4096);
463        let (handle, _task) = spawn_writer_task_default(client);
464
465        // Send multiple frames quickly
466        for i in 0..10u32 {
467            let header = Header::new(1, 0x03, i, 4);
468            let payload = Bytes::copy_from_slice(&i.to_be_bytes());
469            let frame = OutboundFrame::new(&header, payload);
470            handle.send(frame).await.unwrap();
471        }
472
473        // Wait for writes to complete
474        tokio::time::sleep(Duration::from_millis(50)).await;
475
476        // Read all data
477        let mut buf = vec![0u8; 1024];
478        let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
479            .await
480            .unwrap();
481
482        // Should have received all 10 frames
483        let expected_size = 10 * (HEADER_SIZE + 4);
484        assert_eq!(n, expected_size);
485    }
486
487    #[tokio::test]
488    async fn test_try_send_at_capacity() {
489        let (tx, _rx) = mpsc::channel::<OutboundFrame>(10);
490        let pending = Arc::new(AtomicUsize::new(100)); // At capacity
491
492        let handle = WriterHandle::new(tx, pending, 100, Duration::from_secs(1));
493
494        let header = Header::new(1, 0x03, 42, 0);
495        let frame = OutboundFrame::empty(&header);
496
497        let result = handle.try_send(frame);
498        assert!(matches!(result, Err(ProcwireError::BackpressureTimeout)));
499    }
500
501    #[test]
502    fn test_build_remaining_slices_no_skip() {
503        let header = Header::new(1, 0x03, 42, 5);
504        let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
505
506        let slices = build_remaining_slices(&batch, 0);
507        assert_eq!(slices.len(), 2); // header + payload
508    }
509
510    #[test]
511    fn test_build_remaining_slices_partial_header() {
512        let header = Header::new(1, 0x03, 42, 5);
513        let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
514
515        let slices = build_remaining_slices(&batch, 5);
516        // Should have partial header (6 bytes) + full payload
517        assert_eq!(slices.len(), 2);
518        assert_eq!(slices[0].len(), HEADER_SIZE - 5);
519        assert_eq!(slices[1].len(), 5);
520    }
521
522    #[test]
523    fn test_build_remaining_slices_skip_header() {
524        let header = Header::new(1, 0x03, 42, 5);
525        let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
526
527        let slices = build_remaining_slices(&batch, HEADER_SIZE);
528        // Should have only payload
529        assert_eq!(slices.len(), 1);
530        assert_eq!(slices[0].len(), 5);
531    }
532
533    #[tokio::test]
534    async fn test_write_batch_single() {
535        let mut buf = Cursor::new(Vec::new());
536
537        let header = Header::new(1, 0x03, 42, 5);
538        let batch = vec![OutboundFrame::new(&header, Bytes::from_static(b"hello"))];
539
540        write_batch(&mut buf, &batch).await.unwrap();
541
542        let written = buf.into_inner();
543        assert_eq!(written.len(), HEADER_SIZE + 5);
544    }
545
546    #[tokio::test]
547    async fn test_write_batch_multiple() {
548        let mut buf = Cursor::new(Vec::new());
549
550        let batch: Vec<_> = (0..5)
551            .map(|i| {
552                let header = Header::new(1, 0x03, i, 3);
553                OutboundFrame::new(&header, Bytes::from_static(b"abc"))
554            })
555            .collect();
556
557        write_batch(&mut buf, &batch).await.unwrap();
558
559        let written = buf.into_inner();
560        assert_eq!(written.len(), 5 * (HEADER_SIZE + 3));
561    }
562
563    #[tokio::test]
564    async fn test_writer_shutdown_on_channel_close() {
565        let (client, _server) = duplex(4096);
566        let (handle, task) = spawn_writer_task_default(client);
567
568        // Drop the handle to close the channel
569        drop(handle);
570
571        // Writer task should complete cleanly
572        let result = task.await.unwrap();
573        assert!(result.is_ok());
574    }
575}