1use asupersync::Cx;
26use asupersync::stream::Stream;
27use std::io::{self, Read, Seek, SeekFrom};
28use std::path::Path;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31
32pub const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
34
35pub const DEFAULT_MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024;
37
38#[derive(Debug, Clone)]
40pub struct StreamConfig {
41 chunk_size: usize,
43 max_buffer_size: usize,
45 checkpoint_enabled: bool,
47}
48
49impl Default for StreamConfig {
50 fn default() -> Self {
51 Self {
52 chunk_size: DEFAULT_CHUNK_SIZE,
53 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
54 checkpoint_enabled: true,
55 }
56 }
57}
58
59impl StreamConfig {
60 #[must_use]
62 pub fn new() -> Self {
63 Self::default()
64 }
65
66 #[must_use]
68 pub fn with_chunk_size(mut self, size: usize) -> Self {
69 self.chunk_size = size.max(1024); self
71 }
72
73 #[must_use]
75 pub fn with_max_buffer_size(mut self, size: usize) -> Self {
76 self.max_buffer_size = size;
77 self
78 }
79
80 #[must_use]
82 pub fn with_checkpoint(mut self, enabled: bool) -> Self {
83 self.checkpoint_enabled = enabled;
84 self
85 }
86
87 #[must_use]
89 pub fn chunk_size(&self) -> usize {
90 self.chunk_size
91 }
92
93 #[must_use]
95 pub fn max_buffer_size(&self) -> usize {
96 self.max_buffer_size
97 }
98
99 #[must_use]
101 pub fn checkpoint_enabled(&self) -> bool {
102 self.checkpoint_enabled
103 }
104}
105
106#[derive(Debug)]
108pub enum StreamError {
109 Io(io::Error),
111 Cancelled,
113 BufferFull,
115}
116
117impl std::fmt::Display for StreamError {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 Self::Io(e) => write!(f, "streaming I/O error: {e}"),
121 Self::Cancelled => write!(f, "stream cancelled"),
122 Self::BufferFull => write!(f, "stream buffer full"),
123 }
124 }
125}
126
127impl std::error::Error for StreamError {
128 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
129 match self {
130 Self::Io(e) => Some(e),
131 _ => None,
132 }
133 }
134}
135
136impl From<io::Error> for StreamError {
137 fn from(e: io::Error) -> Self {
138 Self::Io(e)
139 }
140}
141
142pub struct CancelAwareStream<S> {
152 inner: S,
153 cx: Cx,
154 cancelled: bool,
155}
156
157impl<S> CancelAwareStream<S> {
158 pub fn new(inner: S, cx: Cx) -> Self {
160 Self {
161 inner,
162 cx,
163 cancelled: false,
164 }
165 }
166
167 #[must_use]
169 pub fn is_cancelled(&self) -> bool {
170 self.cancelled
171 }
172}
173
174impl<S> Stream for CancelAwareStream<S>
175where
176 S: Stream + Unpin,
177{
178 type Item = S::Item;
179
180 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181 if self.cx.is_cancel_requested() {
183 self.cancelled = true;
184 return Poll::Ready(None);
185 }
186
187 Pin::new(&mut self.inner).poll_next(ctx)
189 }
190}
191
192enum FileStreamState {
194 Active {
196 file: std::fs::File,
197 buffer: Vec<u8>,
198 remaining: u64,
199 },
200 Complete,
202 Error,
204}
205
206pub struct FileStream {
228 state: FileStreamState,
229 cx: Cx,
230 config: StreamConfig,
231}
232
233impl FileStream {
234 pub fn open<P: AsRef<Path>>(path: P, cx: Cx, config: StreamConfig) -> io::Result<Self> {
246 let mut file = std::fs::File::open(path)?;
247 let metadata = file.metadata()?;
248 let file_size = metadata.len();
249
250 file.seek(SeekFrom::Start(0))?;
252
253 let buffer = Vec::with_capacity(config.chunk_size);
254
255 Ok(Self {
256 state: FileStreamState::Active {
257 file,
258 buffer,
259 remaining: file_size,
260 },
261 cx,
262 config,
263 })
264 }
265
266 pub fn open_range<P: AsRef<Path>>(
282 path: P,
283 start: u64,
284 length: u64,
285 cx: Cx,
286 config: StreamConfig,
287 ) -> io::Result<Self> {
288 let mut file = std::fs::File::open(path)?;
289 file.seek(SeekFrom::Start(start))?;
290
291 let buffer = Vec::with_capacity(config.chunk_size);
292
293 Ok(Self {
294 state: FileStreamState::Active {
295 file,
296 buffer,
297 remaining: length,
298 },
299 cx,
300 config,
301 })
302 }
303
304 #[must_use]
306 pub fn remaining(&self) -> u64 {
307 match &self.state {
308 FileStreamState::Active { remaining, .. } => *remaining,
309 _ => 0,
310 }
311 }
312
313 #[must_use]
315 pub fn is_complete(&self) -> bool {
316 matches!(self.state, FileStreamState::Complete)
317 }
318}
319
320impl Stream for FileStream {
321 type Item = Vec<u8>;
322
323 fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324 if self.cx.is_cancel_requested() {
326 self.state = FileStreamState::Complete;
327 return Poll::Ready(None);
328 }
329
330 let chunk_size = self.config.chunk_size;
332
333 match &mut self.state {
334 FileStreamState::Active {
335 file,
336 buffer,
337 remaining,
338 } => {
339 if *remaining == 0 {
340 self.state = FileStreamState::Complete;
341 return Poll::Ready(None);
342 }
343
344 let to_read = (chunk_size as u64).min(*remaining) as usize;
346
347 buffer.clear();
349 buffer.resize(to_read, 0);
350
351 match file.read(&mut buffer[..to_read]) {
353 Ok(0) => {
354 self.state = FileStreamState::Complete;
356 Poll::Ready(None)
357 }
358 Ok(n) => {
359 *remaining -= n as u64;
360 buffer.truncate(n);
361
362 let chunk = std::mem::take(buffer);
364 *buffer = Vec::with_capacity(chunk_size);
365
366 Poll::Ready(Some(chunk))
367 }
368 Err(e) if e.kind() == io::ErrorKind::Interrupted => {
369 _ctx.waker().wake_by_ref();
371 Poll::Pending
372 }
373 Err(_) => {
374 self.state = FileStreamState::Error;
375 Poll::Ready(None)
376 }
377 }
378 }
379 FileStreamState::Complete | FileStreamState::Error => Poll::Ready(None),
380 }
381 }
382}
383
384#[allow(unsafe_code)]
391unsafe impl Send for FileStream {}
392
393pub struct ChunkedBytes {
397 data: Vec<u8>,
398 position: usize,
399 chunk_size: usize,
400}
401
402impl ChunkedBytes {
403 #[must_use]
405 pub fn new(data: Vec<u8>, chunk_size: usize) -> Self {
406 Self {
407 data,
408 position: 0,
409 chunk_size: chunk_size.max(1),
410 }
411 }
412
413 #[must_use]
415 pub fn with_default_chunks(data: Vec<u8>) -> Self {
416 Self::new(data, DEFAULT_CHUNK_SIZE)
417 }
418
419 #[must_use]
421 pub fn total_size(&self) -> usize {
422 self.data.len()
423 }
424
425 #[must_use]
427 pub fn remaining(&self) -> usize {
428 self.data.len().saturating_sub(self.position)
429 }
430}
431
432impl Stream for ChunkedBytes {
433 type Item = Vec<u8>;
434
435 fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
436 if self.position >= self.data.len() {
437 return Poll::Ready(None);
438 }
439
440 let end = (self.position + self.chunk_size).min(self.data.len());
441 let chunk = self.data[self.position..end].to_vec();
442 self.position = end;
443
444 Poll::Ready(Some(chunk))
445 }
446}
447
448pub trait StreamingResponseExt {
450 fn stream_file<P: AsRef<Path>>(
462 path: P,
463 cx: Cx,
464 content_type: &[u8],
465 ) -> io::Result<fastapi_core::Response>;
466
467 fn stream_file_with_config<P: AsRef<Path>>(
473 path: P,
474 cx: Cx,
475 content_type: &[u8],
476 config: StreamConfig,
477 ) -> io::Result<fastapi_core::Response>;
478
479 fn stream_file_range<P: AsRef<Path>>(
495 path: P,
496 range: crate::range::ByteRange,
497 total_size: u64,
498 cx: Cx,
499 content_type: &[u8],
500 ) -> io::Result<fastapi_core::Response>;
501
502 fn stream_file_range_with_config<P: AsRef<Path>>(
508 path: P,
509 range: crate::range::ByteRange,
510 total_size: u64,
511 cx: Cx,
512 content_type: &[u8],
513 config: StreamConfig,
514 ) -> io::Result<fastapi_core::Response>;
515}
516
517impl StreamingResponseExt for fastapi_core::Response {
518 fn stream_file<P: AsRef<Path>>(
519 path: P,
520 cx: Cx,
521 content_type: &[u8],
522 ) -> io::Result<fastapi_core::Response> {
523 Self::stream_file_with_config(path, cx, content_type, StreamConfig::default())
524 }
525
526 fn stream_file_with_config<P: AsRef<Path>>(
527 path: P,
528 cx: Cx,
529 content_type: &[u8],
530 config: StreamConfig,
531 ) -> io::Result<fastapi_core::Response> {
532 let stream = FileStream::open(path, cx, config)?;
533
534 Ok(fastapi_core::Response::ok()
535 .header("content-type", content_type.to_vec())
536 .header("accept-ranges", b"bytes".to_vec())
537 .body(fastapi_core::ResponseBody::stream(stream)))
538 }
539
540 fn stream_file_range<P: AsRef<Path>>(
541 path: P,
542 range: crate::range::ByteRange,
543 total_size: u64,
544 cx: Cx,
545 content_type: &[u8],
546 ) -> io::Result<fastapi_core::Response> {
547 Self::stream_file_range_with_config(
548 path,
549 range,
550 total_size,
551 cx,
552 content_type,
553 StreamConfig::default(),
554 )
555 }
556
557 fn stream_file_range_with_config<P: AsRef<Path>>(
558 path: P,
559 range: crate::range::ByteRange,
560 total_size: u64,
561 cx: Cx,
562 content_type: &[u8],
563 config: StreamConfig,
564 ) -> io::Result<fastapi_core::Response> {
565 let stream = FileStream::open_range(path, range.start, range.len(), cx, config)?;
566
567 Ok(fastapi_core::Response::partial_content()
568 .header("content-type", content_type.to_vec())
569 .header("accept-ranges", b"bytes".to_vec())
570 .header(
571 "content-range",
572 range.content_range_header(total_size).into_bytes(),
573 )
574 .header("content-length", range.len().to_string().into_bytes())
575 .body(fastapi_core::ResponseBody::stream(stream)))
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use std::sync::Arc;
583 use std::task::{Wake, Waker};
584
585 struct NoopWaker;
586
587 impl Wake for NoopWaker {
588 fn wake(self: Arc<Self>) {}
589 }
590
591 fn noop_waker() -> Waker {
592 Waker::from(Arc::new(NoopWaker))
593 }
594
595 #[test]
596 fn stream_config_defaults() {
597 let config = StreamConfig::default();
598 assert_eq!(config.chunk_size(), DEFAULT_CHUNK_SIZE);
599 assert_eq!(config.max_buffer_size(), DEFAULT_MAX_BUFFER_SIZE);
600 assert!(config.checkpoint_enabled());
601 }
602
603 #[test]
604 fn stream_config_custom() {
605 let config = StreamConfig::new()
606 .with_chunk_size(1024)
607 .with_max_buffer_size(2048)
608 .with_checkpoint(false);
609
610 assert_eq!(config.chunk_size(), 1024);
611 assert_eq!(config.max_buffer_size(), 2048);
612 assert!(!config.checkpoint_enabled());
613 }
614
615 #[test]
616 fn stream_config_minimum_chunk_size() {
617 let config = StreamConfig::new().with_chunk_size(100);
618 assert_eq!(config.chunk_size(), 1024);
620 }
621
622 #[test]
623 fn chunked_bytes_basic() {
624 let data = b"Hello, World!".to_vec();
625 let mut stream = ChunkedBytes::new(data.clone(), 5);
626
627 assert_eq!(stream.total_size(), 13);
628 assert_eq!(stream.remaining(), 13);
629
630 let waker = noop_waker();
631 let mut ctx = Context::from_waker(&waker);
632
633 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
635 assert_eq!(chunk, Poll::Ready(Some(b"Hello".to_vec())));
636 assert_eq!(stream.remaining(), 8);
637
638 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
640 assert_eq!(chunk, Poll::Ready(Some(b", Wor".to_vec())));
641
642 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
644 assert_eq!(chunk, Poll::Ready(Some(b"ld!".to_vec())));
645
646 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
648 assert_eq!(chunk, Poll::Ready(None));
649 }
650
651 #[test]
652 fn chunked_bytes_empty() {
653 let mut stream = ChunkedBytes::new(Vec::new(), 5);
654 let waker = noop_waker();
655 let mut ctx = Context::from_waker(&waker);
656
657 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
658 assert_eq!(chunk, Poll::Ready(None));
659 }
660
661 #[test]
662 fn chunked_bytes_exact_chunk_size() {
663 let data = b"12345".to_vec();
664 let mut stream = ChunkedBytes::new(data, 5);
665
666 let waker = noop_waker();
667 let mut ctx = Context::from_waker(&waker);
668
669 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
671 assert_eq!(chunk, Poll::Ready(Some(b"12345".to_vec())));
672
673 let chunk = Pin::new(&mut stream).poll_next(&mut ctx);
675 assert_eq!(chunk, Poll::Ready(None));
676 }
677
678 #[test]
679 fn cancel_aware_stream_propagates_items() {
680 let inner = asupersync::stream::iter(vec![1, 2, 3]);
681 let cx = Cx::for_testing();
682 let mut stream = CancelAwareStream::new(inner, cx);
683
684 let waker = noop_waker();
685 let mut ctx = Context::from_waker(&waker);
686
687 assert_eq!(
688 Pin::new(&mut stream).poll_next(&mut ctx),
689 Poll::Ready(Some(1))
690 );
691 assert_eq!(
692 Pin::new(&mut stream).poll_next(&mut ctx),
693 Poll::Ready(Some(2))
694 );
695 assert_eq!(
696 Pin::new(&mut stream).poll_next(&mut ctx),
697 Poll::Ready(Some(3))
698 );
699 assert_eq!(Pin::new(&mut stream).poll_next(&mut ctx), Poll::Ready(None));
700
701 assert!(!stream.is_cancelled());
702 }
703
704 #[test]
705 fn stream_error_display() {
706 let err = StreamError::Cancelled;
707 assert_eq!(format!("{err}"), "stream cancelled");
708
709 let err = StreamError::BufferFull;
710 assert_eq!(format!("{err}"), "stream buffer full");
711
712 let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
713 let err = StreamError::Io(io_err);
714 assert!(format!("{err}").contains("streaming I/O error"));
715 }
716
717 #[test]
722 fn stream_file_adds_accept_ranges_header() {
723 let temp_dir = std::env::temp_dir();
725 let test_file = temp_dir.join("test_stream_accept_ranges.txt");
726 std::fs::write(&test_file, b"Hello, streaming world!").unwrap();
727
728 let cx = Cx::for_testing();
729 let response = fastapi_core::Response::stream_file(&test_file, cx, b"text/plain").unwrap();
730
731 let accept_ranges = response
732 .headers()
733 .iter()
734 .find(|(name, _)| name == "accept-ranges")
735 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
736
737 assert_eq!(accept_ranges, Some("bytes".to_string()));
738
739 let _ = std::fs::remove_file(test_file);
741 }
742
743 #[test]
744 fn stream_file_range_returns_206() {
745 use crate::range::ByteRange;
746
747 let temp_dir = std::env::temp_dir();
749 let test_file = temp_dir.join("test_stream_range_206.txt");
750 std::fs::write(&test_file, b"0123456789ABCDEF").unwrap();
751
752 let cx = Cx::for_testing();
753 let range = ByteRange::new(0, 4); let response = fastapi_core::Response::stream_file_range(
755 &test_file,
756 range,
757 16, cx,
759 b"text/plain",
760 )
761 .unwrap();
762
763 assert_eq!(response.status().as_u16(), 206);
765
766 let _ = std::fs::remove_file(test_file);
768 }
769
770 #[test]
771 fn stream_file_range_sets_content_range_header() {
772 use crate::range::ByteRange;
773
774 let temp_dir = std::env::temp_dir();
776 let test_file = temp_dir.join("test_stream_content_range.txt");
777 std::fs::write(&test_file, b"0123456789ABCDEF").unwrap();
778
779 let cx = Cx::for_testing();
780 let range = ByteRange::new(5, 9); let response = fastapi_core::Response::stream_file_range(
782 &test_file,
783 range,
784 16, cx,
786 b"text/plain",
787 )
788 .unwrap();
789
790 let content_range = response
791 .headers()
792 .iter()
793 .find(|(name, _)| name == "content-range")
794 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
795
796 assert_eq!(content_range, Some("bytes 5-9/16".to_string()));
797
798 let _ = std::fs::remove_file(test_file);
800 }
801
802 #[test]
803 fn stream_file_range_sets_content_length_header() {
804 use crate::range::ByteRange;
805
806 let temp_dir = std::env::temp_dir();
808 let test_file = temp_dir.join("test_stream_content_length.txt");
809 std::fs::write(&test_file, b"0123456789ABCDEF").unwrap();
810
811 let cx = Cx::for_testing();
812 let range = ByteRange::new(0, 99); let response = fastapi_core::Response::stream_file_range(
814 &test_file,
815 range,
816 16, cx,
818 b"text/plain",
819 )
820 .unwrap();
821
822 let content_length = response
823 .headers()
824 .iter()
825 .find(|(name, _)| name == "content-length")
826 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
827
828 assert_eq!(content_length, Some("100".to_string()));
830
831 let _ = std::fs::remove_file(test_file);
833 }
834
835 #[test]
840 fn stream_large_response_in_chunks() {
841 const TARGET_SIZE: usize = 10 * 1024 * 1024; const CHUNK_SIZE: usize = 64 * 1024; let data: Vec<u8> = (0..TARGET_SIZE).map(|i| (i % 256) as u8).collect();
847 let mut stream = ChunkedBytes::new(data.clone(), CHUNK_SIZE);
848
849 let waker = noop_waker();
850 let mut ctx = Context::from_waker(&waker);
851
852 let mut total_received = 0usize;
853 let mut chunk_count = 0usize;
854
855 loop {
856 match Pin::new(&mut stream).poll_next(&mut ctx) {
857 Poll::Ready(Some(chunk)) => {
858 if total_received + CHUNK_SIZE <= TARGET_SIZE {
860 assert_eq!(
861 chunk.len(),
862 CHUNK_SIZE,
863 "Non-final chunks should be {CHUNK_SIZE} bytes"
864 );
865 }
866 total_received += chunk.len();
867 chunk_count += 1;
868 }
869 Poll::Ready(None) => break,
870 Poll::Pending => panic!("ChunkedBytes should never return Pending"),
871 }
872 }
873
874 assert_eq!(total_received, TARGET_SIZE, "Should receive all 10MB");
875 let expected_chunks = TARGET_SIZE.div_ceil(CHUNK_SIZE);
876 assert_eq!(
877 chunk_count, expected_chunks,
878 "Should have correct number of chunks"
879 );
880 }
881
882 #[test]
883 fn cancel_aware_stream_stops_on_cancellation() {
884 let data = vec![1, 2, 3, 4, 5];
886 let inner = asupersync::stream::iter(data);
887 let cx = Cx::for_testing();
888
889 cx.set_cancel_requested(true);
891
892 let mut stream = CancelAwareStream::new(inner, cx);
893
894 let waker = noop_waker();
895 let mut ctx = Context::from_waker(&waker);
896
897 assert_eq!(Pin::new(&mut stream).poll_next(&mut ctx), Poll::Ready(None));
899 assert!(
900 stream.is_cancelled(),
901 "Stream should be marked as cancelled"
902 );
903 }
904
905 #[test]
906 fn file_stream_reads_complete_file() {
907 let temp_dir = std::env::temp_dir();
909 let test_file = temp_dir.join("test_file_stream_complete.bin");
910
911 const FILE_SIZE: usize = 256 * 1024;
913 let data: Vec<u8> = (0..FILE_SIZE).map(|i| (i % 256) as u8).collect();
914 std::fs::write(&test_file, &data).unwrap();
915
916 let cx = Cx::for_testing();
917 let config = StreamConfig::new().with_chunk_size(32 * 1024);
918 let mut stream = FileStream::open(&test_file, cx, config).unwrap();
919
920 let waker = noop_waker();
921 let mut ctx = Context::from_waker(&waker);
922
923 let mut total_received = 0usize;
924 let mut received_data = Vec::new();
925
926 loop {
927 match Pin::new(&mut stream).poll_next(&mut ctx) {
928 Poll::Ready(Some(chunk)) => {
929 total_received += chunk.len();
930 received_data.extend(chunk);
931 }
932 Poll::Ready(None) => break,
933 Poll::Pending => {
934 }
936 }
937 }
938
939 assert_eq!(total_received, FILE_SIZE, "Should receive complete file");
940 assert_eq!(received_data, data, "Data should match original");
941
942 let _ = std::fs::remove_file(test_file);
944 }
945
946 #[test]
947 fn chunked_bytes_total_size_is_correct() {
948 const SIZE: usize = 1024 * 100; let data: Vec<u8> = vec![0u8; SIZE];
951 let stream = ChunkedBytes::new(data, 1024);
952
953 assert_eq!(
954 stream.total_size(),
955 SIZE,
956 "Total size should be known upfront"
957 );
958 }
959
960 #[test]
961 fn file_stream_size_is_known_via_remaining() {
962 let temp_dir = std::env::temp_dir();
964 let test_file = temp_dir.join("test_file_size_known.txt");
965
966 const FILE_SIZE: usize = 12345;
967 let data: Vec<u8> = vec![b'X'; FILE_SIZE];
968 std::fs::write(&test_file, &data).unwrap();
969
970 let cx = Cx::for_testing();
971 let config = StreamConfig::default();
972 let stream = FileStream::open(&test_file, cx, config).unwrap();
973
974 assert_eq!(
976 stream.remaining(),
977 FILE_SIZE as u64,
978 "File size should be known via remaining()"
979 );
980
981 let _ = std::fs::remove_file(test_file);
983 }
984}