1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use zeroize::Zeroizing;
6
7use crate::backend::{Backend, HealthStatus};
8use crate::error::{BackendError, BackendErrorKind};
9use crate::metrics::{metrics_headers, MetricsProvider};
10use crate::session::session_headers;
11use crate::url_validator::validate_cachekitio_url;
12
13pub struct CachekitIO {
17 client: reqwest::Client,
18 api_key: Zeroizing<String>,
19 api_url: String,
20 metrics_provider: Option<MetricsProvider>,
21}
22
23impl std::fmt::Debug for CachekitIO {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("CachekitIO")
27 .field("api_url", &self.api_url)
28 .field("api_key", &"<redacted>")
29 .finish()
30 }
31}
32
33impl CachekitIO {
34 pub fn builder() -> CachekitIOBuilder {
36 CachekitIOBuilder::default()
37 }
38
39 pub fn api_url(&self) -> &str {
41 &self.api_url
42 }
43
44 pub(crate) fn client(&self) -> &reqwest::Client {
46 &self.client
47 }
48
49 pub(crate) fn api_key_str(&self) -> &str {
51 self.api_key.as_str()
52 }
53
54 #[allow(dead_code)]
56 pub(crate) fn metrics_provider(&self) -> Option<&MetricsProvider> {
57 self.metrics_provider.as_ref()
58 }
59
60 fn url(&self, key: &str) -> String {
65 let encoded = urlencoding::encode(key);
66 format!("{}/v1/cache/{}", self.api_url, encoded)
67 }
68
69 fn health_url(&self) -> String {
71 format!("{}/v1/cache/health", self.api_url)
72 }
73
74 pub(crate) fn with_standard_headers(
76 &self,
77 mut req: reqwest::RequestBuilder,
78 ) -> reqwest::RequestBuilder {
79 for (name, value) in session_headers() {
80 req = req.header(name, value);
81 }
82 for (name, value) in metrics_headers(self.metrics_provider.as_ref()) {
83 req = req.header(name, value);
84 }
85 req
86 }
87
88 pub(crate) async fn error_from_response(&self, resp: reqwest::Response) -> BackendError {
90 let status = resp.status().as_u16();
91 let body = resp.bytes().await.unwrap_or_default();
92 from_http_status_sanitized(status, &body, self.api_key.as_str())
93 }
94}
95
96pub(crate) fn reqwest_err_sanitized(e: reqwest::Error, api_key: &str) -> BackendError {
100 let kind = if e.is_timeout() {
101 BackendErrorKind::Timeout
102 } else {
103 BackendErrorKind::Transient
104 };
105 BackendError {
106 kind,
107 message: BackendError::sanitize_message(&e.to_string(), api_key),
108 source: Some(Box::new(e)),
109 }
110}
111
112pub(crate) fn from_http_status_sanitized(status: u16, body: &[u8], api_key: &str) -> BackendError {
114 let sanitized =
115 BackendError::sanitize_message(std::str::from_utf8(body).unwrap_or(""), api_key);
116 BackendError::from_http_status(status, sanitized.as_bytes())
117}
118
119#[cfg(not(target_arch = "wasm32"))]
122#[cfg_attr(not(feature = "unsync"), async_trait)]
123#[cfg_attr(feature = "unsync", async_trait(?Send))]
124impl Backend for CachekitIO {
125 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BackendError> {
126 let req = self.with_standard_headers(
127 self.client
128 .get(self.url(key))
129 .bearer_auth(self.api_key.as_str()),
130 );
131
132 let resp = req
133 .send()
134 .await
135 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
136
137 match resp.status().as_u16() {
138 200 => {
139 let bytes = resp
140 .bytes()
141 .await
142 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
143 Ok(Some(bytes.to_vec()))
144 }
145 404 => Ok(None),
146 _ => Err(self.error_from_response(resp).await),
147 }
148 }
149
150 async fn set(
151 &self,
152 key: &str,
153 value: Vec<u8>,
154 ttl: Option<Duration>,
155 ) -> Result<(), BackendError> {
156 let mut req = self
157 .client
158 .put(self.url(key))
159 .bearer_auth(self.api_key.as_str())
160 .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
161 .body(value);
162
163 if let Some(ttl) = ttl {
164 req = req.header("X-TTL", ttl.as_secs().to_string());
165 }
166
167 let req = self.with_standard_headers(req);
168
169 let resp = req
170 .send()
171 .await
172 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
173
174 let status = resp.status().as_u16();
175 if (200..300).contains(&status) {
176 Ok(())
177 } else {
178 Err(self.error_from_response(resp).await)
179 }
180 }
181
182 async fn delete(&self, key: &str) -> Result<bool, BackendError> {
183 let req = self.with_standard_headers(
184 self.client
185 .delete(self.url(key))
186 .bearer_auth(self.api_key.as_str()),
187 );
188
189 let resp = req
190 .send()
191 .await
192 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
193
194 match resp.status().as_u16() {
195 200 | 204 => Ok(true),
196 404 => Ok(false),
197 _ => Err(self.error_from_response(resp).await),
198 }
199 }
200
201 async fn exists(&self, key: &str) -> Result<bool, BackendError> {
202 let req = self.with_standard_headers(
203 self.client
204 .head(self.url(key))
205 .bearer_auth(self.api_key.as_str()),
206 );
207
208 let resp = req
209 .send()
210 .await
211 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
212
213 match resp.status().as_u16() {
214 200 => Ok(true),
215 404 => Ok(false),
216 status => Err(BackendError::from_http_status(status, &[])),
217 }
218 }
219
220 async fn health(&self) -> Result<HealthStatus, BackendError> {
221 let start = std::time::Instant::now();
222
223 let req = self.with_standard_headers(
224 self.client
225 .get(self.health_url())
226 .bearer_auth(self.api_key.as_str()),
227 );
228
229 let resp = req
230 .send()
231 .await
232 .map_err(|e| reqwest_err_sanitized(e, self.api_key.as_str()))?;
233
234 let latency = start.elapsed();
235 let status = resp.status().as_u16();
236
237 if (200..300).contains(&status) {
238 let mut details = HashMap::new();
239 details.insert("http_status".to_string(), status.to_string());
240 Ok(HealthStatus {
241 is_healthy: true,
242 latency_ms: latency.as_secs_f64() * 1000.0,
243 backend_type: "cachekitio".to_string(),
244 details,
245 })
246 } else {
247 Err(self.error_from_response(resp).await)
248 }
249 }
250}
251
252#[derive(Default)]
256#[must_use]
257pub struct CachekitIOBuilder {
258 api_key: Option<Zeroizing<String>>,
259 api_url: Option<String>,
260 allow_custom_host: bool,
261 metrics_provider: Option<MetricsProvider>,
262}
263
264impl CachekitIOBuilder {
265 pub fn api_key(mut self, key: impl Into<String>) -> Self {
267 self.api_key = Some(Zeroizing::new(key.into()));
268 self
269 }
270
271 pub fn api_url(mut self, url: impl Into<String>) -> Self {
273 self.api_url = Some(url.into());
274 self
275 }
276
277 pub fn allow_custom_host(mut self, allow: bool) -> Self {
279 self.allow_custom_host = allow;
280 self
281 }
282
283 pub fn metrics_provider(mut self, provider: MetricsProvider) -> Self {
285 self.metrics_provider = Some(provider);
286 self
287 }
288
289 pub fn build(self) -> Result<CachekitIO, crate::error::CachekitError> {
298 use crate::error::CachekitError;
299
300 let api_key = self
301 .api_key
302 .filter(|k| !k.is_empty())
303 .ok_or_else(|| CachekitError::Config("api_key is required".to_string()))?;
304
305 let api_url = self
306 .api_url
307 .unwrap_or_else(|| "https://api.cachekit.io".to_string());
308
309 validate_cachekitio_url(&api_url, self.allow_custom_host)?;
311
312 let api_url = api_url.trim_end_matches('/').to_string();
314
315 let client = reqwest::Client::builder();
316
317 #[cfg(not(target_arch = "wasm32"))]
318 let client = client
319 .use_rustls_tls()
320 .redirect(reqwest::redirect::Policy::none())
321 .timeout(Duration::from_secs(30))
322 .connect_timeout(Duration::from_secs(10));
323
324 let client = client
325 .build()
326 .map_err(|e| CachekitError::Config(format!("failed to build HTTP client: {e}")))?;
327
328 Ok(CachekitIO {
329 client,
330 api_key,
331 api_url,
332 metrics_provider: self.metrics_provider,
333 })
334 }
335}