1use std::path::PathBuf;
34use thiserror::Error;
35use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
36
37#[derive(Debug, Error)]
39pub enum StreamError {
40 #[error("IO error: {0}")]
42 Io(#[from] std::io::Error),
43
44 #[error("Stream exhausted")]
46 Exhausted,
47
48 #[error("Invalid configuration: {0}")]
50 InvalidConfig(String),
51
52 #[error("Seek failed: {0}")]
54 SeekFailed(String),
55}
56
57#[derive(Debug, Clone)]
59pub struct StreamConfig {
60 pub chunk_size: usize,
62
63 pub track_bandwidth: bool,
65
66 pub max_retries: u32,
68
69 pub buffer_size: usize,
71}
72
73impl Default for StreamConfig {
74 #[inline]
75 fn default() -> Self {
76 Self {
77 chunk_size: 256 * 1024, track_bandwidth: true,
79 max_retries: 3,
80 buffer_size: 8 * 1024, }
82 }
83}
84
85impl StreamConfig {
86 #[must_use]
88 #[inline]
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 #[must_use]
95 #[inline]
96 pub fn with_chunk_size(mut self, size: usize) -> Self {
97 self.chunk_size = size;
98 self
99 }
100
101 #[must_use]
103 #[inline]
104 pub fn with_bandwidth_tracking(mut self, enabled: bool) -> Self {
105 self.track_bandwidth = enabled;
106 self
107 }
108
109 #[must_use]
111 #[inline]
112 pub fn with_max_retries(mut self, retries: u32) -> Self {
113 self.max_retries = retries;
114 self
115 }
116
117 pub fn validate(&self) -> Result<(), StreamError> {
119 if self.chunk_size == 0 {
120 return Err(StreamError::InvalidConfig(
121 "chunk_size must be greater than 0".to_string(),
122 ));
123 }
124 if self.buffer_size == 0 {
125 return Err(StreamError::InvalidConfig(
126 "buffer_size must be greater than 0".to_string(),
127 ));
128 }
129 Ok(())
130 }
131}
132
133pub struct ContentStream<R> {
135 reader: R,
137
138 config: StreamConfig,
140
141 total_size: Option<u64>,
143
144 bytes_read: u64,
146
147 start_time: std::time::Instant,
149
150 exhausted: bool,
152}
153
154impl<R: AsyncRead + Unpin> ContentStream<R> {
155 pub fn new(
157 reader: R,
158 config: StreamConfig,
159 total_size: Option<u64>,
160 ) -> Result<Self, StreamError> {
161 config.validate()?;
162 Ok(Self {
163 reader,
164 config,
165 total_size,
166 bytes_read: 0,
167 start_time: std::time::Instant::now(),
168 exhausted: false,
169 })
170 }
171
172 pub async fn next_chunk(&mut self) -> Result<Option<Vec<u8>>, StreamError> {
174 if self.exhausted {
175 return Ok(None);
176 }
177
178 let mut buffer = vec![0u8; self.config.chunk_size];
179 let bytes = self.reader.read(&mut buffer).await?;
180
181 if bytes == 0 {
182 self.exhausted = true;
183 return Ok(None);
184 }
185
186 buffer.truncate(bytes);
187 self.bytes_read += bytes as u64;
188
189 Ok(Some(buffer))
190 }
191
192 #[inline]
194 #[must_use]
195 pub fn progress(&self) -> f64 {
196 if let Some(total) = self.total_size {
197 if total == 0 {
198 1.0
199 } else {
200 self.bytes_read as f64 / total as f64
201 }
202 } else {
203 0.0
204 }
205 }
206
207 #[inline]
209 #[must_use]
210 pub const fn bytes_read(&self) -> u64 {
211 self.bytes_read
212 }
213
214 #[inline]
216 #[must_use]
217 pub const fn total_size(&self) -> Option<u64> {
218 self.total_size
219 }
220
221 #[inline]
223 #[must_use]
224 pub const fn is_exhausted(&self) -> bool {
225 self.exhausted
226 }
227
228 #[inline]
230 #[must_use]
231 pub fn bandwidth_bps(&self) -> f64 {
232 let elapsed = self.start_time.elapsed().as_secs_f64();
233 if elapsed > 0.0 {
234 self.bytes_read as f64 / elapsed
235 } else {
236 0.0
237 }
238 }
239
240 #[inline]
242 #[must_use]
243 pub fn bandwidth_mbps(&self) -> f64 {
244 self.bandwidth_bps() * 8.0 / 1_000_000.0
245 }
246
247 #[must_use]
249 #[inline]
250 pub fn time_remaining_secs(&self) -> Option<f64> {
251 if let Some(total) = self.total_size {
252 let remaining = total.saturating_sub(self.bytes_read);
253 let bps = self.bandwidth_bps();
254 if bps > 0.0 {
255 Some(remaining as f64 / bps)
256 } else {
257 None
258 }
259 } else {
260 None
261 }
262 }
263
264 pub async fn read_to_vec(&mut self) -> Result<Vec<u8>, StreamError> {
266 let mut result = Vec::new();
267 while let Some(chunk) = self.next_chunk().await? {
268 result.extend_from_slice(&chunk);
269 }
270 Ok(result)
271 }
272
273 pub async fn reset(&mut self) -> Result<(), StreamError>
275 where
276 R: AsyncSeek,
277 {
278 self.reader
279 .seek(std::io::SeekFrom::Start(0))
280 .await
281 .map_err(|e| StreamError::SeekFailed(e.to_string()))?;
282 self.bytes_read = 0;
283 self.exhausted = false;
284 self.start_time = std::time::Instant::now();
285 Ok(())
286 }
287}
288
289impl ContentStream<tokio::fs::File> {
290 pub async fn from_file(path: PathBuf, config: StreamConfig) -> Result<Self, StreamError> {
292 let file = tokio::fs::File::open(&path).await?;
293 let metadata = file.metadata().await?;
294 let total_size = Some(metadata.len());
295 Self::new(file, config, total_size)
296 }
297}
298
299pub struct ChunkWriter<W> {
301 writer: W,
303
304 bytes_written: u64,
306
307 start_time: std::time::Instant,
309}
310
311impl<W: tokio::io::AsyncWrite + Unpin> ChunkWriter<W> {
312 #[must_use]
314 pub fn new(writer: W) -> Self {
315 Self {
316 writer,
317 bytes_written: 0,
318 start_time: std::time::Instant::now(),
319 }
320 }
321
322 pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), StreamError> {
324 use tokio::io::AsyncWriteExt;
325 self.writer.write_all(chunk).await?;
326 self.bytes_written += chunk.len() as u64;
327 Ok(())
328 }
329
330 pub async fn flush(&mut self) -> Result<(), StreamError> {
332 use tokio::io::AsyncWriteExt;
333 self.writer.flush().await?;
334 Ok(())
335 }
336
337 #[inline]
339 pub const fn bytes_written(&self) -> u64 {
340 self.bytes_written
341 }
342
343 #[inline]
345 pub fn bandwidth_bps(&self) -> f64 {
346 let elapsed = self.start_time.elapsed().as_secs_f64();
347 if elapsed > 0.0 {
348 self.bytes_written as f64 / elapsed
349 } else {
350 0.0
351 }
352 }
353}
354
355impl ChunkWriter<tokio::fs::File> {
356 pub async fn to_file(path: PathBuf) -> Result<Self, StreamError> {
358 let file = tokio::fs::File::create(&path).await?;
359 Ok(Self::new(file))
360 }
361}
362
363pub async fn stream_copy<R, W>(
365 mut reader: ContentStream<R>,
366 mut writer: ChunkWriter<W>,
367) -> Result<u64, StreamError>
368where
369 R: AsyncRead + Unpin,
370 W: tokio::io::AsyncWrite + Unpin,
371{
372 let mut total_bytes = 0u64;
373
374 while let Some(chunk) = reader.next_chunk().await? {
375 writer.write_chunk(&chunk).await?;
376 total_bytes += chunk.len() as u64;
377 }
378
379 writer.flush().await?;
380 Ok(total_bytes)
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use tokio::io::AsyncWriteExt;
387
388 #[tokio::test]
389 async fn test_stream_config_default() {
390 let config = StreamConfig::default();
391 assert_eq!(config.chunk_size, 256 * 1024);
392 assert!(config.track_bandwidth);
393 assert_eq!(config.max_retries, 3);
394 }
395
396 #[tokio::test]
397 async fn test_stream_config_builder() {
398 let config = StreamConfig::new()
399 .with_chunk_size(512 * 1024)
400 .with_bandwidth_tracking(false)
401 .with_max_retries(5);
402
403 assert_eq!(config.chunk_size, 512 * 1024);
404 assert!(!config.track_bandwidth);
405 assert_eq!(config.max_retries, 5);
406 }
407
408 #[tokio::test]
409 async fn test_stream_config_validate() {
410 let mut config = StreamConfig::default();
411 assert!(config.validate().is_ok());
412
413 config.chunk_size = 0;
414 assert!(config.validate().is_err());
415 }
416
417 #[tokio::test]
418 async fn test_content_stream_basic() {
419 let data = b"Hello, World!";
420 let config = StreamConfig::default();
421 let mut stream = ContentStream::new(
422 tokio::io::BufReader::new(&data[..]),
423 config,
424 Some(data.len() as u64),
425 )
426 .unwrap();
427
428 let chunk = stream.next_chunk().await.unwrap();
429 assert!(chunk.is_some());
430 assert_eq!(chunk.unwrap(), data);
431
432 let chunk = stream.next_chunk().await.unwrap();
433 assert!(chunk.is_none());
434 assert!(stream.is_exhausted());
435 }
436
437 #[tokio::test]
438 async fn test_content_stream_progress() {
439 let data = b"Hello, World!";
440 let config = StreamConfig::default();
441 let mut stream = ContentStream::new(
442 tokio::io::BufReader::new(&data[..]),
443 config,
444 Some(data.len() as u64),
445 )
446 .unwrap();
447
448 assert_eq!(stream.progress(), 0.0);
449 let _ = stream.next_chunk().await.unwrap();
450 assert_eq!(stream.progress(), 1.0);
451 }
452
453 #[tokio::test]
454 async fn test_content_stream_bandwidth() {
455 let data = b"Hello, World!";
456 let config = StreamConfig::default();
457 let mut stream = ContentStream::new(
458 tokio::io::BufReader::new(&data[..]),
459 config,
460 Some(data.len() as u64),
461 )
462 .unwrap();
463
464 let _ = stream.next_chunk().await.unwrap();
465 let bps = stream.bandwidth_bps();
466 assert!(bps > 0.0);
467 }
468
469 #[tokio::test]
470 async fn test_chunk_writer() {
471 let mut buffer = Vec::new();
472 let bytes_written = {
473 let mut writer = ChunkWriter::new(&mut buffer);
474
475 writer.write_chunk(b"Hello, ").await.unwrap();
476 writer.write_chunk(b"World!").await.unwrap();
477 writer.flush().await.unwrap();
478
479 writer.bytes_written()
480 };
481
482 assert_eq!(buffer, b"Hello, World!");
483 assert_eq!(bytes_written, 13);
484 }
485
486 #[tokio::test]
487 async fn test_stream_copy() {
488 let data = b"Hello, World!";
489 let config = StreamConfig::default();
490 let stream = ContentStream::new(
491 tokio::io::BufReader::new(&data[..]),
492 config,
493 Some(data.len() as u64),
494 )
495 .unwrap();
496
497 let mut buffer = Vec::new();
498 let writer = ChunkWriter::new(&mut buffer);
499
500 let bytes = stream_copy(stream, writer).await.unwrap();
501 assert_eq!(bytes, 13);
502 assert_eq!(buffer, data);
503 }
504
505 #[tokio::test]
506 async fn test_read_to_vec() {
507 let data = b"Hello, World!";
508 let config = StreamConfig::default();
509 let mut stream = ContentStream::new(
510 tokio::io::BufReader::new(&data[..]),
511 config,
512 Some(data.len() as u64),
513 )
514 .unwrap();
515
516 let result = stream.read_to_vec().await.unwrap();
517 assert_eq!(result, data);
518 }
519
520 #[tokio::test]
521 async fn test_stream_from_file() {
522 let temp_dir = tempfile::tempdir().unwrap();
523 let file_path = temp_dir.path().join("test.txt");
524
525 let mut file = tokio::fs::File::create(&file_path).await.unwrap();
527 file.write_all(b"Hello, World!").await.unwrap();
528 file.flush().await.unwrap();
529 drop(file);
530
531 let config = StreamConfig::default();
533 let mut stream = ContentStream::from_file(file_path, config).await.unwrap();
534
535 let data = stream.read_to_vec().await.unwrap();
536 assert_eq!(data, b"Hello, World!");
537 }
538}