Skip to main content

bytes_radar/net/
traits.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use std::collections::HashMap;
4use std::time::Duration;
5
6/// Progress hook trait for monitoring download and processing progress
7pub trait ProgressHook: Send + Sync {
8    /// Called when download progress is updated
9    ///
10    /// # Arguments
11    /// * `downloaded` - Number of bytes downloaded so far
12    /// * `total` - Total size in bytes (if known)
13    fn on_download_progress(&self, downloaded: u64, total: Option<u64>);
14
15    /// Called when processing starts with a status message
16    ///
17    /// # Arguments  
18    /// * `message` - Status message describing current operation
19    fn on_processing_start(&self, message: &str);
20
21    /// Called when processing progress is updated
22    ///
23    /// # Arguments
24    /// * `current` - Current item being processed
25    /// * `total` - Total items to process
26    fn on_processing_progress(&self, current: usize, total: usize);
27}
28
29/// No-operation progress hook that ignores all progress updates
30pub struct NoOpProgressHook;
31
32impl ProgressHook for NoOpProgressHook {
33    fn on_download_progress(&self, _downloaded: u64, _total: Option<u64>) {}
34    fn on_processing_start(&self, _message: &str) {}
35    fn on_processing_progress(&self, _current: usize, _total: usize) {}
36}
37
38/// Universal configuration for all Git providers
39#[derive(Debug, Clone)]
40pub struct ProviderConfig {
41    /// Custom HTTP headers to include in requests
42    pub headers: HashMap<String, String>,
43
44    /// Request timeout in seconds (None for default)
45    pub timeout: Option<u64>,
46
47    /// Maximum number of redirects to follow
48    pub max_redirects: Option<u32>,
49
50    /// User agent string to use for requests
51    pub user_agent: Option<String>,
52
53    /// Whether to accept invalid SSL certificates
54    pub accept_invalid_certs: bool,
55
56    /// Authentication credentials (varies by provider)
57    pub credentials: HashMap<String, String>,
58
59    /// Provider-specific settings
60    pub provider_settings: HashMap<String, String>,
61
62    /// Maximum file size to download in bytes
63    pub max_file_size: Option<u64>,
64
65    /// Whether to use compression for requests
66    pub use_compression: bool,
67
68    /// Custom proxy URL
69    pub proxy: Option<String>,
70}
71
72impl Default for ProviderConfig {
73    fn default() -> Self {
74        Self {
75            headers: HashMap::new(),
76            timeout: Some(300), // 5 minutes default
77            max_redirects: Some(10),
78            user_agent: Some("bytes-radar/1.0.0".to_string()),
79            accept_invalid_certs: false,
80            credentials: HashMap::new(),
81            provider_settings: HashMap::new(),
82            max_file_size: Some(100 * 1024 * 1024), // 100MB default
83            use_compression: true,
84            proxy: None,
85        }
86    }
87}
88
89impl ProviderConfig {
90    /// Create a new configuration with default values
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Set a custom header
96    ///
97    /// # Arguments
98    /// * `name` - Header name
99    /// * `value` - Header value
100    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
101        self.headers.insert(name.into(), value.into());
102        self
103    }
104
105    /// Set request timeout in seconds
106    ///
107    /// # Arguments
108    /// * `timeout` - Timeout in seconds
109    pub fn with_timeout(mut self, timeout: u64) -> Self {
110        self.timeout = Some(timeout);
111        self
112    }
113
114    /// Set user agent string
115    ///
116    /// # Arguments
117    /// * `user_agent` - User agent string
118    pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
119        self.user_agent = Some(user_agent.into());
120        self
121    }
122
123    /// Set whether to accept invalid SSL certificates
124    ///
125    /// # Arguments
126    /// * `accept` - Whether to accept invalid certificates
127    pub fn with_accept_invalid_certs(mut self, accept: bool) -> Self {
128        self.accept_invalid_certs = accept;
129        self
130    }
131
132    /// Set authentication credentials
133    ///
134    /// # Arguments
135    /// * `key` - Credential key (e.g., "token", "username")
136    /// * `value` - Credential value
137    pub fn with_credential(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
138        self.credentials.insert(key.into(), value.into());
139        self
140    }
141
142    /// Set provider-specific setting
143    ///
144    /// # Arguments
145    /// * `key` - Setting key
146    /// * `value` - Setting value
147    pub fn with_provider_setting(
148        mut self,
149        key: impl Into<String>,
150        value: impl Into<String>,
151    ) -> Self {
152        self.provider_settings.insert(key.into(), value.into());
153        self
154    }
155
156    /// Set maximum file size in bytes
157    ///
158    /// # Arguments
159    /// * `size` - Maximum file size in bytes
160    pub fn with_max_file_size(mut self, size: u64) -> Self {
161        self.max_file_size = Some(size);
162        self
163    }
164
165    /// Set proxy URL
166    ///
167    /// # Arguments
168    /// * `proxy` - Proxy URL
169    pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
170        self.proxy = Some(proxy.into());
171        self
172    }
173}
174
175/// Parsed repository information from a URL
176#[derive(Debug, Clone)]
177pub struct ParsedRepository {
178    /// Repository owner/organization
179    pub owner: String,
180
181    /// Repository name
182    pub repo: String,
183
184    /// Branch name or commit hash (if specified)
185    pub branch_or_commit: Option<String>,
186
187    /// Whether branch_or_commit is a commit hash
188    pub is_commit: bool,
189
190    /// Generated project name for display
191    pub project_name: String,
192
193    /// Host name (e.g., "github.com")
194    pub host: Option<String>,
195}
196
197impl ParsedRepository {
198    /// Create a new parsed repository with default main branch
199    ///
200    /// # Arguments
201    /// * `owner` - Repository owner
202    /// * `repo` - Repository name
203    pub fn new(owner: String, repo: String) -> Self {
204        let project_name = format!("{}@main", repo);
205        Self {
206            owner,
207            repo,
208            branch_or_commit: None,
209            is_commit: false,
210            project_name,
211            host: None,
212        }
213    }
214
215    /// Set the branch and update project name
216    ///
217    /// # Arguments
218    /// * `branch` - Branch name
219    pub fn with_branch(mut self, branch: String) -> Self {
220        self.project_name = format!("{}@{}", self.repo, branch);
221        self.branch_or_commit = Some(branch);
222        self.is_commit = false;
223        self
224    }
225
226    /// Set the commit hash and update project name
227    ///
228    /// # Arguments
229    /// * `commit` - Commit hash
230    pub fn with_commit(mut self, commit: String) -> Self {
231        let short_commit = &commit[..7.min(commit.len())];
232        self.project_name = format!("{}@{}", self.repo, short_commit);
233        self.branch_or_commit = Some(commit);
234        self.is_commit = true;
235        self
236    }
237
238    /// Set the host name
239    ///
240    /// # Arguments
241    /// * `host` - Host name
242    pub fn with_host(mut self, host: String) -> Self {
243        self.host = Some(host);
244        self
245    }
246}
247
248/// Git provider trait for handling different repository hosting services
249#[async_trait]
250pub trait GitProvider: Send + Sync {
251    /// Get the provider name (e.g., "github", "gitlab")
252    fn name(&self) -> &'static str;
253
254    /// Check if this provider can handle the given URL
255    ///
256    /// # Arguments
257    /// * `url` - URL to check
258    fn can_handle(&self, url: &str) -> bool;
259
260    /// Parse a URL into repository information
261    ///
262    /// # Arguments
263    /// * `url` - URL to parse
264    fn parse_url(&self, url: &str) -> Option<ParsedRepository>;
265
266    /// Build download URLs for the parsed repository
267    ///
268    /// # Arguments
269    /// * `parsed` - Parsed repository information
270    fn build_download_urls(&self, parsed: &ParsedRepository) -> Vec<String>;
271
272    /// Get the default branch for a repository (if supported)
273    ///
274    /// # Arguments
275    /// * `client` - HTTP client to use
276    /// * `parsed` - Parsed repository information
277    async fn get_default_branch(
278        &self,
279        client: &Client,
280        parsed: &ParsedRepository,
281    ) -> Option<String>;
282
283    /// Apply configuration to this provider
284    ///
285    /// # Arguments
286    /// * `config` - Configuration to apply
287    fn apply_config(&mut self, config: &ProviderConfig);
288
289    /// Get project name from URL
290    ///
291    /// # Arguments
292    /// * `url` - URL to extract project name from
293    fn get_project_name(&self, url: &str) -> String;
294
295    /// Build HTTP client with provider-specific configuration
296    ///
297    /// # Arguments
298    /// * `config` - Configuration to use
299    fn build_client(
300        &self,
301        config: &ProviderConfig,
302    ) -> Result<Client, Box<dyn std::error::Error + Send + Sync>> {
303        let mut builder = Client::builder();
304
305        // Set user agent
306        if let Some(ref user_agent) = config.user_agent {
307            builder = builder.user_agent(user_agent);
308        }
309
310        // Set timeout (works on both wasm and native)
311        if let Some(timeout) = config.timeout {
312            #[cfg(not(target_arch = "wasm32"))]
313            {
314                builder = builder.timeout(Duration::from_secs(timeout));
315            }
316        }
317
318        // Set redirects
319        #[cfg(not(target_arch = "wasm32"))]
320        if let Some(max_redirects) = config.max_redirects {
321            builder = builder.redirect(reqwest::redirect::Policy::limited(max_redirects as usize));
322        }
323
324        // Set SSL verification
325        #[cfg(not(target_arch = "wasm32"))]
326        if config.accept_invalid_certs {
327            builder = builder.danger_accept_invalid_certs(true);
328        }
329
330        // Set compression
331        #[cfg(not(target_arch = "wasm32"))]
332        if !config.use_compression {
333            builder = builder.no_gzip();
334            builder = builder.no_brotli();
335            builder = builder.no_deflate();
336        }
337
338        // Set proxy (only on native)
339        #[cfg(not(target_arch = "wasm32"))]
340        if let Some(ref proxy) = config.proxy {
341            let proxy = reqwest::Proxy::all(proxy)?;
342            builder = builder.proxy(proxy);
343        }
344
345        // Build default headers
346        let mut headers = reqwest::header::HeaderMap::new();
347
348        // Add custom headers
349        for (name, value) in &config.headers {
350            let header_name = reqwest::header::HeaderName::from_bytes(name.as_bytes())?;
351            let header_value = reqwest::header::HeaderValue::from_str(value)?;
352            headers.insert(header_name, header_value);
353        }
354
355        // Add provider-specific auth headers
356        self.add_auth_headers(&mut headers, config)?;
357
358        if !headers.is_empty() {
359            builder = builder.default_headers(headers);
360        }
361
362        Ok(builder.build()?)
363    }
364
365    /// Add authentication headers specific to this provider
366    ///
367    /// # Arguments
368    /// * `headers` - Header map to add to
369    /// * `config` - Configuration containing credentials
370    fn add_auth_headers(
371        &self,
372        _headers: &mut reqwest::header::HeaderMap,
373        _config: &ProviderConfig,
374    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
375        // Default implementation does nothing
376        // Providers can override this
377        Ok(())
378    }
379
380    /// Validate configuration for this provider
381    ///
382    /// # Arguments
383    /// * `config` - Configuration to validate
384    fn validate_config(&self, config: &ProviderConfig) -> Result<(), String> {
385        // Basic validation
386        if let Some(timeout) = config.timeout {
387            if timeout == 0 {
388                return Err("Timeout cannot be zero".to_string());
389            }
390            if timeout > 3600 {
391                return Err("Timeout cannot exceed 1 hour".to_string());
392            }
393        }
394
395        if let Some(max_file_size) = config.max_file_size {
396            if max_file_size > 1024 * 1024 * 1024 {
397                return Err("Max file size cannot exceed 1GB".to_string());
398            }
399        }
400
401        Ok(())
402    }
403}