use crate::errors::SdkError;
use futures_util::StreamExt;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
const DEFAULT_USER_AGENT: &str = "mangater-sdk/0.1 (+https://github.com/quoeamaster/mangater)";
pub async fn download_resource(
uri: String,
user_agent: Option<String>,
) -> Result<Vec<u8>, SdkError> {
let user_agent = user_agent.unwrap_or(DEFAULT_USER_AGENT.to_string());
let client = reqwest::Client::builder()
.user_agent(user_agent)
.build()
.map_err(|e| SdkError::Network(e.to_string()))?;
tracing::debug!("downloading resource: {}", uri);
let response = client
.get(uri)
.send()
.await
.map_err(|e| SdkError::Network(e.to_string()))?
.error_for_status()
.map_err(|e| SdkError::Network(e.to_string()))?;
let body = response
.bytes()
.await
.map_err(|e| SdkError::Network(e.to_string()))?;
Ok(body.to_vec())
}
pub fn create_parent_folders_if_needed(file_path: String) -> Result<(), SdkError> {
let local_file_path = file_path.clone();
let file_path = std::path::Path::new(&local_file_path);
if let Some(parent) = file_path.parent() {
std::fs::create_dir_all(parent).unwrap();
}
tracing::debug!("folder(s) created: {}", file_path.display());
Ok(())
}
pub async fn download_resource_to_file(
uri: String,
user_agent: Option<String>,
file_path: String,
) -> Result<(), SdkError> {
let user_agent = user_agent.unwrap_or(DEFAULT_USER_AGENT.to_string());
let client = reqwest::Client::builder()
.user_agent(user_agent)
.build()
.map_err(|e| SdkError::Network(e.to_string()))?;
let response = client
.get(uri.clone())
.send()
.await
.map_err(|e| SdkError::Network(e.to_string()))?
.error_for_status()
.map_err(|e| SdkError::Network(e.to_string()))?;
tracing::debug!(
"downloading resource to file: {}, source is {}",
file_path,
uri
);
create_parent_folders_if_needed(file_path.clone())?;
let mut file = File::create(file_path).await?;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| SdkError::Network(e.to_string()))?;
file.write_all(&chunk).await.map_err(SdkError::Storage)?;
}
file.flush().await.map_err(SdkError::Storage)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream::{self};
use futures_util::StreamExt;
use std::fs;
#[tokio::test]
async fn test_download_resource() -> Result<(), SdkError> {
let uri = "https://en.wikipedia.org/wiki/NoSQL";
let content = download_resource(uri.to_string(), None).await?;
assert!(!content.is_empty());
assert_eq!(content.len() > 100000, true); println!("content length: {}", content.len());
Ok(())
}
#[tokio::test]
async fn test_download_resource_to_file() -> Result<(), SdkError> {
let uri = "https://en.wikipedia.org/wiki/NoSQL";
let file_path = "testdata/resource_test_download_resource_to_file.txt";
download_resource_to_file(uri.to_string(), None, file_path.to_string()).await?;
let content = fs::read_to_string(file_path)?;
assert!(!content.is_empty());
assert_eq!(content.len() > 100000, true); println!("after-stream download - content length: {}", content.len());
Ok(())
}
#[tokio::test]
async fn test_download_resources_in_parallel() -> Result<(), SdkError> {
struct UrlFile {
url: String,
file_path: String,
}
let urls = vec![
UrlFile {
url: "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b7/Last.fm_software_screenshot.png/250px-Last.fm_software_screenshot.png".to_string(),
file_path: "testdata/parallel_download/250px-Last.fm_software_screenshot.png".to_string(),
},
UrlFile {
url: "https://en.wikipedia.org/static/images/icons/enwiki-25.svg".to_string(),
file_path: "testdata/parallel_download/enwiki-25.svg".to_string(),
},
UrlFile {
url: "https://en.wikipedia.org/w/resources/assets/mediawiki_compact.svg".to_string(),
file_path: "testdata/parallel_download/mediawiki_compact.svg".to_string(),
},
];
let results = stream::iter(urls)
.map(|url_file| async move {
let result = download_resource_to_file(
url_file.url.to_string(),
None,
url_file.file_path.to_string(),
)
.await;
if let Err(err) = result {
tracing::error!("error downloading resource: {} {:?}", url_file.url, err);
}
})
.buffer_unordered(5)
.collect::<Vec<_>>()
.await;
for result in results {
println!("result: {:?}", result);
}
Ok(())
}
}