Skip to main content

procwire_client/
client.rs

1//! Client builder and runtime loop.
2//!
3//! The [`ClientBuilder`] provides a fluent API for configuring handlers
4//! and building the client. The [`Client`] manages the lifecycle:
5//! 1. Create pipe listener
6//! 2. Send `$init` via stdout
7//! 3. Accept parent connection
8//! 4. Read frames and dispatch to handlers
9//!
10//! # Example
11//!
12//! ```ignore
13//! use procwire_client::Client;
14//!
15//! #[tokio::main]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//!     let client = Client::builder()
18//!         .handle("echo", |data: String, ctx| async move {
19//!             ctx.respond(&data).await
20//!         })
21//!         .handle_stream("count", |n: i32, ctx| async move {
22//!             for i in 0..n {
23//!                 ctx.chunk(&i).await?;
24//!             }
25//!             ctx.end().await
26//!         })
27//!         .event("progress")
28//!         .start()
29//!         .await?;
30//!
31//!     client.wait_for_shutdown().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use std::collections::HashMap;
37use std::future::Future;
38use std::sync::Arc;
39
40use bytes::Bytes;
41use serde::de::DeserializeOwned;
42use tokio::sync::{oneshot, RwLock, Semaphore};
43use tokio::task::JoinHandle;
44use tokio_util::sync::CancellationToken;
45
46use crate::codec::MsgPackCodec;
47use crate::control::{build_init_message, write_stdout_line, ResponseType};
48use crate::error::{ProcwireError, Result};
49use crate::handler::{HandlerRegistry, HandlerResult, RequestContext};
50use crate::protocol::{flags, FrameBuffer, Header, ABORT_METHOD_ID};
51use crate::transport::{generate_pipe_path, PipeListener};
52use crate::writer::{spawn_writer_task, OutboundFrame, WriterConfig, WriterHandle};
53
54/// Default maximum concurrent handlers.
55pub const DEFAULT_MAX_CONCURRENT_HANDLERS: usize = 256;
56
57/// Builder for configuring and creating a Procwire client.
58///
59/// Use the fluent API to register handlers and events, then call `start()`
60/// to begin the client lifecycle.
61pub struct ClientBuilder {
62    registry: HandlerRegistry,
63    writer_config: WriterConfig,
64    max_concurrent_handlers: usize,
65}
66
67impl ClientBuilder {
68    /// Create a new client builder.
69    pub fn new() -> Self {
70        Self {
71            registry: HandlerRegistry::new(),
72            writer_config: WriterConfig::default(),
73            max_concurrent_handlers: DEFAULT_MAX_CONCURRENT_HANDLERS,
74        }
75    }
76
77    /// Register a method handler with "result" response type.
78    ///
79    /// The handler receives deserialized payload and a context for responding.
80    pub fn handle<F, T, Fut>(mut self, method: &str, handler: F) -> Self
81    where
82        F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
83        T: DeserializeOwned + Send + 'static,
84        Fut: Future<Output = HandlerResult> + Send + 'static,
85    {
86        self.registry
87            .register(method, ResponseType::Result, handler);
88        self
89    }
90
91    /// Register a method handler with "stream" response type.
92    ///
93    /// Use `ctx.chunk()` to send stream chunks and `ctx.end()` to finish.
94    pub fn handle_stream<F, T, Fut>(mut self, method: &str, handler: F) -> Self
95    where
96        F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
97        T: DeserializeOwned + Send + 'static,
98        Fut: Future<Output = HandlerResult> + Send + 'static,
99    {
100        self.registry
101            .register(method, ResponseType::Stream, handler);
102        self
103    }
104
105    /// Register a method handler with "ack" response type.
106    ///
107    /// Use `ctx.ack()` to send the acknowledgment.
108    pub fn handle_ack<F, T, Fut>(mut self, method: &str, handler: F) -> Self
109    where
110        F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
111        T: DeserializeOwned + Send + 'static,
112        Fut: Future<Output = HandlerResult> + Send + 'static,
113    {
114        self.registry.register(method, ResponseType::Ack, handler);
115        self
116    }
117
118    /// Register an event that this client can emit.
119    ///
120    /// Events are fire-and-forget messages to the parent.
121    pub fn event(mut self, name: &str) -> Self {
122        self.registry.register_event(name);
123        self
124    }
125
126    /// Set the maximum number of concurrent handlers.
127    ///
128    /// When this limit is reached, new requests will be dropped with a warning.
129    /// Default: 256
130    pub fn max_concurrent_handlers(mut self, limit: usize) -> Self {
131        self.max_concurrent_handlers = limit;
132        self
133    }
134
135    /// Set the maximum pending frames for backpressure.
136    ///
137    /// When this limit is reached, response methods will wait until
138    /// backpressure clears or timeout.
139    /// Default: 1024
140    pub fn max_pending_frames(mut self, limit: usize) -> Self {
141        self.writer_config.max_pending_frames = limit;
142        self
143    }
144
145    /// Set the writer channel capacity.
146    ///
147    /// Default: 1024
148    pub fn channel_capacity(mut self, capacity: usize) -> Self {
149        self.writer_config.channel_capacity = capacity;
150        self
151    }
152
153    /// Set the backpressure timeout.
154    ///
155    /// Default: 5 seconds
156    pub fn backpressure_timeout(mut self, timeout: std::time::Duration) -> Self {
157        self.writer_config.backpressure_timeout = timeout;
158        self
159    }
160
161    /// Build and start the client.
162    ///
163    /// This will:
164    /// 1. Generate pipe path
165    /// 2. Start pipe listener
166    /// 3. Send `$init` to parent (stdout)
167    /// 4. Accept parent connection
168    /// 5. Start frame processing loop
169    pub async fn start(self) -> Result<Client> {
170        Client::start(
171            self.registry,
172            self.writer_config,
173            self.max_concurrent_handlers,
174        )
175        .await
176    }
177}
178
179impl Default for ClientBuilder {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185/// Active request context holder for ABORT handling.
186struct ActiveContext {
187    /// Cancellation token to signal abort.
188    cancellation_token: CancellationToken,
189}
190
191/// A running Procwire client.
192///
193/// Use `emit()` to send events to the parent.
194/// Use `wait_for_shutdown()` to block until the connection closes.
195pub struct Client {
196    /// Registry of handlers.
197    registry: Arc<HandlerRegistry>,
198    /// Writer handle for sending frames.
199    writer: WriterHandle,
200    /// Shutdown signal receiver.
201    shutdown_rx: oneshot::Receiver<()>,
202    /// Writer task handle.
203    _writer_task: JoinHandle<Result<()>>,
204    /// Active request contexts for ABORT handling.
205    /// Maps request_id -> ActiveContext
206    _active_contexts: Arc<RwLock<HashMap<u32, ActiveContext>>>,
207}
208
209impl Client {
210    /// Create a new client builder.
211    pub fn builder() -> ClientBuilder {
212        ClientBuilder::new()
213    }
214
215    /// Start the client with the given registry and configuration.
216    async fn start(
217        registry: HandlerRegistry,
218        writer_config: WriterConfig,
219        max_concurrent_handlers: usize,
220    ) -> Result<Self> {
221        // 1. Generate pipe path
222        let pipe_path = generate_pipe_path();
223
224        // 2. Start pipe listener
225        let listener = PipeListener::bind(&pipe_path).await?;
226
227        // 3. Build schema from registry
228        let schema = registry.build_schema();
229
230        // 4. Send $init to parent (stdout)
231        let init_msg = build_init_message(&pipe_path, &schema);
232        write_stdout_line(&init_msg)?;
233
234        // 5. Accept parent connection
235        let stream = listener.accept().await?;
236
237        // 6. Split stream into reader and writer
238        let (reader, write_half) = stream.into_split();
239
240        // 7. Spawn writer task (replaces Arc<Mutex<Writer>>)
241        let (writer, writer_task) = spawn_writer_task(write_half, writer_config);
242
243        // 8. Create handler semaphore
244        let handler_semaphore = Arc::new(Semaphore::new(max_concurrent_handlers));
245
246        // 9. Create active contexts map for ABORT handling
247        let active_contexts = Arc::new(RwLock::new(HashMap::new()));
248
249        // 10. Spawn read loop
250        let (shutdown_tx, shutdown_rx) = oneshot::channel();
251        let registry = Arc::new(registry);
252        let writer_clone = writer.clone();
253        let registry_clone = registry.clone();
254        let active_contexts_clone = active_contexts.clone();
255
256        tokio::spawn(async move {
257            if let Err(e) = Self::read_loop(
258                reader,
259                registry_clone,
260                writer_clone,
261                handler_semaphore,
262                active_contexts_clone,
263            )
264            .await
265            {
266                tracing::error!("Read loop error: {}", e);
267            }
268            let _ = shutdown_tx.send(());
269        });
270
271        Ok(Client {
272            registry,
273            writer,
274            shutdown_rx,
275            _writer_task: writer_task,
276            _active_contexts: active_contexts,
277        })
278    }
279
280    /// Main read loop - reads frames and dispatches to handlers.
281    async fn read_loop<R: tokio::io::AsyncRead + Unpin>(
282        mut reader: R,
283        registry: Arc<HandlerRegistry>,
284        writer: WriterHandle,
285        semaphore: Arc<Semaphore>,
286        active_contexts: Arc<RwLock<HashMap<u32, ActiveContext>>>,
287    ) -> Result<()> {
288        use tokio::io::AsyncReadExt;
289
290        let mut frame_buffer = FrameBuffer::new();
291        let mut buf = vec![0u8; 64 * 1024]; // 64KB read buffer
292
293        loop {
294            let n = match reader.read(&mut buf).await {
295                Ok(0) => return Ok(()), // Connection closed
296                Ok(n) => n,
297                Err(e) => return Err(ProcwireError::Io(e)),
298            };
299
300            // Parse frames
301            let frames = frame_buffer.push(&buf[..n])?;
302
303            // Dispatch each frame
304            for frame in frames {
305                Self::dispatch_frame(&frame, &registry, &writer, &semaphore, &active_contexts)
306                    .await;
307            }
308        }
309    }
310
311    /// Dispatch a single frame to its handler.
312    async fn dispatch_frame(
313        frame: &crate::protocol::Frame,
314        registry: &Arc<HandlerRegistry>,
315        writer: &WriterHandle,
316        semaphore: &Arc<Semaphore>,
317        active_contexts: &Arc<RwLock<HashMap<u32, ActiveContext>>>,
318    ) {
319        let header = &frame.header;
320
321        // Handle ABORT signal
322        if header.method_id == ABORT_METHOD_ID {
323            tracing::debug!("Received ABORT for request {}", header.request_id);
324
325            // Find and cancel the active context
326            let contexts = active_contexts.read().await;
327            if let Some(ctx) = contexts.get(&header.request_id) {
328                ctx.cancellation_token.cancel();
329                tracing::debug!("Cancelled request {}", header.request_id);
330            } else {
331                tracing::warn!(
332                    "ABORT for unknown or completed request {}",
333                    header.request_id
334                );
335            }
336            return;
337        }
338
339        // Skip responses (we only handle requests)
340        if header.is_response() {
341            tracing::warn!("Received unexpected response frame");
342            return;
343        }
344
345        // Try to acquire semaphore permit
346        let permit = match semaphore.clone().try_acquire_owned() {
347            Ok(p) => p,
348            Err(_) => {
349                tracing::warn!(
350                    "Handler capacity reached, dropping request {} for method {}",
351                    header.request_id,
352                    header.method_id
353                );
354                return;
355            }
356        };
357
358        // Create cancellation token for this request
359        let cancellation_token = CancellationToken::new();
360
361        // Register active context for abort handling
362        {
363            let mut contexts = active_contexts.write().await;
364            contexts.insert(
365                header.request_id,
366                ActiveContext {
367                    cancellation_token: cancellation_token.clone(),
368                },
369            );
370        }
371
372        // Create context for handler with the cancellation token
373        let ctx = RequestContext::with_writer_and_token(
374            header.method_id,
375            header.request_id,
376            writer.clone(),
377            cancellation_token,
378        );
379
380        // Get payload
381        let payload = frame.payload.clone();
382
383        // Clone what we need for the spawned task
384        let registry = registry.clone();
385        let method_id = header.method_id;
386        let request_id = header.request_id;
387        let active_contexts = active_contexts.clone();
388
389        // Spawn handler task
390        tokio::spawn(async move {
391            // Permit is held until this task completes
392            let _permit = permit;
393
394            match registry.dispatch(method_id, &payload, ctx).await {
395                Ok(()) => {}
396                Err(e) => {
397                    tracing::error!("Handler error for method {}: {}", method_id, e);
398                }
399            }
400
401            // Remove from active contexts when handler completes
402            let mut contexts = active_contexts.write().await;
403            contexts.remove(&request_id);
404        });
405    }
406
407    /// Emit an event to the parent (fire-and-forget).
408    ///
409    /// Events are one-way messages that don't expect a response.
410    pub async fn emit<T: serde::Serialize>(&self, event: &str, data: &T) -> Result<()> {
411        let event_id = self
412            .registry
413            .get_event_id(event)
414            .ok_or_else(|| ProcwireError::Protocol(format!("Unknown event: {}", event)))?;
415
416        let payload = MsgPackCodec::encode(data)?;
417
418        let header = Header::new(
419            event_id,
420            flags::DIRECTION_TO_PARENT, // Event, not a response
421            0,                          // Events have request_id = 0
422            payload.len() as u32,
423        );
424
425        let frame = OutboundFrame::new(&header, Bytes::from(payload));
426        self.writer.send(frame).await
427    }
428
429    /// Emit an event with raw bytes payload.
430    pub async fn emit_raw(&self, event: &str, data: &[u8]) -> Result<()> {
431        let event_id = self
432            .registry
433            .get_event_id(event)
434            .ok_or_else(|| ProcwireError::Protocol(format!("Unknown event: {}", event)))?;
435
436        let header = Header::new(event_id, flags::DIRECTION_TO_PARENT, 0, data.len() as u32);
437
438        let frame = OutboundFrame::new(&header, Bytes::copy_from_slice(data));
439        self.writer.send(frame).await
440    }
441
442    /// Get the current backpressure status.
443    pub fn is_backpressure_active(&self) -> bool {
444        self.writer.is_backpressure_active()
445    }
446
447    /// Get the current pending frame count.
448    pub fn pending_frames(&self) -> usize {
449        self.writer.pending_count()
450    }
451
452    /// Wait for shutdown (pipe close or parent kill).
453    ///
454    /// This consumes the client and blocks until the connection closes.
455    pub async fn wait_for_shutdown(self) -> Result<()> {
456        let _ = self.shutdown_rx.await;
457        Ok(())
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_builder_creation() {
467        let builder = ClientBuilder::new();
468        // Just verify it compiles and creates
469        let _ = builder;
470    }
471
472    #[test]
473    fn test_builder_default() {
474        let builder = ClientBuilder::default();
475        let _ = builder;
476    }
477
478    #[test]
479    fn test_builder_method_chaining() {
480        let builder = Client::builder()
481            .handle("echo", |_data: String, _ctx| async { Ok(()) })
482            .handle_stream("stream", |_data: i32, _ctx| async { Ok(()) })
483            .handle_ack("ack", |_data: (), _ctx| async { Ok(()) })
484            .event("progress");
485
486        // Verify registry was populated
487        let schema = builder.registry.build_schema();
488        assert!(schema.get_method("echo").is_some());
489        assert!(schema.get_method("stream").is_some());
490        assert!(schema.get_method("ack").is_some());
491        assert!(schema.get_event("progress").is_some());
492    }
493
494    #[test]
495    fn test_builder_response_types() {
496        let builder = Client::builder()
497            .handle("result", |_: (), _ctx| async { Ok(()) })
498            .handle_stream("stream", |_: (), _ctx| async { Ok(()) })
499            .handle_ack("ack", |_: (), _ctx| async { Ok(()) });
500
501        assert_eq!(
502            builder.registry.get_response_type("result"),
503            Some(ResponseType::Result)
504        );
505        assert_eq!(
506            builder.registry.get_response_type("stream"),
507            Some(ResponseType::Stream)
508        );
509        assert_eq!(
510            builder.registry.get_response_type("ack"),
511            Some(ResponseType::Ack)
512        );
513    }
514
515    #[test]
516    fn test_builder_configuration() {
517        let builder = Client::builder()
518            .max_concurrent_handlers(512)
519            .max_pending_frames(2048)
520            .channel_capacity(512)
521            .backpressure_timeout(std::time::Duration::from_secs(10));
522
523        assert_eq!(builder.max_concurrent_handlers, 512);
524        assert_eq!(builder.writer_config.max_pending_frames, 2048);
525        assert_eq!(builder.writer_config.channel_capacity, 512);
526        assert_eq!(
527            builder.writer_config.backpressure_timeout,
528            std::time::Duration::from_secs(10)
529        );
530    }
531
532    #[tokio::test]
533    async fn test_abort_cancels_active_handler() {
534        use crate::protocol::{Frame, Header, ABORT_METHOD_ID};
535
536        // Setup: Create active_contexts map and add a context
537        let active_contexts = Arc::new(RwLock::new(HashMap::new()));
538        let cancellation_token = CancellationToken::new();
539
540        {
541            let mut contexts = active_contexts.write().await;
542            contexts.insert(
543                42, // request_id
544                ActiveContext {
545                    cancellation_token: cancellation_token.clone(),
546                },
547            );
548        }
549
550        // Verify not cancelled initially
551        assert!(!cancellation_token.is_cancelled());
552
553        // Create ABORT frame
554        let abort_header = Header::new(ABORT_METHOD_ID, 0, 42, 0);
555        let abort_frame = Frame::new(abort_header, bytes::Bytes::new());
556
557        // Create minimal mocks for dispatch_frame
558        let registry = Arc::new(HandlerRegistry::new());
559        let (client, _server) = tokio::io::duplex(4096);
560        let (writer, _task) =
561            crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
562        let semaphore = Arc::new(Semaphore::new(256));
563
564        // Dispatch ABORT frame
565        Client::dispatch_frame(
566            &abort_frame,
567            &registry,
568            &writer,
569            &semaphore,
570            &active_contexts,
571        )
572        .await;
573
574        // Verify cancellation was triggered
575        assert!(cancellation_token.is_cancelled());
576    }
577
578    #[tokio::test]
579    async fn test_abort_for_unknown_request_logs_warning() {
580        use crate::protocol::{Frame, Header, ABORT_METHOD_ID};
581
582        // Setup: Create empty active_contexts map
583        let active_contexts = Arc::new(RwLock::new(HashMap::new()));
584
585        // Create ABORT frame for non-existent request
586        let abort_header = Header::new(ABORT_METHOD_ID, 0, 999, 0);
587        let abort_frame = Frame::new(abort_header, bytes::Bytes::new());
588
589        // Create minimal mocks
590        let registry = Arc::new(HandlerRegistry::new());
591        let (client, _server) = tokio::io::duplex(4096);
592        let (writer, _task) =
593            crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
594        let semaphore = Arc::new(Semaphore::new(256));
595
596        // Dispatch ABORT frame - should not panic, just log warning
597        Client::dispatch_frame(
598            &abort_frame,
599            &registry,
600            &writer,
601            &semaphore,
602            &active_contexts,
603        )
604        .await;
605
606        // No assertion needed - we just verify it doesn't panic
607    }
608
609    #[tokio::test]
610    async fn test_handler_context_is_removed_after_completion() {
611        use crate::protocol::{Frame, Header};
612        use std::sync::atomic::{AtomicBool, Ordering};
613        use std::time::Duration;
614
615        // Setup
616        let active_contexts = Arc::new(RwLock::new(HashMap::new()));
617        let handler_started = Arc::new(AtomicBool::new(false));
618        let handler_completed = Arc::new(AtomicBool::new(false));
619
620        let handler_started_clone = handler_started.clone();
621        let handler_completed_clone = handler_completed.clone();
622
623        // Create registry with a handler that signals when it runs
624        let mut registry = HandlerRegistry::new();
625        registry.register(
626            "test",
627            crate::control::ResponseType::Result,
628            move |_: (), ctx: RequestContext| {
629                let started = handler_started_clone.clone();
630                let completed = handler_completed_clone.clone();
631                async move {
632                    started.store(true, Ordering::SeqCst);
633                    // Small delay to allow test to check active_contexts
634                    tokio::time::sleep(Duration::from_millis(10)).await;
635                    ctx.respond(&"done").await?;
636                    completed.store(true, Ordering::SeqCst);
637                    Ok(())
638                }
639            },
640        );
641
642        let registry = Arc::new(registry);
643
644        // Create test writer
645        let (client, _server) = tokio::io::duplex(4096);
646        let (writer, _task) =
647            crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
648        let semaphore = Arc::new(Semaphore::new(256));
649
650        // Get the method ID assigned to "test"
651        let method_id = registry.get_method_id("test").unwrap();
652
653        // Create request frame with empty MsgPack payload for ()
654        let payload = crate::codec::MsgPackCodec::encode(&()).unwrap();
655        let header = Header::new(method_id, 0, 123, payload.len() as u32);
656        let frame = Frame::new(header, bytes::Bytes::from(payload));
657
658        // Dispatch frame
659        Client::dispatch_frame(&frame, &registry, &writer, &semaphore, &active_contexts).await;
660
661        // Wait for handler to start
662        tokio::time::timeout(Duration::from_millis(100), async {
663            while !handler_started.load(Ordering::SeqCst) {
664                tokio::task::yield_now().await;
665            }
666        })
667        .await
668        .expect("Handler should start");
669
670        // Context should be in active_contexts
671        {
672            let contexts = active_contexts.read().await;
673            assert!(
674                contexts.contains_key(&123),
675                "Context should be active while handler runs"
676            );
677        }
678
679        // Wait for handler to complete
680        tokio::time::timeout(Duration::from_millis(100), async {
681            while !handler_completed.load(Ordering::SeqCst) {
682                tokio::task::yield_now().await;
683            }
684        })
685        .await
686        .expect("Handler should complete");
687
688        // Give a bit of time for cleanup
689        tokio::time::sleep(Duration::from_millis(10)).await;
690
691        // Context should be removed from active_contexts
692        {
693            let contexts = active_contexts.read().await;
694            assert!(
695                !contexts.contains_key(&123),
696                "Context should be removed after handler completes"
697            );
698        }
699    }
700}