Skip to main content

procwire_client/handler/
context.rs

1//! Request context for handlers.
2//!
3//! Provides methods for responding to requests:
4//! - `respond` - send a single response
5//! - `ack` - send acknowledgment
6//! - `chunk` - send a stream chunk
7//! - `end` - end a stream (empty payload)
8//! - `error` - send an error response
9//!
10//! # Cancellation Support
11//!
12//! Handlers can check for and respond to ABORT signals from the parent:
13//! - `is_cancelled()` - check if request was aborted
14//! - `cancelled().await` - wait for cancellation (use with `tokio::select!`)
15//! - `cancellation_token()` - get token for child tasks
16//!
17//! # Example
18//!
19//! ```ignore
20//! async fn echo_handler(data: String, ctx: RequestContext) -> Result<()> {
21//!     ctx.respond(&data).await
22//! }
23//!
24//! async fn stream_handler(count: i32, ctx: RequestContext) -> Result<()> {
25//!     for i in 0..count {
26//!         // Check for cancellation
27//!         if ctx.is_cancelled() {
28//!             return Ok(());
29//!         }
30//!         ctx.chunk(&i).await?;
31//!     }
32//!     ctx.end().await
33//! }
34//!
35//! async fn long_task(data: Input, ctx: RequestContext) -> Result<()> {
36//!     tokio::select! {
37//!         _ = ctx.cancelled() => {
38//!             // Request was aborted, clean up
39//!             return Ok(());
40//!         }
41//!         result = do_work() => {
42//!             ctx.respond(&result).await
43//!         }
44//!     }
45//! }
46//! ```
47
48use bytes::Bytes;
49use tokio_util::sync::CancellationToken;
50
51use crate::codec::MsgPackCodec;
52use crate::error::Result;
53use crate::protocol::{flags, Header};
54use crate::writer::{OutboundFrame, WriterHandle};
55
56/// Context passed to request handlers.
57///
58/// Provides methods for sending responses back to the parent.
59/// All response methods handle serialization and frame building internally.
60///
61/// # Thread Safety
62///
63/// `RequestContext` is `Clone` and can be safely shared across async tasks.
64/// The underlying writer uses a channel-based architecture that eliminates
65/// lock contention.
66///
67/// # Cancellation
68///
69/// Each context has a [`CancellationToken`] that is triggered when the parent
70/// sends an ABORT signal. Handlers should check `is_cancelled()` periodically
71/// or use `cancelled().await` with `tokio::select!` for immediate response.
72#[derive(Clone)]
73pub struct RequestContext {
74    /// Method ID for this request.
75    method_id: u16,
76    /// Request ID for this request (0 = event).
77    request_id: u32,
78    /// Writer handle for sending responses.
79    writer: Option<WriterHandle>,
80    /// Cancellation token for ABORT handling.
81    cancellation_token: CancellationToken,
82}
83
84impl RequestContext {
85    /// Create a new request context (for testing without writer).
86    pub fn new(method_id: u16, request_id: u32) -> Self {
87        Self {
88            method_id,
89            request_id,
90            writer: None,
91            cancellation_token: CancellationToken::new(),
92        }
93    }
94
95    /// Create a new request context with a writer.
96    pub fn with_writer(method_id: u16, request_id: u32, writer: WriterHandle) -> Self {
97        Self {
98            method_id,
99            request_id,
100            writer: Some(writer),
101            cancellation_token: CancellationToken::new(),
102        }
103    }
104
105    /// Create a new request context with a writer and cancellation token.
106    ///
107    /// Used internally when tracking active contexts for ABORT handling.
108    pub(crate) fn with_writer_and_token(
109        method_id: u16,
110        request_id: u32,
111        writer: WriterHandle,
112        cancellation_token: CancellationToken,
113    ) -> Self {
114        Self {
115            method_id,
116            request_id,
117            writer: Some(writer),
118            cancellation_token,
119        }
120    }
121
122    /// Get the method ID.
123    #[inline]
124    pub fn method_id(&self) -> u16 {
125        self.method_id
126    }
127
128    /// Get the request ID.
129    #[inline]
130    pub fn request_id(&self) -> u32 {
131        self.request_id
132    }
133
134    /// Check if this request has been cancelled.
135    ///
136    /// Handlers should check this periodically during long operations.
137    ///
138    /// # Example
139    ///
140    /// ```ignore
141    /// for i in 0..1000 {
142    ///     if ctx.is_cancelled() {
143    ///         tracing::info!("Request cancelled at step {}", i);
144    ///         return Ok(());
145    ///     }
146    ///     do_step(i).await;
147    /// }
148    /// ```
149    #[inline]
150    pub fn is_cancelled(&self) -> bool {
151        self.cancellation_token.is_cancelled()
152    }
153
154    /// Wait for cancellation.
155    ///
156    /// Use with `tokio::select!` to handle cancellation immediately:
157    ///
158    /// # Example
159    ///
160    /// ```ignore
161    /// tokio::select! {
162    ///     _ = ctx.cancelled() => {
163    ///         // Request was cancelled, clean up
164    ///         return Ok(());
165    ///     }
166    ///     result = do_work() => {
167    ///         ctx.respond(&result).await
168    ///     }
169    /// }
170    /// ```
171    pub fn cancelled(&self) -> tokio_util::sync::WaitForCancellationFuture<'_> {
172        self.cancellation_token.cancelled()
173    }
174
175    /// Get the cancellation token for advanced use cases.
176    ///
177    /// Useful when you need to pass the token to child tasks.
178    ///
179    /// # Example
180    ///
181    /// ```ignore
182    /// let token = ctx.cancellation_token();
183    /// let handle = tokio::spawn(async move {
184    ///     tokio::select! {
185    ///         _ = token.cancelled() => None,
186    ///         result = do_work() => Some(result),
187    ///     }
188    /// });
189    /// ```
190    pub fn cancellation_token(&self) -> CancellationToken {
191        self.cancellation_token.clone()
192    }
193
194    /// Cancel this request (internal use).
195    ///
196    /// Called when an ABORT frame is received for this request.
197    /// Currently used only in tests, but kept for potential future use.
198    #[allow(dead_code)]
199    pub(crate) fn cancel(&self) {
200        self.cancellation_token.cancel();
201    }
202
203    /// Send a response with the given payload.
204    ///
205    /// Serializes the payload using MsgPack and sends a response frame.
206    pub async fn respond<T: serde::Serialize>(&self, payload: &T) -> Result<()> {
207        let data = MsgPackCodec::encode(payload)?;
208        self.send_frame(flags::RESPONSE, Bytes::from(data)).await
209    }
210
211    /// Send a response with raw bytes (zero-copy).
212    pub async fn respond_raw(&self, payload: &[u8]) -> Result<()> {
213        self.send_frame(flags::RESPONSE, Bytes::copy_from_slice(payload))
214            .await
215    }
216
217    /// Send a response with pre-allocated Bytes (zero-copy).
218    pub async fn respond_bytes(&self, payload: Bytes) -> Result<()> {
219        self.send_frame(flags::RESPONSE, payload).await
220    }
221
222    /// Send an acknowledgment (empty payload).
223    pub async fn ack(&self) -> Result<()> {
224        self.send_frame_empty(flags::ACK_RESPONSE).await
225    }
226
227    /// Send a stream chunk.
228    ///
229    /// Serializes the payload using MsgPack and sends a stream chunk frame.
230    pub async fn chunk<T: serde::Serialize>(&self, payload: &T) -> Result<()> {
231        let data = MsgPackCodec::encode(payload)?;
232        self.send_frame(flags::STREAM_CHUNK, Bytes::from(data))
233            .await
234    }
235
236    /// Send a stream chunk with raw bytes.
237    pub async fn chunk_raw(&self, payload: &[u8]) -> Result<()> {
238        self.send_frame(flags::STREAM_CHUNK, Bytes::copy_from_slice(payload))
239            .await
240    }
241
242    /// Send a stream chunk with pre-allocated Bytes (zero-copy).
243    pub async fn chunk_bytes(&self, payload: Bytes) -> Result<()> {
244        self.send_frame(flags::STREAM_CHUNK, payload).await
245    }
246
247    /// End a stream.
248    ///
249    /// Sends a stream end frame with empty payload.
250    /// **IMPORTANT**: STREAM_END frames always have empty payload!
251    pub async fn end(&self) -> Result<()> {
252        // NOTE: STREAM_END always has empty payload (payloadLength=0)
253        self.send_frame_empty(flags::STREAM_END_RESPONSE).await
254    }
255
256    /// Send an error response.
257    ///
258    /// Serializes the error message and sends an error frame.
259    pub async fn error(&self, message: &str) -> Result<()> {
260        let data = MsgPackCodec::encode(&message)?;
261        self.send_frame(flags::ERROR_RESPONSE, Bytes::from(data))
262            .await
263    }
264
265    /// Send a frame with the given flags and payload.
266    async fn send_frame(&self, frame_flags: u8, payload: Bytes) -> Result<()> {
267        let writer = match &self.writer {
268            Some(w) => w,
269            None => {
270                // No writer configured (testing mode)
271                return Ok(());
272            }
273        };
274
275        let header = Header::new(
276            self.method_id,
277            frame_flags,
278            self.request_id,
279            payload.len() as u32,
280        );
281
282        let frame = OutboundFrame::new(&header, payload);
283        writer.send(frame).await
284    }
285
286    /// Send a frame with empty payload.
287    async fn send_frame_empty(&self, frame_flags: u8) -> Result<()> {
288        let writer = match &self.writer {
289            Some(w) => w,
290            None => {
291                // No writer configured (testing mode)
292                return Ok(());
293            }
294        };
295
296        let header = Header::new(self.method_id, frame_flags, self.request_id, 0);
297
298        let frame = OutboundFrame::empty(&header);
299        writer.send(frame).await
300    }
301}
302
303/// Wrapper for Bytes payload (zero-copy).
304pub struct RawPayload(pub Bytes);
305
306impl RawPayload {
307    /// Create from bytes.
308    pub fn new(bytes: Bytes) -> Self {
309        Self(bytes)
310    }
311
312    /// Get the bytes.
313    pub fn as_bytes(&self) -> &[u8] {
314        &self.0
315    }
316
317    /// Into bytes.
318    pub fn into_bytes(self) -> Bytes {
319        self.0
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_context_creation() {
329        let ctx = RequestContext::new(1, 42);
330        assert_eq!(ctx.method_id(), 1);
331        assert_eq!(ctx.request_id(), 42);
332    }
333
334    #[tokio::test]
335    async fn test_respond_without_writer() {
336        let ctx = RequestContext::new(1, 42);
337        // Should not panic, just return Ok
338        let result = ctx.respond(&"test").await;
339        assert!(result.is_ok());
340    }
341
342    #[tokio::test]
343    async fn test_all_response_methods_without_writer() {
344        let ctx = RequestContext::new(1, 42);
345
346        assert!(ctx.respond(&"test").await.is_ok());
347        assert!(ctx.respond_raw(b"test").await.is_ok());
348        assert!(ctx.respond_bytes(Bytes::from_static(b"test")).await.is_ok());
349        assert!(ctx.ack().await.is_ok());
350        assert!(ctx.chunk(&1i32).await.is_ok());
351        assert!(ctx.chunk_raw(b"chunk").await.is_ok());
352        assert!(ctx.chunk_bytes(Bytes::from_static(b"chunk")).await.is_ok());
353        assert!(ctx.end().await.is_ok());
354        assert!(ctx.error("error message").await.is_ok());
355    }
356
357    #[tokio::test]
358    async fn test_chunk_allows_multiple_calls() {
359        let ctx = RequestContext::new(1, 42);
360
361        // Multiple chunks should all succeed
362        assert!(ctx.chunk(&1i32).await.is_ok());
363        assert!(ctx.chunk(&2i32).await.is_ok());
364        assert!(ctx.chunk(&3i32).await.is_ok());
365        assert!(ctx.end().await.is_ok());
366    }
367
368    #[tokio::test]
369    async fn test_end_can_be_called_after_chunks() {
370        let ctx = RequestContext::new(1, 42);
371
372        ctx.chunk(&"first").await.unwrap();
373        ctx.chunk(&"second").await.unwrap();
374        ctx.end().await.unwrap();
375    }
376
377    #[test]
378    fn test_context_is_clone() {
379        let ctx = RequestContext::new(1, 42);
380        let ctx2 = ctx.clone();
381
382        assert_eq!(ctx.method_id(), ctx2.method_id());
383        assert_eq!(ctx.request_id(), ctx2.request_id());
384    }
385
386    #[test]
387    fn test_raw_payload() {
388        let data = Bytes::from_static(b"hello world");
389        let payload = RawPayload::new(data.clone());
390
391        assert_eq!(payload.as_bytes(), b"hello world");
392        assert_eq!(payload.into_bytes(), data);
393    }
394
395    #[tokio::test]
396    async fn test_context_with_writer() {
397        use crate::writer::spawn_writer_task_default;
398        use tokio::io::duplex;
399
400        let (client, _server) = duplex(4096);
401        let (writer_handle, _task) = spawn_writer_task_default(client);
402
403        let ctx = RequestContext::with_writer(1, 42, writer_handle);
404
405        // Should work with writer
406        assert!(ctx.respond(&"hello").await.is_ok());
407        assert!(ctx.chunk(&123i32).await.is_ok());
408        assert!(ctx.end().await.is_ok());
409    }
410
411    #[test]
412    fn test_cancellation_token_initially_not_cancelled() {
413        let ctx = RequestContext::new(1, 42);
414        assert!(!ctx.is_cancelled());
415    }
416
417    #[test]
418    fn test_cancellation_after_cancel() {
419        let ctx = RequestContext::new(1, 42);
420        assert!(!ctx.is_cancelled());
421
422        ctx.cancel();
423
424        assert!(ctx.is_cancelled());
425    }
426
427    #[test]
428    fn test_cancellation_propagates_to_clones() {
429        let ctx = RequestContext::new(1, 42);
430        let ctx_clone = ctx.clone();
431
432        assert!(!ctx.is_cancelled());
433        assert!(!ctx_clone.is_cancelled());
434
435        ctx.cancel();
436
437        // Both should see cancellation
438        assert!(ctx.is_cancelled());
439        assert!(ctx_clone.is_cancelled());
440    }
441
442    #[tokio::test]
443    async fn test_cancelled_future_completes_after_cancel() {
444        use std::time::Duration;
445
446        let ctx = RequestContext::new(1, 42);
447        let ctx_clone = ctx.clone();
448
449        // Spawn task that cancels after delay
450        tokio::spawn(async move {
451            tokio::time::sleep(Duration::from_millis(10)).await;
452            ctx_clone.cancel();
453        });
454
455        // This should complete after cancellation
456        tokio::time::timeout(Duration::from_millis(100), ctx.cancelled())
457            .await
458            .expect("cancelled() should complete within timeout");
459    }
460
461    #[tokio::test]
462    async fn test_cancellation_token_can_be_passed_to_child_task() {
463        use std::time::Duration;
464
465        let ctx = RequestContext::new(1, 42);
466        let token = ctx.cancellation_token();
467
468        // Spawn child task with token
469        let handle = tokio::spawn(async move {
470            tokio::select! {
471                _ = token.cancelled() => "cancelled",
472                _ = tokio::time::sleep(Duration::from_secs(10)) => "timeout",
473            }
474        });
475
476        // Cancel the context
477        tokio::time::sleep(Duration::from_millis(10)).await;
478        ctx.cancel();
479
480        // Child should see cancellation
481        let result = tokio::time::timeout(Duration::from_millis(100), handle)
482            .await
483            .expect("task should complete")
484            .expect("task should not panic");
485
486        assert_eq!(result, "cancelled");
487    }
488
489    #[tokio::test]
490    async fn test_with_writer_and_token() {
491        use crate::writer::spawn_writer_task_default;
492        use tokio::io::duplex;
493
494        let (client, _server) = duplex(4096);
495        let (writer_handle, _task) = spawn_writer_task_default(client);
496
497        let token = CancellationToken::new();
498        let ctx = RequestContext::with_writer_and_token(1, 42, writer_handle, token.clone());
499
500        assert!(!ctx.is_cancelled());
501
502        token.cancel();
503
504        assert!(ctx.is_cancelled());
505    }
506}