1use std::path::PathBuf;
34use thiserror::Error;
35use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
36
37#[derive(Debug, Error)]
39pub enum StreamError {
40 #[error("IO error: {0}")]
42 Io(#[from] std::io::Error),
43
44 #[error("Stream exhausted")]
46 Exhausted,
47
48 #[error("Invalid configuration: {0}")]
50 InvalidConfig(String),
51
52 #[error("Seek failed: {0}")]
54 SeekFailed(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct StreamConfig {
60 pub chunk_size: usize,
62
63 pub track_bandwidth: bool,
65
66 pub max_retries: u32,
68
69 pub buffer_size: usize,
71}
72
73impl Default for StreamConfig {
74 #[inline]
75 fn default() -> Self {
76 Self {
77 chunk_size: 256 * 1024, track_bandwidth: true,
79 max_retries: 3,
80 buffer_size: 8 * 1024, }
82 }
83}
84
85impl StreamConfig {
86 #[must_use]
88 #[inline]
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 #[must_use]
95 #[inline]
96 pub fn with_chunk_size(mut self, size: usize) -> Self {
97 self.chunk_size = size;
98 self
99 }
100
101 #[must_use]
103 #[inline]
104 pub fn with_bandwidth_tracking(mut self, enabled: bool) -> Self {
105 self.track_bandwidth = enabled;
106 self
107 }
108
109 #[must_use]
111 #[inline]
112 pub fn with_max_retries(mut self, retries: u32) -> Self {
113 self.max_retries = retries;
114 self
115 }
116
117 pub fn validate(&self) -> Result<(), StreamError> {
119 if self.chunk_size == 0 {
120 return Err(StreamError::InvalidConfig(
121 "chunk_size must be greater than 0".to_string(),
122 ));
123 }
124 if self.buffer_size == 0 {
125 return Err(StreamError::InvalidConfig(
126 "buffer_size must be greater than 0".to_string(),
127 ));
128 }
129 Ok(())
130 }
131}
132
133pub struct ContentStream<R> {
135 reader: R,
137
138 config: StreamConfig,
140
141 total_size: Option<u64>,
143
144 bytes_read: u64,
146
147 start_time: std::time::Instant,
149
150 exhausted: bool,
152}
153
154impl<R: AsyncRead + Unpin> ContentStream<R> {
155 pub fn new(
157 reader: R,
158 config: StreamConfig,
159 total_size: Option<u64>,
160 ) -> Result<Self, StreamError> {
161 config.validate()?;
162 Ok(Self {
163 reader,
164 config,
165 total_size,
166 bytes_read: 0,
167 start_time: std::time::Instant::now(),
168 exhausted: false,
169 })
170 }
171
172 pub async fn next_chunk(&mut self) -> Result<Option<Vec<u8>>, StreamError> {
174 if self.exhausted {
175 return Ok(None);
176 }
177
178 let mut buffer = vec![0u8; self.config.chunk_size];
179 let bytes = self.reader.read(&mut buffer).await?;
180
181 if bytes == 0 {
182 self.exhausted = true;
183 return Ok(None);
184 }
185
186 buffer.truncate(bytes);
187 self.bytes_read += bytes as u64;
188
189 Ok(Some(buffer))
190 }
191
192 #[inline]
194 #[must_use]
195 pub fn progress(&self) -> f64 {
196 if let Some(total) = self.total_size {
197 if total == 0 {
198 1.0
199 } else {
200 self.bytes_read as f64 / total as f64
201 }
202 } else {
203 0.0
204 }
205 }
206
207 #[inline]
209 #[must_use]
210 pub const fn bytes_read(&self) -> u64 {
211 self.bytes_read
212 }
213
214 #[inline]
216 #[must_use]
217 pub const fn total_size(&self) -> Option<u64> {
218 self.total_size
219 }
220
221 #[inline]
223 #[must_use]
224 pub const fn is_exhausted(&self) -> bool {
225 self.exhausted
226 }
227
228 #[inline]
230 #[must_use]
231 pub fn bandwidth_bps(&self) -> f64 {
232 let elapsed_secs = self
236 .start_time
237 .elapsed()
238 .max(std::time::Duration::from_nanos(1))
239 .as_secs_f64();
240 self.bytes_read as f64 / elapsed_secs
241 }
242
243 #[inline]
245 #[must_use]
246 pub fn bandwidth_mbps(&self) -> f64 {
247 self.bandwidth_bps() * 8.0 / 1_000_000.0
248 }
249
250 #[must_use]
252 #[inline]
253 pub fn time_remaining_secs(&self) -> Option<f64> {
254 if let Some(total) = self.total_size {
255 let remaining = total.saturating_sub(self.bytes_read);
256 let bps = self.bandwidth_bps();
257 if bps > 0.0 {
258 Some(remaining as f64 / bps)
259 } else {
260 None
261 }
262 } else {
263 None
264 }
265 }
266
267 pub async fn read_to_vec(&mut self) -> Result<Vec<u8>, StreamError> {
269 let mut result = Vec::new();
270 while let Some(chunk) = self.next_chunk().await? {
271 result.extend_from_slice(&chunk);
272 }
273 Ok(result)
274 }
275
276 pub async fn reset(&mut self) -> Result<(), StreamError>
278 where
279 R: AsyncSeek,
280 {
281 self.reader
282 .seek(std::io::SeekFrom::Start(0))
283 .await
284 .map_err(|e| StreamError::SeekFailed(e.to_string()))?;
285 self.bytes_read = 0;
286 self.exhausted = false;
287 self.start_time = std::time::Instant::now();
288 Ok(())
289 }
290}
291
292impl ContentStream<tokio::fs::File> {
293 pub async fn from_file(path: PathBuf, config: StreamConfig) -> Result<Self, StreamError> {
295 let file = tokio::fs::File::open(&path).await?;
296 let metadata = file.metadata().await?;
297 let total_size = Some(metadata.len());
298 Self::new(file, config, total_size)
299 }
300}
301
302pub struct ChunkWriter<W> {
304 writer: W,
306
307 bytes_written: u64,
309
310 start_time: std::time::Instant,
312}
313
314impl<W: tokio::io::AsyncWrite + Unpin> ChunkWriter<W> {
315 #[must_use]
317 pub fn new(writer: W) -> Self {
318 Self {
319 writer,
320 bytes_written: 0,
321 start_time: std::time::Instant::now(),
322 }
323 }
324
325 pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), StreamError> {
327 use tokio::io::AsyncWriteExt;
328 self.writer.write_all(chunk).await?;
329 self.bytes_written += chunk.len() as u64;
330 Ok(())
331 }
332
333 pub async fn flush(&mut self) -> Result<(), StreamError> {
335 use tokio::io::AsyncWriteExt;
336 self.writer.flush().await?;
337 Ok(())
338 }
339
340 #[inline]
342 pub const fn bytes_written(&self) -> u64 {
343 self.bytes_written
344 }
345
346 #[inline]
348 pub fn bandwidth_bps(&self) -> f64 {
349 let elapsed_secs = self
350 .start_time
351 .elapsed()
352 .max(std::time::Duration::from_nanos(1))
353 .as_secs_f64();
354 self.bytes_written as f64 / elapsed_secs
355 }
356}
357
358impl ChunkWriter<tokio::fs::File> {
359 pub async fn to_file(path: PathBuf) -> Result<Self, StreamError> {
361 let file = tokio::fs::File::create(&path).await?;
362 Ok(Self::new(file))
363 }
364}
365
366pub async fn stream_copy<R, W>(
368 mut reader: ContentStream<R>,
369 mut writer: ChunkWriter<W>,
370) -> Result<u64, StreamError>
371where
372 R: AsyncRead + Unpin,
373 W: tokio::io::AsyncWrite + Unpin,
374{
375 let mut total_bytes = 0u64;
376
377 while let Some(chunk) = reader.next_chunk().await? {
378 writer.write_chunk(&chunk).await?;
379 total_bytes += chunk.len() as u64;
380 }
381
382 writer.flush().await?;
383 Ok(total_bytes)
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use tokio::io::AsyncWriteExt;
390
391 #[tokio::test]
392 async fn test_stream_config_default() {
393 let config = StreamConfig::default();
394 assert_eq!(config.chunk_size, 256 * 1024);
395 assert!(config.track_bandwidth);
396 assert_eq!(config.max_retries, 3);
397 }
398
399 #[tokio::test]
400 async fn test_stream_config_builder() {
401 let config = StreamConfig::new()
402 .with_chunk_size(512 * 1024)
403 .with_bandwidth_tracking(false)
404 .with_max_retries(5);
405
406 assert_eq!(config.chunk_size, 512 * 1024);
407 assert!(!config.track_bandwidth);
408 assert_eq!(config.max_retries, 5);
409 }
410
411 #[tokio::test]
412 async fn test_stream_config_validate() {
413 let mut config = StreamConfig::default();
414 assert!(config.validate().is_ok());
415
416 config.chunk_size = 0;
417 assert!(config.validate().is_err());
418 }
419
420 #[tokio::test]
421 async fn test_content_stream_basic() {
422 let data = b"Hello, World!";
423 let config = StreamConfig::default();
424 let mut stream = ContentStream::new(
425 tokio::io::BufReader::new(&data[..]),
426 config,
427 Some(data.len() as u64),
428 )
429 .unwrap();
430
431 let chunk = stream.next_chunk().await.unwrap();
432 assert!(chunk.is_some());
433 assert_eq!(chunk.unwrap(), data);
434
435 let chunk = stream.next_chunk().await.unwrap();
436 assert!(chunk.is_none());
437 assert!(stream.is_exhausted());
438 }
439
440 #[tokio::test]
441 async fn test_content_stream_progress() {
442 let data = b"Hello, World!";
443 let config = StreamConfig::default();
444 let mut stream = ContentStream::new(
445 tokio::io::BufReader::new(&data[..]),
446 config,
447 Some(data.len() as u64),
448 )
449 .unwrap();
450
451 assert_eq!(stream.progress(), 0.0);
452 let _ = stream.next_chunk().await.unwrap();
453 assert_eq!(stream.progress(), 1.0);
454 }
455
456 #[tokio::test]
457 async fn test_content_stream_bandwidth() {
458 let data = b"Hello, World!";
459 let config = StreamConfig::default();
460 let mut stream = ContentStream::new(
461 tokio::io::BufReader::new(&data[..]),
462 config,
463 Some(data.len() as u64),
464 )
465 .unwrap();
466
467 let _ = stream.next_chunk().await.unwrap();
468 let bps = stream.bandwidth_bps();
469 assert!(bps > 0.0);
470 }
471
472 #[tokio::test]
473 async fn test_chunk_writer() {
474 let mut buffer = Vec::new();
475 let bytes_written = {
476 let mut writer = ChunkWriter::new(&mut buffer);
477
478 writer.write_chunk(b"Hello, ").await.unwrap();
479 writer.write_chunk(b"World!").await.unwrap();
480 writer.flush().await.unwrap();
481
482 writer.bytes_written()
483 };
484
485 assert_eq!(buffer, b"Hello, World!");
486 assert_eq!(bytes_written, 13);
487 }
488
489 #[tokio::test]
490 async fn test_stream_copy() {
491 let data = b"Hello, World!";
492 let config = StreamConfig::default();
493 let stream = ContentStream::new(
494 tokio::io::BufReader::new(&data[..]),
495 config,
496 Some(data.len() as u64),
497 )
498 .unwrap();
499
500 let mut buffer = Vec::new();
501 let writer = ChunkWriter::new(&mut buffer);
502
503 let bytes = stream_copy(stream, writer).await.unwrap();
504 assert_eq!(bytes, 13);
505 assert_eq!(buffer, data);
506 }
507
508 #[tokio::test]
509 async fn test_read_to_vec() {
510 let data = b"Hello, World!";
511 let config = StreamConfig::default();
512 let mut stream = ContentStream::new(
513 tokio::io::BufReader::new(&data[..]),
514 config,
515 Some(data.len() as u64),
516 )
517 .unwrap();
518
519 let result = stream.read_to_vec().await.unwrap();
520 assert_eq!(result, data);
521 }
522
523 #[tokio::test]
524 async fn test_stream_from_file() {
525 let temp_dir = tempfile::tempdir().unwrap();
526 let file_path = temp_dir.path().join("test.txt");
527
528 let mut file = tokio::fs::File::create(&file_path).await.unwrap();
530 file.write_all(b"Hello, World!").await.unwrap();
531 file.flush().await.unwrap();
532 drop(file);
533
534 let config = StreamConfig::default();
536 let mut stream = ContentStream::from_file(file_path, config).await.unwrap();
537
538 let data = stream.read_to_vec().await.unwrap();
539 assert_eq!(data, b"Hello, World!");
540 }
541}