Skip to main content

docs_mcp/
cache.rs

1use std::path::{Path, PathBuf};
2use std::time::{Duration, SystemTime};
3
4use directories::ProjectDirs;
5use hex::encode as hex_encode;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8
9use crate::error::{DocsError, Result};
10
11const CACHE_TTL_SECS: u64 = 24 * 60 * 60; // 1 day
12
13#[derive(Serialize, Deserialize)]
14struct CacheEntry {
15    cached_at: u64, // Unix timestamp (secs)
16    url: String,
17    body: String, // JSON body as string
18}
19
20pub struct DiskCache {
21    cache_dir: PathBuf,
22}
23
24impl DiskCache {
25    pub fn new() -> Result<Self> {
26        let cache_dir = resolve_cache_dir()?;
27        std::fs::create_dir_all(&cache_dir)?;
28        let cache = Self { cache_dir };
29        cache.prune_expired()?;
30        Ok(cache)
31    }
32
33    fn cache_path(&self, key: &str) -> PathBuf {
34        self.cache_dir.join(format!("{key}.json"))
35    }
36
37    fn cache_key(url: &str) -> String {
38        let mut hasher = Sha256::new();
39        hasher.update(url.as_bytes());
40        hex_encode(hasher.finalize())
41    }
42
43    pub async fn get_json<T>(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<T>
44    where
45        T: serde::de::DeserializeOwned,
46    {
47        let key = Self::cache_key(url);
48        let path = self.cache_path(&key);
49
50        if let Some(body) = self.read_valid_cache(&path)? {
51            return serde_json::from_str(&body).map_err(DocsError::Json);
52        }
53
54        let resp = client.get(url).send().await?;
55        if !resp.status().is_success() {
56            return Err(DocsError::Other(format!(
57                "HTTP {} for {}",
58                resp.status(),
59                url
60            )));
61        }
62        let body = resp.text().await?;
63        let value = serde_json::from_str(&body).map_err(DocsError::Json)?;
64        self.write_cache(&path, url, &body)?;
65        Ok(value)
66    }
67
68    /// Download a zstd-compressed JSON file and return the deserialized value.
69    ///
70    /// docs.rs serves rustdoc JSON as `Content-Type: application/zstd` bodies.
71    /// The decompressed JSON text is cached so repeat calls skip the download.
72    pub async fn get_zstd_json<T>(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<T>
73    where
74        T: serde::de::DeserializeOwned,
75    {
76        let key = Self::cache_key(url);
77        let path = self.cache_path(&key);
78
79        if let Some(body) = self.read_valid_cache(&path)? {
80            return serde_json::from_str(&body).map_err(DocsError::Json);
81        }
82
83        let resp = client.get(url).send().await?;
84        if !resp.status().is_success() {
85            return Err(DocsError::Other(format!(
86                "HTTP {} for {}",
87                resp.status(),
88                url
89            )));
90        }
91        let bytes = resp.bytes().await?;
92        let body = decompress_zstd(&bytes)?;
93        let value = serde_json::from_str(&body).map_err(DocsError::Json)?;
94        self.write_cache(&path, url, &body)?;
95        Ok(value)
96    }
97
98    pub async fn get_text(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<String> {
99        let key = Self::cache_key(url);
100        let path = self.cache_path(&key);
101
102        if let Some(body) = self.read_valid_cache(&path)? {
103            // body was stored as JSON string, decode it
104            return serde_json::from_str::<String>(&body).map_err(DocsError::Json);
105        }
106
107        let resp = client.get(url).send().await?;
108        if !resp.status().is_success() {
109            return Err(DocsError::Other(format!(
110                "HTTP {} for {}",
111                resp.status(),
112                url
113            )));
114        }
115        let text = resp.text().await?;
116        // Store text as JSON string
117        let body = serde_json::to_string(&text)?;
118        self.write_cache(&path, url, &body)?;
119        Ok(text)
120    }
121
122    /// Returns true if URL returns success (200), false for 404, error for other failures.
123    pub async fn head_check(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<bool> {
124        let resp = client.head(url).send().await?;
125        Ok(resp.status().is_success())
126    }
127
128    fn read_valid_cache(&self, path: &Path) -> Result<Option<String>> {
129        if !path.exists() {
130            return Ok(None);
131        }
132        let raw = std::fs::read_to_string(path)?;
133        let entry: CacheEntry = match serde_json::from_str(&raw) {
134            Ok(e) => e,
135            Err(_) => {
136                let _ = std::fs::remove_file(path);
137                return Ok(None);
138            }
139        };
140        let now = unix_now();
141        if now.saturating_sub(entry.cached_at) > CACHE_TTL_SECS {
142            let _ = std::fs::remove_file(path);
143            return Ok(None);
144        }
145        Ok(Some(entry.body))
146    }
147
148    fn write_cache(&self, path: &Path, url: &str, body: &str) -> Result<()> {
149        let entry = CacheEntry {
150            cached_at: unix_now(),
151            url: url.to_string(),
152            body: body.to_string(),
153        };
154        let raw = serde_json::to_string(&entry)?;
155        std::fs::write(path, raw)?;
156        Ok(())
157    }
158
159    fn prune_expired(&self) -> Result<()> {
160        let now = unix_now();
161        let Ok(entries) = std::fs::read_dir(&self.cache_dir) else {
162            return Ok(());
163        };
164        for entry in entries.flatten() {
165            let path = entry.path();
166            if path.extension().and_then(|e| e.to_str()) != Some("json") {
167                continue;
168            }
169            if let Ok(raw) = std::fs::read_to_string(&path) {
170                if let Ok(entry) = serde_json::from_str::<CacheEntry>(&raw) {
171                    if now.saturating_sub(entry.cached_at) > CACHE_TTL_SECS {
172                        let _ = std::fs::remove_file(&path);
173                    }
174                }
175            }
176        }
177        Ok(())
178    }
179}
180
181/// Decompress a zstd-compressed byte slice and return it as a UTF-8 string.
182///
183/// docs.rs serves rustdoc JSON as `Content-Type: application/zstd` with a
184/// `.json.zst` filename. This decompresses the raw bytes to a JSON string.
185pub fn decompress_zstd(bytes: &[u8]) -> Result<String> {
186    let decompressed = zstd::decode_all(std::io::Cursor::new(bytes))
187        .map_err(|e| DocsError::Other(format!("Zstd decompression failed: {e}")))?;
188    String::from_utf8(decompressed)
189        .map_err(|e| DocsError::Other(format!("Decompressed content is not valid UTF-8: {e}")))
190}
191
192fn unix_now() -> u64 {
193    SystemTime::now()
194        .duration_since(SystemTime::UNIX_EPOCH)
195        .unwrap_or(Duration::ZERO)
196        .as_secs()
197}
198
199fn resolve_cache_dir() -> Result<PathBuf> {
200    if let Some(dirs) = ProjectDirs::from("", "", "docs-mcp") {
201        Ok(dirs.cache_dir().to_path_buf())
202    } else {
203        // Fallback to current directory
204        Ok(PathBuf::from(".cache/docs-mcp"))
205    }
206}