use async_trait::async_trait;
use reqwest::Client;
use std::collections::HashMap;
use std::time::Duration;
pub trait ProgressHook: Send + Sync {
fn on_download_progress(&self, downloaded: u64, total: Option<u64>);
fn on_processing_start(&self, message: &str);
fn on_processing_progress(&self, current: usize, total: usize);
}
pub struct NoOpProgressHook;
impl ProgressHook for NoOpProgressHook {
fn on_download_progress(&self, _downloaded: u64, _total: Option<u64>) {}
fn on_processing_start(&self, _message: &str) {}
fn on_processing_progress(&self, _current: usize, _total: usize) {}
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub headers: HashMap<String, String>,
pub timeout: Option<u64>,
pub max_redirects: Option<u32>,
pub user_agent: Option<String>,
pub accept_invalid_certs: bool,
pub credentials: HashMap<String, String>,
pub provider_settings: HashMap<String, String>,
pub max_file_size: Option<u64>,
pub use_compression: bool,
pub proxy: Option<String>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
headers: HashMap::new(),
timeout: Some(300), max_redirects: Some(10),
user_agent: Some("bytes-radar/1.0.0".to_string()),
accept_invalid_certs: false,
credentials: HashMap::new(),
provider_settings: HashMap::new(),
max_file_size: Some(100 * 1024 * 1024), use_compression: true,
proxy: None,
}
}
}
impl ProviderConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
pub fn with_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
pub fn with_accept_invalid_certs(mut self, accept: bool) -> Self {
self.accept_invalid_certs = accept;
self
}
pub fn with_credential(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.credentials.insert(key.into(), value.into());
self
}
pub fn with_provider_setting(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.provider_settings.insert(key.into(), value.into());
self
}
pub fn with_max_file_size(mut self, size: u64) -> Self {
self.max_file_size = Some(size);
self
}
pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ParsedRepository {
pub owner: String,
pub repo: String,
pub branch_or_commit: Option<String>,
pub is_commit: bool,
pub project_name: String,
pub host: Option<String>,
}
impl ParsedRepository {
pub fn new(owner: String, repo: String) -> Self {
let project_name = format!("{}@main", repo);
Self {
owner,
repo,
branch_or_commit: None,
is_commit: false,
project_name,
host: None,
}
}
pub fn with_branch(mut self, branch: String) -> Self {
self.project_name = format!("{}@{}", self.repo, branch);
self.branch_or_commit = Some(branch);
self.is_commit = false;
self
}
pub fn with_commit(mut self, commit: String) -> Self {
let short_commit = &commit[..7.min(commit.len())];
self.project_name = format!("{}@{}", self.repo, short_commit);
self.branch_or_commit = Some(commit);
self.is_commit = true;
self
}
pub fn with_host(mut self, host: String) -> Self {
self.host = Some(host);
self
}
}
#[async_trait]
pub trait GitProvider: Send + Sync {
fn name(&self) -> &'static str;
fn can_handle(&self, url: &str) -> bool;
fn parse_url(&self, url: &str) -> Option<ParsedRepository>;
fn build_download_urls(&self, parsed: &ParsedRepository) -> Vec<String>;
async fn get_default_branch(
&self,
client: &Client,
parsed: &ParsedRepository,
) -> Option<String>;
fn apply_config(&mut self, config: &ProviderConfig);
fn get_project_name(&self, url: &str) -> String;
fn build_client(
&self,
config: &ProviderConfig,
) -> Result<Client, Box<dyn std::error::Error + Send + Sync>> {
let mut builder = Client::builder();
if let Some(ref user_agent) = config.user_agent {
builder = builder.user_agent(user_agent);
}
if let Some(timeout) = config.timeout {
#[cfg(not(target_arch = "wasm32"))]
{
builder = builder.timeout(Duration::from_secs(timeout));
}
}
#[cfg(not(target_arch = "wasm32"))]
if let Some(max_redirects) = config.max_redirects {
builder = builder.redirect(reqwest::redirect::Policy::limited(max_redirects as usize));
}
#[cfg(not(target_arch = "wasm32"))]
if config.accept_invalid_certs {
builder = builder.danger_accept_invalid_certs(true);
}
#[cfg(not(target_arch = "wasm32"))]
if !config.use_compression {
builder = builder.no_gzip();
builder = builder.no_brotli();
builder = builder.no_deflate();
}
#[cfg(not(target_arch = "wasm32"))]
if let Some(ref proxy) = config.proxy {
let proxy = reqwest::Proxy::all(proxy)?;
builder = builder.proxy(proxy);
}
let mut headers = reqwest::header::HeaderMap::new();
for (name, value) in &config.headers {
let header_name = reqwest::header::HeaderName::from_bytes(name.as_bytes())?;
let header_value = reqwest::header::HeaderValue::from_str(value)?;
headers.insert(header_name, header_value);
}
self.add_auth_headers(&mut headers, config)?;
if !headers.is_empty() {
builder = builder.default_headers(headers);
}
Ok(builder.build()?)
}
fn add_auth_headers(
&self,
_headers: &mut reqwest::header::HeaderMap,
_config: &ProviderConfig,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
fn validate_config(&self, config: &ProviderConfig) -> Result<(), String> {
if let Some(timeout) = config.timeout {
if timeout == 0 {
return Err("Timeout cannot be zero".to_string());
}
if timeout > 3600 {
return Err("Timeout cannot exceed 1 hour".to_string());
}
}
if let Some(max_file_size) = config.max_file_size {
if max_file_size > 1024 * 1024 * 1024 {
return Err("Max file size cannot exceed 1GB".to_string());
}
}
Ok(())
}
}