use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::types::Request;
use crate::types::Response;
#[derive(Debug, Clone)]
pub struct ProgressState {
pub bytes_read: u64,
pub total_bytes: Option<u64>,
}
impl ProgressState {
pub fn percent(&self) -> Option<u8> {
self.total_bytes.map(|total| {
if total == 0 {
100
} else {
((self.bytes_read as f64 / total as f64) * 100.0).min(100.0) as u8
}
})
}
}
#[derive(Clone)]
pub struct ProgressTracker {
bytes_read: Arc<AtomicU64>,
total_bytes: Option<u64>,
}
impl ProgressTracker {
pub fn state(&self) -> ProgressState {
ProgressState {
bytes_read: self.bytes_read.load(Ordering::Relaxed),
total_bytes: self.total_bytes,
}
}
pub fn bytes_read(&self) -> u64 {
self.bytes_read.load(Ordering::Relaxed)
}
pub fn total_bytes(&self) -> Option<u64> {
self.total_bytes
}
pub fn percent(&self) -> Option<u8> {
self.state().percent()
}
}
pub struct UploadProgress {
callback: Option<Arc<dyn Fn(ProgressState) + Send + Sync + 'static>>,
min_notify_interval: u64,
}
impl Default for UploadProgress {
fn default() -> Self {
Self::new()
}
}
impl UploadProgress {
pub fn new() -> Self {
Self {
callback: None,
min_notify_interval: 0,
}
}
pub fn on_progress<F>(mut self, f: F) -> Self
where
F: Fn(ProgressState) + Send + Sync + 'static,
{
self.callback = Some(Arc::new(f));
self
}
pub fn min_notify_interval_bytes(mut self, bytes: u64) -> Self {
self.min_notify_interval = bytes;
self
}
}
impl IntoMiddleware for UploadProgress {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let callback = self.callback;
let min_interval = self.min_notify_interval;
move |mut req: Request, next: Next| {
let callback = callback.clone();
Box::pin(async move {
let total_bytes = req
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let bytes_read = Arc::new(AtomicU64::new(0));
let tracker = ProgressTracker {
bytes_read: Arc::clone(&bytes_read),
total_bytes,
};
req.extensions_mut().insert(tracker);
use http_body_util::BodyExt;
let body = req.body_mut();
let mut collected = Vec::new();
let mut last_notified_at: u64 = 0;
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let Some(data) = frame.data_ref() {
collected.extend_from_slice(data);
let total = bytes_read.fetch_add(data.len() as u64, Ordering::Relaxed)
+ data.len() as u64;
if let Some(cb) = &callback {
if min_interval == 0 || total - last_notified_at >= min_interval {
last_notified_at = total;
cb(ProgressState {
bytes_read: total,
total_bytes,
});
}
}
}
}
Some(Err(_)) => break,
None => break,
}
}
if let Some(cb) = &callback {
let final_read = bytes_read.load(Ordering::Relaxed);
if final_read != last_notified_at {
cb(ProgressState {
bytes_read: final_read,
total_bytes,
});
}
}
let (parts, _) = req.into_parts();
let req = http::Request::from_parts(
parts,
crate::body::TakoBody::from(bytes::Bytes::from(collected)),
);
next.run(req).await
})
}
}
}