use futures_util::stream::{FuturesUnordered, StreamExt};
use pulith_resource::{RequestedResource, ResolvedResource};
use pulith_source::{PlannedSources, ResolvedSourceCandidate, SelectionStrategy, SourceSpec};
use std::path::Path;
use std::sync::Arc;
use crate::config::{DownloadSource, MultiSourceOptions, SourceSelectionStrategy};
use crate::error::{Error, Result};
use crate::fetch::fetcher::{FetchReceipt, FetchSource, Fetcher};
use crate::net::http::HttpClient;
pub struct MultiSourceFetcher<C: HttpClient> {
fetcher: Arc<Fetcher<C>>,
}
impl<C: HttpClient + 'static> MultiSourceFetcher<C> {
pub fn new(fetcher: Arc<Fetcher<C>>) -> Self {
Self { fetcher }
}
pub async fn fetch_multi_source_with_receipt(
&self,
sources: Vec<DownloadSource>,
destination: &Path,
options: MultiSourceOptions,
) -> Result<FetchReceipt> {
if sources.is_empty() {
return Err(Error::InvalidState("No sources provided".into()));
}
match options.strategy {
SourceSelectionStrategy::Priority => {
self.fetch_priority(sources, destination, options).await
}
SourceSelectionStrategy::RaceAll => {
self.fetch_race(sources, destination, options).await
}
SourceSelectionStrategy::FastestFirst => {
self.fetch_fastest(sources, destination, options).await
}
SourceSelectionStrategy::Geographic => {
self.fetch_geographic(sources, destination, options).await
}
}
}
async fn fetch_priority(
&self,
mut sources: Vec<DownloadSource>,
destination: &Path,
_options: MultiSourceOptions,
) -> Result<FetchReceipt> {
for source in sources.drain(..) {
match self
.try_source(&source, destination, &crate::FetchOptions::default())
.await
{
Ok(path) => return Ok(path),
Err(_) => continue,
}
}
Err(Error::Network("All sources failed".to_string()))
}
async fn fetch_race(
&self,
sources: Vec<DownloadSource>,
destination: &Path,
_options: MultiSourceOptions,
) -> Result<FetchReceipt> {
let mut futures = FuturesUnordered::new();
for source in sources {
let fetcher = self.fetcher.clone();
let dest = destination.to_path_buf();
let future = async move {
fetcher
.fetch_with_receipt(&source.url, &dest, crate::FetchOptions::default())
.await
};
futures.push(Box::pin(future));
}
while let Some(result) = futures.next().await {
if let Ok(path) = result {
return Ok(path);
}
}
Err(Error::Network("All sources failed".to_string()))
}
async fn fetch_fastest(
&self,
sources: Vec<DownloadSource>,
destination: &Path,
_options: MultiSourceOptions,
) -> Result<FetchReceipt> {
self.fetch_priority(sources, destination, _options).await
}
async fn fetch_geographic(
&self,
sources: Vec<DownloadSource>,
destination: &Path,
_options: MultiSourceOptions,
) -> Result<FetchReceipt> {
self.fetch_priority(sources, destination, _options).await
}
async fn try_source(
&self,
source: &DownloadSource,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let mut fetch_options = options.clone();
fetch_options.checksum = source.checksum;
self.fetcher
.fetch_with_receipt(&source.url, destination, fetch_options)
.await
}
pub async fn fetch_planned_sources_with_receipt(
&self,
planned: &PlannedSources,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let candidates = planned.candidates();
if candidates.is_empty() {
return Err(Error::InvalidState(
"No planned source candidates provided".into(),
));
}
match planned.strategy() {
SelectionStrategy::OrderedFallback | SelectionStrategy::Exhaustive => {
self.fetch_candidate_sequence(candidates, destination, options)
.await
}
SelectionStrategy::Race => {
self.fetch_candidate_race(candidates, destination, options)
.await
}
}
}
pub async fn fetch_source_spec_with_receipt(
&self,
spec: SourceSpec,
strategy: SelectionStrategy,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let planned = spec.plan(strategy);
self.fetch_planned_sources_with_receipt(&planned, destination, options)
.await
}
pub async fn fetch_requested_resource_with_receipt(
&self,
resource: &RequestedResource,
strategy: SelectionStrategy,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let planned = PlannedSources::from_requested_resource(resource, strategy)
.map_err(|error| Error::InvalidState(error.to_string()))?;
self.fetch_planned_sources_with_receipt(&planned, destination, options)
.await
}
pub async fn fetch_resolved_resource_with_receipt(
&self,
resource: &ResolvedResource,
strategy: SelectionStrategy,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let planned = PlannedSources::from_resolved_resource(resource, strategy)
.map_err(|error| Error::InvalidState(error.to_string()))?;
self.fetch_planned_sources_with_receipt(&planned, destination, options)
.await
}
async fn fetch_candidate_sequence(
&self,
candidates: &[ResolvedSourceCandidate],
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let mut last_error = None;
for candidate in candidates {
match self.try_candidate(candidate, destination, options).await {
Ok(path) => return Ok(path),
Err(error) => last_error = Some(error),
}
}
Err(last_error
.unwrap_or_else(|| Error::Network("All planned candidates failed".to_string())))
}
async fn fetch_candidate_race(
&self,
candidates: &[ResolvedSourceCandidate],
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
let mut futures = FuturesUnordered::new();
for candidate in candidates.iter().cloned() {
let fetcher = self.fetcher.clone();
let dest = destination.to_path_buf();
let options = options.clone();
futures.push(Box::pin(async move {
match candidate {
ResolvedSourceCandidate::Url(url) => {
fetcher
.fetch_with_receipt(url.as_url().as_ref(), &dest, options)
.await
}
ResolvedSourceCandidate::LocalPath(path) => copy_local_candidate(&path, &dest),
ResolvedSourceCandidate::Git { .. } => Err(Error::InvalidState(
"git candidates are not executable by pulith-fetch yet".to_string(),
)),
}
}));
}
let mut last_error = None;
while let Some(result) = futures.next().await {
match result {
Ok(path) => return Ok(path),
Err(error) => last_error = Some(error),
}
}
Err(last_error
.unwrap_or_else(|| Error::Network("All planned candidates failed".to_string())))
}
async fn try_candidate(
&self,
candidate: &ResolvedSourceCandidate,
destination: &Path,
options: &crate::FetchOptions,
) -> Result<FetchReceipt> {
match candidate {
ResolvedSourceCandidate::Url(url) => {
self.fetcher
.fetch_with_receipt(url.as_url().as_ref(), destination, options.clone())
.await
}
ResolvedSourceCandidate::LocalPath(path) => copy_local_candidate(path, destination),
ResolvedSourceCandidate::Git { .. } => Err(Error::InvalidState(
"git candidates are not executable by pulith-fetch yet".to_string(),
)),
}
}
}
fn copy_local_candidate(source: &Path, destination: &Path) -> Result<FetchReceipt> {
if source.is_dir() {
return Err(Error::InvalidState(
"local directory candidates are not executable by pulith-fetch".to_string(),
));
}
let dest_dir = destination.parent().unwrap_or_else(|| Path::new("."));
std::fs::create_dir_all(dest_dir).map_err(|source| {
Error::Fs(pulith_fs::Error::Write {
path: dest_dir.to_path_buf(),
source,
})
})?;
let file_name = destination
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("download"));
let staging_dir = tempfile::Builder::new()
.prefix(".pulith-local-copy.")
.tempdir_in(dest_dir)
.map_err(|error| Error::Network(error.to_string()))?;
let staging_path = staging_dir.path().join(file_name);
std::fs::copy(source, &staging_path).map_err(|source| {
Error::Fs(pulith_fs::Error::Write {
path: staging_path.clone(),
source,
})
})?;
replace_destination_file(&staging_path, destination)?;
let size = std::fs::metadata(destination)
.map_err(|error| Error::Network(error.to_string()))?
.len();
Ok(FetchReceipt {
source: FetchSource::LocalPath(source.to_path_buf()),
destination: destination.to_path_buf(),
bytes_downloaded: size,
total_bytes: Some(size),
sha256_hex: None,
})
}
fn replace_destination_file(staged_path: &Path, destination: &Path) -> Result<()> {
if let Ok(metadata) = std::fs::symlink_metadata(destination) {
if metadata.file_type().is_dir() {
return Err(Error::DestinationIsDirectory);
}
std::fs::remove_file(destination).map_err(|source| {
Error::Fs(pulith_fs::Error::Write {
path: destination.to_path_buf(),
source,
})
})?;
}
std::fs::rename(staged_path, destination).map_err(|source| {
Error::Fs(pulith_fs::Error::Write {
path: destination.to_path_buf(),
source,
})
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
use crate::net::http::BoxStream;
use crate::{DownloadSource, MultiSourceOptions, SourceSelectionStrategy};
use bytes::Bytes;
use pulith_resource::{
RequestedResource, ResolvedLocator, ResolvedVersion, ResourceId, ResourceLocator,
ResourceSpec, ValidUrl,
};
use pulith_source::{
HttpAssetSource, LocalSource, RemoteSource, SelectionStrategy, SourceDefinition, SourceSet,
SourceSpec,
};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug)]
struct MockError(String);
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for MockError {}
struct MockHttpClient {
should_fail: bool,
}
impl MockHttpClient {
fn new() -> Self {
Self { should_fail: false }
}
}
impl HttpClient for MockHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
if self.should_fail {
Err(MockError("Stream failed".to_string()))
} else {
let stream = futures_util::stream::once(async { Ok(Bytes::from("test data")) });
Ok(Box::pin(stream)
as BoxStream<
'static,
std::result::Result<Bytes, Self::Error>,
>)
}
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
if self.should_fail {
Err(MockError("HEAD request failed".to_string()))
} else {
Ok(Some(1024))
}
}
}
#[tokio::test]
async fn test_multi_source_fetcher_new() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let _multi_fetcher = MultiSourceFetcher::new(fetcher);
}
#[tokio::test]
async fn test_fetch_multi_source_empty_sources() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let sources = Vec::new();
let destination = std::path::Path::new("/tmp/test");
let options = MultiSourceOptions {
sources: Vec::new(),
strategy: SourceSelectionStrategy::Priority,
verify_consistency: false,
per_source_timeout: None,
};
let result = multi_fetcher
.fetch_multi_source_with_receipt(sources, destination, options)
.await;
assert!(result.is_err());
match result.unwrap_err() {
Error::InvalidState(msg) => assert_eq!(msg, "No sources provided"),
_ => panic!("Expected InvalidState error"),
}
}
#[tokio::test]
async fn test_fetch_multi_source_priority_strategy() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let sources = vec![
DownloadSource::new("http://example1.com".to_string()),
DownloadSource::new("http://example2.com".to_string()),
];
let destination = std::path::Path::new("/tmp/test");
let options = MultiSourceOptions {
sources: sources.clone(),
strategy: SourceSelectionStrategy::Priority,
verify_consistency: false,
per_source_timeout: None,
};
let result = multi_fetcher
.fetch_multi_source_with_receipt(sources, destination, options)
.await;
assert!(result.is_err() || result.is_ok());
}
#[tokio::test]
async fn test_fetch_multi_source_race_all_strategy() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let sources = vec![
DownloadSource::new("http://example1.com".to_string()),
DownloadSource::new("http://example2.com".to_string()),
];
let destination = std::path::Path::new("/tmp/test");
let options = MultiSourceOptions {
sources: sources.clone(),
strategy: SourceSelectionStrategy::RaceAll,
verify_consistency: false,
per_source_timeout: None,
};
let result = multi_fetcher
.fetch_multi_source_with_receipt(sources, destination, options)
.await;
assert!(result.is_err() || result.is_ok());
}
#[tokio::test]
async fn test_fetch_multi_source_fastest_first_strategy() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let sources = vec![
DownloadSource::new("http://example1.com".to_string()),
DownloadSource::new("http://example2.com".to_string()),
];
let destination = std::path::Path::new("/tmp/test");
let options = MultiSourceOptions {
sources: sources.clone(),
strategy: SourceSelectionStrategy::FastestFirst,
verify_consistency: false,
per_source_timeout: None,
};
let result = multi_fetcher
.fetch_multi_source_with_receipt(sources, destination, options)
.await;
assert!(result.is_err() || result.is_ok());
}
#[tokio::test]
async fn test_fetch_multi_source_geographic_strategy() {
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, "/tmp"));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let sources = vec![
DownloadSource::new("http://us.example.com".to_string()),
DownloadSource::new("http://eu.example.com".to_string()),
];
let destination = std::path::Path::new("/tmp/test");
let options = MultiSourceOptions {
sources: sources.clone(),
strategy: SourceSelectionStrategy::Geographic,
verify_consistency: false,
per_source_timeout: None,
};
let result = multi_fetcher
.fetch_multi_source_with_receipt(sources, destination, options)
.await;
assert!(result.is_err() || result.is_ok());
}
#[tokio::test]
async fn test_fetch_planned_sources_with_http_candidates() {
let temp = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, temp.path().join("workspace")));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let planned = SourceSpec::new(
SourceSet::new(vec![
SourceDefinition::Remote(RemoteSource::HttpAsset(HttpAssetSource {
url: ValidUrl::parse("https://example.com/file").unwrap(),
file_name: None,
})),
SourceDefinition::Remote(RemoteSource::HttpAsset(HttpAssetSource {
url: ValidUrl::parse("https://mirror.example.com/file").unwrap(),
file_name: None,
})),
])
.unwrap(),
)
.plan(SelectionStrategy::OrderedFallback);
let destination = temp.path().join("downloads").join("artifact.bin");
let result = multi_fetcher
.fetch_planned_sources_with_receipt(
&planned,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert!(destination.exists());
}
#[tokio::test]
async fn test_fetch_source_spec_with_receipt_plans_and_fetches() {
let temp = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(client, temp.path().join("workspace")));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let spec = SourceSpec::new(
SourceSet::new(vec![SourceDefinition::Remote(RemoteSource::HttpAsset(
HttpAssetSource {
url: ValidUrl::parse("https://example.com/file").unwrap(),
file_name: None,
},
))])
.unwrap(),
);
let destination = temp.path().join("downloads").join("artifact.bin");
let result = multi_fetcher
.fetch_source_spec_with_receipt(
spec,
SelectionStrategy::OrderedFallback,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert!(destination.exists());
}
#[tokio::test]
async fn test_fetch_planned_sources_with_local_candidate() {
let destination_root = tempfile::tempdir().unwrap();
let source_root = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(
client,
destination_root.path().join("workspace"),
));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let source_path = source_root.path().join("local.bin");
std::fs::write(&source_path, b"local-data").unwrap();
let destination = destination_root
.path()
.join("downloads")
.join("artifact.bin");
let planned = SourceSpec::new(
SourceSet::new(vec![SourceDefinition::Local(LocalSource {
path: source_path,
})])
.unwrap(),
)
.plan(SelectionStrategy::OrderedFallback);
let result = multi_fetcher
.fetch_planned_sources_with_receipt(
&planned,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert_eq!(std::fs::read(destination).unwrap(), b"local-data");
}
#[tokio::test]
async fn test_fetch_planned_sources_with_repeated_local_candidates_in_same_directory() {
let destination_root = tempfile::tempdir().unwrap();
let source_root = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(
client,
destination_root.path().join("workspace"),
));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let destination_dir = destination_root.path().join("downloads");
for (name, payload) in [("artifact-v1.bin", b"v1"), ("artifact-v2.bin", b"v2")] {
let source_path = source_root.path().join(name);
std::fs::write(&source_path, payload).unwrap();
let destination = destination_dir.join(name);
let planned = SourceSpec::new(
SourceSet::new(vec![SourceDefinition::Local(LocalSource {
path: source_path.clone(),
})])
.unwrap(),
)
.plan(SelectionStrategy::OrderedFallback);
let result = multi_fetcher
.fetch_planned_sources_with_receipt(
&planned,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert_eq!(std::fs::read(destination).unwrap(), payload);
}
}
#[tokio::test]
async fn test_fetch_resolved_resource_with_receipt() {
let destination_root = tempfile::tempdir().unwrap();
let source_root = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(
client,
destination_root.path().join("workspace"),
));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let source_path = source_root.path().join("runtime.zip");
std::fs::write(&source_path, b"archive-bytes").unwrap();
let resource = RequestedResource::new(ResourceSpec::new(
ResourceId::parse("example/runtime").unwrap(),
ResourceLocator::LocalPath(source_path),
))
.resolve(
ResolvedVersion::new("1.0.0").unwrap(),
ResolvedLocator::LocalPath(PathBuf::from("/local/runtime.zip")),
None,
);
let destination = destination_root
.path()
.join("downloads")
.join("runtime.zip");
let result = multi_fetcher
.fetch_resolved_resource_with_receipt(
&resource,
SelectionStrategy::OrderedFallback,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert_eq!(std::fs::read(destination).unwrap(), b"archive-bytes");
}
#[tokio::test]
async fn test_fetch_requested_resource_with_receipt() {
let destination_root = tempfile::tempdir().unwrap();
let source_root = tempfile::tempdir().unwrap();
let client = MockHttpClient::new();
let fetcher = Arc::new(Fetcher::new(
client,
destination_root.path().join("workspace"),
));
let multi_fetcher = MultiSourceFetcher::new(fetcher);
let source_path = source_root.path().join("runtime.zip");
std::fs::write(&source_path, b"archive-bytes-requested").unwrap();
let resource = RequestedResource::new(ResourceSpec::new(
ResourceId::parse("example/runtime").unwrap(),
ResourceLocator::LocalPath(source_path),
));
let destination = destination_root
.path()
.join("downloads")
.join("runtime-requested.zip");
let result = multi_fetcher
.fetch_requested_resource_with_receipt(
&resource,
SelectionStrategy::OrderedFallback,
&destination,
&crate::FetchOptions::default(),
)
.await;
assert!(result.is_ok());
assert_eq!(
std::fs::read(destination).unwrap(),
b"archive-bytes-requested"
);
}
}