1use std::path::{Path, PathBuf};
2use std::time::{Duration, SystemTime};
3
4use directories::ProjectDirs;
5use hex::encode as hex_encode;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8
9use crate::error::{DocsError, Result};
10
11const CACHE_TTL_SECS: u64 = 24 * 60 * 60; #[derive(Serialize, Deserialize)]
14struct CacheEntry {
15 cached_at: u64, url: String,
17 body: String, }
19
20pub struct DiskCache {
21 cache_dir: PathBuf,
22}
23
24impl DiskCache {
25 pub fn new() -> Result<Self> {
26 let cache_dir = resolve_cache_dir()?;
27 std::fs::create_dir_all(&cache_dir)?;
28 let cache = Self { cache_dir };
29 cache.prune_expired()?;
30 Ok(cache)
31 }
32
33 fn cache_path(&self, key: &str) -> PathBuf {
34 self.cache_dir.join(format!("{key}.json"))
35 }
36
37 fn cache_key(url: &str) -> String {
38 let mut hasher = Sha256::new();
39 hasher.update(url.as_bytes());
40 hex_encode(hasher.finalize())
41 }
42
43 pub async fn get_json<T>(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<T>
44 where
45 T: serde::de::DeserializeOwned,
46 {
47 let key = Self::cache_key(url);
48 let path = self.cache_path(&key);
49
50 if let Some(body) = self.read_valid_cache(&path)? {
51 return serde_json::from_str(&body).map_err(DocsError::Json);
52 }
53
54 let resp = client.get(url).send().await?;
55 if !resp.status().is_success() {
56 return Err(DocsError::Other(format!(
57 "HTTP {} for {}",
58 resp.status(),
59 url
60 )));
61 }
62 let body = resp.text().await?;
63 let value = serde_json::from_str(&body).map_err(DocsError::Json)?;
64 self.write_cache(&path, url, &body)?;
65 Ok(value)
66 }
67
68 pub async fn get_zstd_json<T>(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<T>
73 where
74 T: serde::de::DeserializeOwned,
75 {
76 let key = Self::cache_key(url);
77 let path = self.cache_path(&key);
78
79 if let Some(body) = self.read_valid_cache(&path)? {
80 return serde_json::from_str(&body).map_err(DocsError::Json);
81 }
82
83 let resp = client.get(url).send().await?;
84 if !resp.status().is_success() {
85 return Err(DocsError::Other(format!(
86 "HTTP {} for {}",
87 resp.status(),
88 url
89 )));
90 }
91 let bytes = resp.bytes().await?;
92 let body = decompress_zstd(&bytes)?;
93 let value = serde_json::from_str(&body).map_err(DocsError::Json)?;
94 self.write_cache(&path, url, &body)?;
95 Ok(value)
96 }
97
98 pub async fn get_text(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<String> {
99 let key = Self::cache_key(url);
100 let path = self.cache_path(&key);
101
102 if let Some(body) = self.read_valid_cache(&path)? {
103 return serde_json::from_str::<String>(&body).map_err(DocsError::Json);
105 }
106
107 let resp = client.get(url).send().await?;
108 if !resp.status().is_success() {
109 return Err(DocsError::Other(format!(
110 "HTTP {} for {}",
111 resp.status(),
112 url
113 )));
114 }
115 let text = resp.text().await?;
116 let body = serde_json::to_string(&text)?;
118 self.write_cache(&path, url, &body)?;
119 Ok(text)
120 }
121
122 pub async fn head_check(&self, client: &reqwest_middleware::ClientWithMiddleware, url: &str) -> Result<bool> {
124 let resp = client.head(url).send().await?;
125 Ok(resp.status().is_success())
126 }
127
128 fn read_valid_cache(&self, path: &Path) -> Result<Option<String>> {
129 if !path.exists() {
130 return Ok(None);
131 }
132 let raw = std::fs::read_to_string(path)?;
133 let entry: CacheEntry = match serde_json::from_str(&raw) {
134 Ok(e) => e,
135 Err(_) => {
136 let _ = std::fs::remove_file(path);
137 return Ok(None);
138 }
139 };
140 let now = unix_now();
141 if now.saturating_sub(entry.cached_at) > CACHE_TTL_SECS {
142 let _ = std::fs::remove_file(path);
143 return Ok(None);
144 }
145 Ok(Some(entry.body))
146 }
147
148 fn write_cache(&self, path: &Path, url: &str, body: &str) -> Result<()> {
149 let entry = CacheEntry {
150 cached_at: unix_now(),
151 url: url.to_string(),
152 body: body.to_string(),
153 };
154 let raw = serde_json::to_string(&entry)?;
155 std::fs::write(path, raw)?;
156 Ok(())
157 }
158
159 fn prune_expired(&self) -> Result<()> {
160 let now = unix_now();
161 let Ok(entries) = std::fs::read_dir(&self.cache_dir) else {
162 return Ok(());
163 };
164 for entry in entries.flatten() {
165 let path = entry.path();
166 if path.extension().and_then(|e| e.to_str()) != Some("json") {
167 continue;
168 }
169 if let Ok(raw) = std::fs::read_to_string(&path) {
170 if let Ok(entry) = serde_json::from_str::<CacheEntry>(&raw) {
171 if now.saturating_sub(entry.cached_at) > CACHE_TTL_SECS {
172 let _ = std::fs::remove_file(&path);
173 }
174 }
175 }
176 }
177 Ok(())
178 }
179}
180
181pub fn decompress_zstd(bytes: &[u8]) -> Result<String> {
186 let decompressed = zstd::decode_all(std::io::Cursor::new(bytes))
187 .map_err(|e| DocsError::Other(format!("Zstd decompression failed: {e}")))?;
188 String::from_utf8(decompressed)
189 .map_err(|e| DocsError::Other(format!("Decompressed content is not valid UTF-8: {e}")))
190}
191
192fn unix_now() -> u64 {
193 SystemTime::now()
194 .duration_since(SystemTime::UNIX_EPOCH)
195 .unwrap_or(Duration::ZERO)
196 .as_secs()
197}
198
199fn resolve_cache_dir() -> Result<PathBuf> {
200 if let Some(dirs) = ProjectDirs::from("", "", "docs-mcp") {
201 Ok(dirs.cache_dir().to_path_buf())
202 } else {
203 Ok(PathBuf::from(".cache/docs-mcp"))
205 }
206}