use std::collections::HashMap;
use std::sync::Arc;
use super::fetcher::{Fetcher, GitFetcher, HttpsFetcher};
#[derive(Clone, Default)]
pub struct FetcherRegistry {
fetchers: HashMap<&'static str, Arc<dyn Fetcher>>,
}
impl FetcherRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, fetcher: Arc<dyn Fetcher>) {
for scheme in fetcher.schemes() {
self.fetchers.insert(*scheme, Arc::clone(&fetcher));
}
}
pub fn get(&self, scheme: &str) -> Option<Arc<dyn Fetcher>> {
self.fetchers.get(scheme).map(Arc::clone)
}
pub fn contains(&self, scheme: &str) -> bool {
self.fetchers.contains_key(scheme)
}
}
impl std::fmt::Debug for FetcherRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FetcherRegistry")
.field(
"schemes",
&self.fetchers.keys().copied().collect::<Vec<_>>(),
)
.finish()
}
}
pub fn default_fetcher_registry() -> FetcherRegistry {
let mut r = FetcherRegistry::new();
r.register(Arc::new(HttpsFetcher));
r.register(Arc::new(GitFetcher));
r
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_registry_has_transport_schemes_only() {
let r = default_fetcher_registry();
for s in ["https", "git", "git+ssh"] {
assert!(r.contains(s), "default registry missing `{s}:`");
}
assert!(!r.contains("path"));
assert!(!r.contains("github"));
assert!(!r.contains("gitlab"));
}
#[test]
fn register_then_get() {
struct Custom;
impl Fetcher for Custom {
fn fetch(
&self,
_uri: &super::super::uri::ParsedUri,
_dest: &std::path::Path,
) -> Result<(), super::super::FetchError> {
unreachable!("test fetcher: fetch shouldn't be called")
}
fn schemes(&self) -> &'static [&'static str] {
&["custom"]
}
}
let mut r = FetcherRegistry::new();
r.register(Arc::new(Custom));
assert!(r.contains("custom"));
let _f = r.get("custom").expect("custom fetcher present");
}
#[test]
fn register_overwrites_on_scheme_collision() {
struct A;
impl Fetcher for A {
fn fetch(
&self,
_uri: &super::super::uri::ParsedUri,
_dest: &std::path::Path,
) -> Result<(), super::super::FetchError> {
Err(super::super::FetchError::Other {
message: "A".into(),
})
}
fn schemes(&self) -> &'static [&'static str] {
&["shared"]
}
}
struct B;
impl Fetcher for B {
fn fetch(
&self,
_uri: &super::super::uri::ParsedUri,
_dest: &std::path::Path,
) -> Result<(), super::super::FetchError> {
Err(super::super::FetchError::Other {
message: "B".into(),
})
}
fn schemes(&self) -> &'static [&'static str] {
&["shared"]
}
}
let mut r = FetcherRegistry::new();
r.register(Arc::new(A));
r.register(Arc::new(B));
let f = r.get("shared").unwrap();
let dummy = super::super::uri::ParsedUri::parse("shared:x").unwrap();
let tmp = tempfile::tempdir().unwrap();
let err = f.fetch(&dummy, tmp.path()).unwrap_err();
match err {
super::super::FetchError::Other { message } => assert_eq!(message, "B"),
other => panic!("expected Other(B), got: {other}"),
}
}
}