procwire_client/handler/
context.rs1use bytes::Bytes;
49use tokio_util::sync::CancellationToken;
50
51use crate::codec::MsgPackCodec;
52use crate::error::Result;
53use crate::protocol::{flags, Header};
54use crate::writer::{OutboundFrame, WriterHandle};
55
56#[derive(Clone)]
73pub struct RequestContext {
74 method_id: u16,
76 request_id: u32,
78 writer: Option<WriterHandle>,
80 cancellation_token: CancellationToken,
82}
83
84impl RequestContext {
85 pub fn new(method_id: u16, request_id: u32) -> Self {
87 Self {
88 method_id,
89 request_id,
90 writer: None,
91 cancellation_token: CancellationToken::new(),
92 }
93 }
94
95 pub fn with_writer(method_id: u16, request_id: u32, writer: WriterHandle) -> Self {
97 Self {
98 method_id,
99 request_id,
100 writer: Some(writer),
101 cancellation_token: CancellationToken::new(),
102 }
103 }
104
105 pub(crate) fn with_writer_and_token(
109 method_id: u16,
110 request_id: u32,
111 writer: WriterHandle,
112 cancellation_token: CancellationToken,
113 ) -> Self {
114 Self {
115 method_id,
116 request_id,
117 writer: Some(writer),
118 cancellation_token,
119 }
120 }
121
122 #[inline]
124 pub fn method_id(&self) -> u16 {
125 self.method_id
126 }
127
128 #[inline]
130 pub fn request_id(&self) -> u32 {
131 self.request_id
132 }
133
134 #[inline]
150 pub fn is_cancelled(&self) -> bool {
151 self.cancellation_token.is_cancelled()
152 }
153
154 pub fn cancelled(&self) -> tokio_util::sync::WaitForCancellationFuture<'_> {
172 self.cancellation_token.cancelled()
173 }
174
175 pub fn cancellation_token(&self) -> CancellationToken {
191 self.cancellation_token.clone()
192 }
193
194 #[allow(dead_code)]
199 pub(crate) fn cancel(&self) {
200 self.cancellation_token.cancel();
201 }
202
203 pub async fn respond<T: serde::Serialize>(&self, payload: &T) -> Result<()> {
207 let data = MsgPackCodec::encode(payload)?;
208 self.send_frame(flags::RESPONSE, Bytes::from(data)).await
209 }
210
211 pub async fn respond_raw(&self, payload: &[u8]) -> Result<()> {
213 self.send_frame(flags::RESPONSE, Bytes::copy_from_slice(payload))
214 .await
215 }
216
217 pub async fn respond_bytes(&self, payload: Bytes) -> Result<()> {
219 self.send_frame(flags::RESPONSE, payload).await
220 }
221
222 pub async fn ack(&self) -> Result<()> {
224 self.send_frame_empty(flags::ACK_RESPONSE).await
225 }
226
227 pub async fn chunk<T: serde::Serialize>(&self, payload: &T) -> Result<()> {
231 let data = MsgPackCodec::encode(payload)?;
232 self.send_frame(flags::STREAM_CHUNK, Bytes::from(data))
233 .await
234 }
235
236 pub async fn chunk_raw(&self, payload: &[u8]) -> Result<()> {
238 self.send_frame(flags::STREAM_CHUNK, Bytes::copy_from_slice(payload))
239 .await
240 }
241
242 pub async fn chunk_bytes(&self, payload: Bytes) -> Result<()> {
244 self.send_frame(flags::STREAM_CHUNK, payload).await
245 }
246
247 pub async fn end(&self) -> Result<()> {
252 self.send_frame_empty(flags::STREAM_END_RESPONSE).await
254 }
255
256 pub async fn error(&self, message: &str) -> Result<()> {
260 let data = MsgPackCodec::encode(&message)?;
261 self.send_frame(flags::ERROR_RESPONSE, Bytes::from(data))
262 .await
263 }
264
265 async fn send_frame(&self, frame_flags: u8, payload: Bytes) -> Result<()> {
267 let writer = match &self.writer {
268 Some(w) => w,
269 None => {
270 return Ok(());
272 }
273 };
274
275 let header = Header::new(
276 self.method_id,
277 frame_flags,
278 self.request_id,
279 payload.len() as u32,
280 );
281
282 let frame = OutboundFrame::new(&header, payload);
283 writer.send(frame).await
284 }
285
286 async fn send_frame_empty(&self, frame_flags: u8) -> Result<()> {
288 let writer = match &self.writer {
289 Some(w) => w,
290 None => {
291 return Ok(());
293 }
294 };
295
296 let header = Header::new(self.method_id, frame_flags, self.request_id, 0);
297
298 let frame = OutboundFrame::empty(&header);
299 writer.send(frame).await
300 }
301}
302
303pub struct RawPayload(pub Bytes);
305
306impl RawPayload {
307 pub fn new(bytes: Bytes) -> Self {
309 Self(bytes)
310 }
311
312 pub fn as_bytes(&self) -> &[u8] {
314 &self.0
315 }
316
317 pub fn into_bytes(self) -> Bytes {
319 self.0
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_context_creation() {
329 let ctx = RequestContext::new(1, 42);
330 assert_eq!(ctx.method_id(), 1);
331 assert_eq!(ctx.request_id(), 42);
332 }
333
334 #[tokio::test]
335 async fn test_respond_without_writer() {
336 let ctx = RequestContext::new(1, 42);
337 let result = ctx.respond(&"test").await;
339 assert!(result.is_ok());
340 }
341
342 #[tokio::test]
343 async fn test_all_response_methods_without_writer() {
344 let ctx = RequestContext::new(1, 42);
345
346 assert!(ctx.respond(&"test").await.is_ok());
347 assert!(ctx.respond_raw(b"test").await.is_ok());
348 assert!(ctx.respond_bytes(Bytes::from_static(b"test")).await.is_ok());
349 assert!(ctx.ack().await.is_ok());
350 assert!(ctx.chunk(&1i32).await.is_ok());
351 assert!(ctx.chunk_raw(b"chunk").await.is_ok());
352 assert!(ctx.chunk_bytes(Bytes::from_static(b"chunk")).await.is_ok());
353 assert!(ctx.end().await.is_ok());
354 assert!(ctx.error("error message").await.is_ok());
355 }
356
357 #[tokio::test]
358 async fn test_chunk_allows_multiple_calls() {
359 let ctx = RequestContext::new(1, 42);
360
361 assert!(ctx.chunk(&1i32).await.is_ok());
363 assert!(ctx.chunk(&2i32).await.is_ok());
364 assert!(ctx.chunk(&3i32).await.is_ok());
365 assert!(ctx.end().await.is_ok());
366 }
367
368 #[tokio::test]
369 async fn test_end_can_be_called_after_chunks() {
370 let ctx = RequestContext::new(1, 42);
371
372 ctx.chunk(&"first").await.unwrap();
373 ctx.chunk(&"second").await.unwrap();
374 ctx.end().await.unwrap();
375 }
376
377 #[test]
378 fn test_context_is_clone() {
379 let ctx = RequestContext::new(1, 42);
380 let ctx2 = ctx.clone();
381
382 assert_eq!(ctx.method_id(), ctx2.method_id());
383 assert_eq!(ctx.request_id(), ctx2.request_id());
384 }
385
386 #[test]
387 fn test_raw_payload() {
388 let data = Bytes::from_static(b"hello world");
389 let payload = RawPayload::new(data.clone());
390
391 assert_eq!(payload.as_bytes(), b"hello world");
392 assert_eq!(payload.into_bytes(), data);
393 }
394
395 #[tokio::test]
396 async fn test_context_with_writer() {
397 use crate::writer::spawn_writer_task_default;
398 use tokio::io::duplex;
399
400 let (client, _server) = duplex(4096);
401 let (writer_handle, _task) = spawn_writer_task_default(client);
402
403 let ctx = RequestContext::with_writer(1, 42, writer_handle);
404
405 assert!(ctx.respond(&"hello").await.is_ok());
407 assert!(ctx.chunk(&123i32).await.is_ok());
408 assert!(ctx.end().await.is_ok());
409 }
410
411 #[test]
412 fn test_cancellation_token_initially_not_cancelled() {
413 let ctx = RequestContext::new(1, 42);
414 assert!(!ctx.is_cancelled());
415 }
416
417 #[test]
418 fn test_cancellation_after_cancel() {
419 let ctx = RequestContext::new(1, 42);
420 assert!(!ctx.is_cancelled());
421
422 ctx.cancel();
423
424 assert!(ctx.is_cancelled());
425 }
426
427 #[test]
428 fn test_cancellation_propagates_to_clones() {
429 let ctx = RequestContext::new(1, 42);
430 let ctx_clone = ctx.clone();
431
432 assert!(!ctx.is_cancelled());
433 assert!(!ctx_clone.is_cancelled());
434
435 ctx.cancel();
436
437 assert!(ctx.is_cancelled());
439 assert!(ctx_clone.is_cancelled());
440 }
441
442 #[tokio::test]
443 async fn test_cancelled_future_completes_after_cancel() {
444 use std::time::Duration;
445
446 let ctx = RequestContext::new(1, 42);
447 let ctx_clone = ctx.clone();
448
449 tokio::spawn(async move {
451 tokio::time::sleep(Duration::from_millis(10)).await;
452 ctx_clone.cancel();
453 });
454
455 tokio::time::timeout(Duration::from_millis(100), ctx.cancelled())
457 .await
458 .expect("cancelled() should complete within timeout");
459 }
460
461 #[tokio::test]
462 async fn test_cancellation_token_can_be_passed_to_child_task() {
463 use std::time::Duration;
464
465 let ctx = RequestContext::new(1, 42);
466 let token = ctx.cancellation_token();
467
468 let handle = tokio::spawn(async move {
470 tokio::select! {
471 _ = token.cancelled() => "cancelled",
472 _ = tokio::time::sleep(Duration::from_secs(10)) => "timeout",
473 }
474 });
475
476 tokio::time::sleep(Duration::from_millis(10)).await;
478 ctx.cancel();
479
480 let result = tokio::time::timeout(Duration::from_millis(100), handle)
482 .await
483 .expect("task should complete")
484 .expect("task should not panic");
485
486 assert_eq!(result, "cancelled");
487 }
488
489 #[tokio::test]
490 async fn test_with_writer_and_token() {
491 use crate::writer::spawn_writer_task_default;
492 use tokio::io::duplex;
493
494 let (client, _server) = duplex(4096);
495 let (writer_handle, _task) = spawn_writer_task_default(client);
496
497 let token = CancellationToken::new();
498 let ctx = RequestContext::with_writer_and_token(1, 42, writer_handle, token.clone());
499
500 assert!(!ctx.is_cancelled());
501
502 token.cancel();
503
504 assert!(ctx.is_cancelled());
505 }
506}