Skip to main content

chub_core/
fetch.rs

1use std::fs;
2use std::path::PathBuf;
3
4use sha2::{Digest, Sha256};
5
6use crate::cache::{
7    get_source_data_dir, get_source_dir, get_source_registry_path, get_source_search_index_path,
8    read_cached_doc, read_meta, save_cached_doc, save_source_registry,
9    should_fetch_remote_registry, write_meta,
10};
11use crate::config::{load_config, SourceConfig};
12use crate::error::{Error, Result};
13
14const FETCH_TIMEOUT_SECS: u64 = 30;
15
16/// Maximum size for registry.json downloads (50 MB).
17const MAX_REGISTRY_SIZE: usize = 50 * 1024 * 1024;
18/// Maximum size for bundle.tar.gz downloads (500 MB).
19const MAX_BUNDLE_SIZE: usize = 500 * 1024 * 1024;
20/// Maximum size for individual doc downloads (10 MB).
21const MAX_DOC_SIZE: usize = 10 * 1024 * 1024;
22
23/// Fetch registry for a single remote source.
24pub async fn fetch_remote_registry(source: &SourceConfig, force: bool) -> Result<()> {
25    if !force && !should_fetch_remote_registry(&source.name) {
26        return Ok(());
27    }
28
29    let url = format!(
30        "{}/registry.json",
31        source.url.as_deref().unwrap_or_default()
32    );
33
34    let client = reqwest::Client::builder()
35        .timeout(std::time::Duration::from_secs(FETCH_TIMEOUT_SECS))
36        .build()
37        .map_err(|e| Error::Config(format!("HTTP client error: {}", e)))?;
38
39    let res = client.get(&url).send().await.map_err(|e| {
40        Error::Config(format!(
41            "Failed to fetch registry from {}: {}",
42            source.name, e
43        ))
44    })?;
45
46    if !res.status().is_success() {
47        return Err(Error::Config(format!(
48            "Failed to fetch registry from {}: {} {}",
49            source.name,
50            res.status().as_u16(),
51            res.status().canonical_reason().unwrap_or("")
52        )));
53    }
54
55    let data = read_response_limited(res, MAX_REGISTRY_SIZE, "registry").await?;
56
57    save_source_registry(&source.name, &data);
58    crate::cache::touch_source_meta(&source.name);
59    Ok(())
60}
61
62/// Fetch registries for all configured sources.
63pub async fn fetch_all_registries(force: bool) -> Vec<FetchError> {
64    let config = load_config();
65    let mut errors = Vec::new();
66
67    for source in &config.sources {
68        if source.path.is_some() {
69            continue;
70        }
71        if let Err(e) = fetch_remote_registry(source, force).await {
72            errors.push(FetchError {
73                source: source.name.clone(),
74                error: e.to_string(),
75            });
76        }
77    }
78
79    errors
80}
81
82#[derive(Debug, Clone, serde::Serialize)]
83pub struct FetchError {
84    pub source: String,
85    pub error: String,
86}
87
88/// Download full bundle for a remote source.
89pub async fn fetch_full_bundle(source_name: &str) -> Result<()> {
90    let config = load_config();
91    let source = config
92        .sources
93        .iter()
94        .find(|s| s.name == source_name)
95        .ok_or_else(|| Error::Config(format!("Source \"{}\" not found", source_name)))?;
96
97    if source.path.is_some() {
98        return Err(Error::Config(format!(
99            "Source \"{}\" is not a remote source.",
100            source_name
101        )));
102    }
103
104    let url = format!(
105        "{}/bundle.tar.gz",
106        source.url.as_deref().unwrap_or_default()
107    );
108
109    let client = reqwest::Client::builder()
110        .timeout(std::time::Duration::from_secs(FETCH_TIMEOUT_SECS))
111        .build()
112        .map_err(|e| Error::Config(format!("HTTP client error: {}", e)))?;
113
114    let res = client.get(&url).send().await.map_err(|e| {
115        Error::Config(format!(
116            "Failed to fetch bundle from {}: {}",
117            source_name, e
118        ))
119    })?;
120
121    if !res.status().is_success() {
122        return Err(Error::Config(format!(
123            "Failed to fetch bundle from {}: {} {}",
124            source_name,
125            res.status().as_u16(),
126            res.status().canonical_reason().unwrap_or("")
127        )));
128    }
129
130    let bytes = read_response_bytes_limited(res, MAX_BUNDLE_SIZE, "bundle").await?;
131
132    let source_dir = get_source_dir(source_name);
133    fs::create_dir_all(&source_dir)?;
134
135    // Use a unique temp file name to avoid predictable-name attacks
136    let tmp_name = format!("bundle.{}.tar.gz", std::process::id());
137    let tmp_path = source_dir.join(&tmp_name);
138    fs::write(&tmp_path, &bytes)?;
139
140    // Extract tar.gz with path validation
141    let data_dir = get_source_data_dir(source_name);
142    fs::create_dir_all(&data_dir)?;
143
144    let file = fs::File::open(&tmp_path)?;
145    let gz = flate2::read::GzDecoder::new(file);
146    let mut archive = tar::Archive::new(gz);
147
148    // Validate each entry path before extraction to prevent path traversal
149    for entry_result in archive.entries()? {
150        let mut entry = entry_result?;
151        let entry_path = entry.path()?.to_path_buf();
152        let entry_str = entry_path.to_string_lossy();
153
154        // Reject absolute paths, paths with "..", and paths with backslashes
155        if entry_path.is_absolute() || entry_str.contains("..") || entry_str.contains('\\') {
156            return Err(Error::Config(format!(
157                "Malicious tar entry rejected: \"{}\"",
158                entry_str
159            )));
160        }
161
162        let target = data_dir.join(&entry_path);
163        if let Some(parent) = target.parent() {
164            fs::create_dir_all(parent)?;
165        }
166        entry.unpack(&target)?;
167    }
168
169    // Copy registry.json from extracted bundle if present
170    let extracted_registry = data_dir.join("registry.json");
171    if extracted_registry.exists() {
172        let reg_data = fs::read_to_string(&extracted_registry)?;
173        fs::write(get_source_registry_path(source_name), &reg_data)?;
174    }
175
176    // Copy search-index.json from extracted bundle if present
177    let extracted_search_index = data_dir.join("search-index.json");
178    if extracted_search_index.exists() {
179        let idx_data = fs::read_to_string(&extracted_search_index)?;
180        fs::write(get_source_search_index_path(source_name), &idx_data)?;
181    } else {
182        let _ = fs::remove_file(get_source_search_index_path(source_name));
183    }
184
185    // Update meta
186    let mut meta = read_meta(source_name);
187    meta.last_updated = Some(
188        std::time::SystemTime::now()
189            .duration_since(std::time::UNIX_EPOCH)
190            .unwrap_or_default()
191            .as_millis() as u64,
192    );
193    meta.full_bundle = true;
194    write_meta(source_name, &meta);
195
196    // Clean up temp file
197    let _ = fs::remove_file(&tmp_path);
198
199    Ok(())
200}
201
202/// Fetch a single doc. Source must have name + (url or path).
203pub async fn fetch_doc(source: &SourceConfig, doc_path: &str) -> Result<String> {
204    // Local source: read directly
205    if let Some(ref local_path) = source.path {
206        let full_path = PathBuf::from(local_path).join(doc_path);
207        if !full_path.exists() {
208            return Err(Error::NotFound(format!(
209                "File not found: {}",
210                full_path.display()
211            )));
212        }
213        return Ok(fs::read_to_string(&full_path)?);
214    }
215
216    // Remote source: check cache first
217    if let Some(content) = read_cached_doc(&source.name, doc_path) {
218        return Ok(content);
219    }
220
221    // Fetch from CDN
222    let url = format!("{}/{}", source.url.as_deref().unwrap_or_default(), doc_path);
223
224    let client = reqwest::Client::builder()
225        .timeout(std::time::Duration::from_secs(FETCH_TIMEOUT_SECS))
226        .build()
227        .map_err(|e| Error::Config(format!("HTTP client error: {}", e)))?;
228
229    let res = client.get(&url).send().await.map_err(|e| {
230        Error::Config(format!(
231            "Failed to fetch {} from {}: {}",
232            doc_path, source.name, e
233        ))
234    })?;
235
236    if !res.status().is_success() {
237        return Err(Error::Config(format!(
238            "Failed to fetch {} from {}: {} {}",
239            doc_path,
240            source.name,
241            res.status().as_u16(),
242            res.status().canonical_reason().unwrap_or("")
243        )));
244    }
245
246    let content = read_response_limited(res, MAX_DOC_SIZE, "doc").await?;
247
248    // Cache locally
249    save_cached_doc(&source.name, doc_path, &content);
250
251    Ok(content)
252}
253
254/// Fetch all files in an entry directory. Returns vec of (filename, content).
255pub async fn fetch_doc_full(
256    source: &SourceConfig,
257    base_path: &str,
258    files: &[String],
259) -> Result<Vec<(String, String)>> {
260    let mut results = Vec::new();
261    for file in files {
262        let file_path = format!("{}/{}", base_path, file);
263        let content = fetch_doc(source, &file_path).await?;
264        results.push((file.clone(), content));
265    }
266    Ok(results)
267}
268
269/// Read a text response body with a size limit.
270async fn read_response_limited(
271    res: reqwest::Response,
272    max_bytes: usize,
273    kind: &str,
274) -> Result<String> {
275    // Check Content-Length header first (if present)
276    if let Some(len) = res.content_length() {
277        if len as usize > max_bytes {
278            return Err(Error::Config(format!(
279                "Response too large for {} ({} bytes, max {})",
280                kind, len, max_bytes
281            )));
282        }
283    }
284
285    let bytes = read_response_bytes_limited(res, max_bytes, kind).await?;
286    String::from_utf8(bytes)
287        .map_err(|_| Error::Config(format!("Invalid UTF-8 in {} response", kind)))
288}
289
290/// Read a binary response body with a size limit.
291async fn read_response_bytes_limited(
292    res: reqwest::Response,
293    max_bytes: usize,
294    kind: &str,
295) -> Result<Vec<u8>> {
296    // Check Content-Length header first (if present)
297    if let Some(len) = res.content_length() {
298        if len as usize > max_bytes {
299            return Err(Error::Config(format!(
300                "Response too large for {} ({} bytes, max {})",
301                kind, len, max_bytes
302            )));
303        }
304    }
305
306    let bytes = res
307        .bytes()
308        .await
309        .map_err(|e| Error::Config(format!("Failed to read {} body: {}", kind, e)))?;
310
311    if bytes.len() > max_bytes {
312        return Err(Error::Config(format!(
313            "Response too large for {} ({} bytes, max {})",
314            kind,
315            bytes.len(),
316            max_bytes
317        )));
318    }
319
320    Ok(bytes.to_vec())
321}
322
323/// Verify fetched content against an expected SHA-256 hash.
324/// Returns Ok(content) if hash matches or no hash was provided.
325/// Returns Err if hash mismatch (content tampering detected).
326pub fn verify_content_hash(
327    content: &str,
328    expected_hash: Option<&str>,
329    doc_path: &str,
330) -> Result<()> {
331    if let Some(expected) = expected_hash {
332        let actual = format!("{:x}", Sha256::digest(content.as_bytes()));
333        if actual != expected {
334            return Err(Error::Config(format!(
335                "Content integrity check failed for \"{}\": expected hash {}, got {}",
336                doc_path, expected, actual
337            )));
338        }
339    }
340    Ok(())
341}
342
343/// Ensure at least one registry is available.
344pub async fn ensure_registry() -> Result<()> {
345    if crate::cache::has_any_registry() {
346        // Auto-refresh stale remote registries (best-effort)
347        let config = load_config();
348        for source in &config.sources {
349            if source.path.is_some() {
350                continue;
351            }
352            if should_fetch_remote_registry(&source.name) {
353                let _ = fetch_remote_registry(source, false).await;
354            }
355        }
356        return Ok(());
357    }
358
359    // No registries at all — must download from remote
360    let errors = fetch_all_registries(true).await;
361    if !errors.is_empty() && !crate::cache::has_any_registry() {
362        return Err(Error::Config(format!(
363            "Failed to fetch registries: {}",
364            errors
365                .iter()
366                .map(|e| format!("{}: {}", e.source, e.error))
367                .collect::<Vec<_>>()
368                .join("; ")
369        )));
370    }
371
372    Ok(())
373}