use std::fmt;
use std::io;
use std::path::Path;
use bytes::Bytes;
use reqwest::{Client, RequestBuilder, Response, StatusCode};
use crate::progress::{
BarColumn, DownloadColumn, Progress, ProgressColumn, TaskId, TextColumn, TimeRemainingColumn,
TransferSpeedColumn,
};
pub use reqwest::{Error, Result};
#[derive(Debug)]
pub enum TextError {
Reqwest(reqwest::Error),
Utf8(std::string::FromUtf8Error),
}
impl fmt::Display for TextError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TextError::Reqwest(e) => write!(f, "request error: {}", e),
TextError::Utf8(e) => write!(f, "UTF-8 error: {}", e),
}
}
}
impl std::error::Error for TextError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
TextError::Reqwest(e) => Some(e),
TextError::Utf8(e) => Some(e),
}
}
}
impl From<reqwest::Error> for TextError {
fn from(e: reqwest::Error) -> Self {
TextError::Reqwest(e)
}
}
#[derive(Debug)]
pub enum JsonError {
Reqwest(reqwest::Error),
Json(serde_json::Error),
}
impl fmt::Display for JsonError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JsonError::Reqwest(e) => write!(f, "request error: {}", e),
JsonError::Json(e) => write!(f, "JSON error: {}", e),
}
}
}
impl std::error::Error for JsonError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
JsonError::Reqwest(e) => Some(e),
JsonError::Json(e) => Some(e),
}
}
}
impl From<reqwest::Error> for JsonError {
fn from(e: reqwest::Error) -> Self {
JsonError::Reqwest(e)
}
}
#[derive(Debug)]
pub enum DownloadError {
Http(reqwest::Error),
Io(io::Error),
}
impl fmt::Display for DownloadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DownloadError::Http(e) => write!(f, "HTTP error: {}", e),
DownloadError::Io(e) => write!(f, "I/O error: {}", e),
}
}
}
impl std::error::Error for DownloadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
DownloadError::Http(e) => Some(e),
DownloadError::Io(e) => Some(e),
}
}
}
impl From<reqwest::Error> for DownloadError {
fn from(e: reqwest::Error) -> Self {
DownloadError::Http(e)
}
}
impl From<io::Error> for DownloadError {
fn from(e: io::Error) -> Self {
DownloadError::Io(e)
}
}
pub trait RequestBuilderProgress: Sized {
fn with_progress(self, description: &str) -> ProgressRequestBuilder;
}
impl RequestBuilderProgress for RequestBuilder {
fn with_progress(self, description: &str) -> ProgressRequestBuilder {
ProgressRequestBuilder {
inner: self,
description: description.to_string(),
}
}
}
pub struct ProgressRequestBuilder {
inner: RequestBuilder,
description: String,
}
impl ProgressRequestBuilder {
pub async fn send(self) -> Result<ProgressResponse> {
let description = self.description.clone();
let response = self.inner.send().await?;
Ok(ProgressResponse::new(response, description))
}
}
pub struct ProgressResponse {
inner: Option<Response>,
progress: Option<Progress>,
task_id: TaskId,
total: Option<f64>,
#[allow(dead_code)]
description: String,
}
impl ProgressResponse {
fn new(response: Response, description: String) -> Self {
let total = response.content_length().map(|n| n as f64);
let mut progress = create_progress(total);
let task_id = progress.add_task(&description, total);
progress.start();
Self {
inner: Some(response),
progress: Some(progress),
task_id,
total,
description,
}
}
pub fn status(&self) -> StatusCode {
self.inner
.as_ref()
.map(|r| r.status())
.unwrap_or(StatusCode::OK)
}
pub fn content_length(&self) -> Option<u64> {
self.inner.as_ref().and_then(|r| r.content_length())
}
pub async fn bytes(mut self) -> Result<Bytes> {
let result = self.collect_body().await;
self.stop_progress();
result
}
pub async fn text(mut self) -> std::result::Result<String, TextError> {
let bytes = self.collect_body().await.map_err(TextError::Reqwest)?;
self.stop_progress();
String::from_utf8(bytes.to_vec()).map_err(TextError::Utf8)
}
pub async fn json<T: serde::de::DeserializeOwned>(
mut self,
) -> std::result::Result<T, JsonError> {
let bytes = self.collect_body().await.map_err(JsonError::Reqwest)?;
self.stop_progress();
serde_json::from_slice(&bytes).map_err(JsonError::Json)
}
async fn collect_body(&mut self) -> Result<Bytes> {
use futures_util::StreamExt;
let response = self.inner.take().expect("inner response already consumed");
let mut body_stream = response.bytes_stream();
let mut collected = Vec::new();
while let Some(chunk) = body_stream.next().await {
let chunk: Bytes = chunk?;
collected.extend_from_slice(&chunk);
if let Some(ref mut progress) = self.progress {
progress.advance(self.task_id, chunk.len() as f64);
progress.refresh();
}
}
Ok(Bytes::from(collected))
}
fn stop_progress(&mut self) {
if let Some(mut progress) = self.progress.take() {
if let Some(total) = self.total {
progress.update(self.task_id, Some(total), None, None, None, None);
}
progress.stop();
}
}
}
impl Drop for ProgressResponse {
fn drop(&mut self) {
self.stop_progress();
}
}
pub async fn get(url: &str) -> Result<ProgressResponse> {
Client::new()
.get(url)
.with_progress("Downloading")
.send()
.await
}
pub async fn post(url: &str) -> Result<ProgressResponse> {
Client::new()
.post(url)
.with_progress("Uploading")
.send()
.await
}
pub async fn download(url: &str, path: &Path) -> std::result::Result<u64, DownloadError> {
download_with_progress(url, path, "Downloading").await
}
pub async fn download_with_progress(
url: &str,
path: &Path,
description: &str,
) -> std::result::Result<u64, DownloadError> {
use std::fs::File;
use std::io::Write;
let response = Client::new()
.get(url)
.with_progress(description)
.send()
.await?;
let bytes = response.bytes().await.map_err(DownloadError::Http)?;
let len = bytes.len() as u64;
let mut file = File::create(path)?;
file.write_all(&bytes)?;
Ok(len)
}
fn create_progress(_total: Option<f64>) -> Progress {
let columns: Vec<Box<dyn ProgressColumn>> = vec![
Box::new(TextColumn::new("{task.description}")),
Box::new(BarColumn::new()),
Box::new(DownloadColumn::new()),
Box::new(TransferSpeedColumn::new()),
Box::new(TimeRemainingColumn::new()),
];
Progress::new(columns)
.with_auto_refresh(true)
.with_refresh_per_second(10.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_request_builder_creation() {
let builder = Client::new().get("https://example.com");
let progress_builder = builder.with_progress("Test");
assert_eq!(progress_builder.description, "Test");
}
#[test]
fn test_create_progress_with_total() {
let _progress = create_progress(Some(1000.0));
}
#[test]
fn test_create_progress_without_total() {
let _progress = create_progress(None);
}
}