gha_toolkit/
cache.rs

1//! # GitHub Actions cache client
2//!
3//! The [`CacheClient`] is an idiomatic Rust port of
4//! [@actions/cache](https://github.com/actions/cache).
5//!
6//! See [Caching dependencies to speed up
7//! workflows](https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows)
8//! for the official GitHub documentation.
9//!
10//! ```rust
11//! use std::io::Cursor;
12//! use std::time::SystemTime;
13//!
14//! # use gha_toolkit::cache::*;
15//! #
16//! # #[tokio::main(flavor = "current_thread")]
17//! # async fn main() -> anyhow::Result<()> {
18//! let version = SystemTime::UNIX_EPOCH.elapsed()?.as_millis();
19//! let cache_key = format!("gha-toolkit-{version:#x}");
20//!
21//! let client = CacheClient::from_env()?
22//!     .cache_from([&cache_key].into_iter())
23//!     .cache_to(&cache_key)
24//!     .build()?;
25//!
26//! let scope = "gha_toolkit::cache";
27//! let data = "Hello World!";
28//!
29//! // Create a new cache entry.
30//! client.put(scope, Cursor::new(data)).await?;
31//!
32//! // Read from the cache entry.
33//! let cache_entry = client.entry(scope).await?.expect("cache entry");
34//!
35//! // Fetch the cache data.
36//! let archive_location = cache_entry.archive_location.expect("archive location");
37//! let cached_bytes = client.get(&archive_location).await?;
38//!
39//! // Decode the cache data to a UTF-8 string.
40//! let cached_data = String::from_utf8(cached_bytes)?;
41//!
42//! assert_eq!(cached_data, data);
43//! # Ok(())
44//! # }
45//! ```
46
47use std::env;
48use std::io::{prelude::*, SeekFrom};
49use std::ops::DerefMut as _;
50use std::sync::Arc;
51use std::time::Duration;
52
53use async_lock::{Mutex, Semaphore};
54use bytes::Bytes;
55use futures::prelude::*;
56use http::{header, header::HeaderName, HeaderMap, HeaderValue, StatusCode};
57use hyperx::header::{ContentRange, ContentRangeSpec, Header as _};
58use reqwest::{Body, Url};
59use reqwest_middleware::ClientWithMiddleware;
60use reqwest_retry::policies::ExponentialBackoff;
61#[cfg(doc)]
62use reqwest_retry::policies::ExponentialBackoffBuilder;
63use reqwest_retry::RetryTransientMiddleware;
64use reqwest_retry_after::RetryAfterMiddleware;
65use reqwest_tracing::TracingMiddleware;
66use sha2::{Digest, Sha256};
67use tracing::{debug, instrument, warn};
68
69use crate::{Error, Result};
70
71use serde::{Deserialize, Serialize};
72
73const BASE_URL_PATH: &str = "/_apis/artifactcache/";
74const DEFAULT_USER_AGENT: &str = concat!(env!("CARGO_CRATE_NAME"), "/", env!("CARGO_PKG_VERSION"));
75const DEFAULT_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60);
76const DEFAULT_UPLOAD_TIMEOUT: Duration = Duration::from_secs(60);
77
78/// GitHub Actions cache entry.
79///
80/// See [module][self] documentation.
81#[derive(Deserialize, Serialize)]
82#[serde(rename_all = "camelCase")]
83pub struct ArtifactCacheEntry {
84    /// Cache key for looking up cache entries by the key prefix.
85    pub cache_key: Option<String>,
86
87    /// Scope for the cache entry, e.g. the source filename or a hash of the
88    /// source file(s).
89    pub scope: Option<String>,
90
91    /// Creation time for the cache entry.
92    pub creation_time: Option<String>,
93
94    /// URL for downloading the cache archive.
95    pub archive_location: Option<String>,
96}
97
98#[derive(Serialize)]
99#[serde(rename_all = "camelCase")]
100struct CommitCacheRequest {
101    pub size: i64,
102}
103
104#[derive(Serialize)]
105#[serde(rename_all = "camelCase")]
106struct ReserveCacheRequest<'a> {
107    pub key: &'a str,
108    pub version: &'a str,
109    pub cache_size: i64,
110}
111
112#[derive(Deserialize)]
113#[serde(rename_all = "camelCase")]
114struct ReserveCacheResponse {
115    pub cache_id: i64,
116}
117
118#[derive(Serialize)]
119#[serde(rename_all = "camelCase")]
120struct CacheQuery<'a> {
121    pub keys: &'a str,
122    pub version: &'a str,
123}
124
125/// GitHub Actions cache client builder.
126///
127/// See [module][self] documentation.
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub struct CacheClientBuilder {
130    /// GitHub Actions cache API base URL.
131    pub base_url: String,
132
133    /// GitHub Actions access token.
134    pub token: String,
135
136    /// User agent for HTTP requests.
137    pub user_agent: String,
138
139    /// Cache key to write.
140    pub cache_to: Option<String>,
141
142    /// Cache key prefixes to read.
143    pub cache_from: Vec<String>,
144
145    /// Maximum number of retries.
146    pub max_retries: u32,
147
148    /// Minimum retry interval. See [`ExponentialBackoff::min_retry_interval`].
149    pub min_retry_interval: Duration,
150
151    /// Maximum retry interval. See [`ExponentialBackoff::max_retry_interval`].
152    pub max_retry_interval: Duration,
153
154    /// Retry backoff factor base. See [`ExponentialBackoff::backoff_exponent`].
155    pub backoff_factor_base: u32,
156
157    /// Maximum chunk size in bytes for downloads.
158    pub download_chunk_size: u64,
159
160    /// Maximum time for each chunk download request.
161    pub download_chunk_timeout: Duration,
162
163    /// Number of parallel downloads.
164    pub download_concurrency: u32,
165
166    /// Maximum chunk size in bytes for uploads.
167    pub upload_chunk_size: u64,
168
169    /// Maximum time for each chunk upload request.
170    pub upload_chunk_timeout: Duration,
171
172    /// Number of parallel uploads.
173    pub upload_concurrency: u32,
174}
175
176impl Default for CacheClientBuilder {
177    fn default() -> Self {
178        Self {
179            base_url: Default::default(),
180            token: Default::default(),
181            user_agent: DEFAULT_USER_AGENT.into(),
182            cache_to: None,
183            cache_from: vec![],
184            max_retries: 2,
185            min_retry_interval: Duration::from_millis(50),
186            max_retry_interval: Duration::from_secs(10),
187            backoff_factor_base: 3,
188            download_chunk_size: 4 << 20, // 4 MiB
189            download_chunk_timeout: DEFAULT_DOWNLOAD_TIMEOUT,
190            download_concurrency: 8,
191            upload_concurrency: 4,
192            upload_chunk_size: 1 << 20, // 1 MiB
193            upload_chunk_timeout: DEFAULT_UPLOAD_TIMEOUT,
194        }
195    }
196}
197
198impl CacheClientBuilder {
199    /// Creates a new [`CacheClientBuilder`] for the given GitHub Actions cache
200    /// API base URL and access token.
201    pub fn new<B: Into<String>, T: Into<String>>(base_url: B, token: T) -> Self {
202        Self {
203            base_url: base_url.into(),
204            token: token.into(),
205            ..Default::default()
206        }
207    }
208
209    /// Creates a new [`CacheClientBuilder`] from GitHub Actions cache
210    /// environmental variables.
211    ///
212    /// The following environmental variables are read:
213    ///
214    /// - `ACTIONS_CACHE_URL` - GitHub Actions cache API base URL
215    /// - `ACTIONS_RUNTIME_TOKEN` - GitHub Actions access token
216    /// - `SEGMENT_DOWNLOAD_TIMEOUT_MINS` - download chunk timeout
217    ///
218    pub fn from_env() -> Result<Self> {
219        let url = env::var("ACTIONS_CACHE_URL").map_err(|source| Error::VarError {
220            source,
221            name: "ACTIONS_CACHE_URL",
222        })?;
223        let token = env::var("ACTIONS_RUNTIME_TOKEN").map_err(|source| Error::VarError {
224            source,
225            name: "ACTIONS_RUNTIME_TOKEN",
226        })?;
227
228        let mut builder = CacheClientBuilder::new(&url, &token);
229
230        if let Some(timeout) = std::env::var("SEGMENT_DOWNLOAD_TIMEOUT_MINS")
231            .ok()
232            .and_then(|s| s.parse().ok())
233            .map(|v: u64| Duration::from_secs(v * 60))
234        {
235            builder.download_chunk_timeout = timeout;
236        }
237
238        Ok(builder)
239    }
240
241    /// Sets the GitHub Actions cache API base URL.
242    pub fn base_url<T: Into<String>>(mut self, base_url: T) -> Self {
243        self.base_url = base_url.into();
244        self
245    }
246
247    /// Sets the cache key prefixes to read.
248    pub fn token<T: Into<String>>(mut self, token: T) -> Self {
249        self.token = token.into();
250        self
251    }
252
253    /// Sets the user agent for HTTP requests.
254    pub fn user_agent<T: Into<String>>(mut self, user_agent: T) -> Self {
255        self.user_agent = user_agent.into();
256        self
257    }
258
259    /// Sets the cache key to write.
260    pub fn cache_to<T: Into<String>>(mut self, cache_to: T) -> Self {
261        self.cache_to = Some(cache_to.into());
262        self
263    }
264    /// Sets the cache key prefixes to read.
265    pub fn cache_from<T>(mut self, cache_from: T) -> Self
266    where
267        T: Iterator,
268        T::Item: Into<String>,
269    {
270        self.cache_from = cache_from.map(Into::into).collect();
271        self
272    }
273
274    /// Sets the maximum number of retries.
275    pub fn max_retries(mut self, max_retries: u32) -> Self {
276        self.max_retries = max_retries;
277        self
278    }
279
280    /// Sets the minimum retry interval.
281    pub fn min_retry_interval(mut self, min_retry_interval: Duration) -> Self {
282        self.min_retry_interval = min_retry_interval;
283        self
284    }
285
286    /// Sets the maximum retry interval.
287    pub fn max_retry_interval(mut self, max_retry_interval: Duration) -> Self {
288        self.max_retry_interval = max_retry_interval;
289        self
290    }
291
292    /// Sets the retry backoff factor base.
293    pub fn backoff_factor_base(mut self, backoff_factor_base: u32) -> Self {
294        self.backoff_factor_base = backoff_factor_base;
295        self
296    }
297
298    /// Maximum chunk size in bytes for downloads.
299    pub fn download_chunk_size(mut self, download_chunk_size: u64) -> Self {
300        self.download_chunk_size = download_chunk_size;
301        self
302    }
303
304    /// Sets the maximum time for each chunk download request.
305    pub fn download_chunk_timeout(mut self, download_chunk_timeout: Duration) -> Self {
306        self.download_chunk_timeout = download_chunk_timeout;
307        self
308    }
309
310    /// Sets the number of parallel downloads.
311    pub fn download_concurrency(mut self, download_concurrency: u32) -> Self {
312        self.download_concurrency = download_concurrency;
313        self
314    }
315
316    /// Sets the maximum chunk size in bytes for uploads.
317    pub fn upload_chunk_size(mut self, upload_chunk_size: u64) -> Self {
318        self.upload_chunk_size = upload_chunk_size;
319        self
320    }
321
322    /// Sets the maximum time for each chunk upload request.
323    pub fn upload_chunk_timeout(mut self, upload_chunk_timeout: Duration) -> Self {
324        self.upload_chunk_timeout = upload_chunk_timeout;
325        self
326    }
327
328    /// Sets the number of parallel downloads.
329    pub fn upload_concurrency(mut self, upload_concurrency: u32) -> Self {
330        self.upload_concurrency = upload_concurrency;
331        self
332    }
333
334    /// Consumes this [`CacheClientBuilder`] and build a [`CacheClient`].
335    pub fn build(self) -> Result<CacheClient> {
336        self.try_into()
337    }
338}
339
340/// GitHub Actions cache client.
341///
342/// See [module][self] documentation.
343pub struct CacheClient {
344    client: ClientWithMiddleware,
345    base_url: Url,
346    api_headers: HeaderMap,
347
348    cache_to: Option<String>,
349    cache_from: Option<String>,
350
351    download_chunk_size: u64,
352    download_chunk_timeout: Duration,
353    download_concurrency: u32,
354
355    upload_chunk_size: u64,
356    upload_chunk_timeout: Duration,
357    upload_concurrency: u32,
358}
359
360impl TryInto<CacheClient> for CacheClientBuilder {
361    type Error = Error;
362
363    fn try_into(self) -> Result<CacheClient, Self::Error> {
364        if self.cache_to.is_none() && self.cache_from.is_empty() {
365            return Err(Error::MissingKey);
366        }
367
368        let cache_to = if let Some(cache_to) = self.cache_to {
369            check_key(&cache_to)?;
370            Some(cache_to)
371        } else {
372            None
373        };
374
375        let cache_from = if !self.cache_from.is_empty() {
376            for key in &self.cache_from {
377                check_key(key)?;
378            }
379            Some(self.cache_from.join(","))
380        } else {
381            None
382        };
383
384        let mut api_headers = HeaderMap::new();
385        api_headers.insert(
386            header::ACCEPT,
387            HeaderValue::from_static("application/json;api-version=6.0-preview.1"),
388        );
389
390        let auth_value = Bytes::from(format!("Bearer {}", self.token));
391        let mut auth_value = header::HeaderValue::from_maybe_shared(auth_value)?;
392        auth_value.set_sensitive(true);
393        api_headers.insert(http::header::AUTHORIZATION, auth_value);
394
395        let retry_policy = ExponentialBackoff::builder()
396            .retry_bounds(self.min_retry_interval, self.max_retry_interval)
397            .backoff_exponent(self.backoff_factor_base)
398            .build_with_max_retries(self.max_retries);
399
400        let client = reqwest::ClientBuilder::new()
401            .user_agent(self.user_agent)
402            .build()?;
403        let client = reqwest_middleware::ClientBuilder::new(client)
404            .with(TracingMiddleware::default())
405            .with(RetryAfterMiddleware::new())
406            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
407            .build();
408
409        let base_url = Url::parse(&format!(
410            "{}{}",
411            self.base_url.trim_end_matches('/'),
412            BASE_URL_PATH
413        ))?;
414
415        Ok(CacheClient {
416            client,
417            base_url,
418            api_headers,
419            cache_to,
420            cache_from,
421            download_chunk_size: self.download_chunk_size,
422            download_chunk_timeout: self.download_chunk_timeout,
423            download_concurrency: self.download_concurrency,
424            upload_concurrency: self.upload_concurrency,
425            upload_chunk_timeout: self.upload_chunk_timeout,
426            upload_chunk_size: self.upload_chunk_size,
427        })
428    }
429}
430
431impl CacheClient {
432    /// Creates a new [`CacheClientBuilder`].
433    ///
434    /// See [`CacheClientBuilder::new`].
435    pub fn builder<B: Into<String>, T: Into<String>>(base_url: B, token: T) -> CacheClientBuilder {
436        CacheClientBuilder::new(base_url, token)
437    }
438
439    /// Creates a new [`CacheClientBuilder`] from environmental variables.
440    ///
441    /// See [`CacheClientBuilder::from_env`].
442    pub fn from_env() -> Result<CacheClientBuilder> {
443        CacheClientBuilder::from_env()
444    }
445
446    /// Gets the GitHub Actions cache API base URL.
447    ///
448    /// See [`CacheClientBuilder::base_url`].
449    pub fn base_url(&self) -> &str {
450        let base_url = self.base_url.as_str();
451        &base_url[..base_url.len() - BASE_URL_PATH.len()]
452    }
453
454    /// Gets the cache key to write.
455    ///
456    /// See [`CacheClientBuilder::cache_to`].
457    pub fn cache_to(&self) -> Option<&str> {
458        self.cache_to.as_deref()
459    }
460
461    /// Gets the cache key prefixes to read.
462    ///
463    /// See [`CacheClientBuilder::cache_from`].
464    pub fn cache_from(&self) -> Option<&str> {
465        self.cache_from.as_deref()
466    }
467
468    /// Gets the cache entry identified by the given `version`.
469    #[instrument(skip(self))]
470    pub async fn entry(&self, version: &str) -> Result<Option<ArtifactCacheEntry>> {
471        let cache_from = if let Some(cache_from) = self.cache_from.as_ref() {
472            cache_from
473        } else {
474            return Ok(None);
475        };
476
477        let query = serde_urlencoded::to_string(&CacheQuery {
478            keys: cache_from,
479            version: &get_cache_version(version),
480        })?;
481
482        let mut url = self.base_url.join("cache")?;
483        url.set_query(Some(&query));
484
485        let response = self
486            .client
487            .get(url)
488            .headers(self.api_headers.clone())
489            .send()
490            .await?;
491        let status = response.status();
492        if status == http::StatusCode::NO_CONTENT {
493            return Ok(None);
494        };
495        if !status.is_success() {
496            let message = response.text().await.unwrap_or_else(|err| err.to_string());
497            return Err(Error::CacheServiceStatus { status, message });
498        }
499
500        let cache_result: ArtifactCacheEntry = response.json().await?;
501        debug!("Cache Result: {}", serde_json::to_string(&cache_result)?);
502
503        if let Some(cache_download_url) = cache_result.archive_location.as_ref() {
504            println!(
505                "::add-mask::{}",
506                shell_escape::escape(cache_download_url.into())
507            );
508        } else {
509            return Err(Error::CacheNotFound);
510        }
511
512        Ok(Some(cache_result))
513    }
514
515    /// Gets the cache archive as a byte array.
516    #[instrument(skip(self))]
517    pub async fn get(&self, url: &str) -> Result<Vec<u8>> {
518        let uri = Url::parse(url)?;
519
520        let (data, cache_size) = self.download_first_chunk(uri.clone()).await?;
521
522        if cache_size.is_none() {
523            return Ok(data.to_vec());
524        }
525
526        if let Some(ContentRange(ContentRangeSpec::Bytes {
527            instance_length: Some(cache_size),
528            ..
529        })) = cache_size
530        {
531            let actual_size = data.len() as u64;
532            if actual_size == cache_size {
533                return Ok(data.to_vec());
534            }
535            if actual_size != self.download_chunk_size {
536                return Err(Error::CacheChunkSize {
537                    expected_size: self.download_chunk_size as usize,
538                    actual_size: actual_size as usize,
539                    message: "verifying the first chunk size using the content-range header",
540                });
541            }
542
543            // Download chunks in parallel
544            if cache_size as usize
545                <= self.download_chunk_size as usize * self.download_concurrency as usize
546            {
547                let mut chunks = Vec::new();
548                let mut start = self.download_chunk_size;
549                while start < cache_size {
550                    let chunk_size = u64::min(cache_size, self.download_chunk_size);
551                    let uri = uri.clone();
552                    chunks.push(self.download_chunk(uri, start, chunk_size));
553                    start += self.download_chunk_size;
554                }
555
556                let mut chunks = future::try_join_all(chunks.into_iter()).await?;
557                chunks.insert(0, data);
558
559                return Ok(chunks.concat());
560            }
561
562            // Download chunks with max concurrency
563            let permit = Arc::new(Semaphore::new(self.download_concurrency as usize));
564
565            let mut chunks = Vec::new();
566            let mut start = self.download_chunk_size;
567            while start < cache_size {
568                let chunk_size = u64::min(cache_size, self.download_chunk_size);
569                let uri = uri.clone();
570                let permit = permit.clone();
571
572                chunks.push(async move {
573                    let _guard = permit.acquire().await;
574                    self.download_chunk(uri, start, chunk_size).await
575                });
576
577                start += self.upload_chunk_size;
578            }
579
580            let mut chunks = future::try_join_all(chunks).await?;
581            chunks.insert(0, data);
582
583            return Ok(chunks.concat());
584        }
585
586        debug!("Unable to validate download, no Content-Range header or unknown size");
587
588        let actual_size = data.len() as u64;
589        if actual_size < self.download_chunk_size {
590            return Ok(data.to_vec());
591        }
592        if actual_size != self.download_chunk_size {
593            return Err(Error::CacheChunkSize {
594                expected_size: self.download_chunk_size as usize,
595                actual_size: actual_size as usize,
596                message: "verifying the first chunk size without the content-range header",
597            });
598        }
599
600        let mut start = self.download_chunk_size;
601        let mut chunks = vec![data];
602        loop {
603            let chunk = self
604                .download_chunk(uri.clone(), start, self.download_chunk_size)
605                .await?;
606            if chunk.is_empty() {
607                break;
608            }
609
610            let chunk_size = chunk.len() as u64;
611            chunks.push(chunk);
612
613            if chunk_size < self.download_chunk_size {
614                break;
615            }
616            if chunk_size != self.download_chunk_size {
617                return Err(Error::CacheChunkSize {
618                    expected_size: self.download_chunk_size as usize,
619                    actual_size: chunk_size as usize,
620                    message: "verifying a chunk size without the content-range header",
621                });
622            }
623
624            start += self.download_chunk_size;
625        }
626
627        Ok(chunks.concat())
628    }
629
630    #[instrument(skip(self, uri))]
631    async fn download_first_chunk(&self, uri: Url) -> Result<(Bytes, Option<ContentRange>)> {
632        self.do_download_chunk(uri, 0, self.download_chunk_size, true)
633            .await
634    }
635
636    #[instrument(skip_all, fields(uri, start, size))]
637    async fn download_chunk(&self, uri: Url, start: u64, size: u64) -> Result<Bytes> {
638        let (bytes, _) = self.do_download_chunk(uri, start, size, false).await?;
639        Ok(bytes)
640    }
641
642    #[instrument(skip(self, uri))]
643    async fn do_download_chunk(
644        &self,
645        uri: Url,
646        start: u64,
647        size: u64,
648        expect_partial: bool,
649    ) -> Result<(Bytes, Option<ContentRange>)> {
650        let range = format!("bytes={start}-{}", start + size - 1);
651
652        let response = self
653            .client
654            .get(uri)
655            .header(header::RANGE, HeaderValue::from_str(&range)?)
656            .header(
657                HeaderName::from_static("x-ms-range-get-content-md5"),
658                HeaderValue::from_static("true"),
659            )
660            .timeout(self.download_chunk_timeout)
661            .send()
662            .await?;
663
664        let status = response.status();
665        let partial_content = expect_partial && status == StatusCode::PARTIAL_CONTENT;
666        if !status.is_success() {
667            let message = response.text().await.unwrap_or_else(|err| err.to_string());
668            return Err(Error::CacheServiceStatus { status, message });
669        }
670
671        let content_length = response.content_length();
672        let headers = response.headers();
673
674        let content_range = if partial_content {
675            headers
676                .get(header::CONTENT_RANGE)
677                .and_then(|v| ContentRange::parse_header(&v).ok())
678        } else {
679            Some(ContentRange(ContentRangeSpec::Bytes {
680                range: None,
681                instance_length: content_length,
682            }))
683        };
684
685        let md5sum = response
686            .headers()
687            .get(HeaderName::from_static("content-md5"))
688            .and_then(|v| v.to_str().ok())
689            .and_then(|s| hex::decode(s).ok());
690
691        let bytes = response.bytes().await?;
692        let actual_size = bytes.len() as u64;
693        if actual_size != content_length.unwrap_or(actual_size) || actual_size > size {
694            return Err(Error::CacheChunkSize {
695                expected_size: size as usize,
696                actual_size: bytes.len(),
697                message: if expect_partial {
698                    "downloading a chunk"
699                } else {
700                    "downloading the first chunk"
701                },
702            });
703        }
704
705        if let Some(md5sum) = md5sum {
706            use md5::Digest as _;
707            let checksum = md5::Md5::digest(&bytes);
708            if md5sum[..] != checksum[..] {
709                return Err(Error::CacheChunkChecksum);
710            }
711        }
712
713        Ok((bytes, content_range))
714    }
715
716    /// Puts the cache archive as the given `version`.
717    #[instrument(skip(self, data))]
718    pub async fn put<T: Read + Seek>(&self, version: &str, mut data: T) -> Result<()> {
719        let cache_to = if let Some(cache_to) = self.cache_to.as_ref() {
720            cache_to
721        } else {
722            return Ok(());
723        };
724
725        let cache_size = data.seek(SeekFrom::End(0))?;
726        if cache_size > i64::MAX as u64 {
727            return Err(Error::CacheSizeTooLarge(cache_size as usize));
728        }
729
730        let version = &get_cache_version(version);
731        let cache_id = self.reserve(cache_to, version, cache_size).await?;
732
733        if let Some(cache_id) = cache_id {
734            data.rewind()?;
735            self.upload(cache_id, cache_size, data).await?;
736            self.commit(cache_id, cache_size).await?;
737        }
738
739        Ok(())
740    }
741
742    #[instrument(skip(self))]
743    async fn reserve(&self, key: &str, version: &str, cache_size: u64) -> Result<Option<i64>> {
744        let url = self.base_url.join("caches")?;
745
746        let reserve_cache_request = ReserveCacheRequest {
747            key,
748            version,
749            cache_size: cache_size as i64,
750        };
751
752        let response = self
753            .client
754            .post(url)
755            .headers(self.api_headers.clone())
756            .json(&reserve_cache_request)
757            .send()
758            .await?;
759
760        let status = response.status();
761        match status {
762            http::StatusCode::NO_CONTENT | http::StatusCode::CONFLICT => {
763                warn!("No cache ID for key {} version {version}: {status:?}", key);
764                return Ok(None);
765            }
766            _ if !status.is_success() => {
767                let message = response.text().await.unwrap_or_else(|err| err.to_string());
768                return Err(Error::CacheServiceStatus { status, message });
769            }
770            _ => {}
771        }
772
773        let ReserveCacheResponse { cache_id } = response.json().await?;
774        Ok(Some(cache_id))
775    }
776
777    #[instrument(skip(self, data))]
778    async fn upload<T: Read + Seek>(
779        &self,
780        cache_id: i64,
781        cache_size: u64,
782        mut data: T,
783    ) -> Result<()> {
784        let uri = self.base_url.join(&format!("caches/{cache_id}"))?;
785
786        // Upload all data
787        if cache_size <= self.upload_chunk_size {
788            let mut buf = Vec::new();
789            let _ = data.read_to_end(&mut buf)?;
790            return self.upload_chunk(uri, buf, 0, cache_size).await;
791        }
792
793        // Upload chunks in parallel
794        if cache_size as usize <= self.upload_chunk_size as usize * self.upload_concurrency as usize
795        {
796            let mut chunks = Vec::new();
797            let mut start = 0;
798            while start < cache_size {
799                let mut chunk = Vec::new();
800                let chunk_size = u64::min(cache_size, self.upload_chunk_size);
801                let _ = (&mut data).take(chunk_size).read_to_end(&mut chunk)?;
802                chunks.push(self.upload_chunk(uri.clone(), chunk, start, chunk_size));
803                start += self.upload_chunk_size;
804            }
805
806            let _ = future::try_join_all(chunks).await?;
807
808            return Ok(());
809        }
810
811        // Upload chunks with max concurrency
812        let data = Arc::new(Mutex::new(data));
813        let permit = Arc::new(Semaphore::new(self.upload_concurrency as usize));
814
815        let mut chunks = Vec::new();
816        let mut start = 0;
817        while start < cache_size {
818            let chunk_size = u64::min(cache_size, self.upload_chunk_size);
819            let uri = uri.clone();
820            let data = data.clone();
821            let permit = permit.clone();
822
823            chunks.push(async move {
824                let _guard = permit.acquire().await;
825
826                let mut data = data.lock().await;
827                let data = data.deref_mut();
828
829                let mut chunk = Vec::new();
830                let _ = data.seek(SeekFrom::Start(start))?;
831                let _ = data.take(chunk_size).read_to_end(&mut chunk)?;
832
833                self.upload_chunk(uri, chunk, start, chunk_size).await
834            });
835
836            start += self.upload_chunk_size;
837        }
838
839        let _ = future::try_join_all(chunks).await?;
840
841        Ok(())
842    }
843
844    #[instrument(skip(self, uri, body))]
845    async fn upload_chunk<T: Into<Body>>(
846        &self,
847        uri: Url,
848        body: T,
849        start: u64,
850        size: u64,
851    ) -> Result<()> {
852        let content_range = format!("bytes {start}-{}/*", start + size - 1);
853
854        let response = self
855            .client
856            .patch(uri)
857            .headers(self.api_headers.clone())
858            .header(
859                header::CONTENT_TYPE,
860                HeaderValue::from_static("application/octet-stream"),
861            )
862            .header(
863                header::CONTENT_RANGE,
864                HeaderValue::from_str(&content_range)?,
865            )
866            .body(body)
867            .timeout(self.upload_chunk_timeout)
868            .send()
869            .await?;
870
871        let status = response.status();
872        if status.is_success() {
873            Ok(())
874        } else {
875            let message = response.text().await.unwrap_or_else(|err| err.to_string());
876            Err(Error::CacheServiceStatus { status, message })
877        }
878    }
879
880    #[instrument(skip(self))]
881    async fn commit(&self, cache_id: i64, cache_size: u64) -> Result<()> {
882        let url = self.base_url.join(&format!("caches/{cache_id}"))?;
883        let commit_cache_request = CommitCacheRequest {
884            size: cache_size as i64,
885        };
886
887        let response = self
888            .client
889            .post(url)
890            .headers(self.api_headers.clone())
891            .json(&commit_cache_request)
892            .send()
893            .await?;
894
895        let status = response.status();
896        if status.is_success() {
897            Ok(())
898        } else {
899            let message = response.text().await.unwrap_or_else(|err| err.to_string());
900            Err(Error::CacheServiceStatus { status, message })
901        }
902    }
903}
904
905fn get_cache_version(version: &str) -> String {
906    let mut hasher = Sha256::new();
907
908    hasher.update(version);
909    hasher.update("|");
910
911    // Add salt to cache version to support breaking changes in cache entry
912    hasher.update(env!("CARGO_PKG_VERSION_MAJOR"));
913    hasher.update(".");
914    hasher.update(env!("CARGO_PKG_VERSION_MINOR"));
915
916    let result = hasher.finalize();
917    hex::encode(&result[..])
918}
919
920pub fn check_key(key: &str) -> Result<()> {
921    if key.len() > 512 {
922        return Err(Error::InvalidKeyLength(key.to_string()));
923    }
924    if key.chars().any(|c| c == ',') {
925        return Err(Error::InvalidKeyComma(key.to_string()));
926    }
927    Ok(())
928}