1use std::collections::HashMap;
37use std::future::Future;
38use std::sync::Arc;
39
40use bytes::Bytes;
41use serde::de::DeserializeOwned;
42use tokio::sync::{oneshot, RwLock, Semaphore};
43use tokio::task::JoinHandle;
44use tokio_util::sync::CancellationToken;
45
46use crate::codec::MsgPackCodec;
47use crate::control::{build_init_message, write_stdout_line, ResponseType};
48use crate::error::{ProcwireError, Result};
49use crate::handler::{HandlerRegistry, HandlerResult, RequestContext};
50use crate::protocol::{flags, FrameBuffer, Header, ABORT_METHOD_ID};
51use crate::transport::{generate_pipe_path, PipeListener};
52use crate::writer::{spawn_writer_task, OutboundFrame, WriterConfig, WriterHandle};
53
54pub const DEFAULT_MAX_CONCURRENT_HANDLERS: usize = 256;
56
57pub struct ClientBuilder {
62 registry: HandlerRegistry,
63 writer_config: WriterConfig,
64 max_concurrent_handlers: usize,
65}
66
67impl ClientBuilder {
68 pub fn new() -> Self {
70 Self {
71 registry: HandlerRegistry::new(),
72 writer_config: WriterConfig::default(),
73 max_concurrent_handlers: DEFAULT_MAX_CONCURRENT_HANDLERS,
74 }
75 }
76
77 pub fn handle<F, T, Fut>(mut self, method: &str, handler: F) -> Self
81 where
82 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
83 T: DeserializeOwned + Send + 'static,
84 Fut: Future<Output = HandlerResult> + Send + 'static,
85 {
86 self.registry
87 .register(method, ResponseType::Result, handler);
88 self
89 }
90
91 pub fn handle_stream<F, T, Fut>(mut self, method: &str, handler: F) -> Self
95 where
96 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
97 T: DeserializeOwned + Send + 'static,
98 Fut: Future<Output = HandlerResult> + Send + 'static,
99 {
100 self.registry
101 .register(method, ResponseType::Stream, handler);
102 self
103 }
104
105 pub fn handle_ack<F, T, Fut>(mut self, method: &str, handler: F) -> Self
109 where
110 F: Fn(T, RequestContext) -> Fut + Send + Sync + 'static,
111 T: DeserializeOwned + Send + 'static,
112 Fut: Future<Output = HandlerResult> + Send + 'static,
113 {
114 self.registry.register(method, ResponseType::Ack, handler);
115 self
116 }
117
118 pub fn event(mut self, name: &str) -> Self {
122 self.registry.register_event(name);
123 self
124 }
125
126 pub fn max_concurrent_handlers(mut self, limit: usize) -> Self {
131 self.max_concurrent_handlers = limit;
132 self
133 }
134
135 pub fn max_pending_frames(mut self, limit: usize) -> Self {
141 self.writer_config.max_pending_frames = limit;
142 self
143 }
144
145 pub fn channel_capacity(mut self, capacity: usize) -> Self {
149 self.writer_config.channel_capacity = capacity;
150 self
151 }
152
153 pub fn backpressure_timeout(mut self, timeout: std::time::Duration) -> Self {
157 self.writer_config.backpressure_timeout = timeout;
158 self
159 }
160
161 pub async fn start(self) -> Result<Client> {
170 Client::start(
171 self.registry,
172 self.writer_config,
173 self.max_concurrent_handlers,
174 )
175 .await
176 }
177}
178
179impl Default for ClientBuilder {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185struct ActiveContext {
187 cancellation_token: CancellationToken,
189}
190
191pub struct Client {
196 registry: Arc<HandlerRegistry>,
198 writer: WriterHandle,
200 shutdown_rx: oneshot::Receiver<()>,
202 _writer_task: JoinHandle<Result<()>>,
204 _active_contexts: Arc<RwLock<HashMap<u32, ActiveContext>>>,
207}
208
209impl Client {
210 pub fn builder() -> ClientBuilder {
212 ClientBuilder::new()
213 }
214
215 async fn start(
217 registry: HandlerRegistry,
218 writer_config: WriterConfig,
219 max_concurrent_handlers: usize,
220 ) -> Result<Self> {
221 let pipe_path = generate_pipe_path();
223
224 let listener = PipeListener::bind(&pipe_path).await?;
226
227 let schema = registry.build_schema();
229
230 let init_msg = build_init_message(&pipe_path, &schema);
232 write_stdout_line(&init_msg)?;
233
234 let stream = listener.accept().await?;
236
237 let (reader, write_half) = stream.into_split();
239
240 let (writer, writer_task) = spawn_writer_task(write_half, writer_config);
242
243 let handler_semaphore = Arc::new(Semaphore::new(max_concurrent_handlers));
245
246 let active_contexts = Arc::new(RwLock::new(HashMap::new()));
248
249 let (shutdown_tx, shutdown_rx) = oneshot::channel();
251 let registry = Arc::new(registry);
252 let writer_clone = writer.clone();
253 let registry_clone = registry.clone();
254 let active_contexts_clone = active_contexts.clone();
255
256 tokio::spawn(async move {
257 if let Err(e) = Self::read_loop(
258 reader,
259 registry_clone,
260 writer_clone,
261 handler_semaphore,
262 active_contexts_clone,
263 )
264 .await
265 {
266 tracing::error!("Read loop error: {}", e);
267 }
268 let _ = shutdown_tx.send(());
269 });
270
271 Ok(Client {
272 registry,
273 writer,
274 shutdown_rx,
275 _writer_task: writer_task,
276 _active_contexts: active_contexts,
277 })
278 }
279
280 async fn read_loop<R: tokio::io::AsyncRead + Unpin>(
282 mut reader: R,
283 registry: Arc<HandlerRegistry>,
284 writer: WriterHandle,
285 semaphore: Arc<Semaphore>,
286 active_contexts: Arc<RwLock<HashMap<u32, ActiveContext>>>,
287 ) -> Result<()> {
288 use tokio::io::AsyncReadExt;
289
290 let mut frame_buffer = FrameBuffer::new();
291 let mut buf = vec![0u8; 64 * 1024]; loop {
294 let n = match reader.read(&mut buf).await {
295 Ok(0) => return Ok(()), Ok(n) => n,
297 Err(e) => return Err(ProcwireError::Io(e)),
298 };
299
300 let frames = frame_buffer.push(&buf[..n])?;
302
303 for frame in frames {
305 Self::dispatch_frame(&frame, ®istry, &writer, &semaphore, &active_contexts)
306 .await;
307 }
308 }
309 }
310
311 async fn dispatch_frame(
313 frame: &crate::protocol::Frame,
314 registry: &Arc<HandlerRegistry>,
315 writer: &WriterHandle,
316 semaphore: &Arc<Semaphore>,
317 active_contexts: &Arc<RwLock<HashMap<u32, ActiveContext>>>,
318 ) {
319 let header = &frame.header;
320
321 if header.method_id == ABORT_METHOD_ID {
323 tracing::debug!("Received ABORT for request {}", header.request_id);
324
325 let contexts = active_contexts.read().await;
327 if let Some(ctx) = contexts.get(&header.request_id) {
328 ctx.cancellation_token.cancel();
329 tracing::debug!("Cancelled request {}", header.request_id);
330 } else {
331 tracing::warn!(
332 "ABORT for unknown or completed request {}",
333 header.request_id
334 );
335 }
336 return;
337 }
338
339 if header.is_response() {
341 tracing::warn!("Received unexpected response frame");
342 return;
343 }
344
345 let permit = match semaphore.clone().try_acquire_owned() {
347 Ok(p) => p,
348 Err(_) => {
349 tracing::warn!(
350 "Handler capacity reached, dropping request {} for method {}",
351 header.request_id,
352 header.method_id
353 );
354 return;
355 }
356 };
357
358 let cancellation_token = CancellationToken::new();
360
361 {
363 let mut contexts = active_contexts.write().await;
364 contexts.insert(
365 header.request_id,
366 ActiveContext {
367 cancellation_token: cancellation_token.clone(),
368 },
369 );
370 }
371
372 let ctx = RequestContext::with_writer_and_token(
374 header.method_id,
375 header.request_id,
376 writer.clone(),
377 cancellation_token,
378 );
379
380 let payload = frame.payload.clone();
382
383 let registry = registry.clone();
385 let method_id = header.method_id;
386 let request_id = header.request_id;
387 let active_contexts = active_contexts.clone();
388
389 tokio::spawn(async move {
391 let _permit = permit;
393
394 match registry.dispatch(method_id, &payload, ctx).await {
395 Ok(()) => {}
396 Err(e) => {
397 tracing::error!("Handler error for method {}: {}", method_id, e);
398 }
399 }
400
401 let mut contexts = active_contexts.write().await;
403 contexts.remove(&request_id);
404 });
405 }
406
407 pub async fn emit<T: serde::Serialize>(&self, event: &str, data: &T) -> Result<()> {
411 let event_id = self
412 .registry
413 .get_event_id(event)
414 .ok_or_else(|| ProcwireError::Protocol(format!("Unknown event: {}", event)))?;
415
416 let payload = MsgPackCodec::encode(data)?;
417
418 let header = Header::new(
419 event_id,
420 flags::DIRECTION_TO_PARENT, 0, payload.len() as u32,
423 );
424
425 let frame = OutboundFrame::new(&header, Bytes::from(payload));
426 self.writer.send(frame).await
427 }
428
429 pub async fn emit_raw(&self, event: &str, data: &[u8]) -> Result<()> {
431 let event_id = self
432 .registry
433 .get_event_id(event)
434 .ok_or_else(|| ProcwireError::Protocol(format!("Unknown event: {}", event)))?;
435
436 let header = Header::new(event_id, flags::DIRECTION_TO_PARENT, 0, data.len() as u32);
437
438 let frame = OutboundFrame::new(&header, Bytes::copy_from_slice(data));
439 self.writer.send(frame).await
440 }
441
442 pub fn is_backpressure_active(&self) -> bool {
444 self.writer.is_backpressure_active()
445 }
446
447 pub fn pending_frames(&self) -> usize {
449 self.writer.pending_count()
450 }
451
452 pub async fn wait_for_shutdown(self) -> Result<()> {
456 let _ = self.shutdown_rx.await;
457 Ok(())
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_builder_creation() {
467 let builder = ClientBuilder::new();
468 let _ = builder;
470 }
471
472 #[test]
473 fn test_builder_default() {
474 let builder = ClientBuilder::default();
475 let _ = builder;
476 }
477
478 #[test]
479 fn test_builder_method_chaining() {
480 let builder = Client::builder()
481 .handle("echo", |_data: String, _ctx| async { Ok(()) })
482 .handle_stream("stream", |_data: i32, _ctx| async { Ok(()) })
483 .handle_ack("ack", |_data: (), _ctx| async { Ok(()) })
484 .event("progress");
485
486 let schema = builder.registry.build_schema();
488 assert!(schema.get_method("echo").is_some());
489 assert!(schema.get_method("stream").is_some());
490 assert!(schema.get_method("ack").is_some());
491 assert!(schema.get_event("progress").is_some());
492 }
493
494 #[test]
495 fn test_builder_response_types() {
496 let builder = Client::builder()
497 .handle("result", |_: (), _ctx| async { Ok(()) })
498 .handle_stream("stream", |_: (), _ctx| async { Ok(()) })
499 .handle_ack("ack", |_: (), _ctx| async { Ok(()) });
500
501 assert_eq!(
502 builder.registry.get_response_type("result"),
503 Some(ResponseType::Result)
504 );
505 assert_eq!(
506 builder.registry.get_response_type("stream"),
507 Some(ResponseType::Stream)
508 );
509 assert_eq!(
510 builder.registry.get_response_type("ack"),
511 Some(ResponseType::Ack)
512 );
513 }
514
515 #[test]
516 fn test_builder_configuration() {
517 let builder = Client::builder()
518 .max_concurrent_handlers(512)
519 .max_pending_frames(2048)
520 .channel_capacity(512)
521 .backpressure_timeout(std::time::Duration::from_secs(10));
522
523 assert_eq!(builder.max_concurrent_handlers, 512);
524 assert_eq!(builder.writer_config.max_pending_frames, 2048);
525 assert_eq!(builder.writer_config.channel_capacity, 512);
526 assert_eq!(
527 builder.writer_config.backpressure_timeout,
528 std::time::Duration::from_secs(10)
529 );
530 }
531
532 #[tokio::test]
533 async fn test_abort_cancels_active_handler() {
534 use crate::protocol::{Frame, Header, ABORT_METHOD_ID};
535
536 let active_contexts = Arc::new(RwLock::new(HashMap::new()));
538 let cancellation_token = CancellationToken::new();
539
540 {
541 let mut contexts = active_contexts.write().await;
542 contexts.insert(
543 42, ActiveContext {
545 cancellation_token: cancellation_token.clone(),
546 },
547 );
548 }
549
550 assert!(!cancellation_token.is_cancelled());
552
553 let abort_header = Header::new(ABORT_METHOD_ID, 0, 42, 0);
555 let abort_frame = Frame::new(abort_header, bytes::Bytes::new());
556
557 let registry = Arc::new(HandlerRegistry::new());
559 let (client, _server) = tokio::io::duplex(4096);
560 let (writer, _task) =
561 crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
562 let semaphore = Arc::new(Semaphore::new(256));
563
564 Client::dispatch_frame(
566 &abort_frame,
567 ®istry,
568 &writer,
569 &semaphore,
570 &active_contexts,
571 )
572 .await;
573
574 assert!(cancellation_token.is_cancelled());
576 }
577
578 #[tokio::test]
579 async fn test_abort_for_unknown_request_logs_warning() {
580 use crate::protocol::{Frame, Header, ABORT_METHOD_ID};
581
582 let active_contexts = Arc::new(RwLock::new(HashMap::new()));
584
585 let abort_header = Header::new(ABORT_METHOD_ID, 0, 999, 0);
587 let abort_frame = Frame::new(abort_header, bytes::Bytes::new());
588
589 let registry = Arc::new(HandlerRegistry::new());
591 let (client, _server) = tokio::io::duplex(4096);
592 let (writer, _task) =
593 crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
594 let semaphore = Arc::new(Semaphore::new(256));
595
596 Client::dispatch_frame(
598 &abort_frame,
599 ®istry,
600 &writer,
601 &semaphore,
602 &active_contexts,
603 )
604 .await;
605
606 }
608
609 #[tokio::test]
610 async fn test_handler_context_is_removed_after_completion() {
611 use crate::protocol::{Frame, Header};
612 use std::sync::atomic::{AtomicBool, Ordering};
613 use std::time::Duration;
614
615 let active_contexts = Arc::new(RwLock::new(HashMap::new()));
617 let handler_started = Arc::new(AtomicBool::new(false));
618 let handler_completed = Arc::new(AtomicBool::new(false));
619
620 let handler_started_clone = handler_started.clone();
621 let handler_completed_clone = handler_completed.clone();
622
623 let mut registry = HandlerRegistry::new();
625 registry.register(
626 "test",
627 crate::control::ResponseType::Result,
628 move |_: (), ctx: RequestContext| {
629 let started = handler_started_clone.clone();
630 let completed = handler_completed_clone.clone();
631 async move {
632 started.store(true, Ordering::SeqCst);
633 tokio::time::sleep(Duration::from_millis(10)).await;
635 ctx.respond(&"done").await?;
636 completed.store(true, Ordering::SeqCst);
637 Ok(())
638 }
639 },
640 );
641
642 let registry = Arc::new(registry);
643
644 let (client, _server) = tokio::io::duplex(4096);
646 let (writer, _task) =
647 crate::writer::spawn_writer_task(client, crate::writer::WriterConfig::default());
648 let semaphore = Arc::new(Semaphore::new(256));
649
650 let method_id = registry.get_method_id("test").unwrap();
652
653 let payload = crate::codec::MsgPackCodec::encode(&()).unwrap();
655 let header = Header::new(method_id, 0, 123, payload.len() as u32);
656 let frame = Frame::new(header, bytes::Bytes::from(payload));
657
658 Client::dispatch_frame(&frame, ®istry, &writer, &semaphore, &active_contexts).await;
660
661 tokio::time::timeout(Duration::from_millis(100), async {
663 while !handler_started.load(Ordering::SeqCst) {
664 tokio::task::yield_now().await;
665 }
666 })
667 .await
668 .expect("Handler should start");
669
670 {
672 let contexts = active_contexts.read().await;
673 assert!(
674 contexts.contains_key(&123),
675 "Context should be active while handler runs"
676 );
677 }
678
679 tokio::time::timeout(Duration::from_millis(100), async {
681 while !handler_completed.load(Ordering::SeqCst) {
682 tokio::task::yield_now().await;
683 }
684 })
685 .await
686 .expect("Handler should complete");
687
688 tokio::time::sleep(Duration::from_millis(10)).await;
690
691 {
693 let contexts = active_contexts.read().await;
694 assert!(
695 !contexts.contains_key(&123),
696 "Context should be removed after handler completes"
697 );
698 }
699 }
700}