kalosm_common/
cache.rs

1use hf_hub::{Repo, RepoType};
2use httpdate::parse_http_date;
3use kalosm_model_types::{FileLoadingProgress, FileSource};
4use reqwest::{
5    header::{HeaderValue, CONTENT_LENGTH, LAST_MODIFIED, RANGE},
6    IntoUrl,
7};
8use reqwest::{Response, StatusCode};
9use std::path::PathBuf;
10use std::str::FromStr;
11use tokio::fs::{File, OpenOptions};
12use tokio::io::AsyncWriteExt;
13
14#[derive(Debug, thiserror::Error)]
15pub enum CacheError {
16    #[error("Hugging Face API error: {0}")]
17    HuggingFaceApi(#[from] hf_hub::api::sync::ApiError),
18    #[error("Unable to get file metadata for {0}: {1}")]
19    UnableToGetFileMetadata(PathBuf, #[source] tokio::io::Error),
20    #[error("IO error: {0}")]
21    Io(#[from] std::io::Error),
22    #[error("HTTP error: {0}")]
23    Http(#[from] reqwest::Error),
24    #[error("Unexpected status code: {0}")]
25    UnexpectedStatusCode(StatusCode),
26}
27
28#[derive(Debug, Clone)]
29pub struct Cache {
30    location: PathBuf,
31    /// The huggingface token to use (defaults to the token set with `huggingface-cli login`)
32    huggingface_token: Option<String>,
33}
34
35impl Cache {
36    /// Create a new cache with a specific location
37    pub fn new(location: PathBuf) -> Self {
38        Self {
39            location,
40            huggingface_token: None,
41        }
42    }
43
44    /// Set the Hugging Face token to use for downloading (defaults to the token set with `huggingface-cli login`, and then the environment variable `HF_TOKEN`)
45    pub fn with_huggingface_token(mut self, token: Option<String>) -> Self {
46        self.huggingface_token = token;
47        self
48    }
49
50    /// Check if the file exists locally (if it is a local file or if it has been downloaded)
51    pub fn exists(&self, source: &FileSource) -> bool {
52        match source {
53            FileSource::HuggingFace {
54                model_id,
55                revision,
56                file,
57                ..
58            } => {
59                let path = self.location.join(model_id).join(revision);
60                let complete_download = path.join(file);
61                complete_download.exists()
62            }
63            FileSource::Local(path) => path.exists(),
64        }
65    }
66
67    /// Get the file from the cache, downloading it if necessary
68    pub async fn get(
69        &self,
70        source: &FileSource,
71        progress: impl FnMut(FileLoadingProgress),
72    ) -> Result<PathBuf, CacheError> {
73        match source {
74            FileSource::HuggingFace {
75                model_id,
76                revision,
77                file,
78            } => {
79                let token = self.huggingface_token.clone().or_else(huggingface_token);
80
81                let path = self.location.join(model_id).join(revision);
82                let complete_download = path.join(file);
83
84                let repo = Repo::with_revision(
85                    model_id.to_string(),
86                    RepoType::Model,
87                    revision.to_string(),
88                );
89                let api = hf_hub::api::sync::Api::new()?.repo(repo);
90                let url = api.url(file);
91                let client = reqwest::Client::new();
92                tracing::trace!("Fetching metadata for {file} from {url}");
93                let response = client
94                    .head(&url)
95                    .with_authorization_header(token.clone())
96                    .send()
97                    .await;
98
99                if complete_download.exists() {
100                    let metadata = tokio::fs::metadata(&complete_download).await.map_err(|e| {
101                        CacheError::UnableToGetFileMetadata(complete_download.clone(), e)
102                    })?;
103                    let file_last_modified = metadata.modified()?;
104                    // If the server says the file hasn't been modified since we downloaded it, we can use the local file
105                    if let Some(last_updated) = response
106                        .as_ref()
107                        .ok()
108                        .and_then(|response| response.headers().get(LAST_MODIFIED))
109                        .and_then(|last_updated| last_updated.to_str().ok())
110                        .and_then(|s| parse_http_date(s).ok())
111                    {
112                        if last_updated <= file_last_modified {
113                            return Ok(complete_download);
114                        }
115                    } else {
116                        // Or if we are offline, we can use the local file
117                        return Ok(complete_download);
118                    }
119                }
120                let incomplete_download = path.join(format!("{}.partial", file));
121
122                tracing::trace!("Downloading into {:?}", incomplete_download);
123
124                download_into(
125                    url,
126                    &incomplete_download,
127                    response?,
128                    client,
129                    token,
130                    progress,
131                )
132                .await?;
133
134                // Rename the file to remove the .partial extension
135                tokio::fs::rename(&incomplete_download, &complete_download).await?;
136
137                Ok(complete_download)
138            }
139            FileSource::Local(path) => Ok(path.clone()),
140        }
141    }
142}
143
144impl Default for Cache {
145    fn default() -> Self {
146        Self {
147            location: dirs::data_dir().unwrap().join("kalosm").join("cache"),
148            huggingface_token: None,
149        }
150    }
151}
152
153async fn download_into<U: IntoUrl>(
154    url: U,
155    file: &PathBuf,
156    head: Response,
157    client: reqwest::Client,
158    token: Option<String>,
159    mut progress: impl FnMut(FileLoadingProgress),
160) -> Result<(), CacheError> {
161    let length = head
162        .headers()
163        .get(CONTENT_LENGTH)
164        .ok_or("response doesn't include the content length")
165        .unwrap();
166    let length = length.to_str().ok().and_then(|s| u64::from_str(s).ok());
167
168    let (start, mut output_file) = if let Ok(metadata) = tokio::fs::metadata(file).await {
169        let start = metadata.len();
170        let output_file = OpenOptions::new().append(true).open(file).await.unwrap();
171        (start, output_file)
172    } else {
173        tokio::fs::create_dir_all(file.parent().unwrap()).await?;
174        (0, File::create(file).await.unwrap())
175    };
176
177    if let Some(length) = length {
178        progress(FileLoadingProgress {
179            progress: start,
180            cached_size: start,
181            size: length,
182            start_time: std::time::Instant::now(),
183        });
184    }
185
186    if Some(start) == length {
187        tracing::trace!("File {} already downloaded", file.display());
188        progress(FileLoadingProgress {
189            progress: start,
190            cached_size: start,
191            size: length.unwrap_or(0),
192            start_time: std::time::Instant::now(),
193        });
194        return Ok(());
195    }
196
197    let range = length
198        .and_then(|length| HeaderValue::from_str(&format!("bytes={}-{}", start, length - 1)).ok());
199
200    tracing::trace!("Fetching range {:?}", range);
201    let mut request = client.get(url).with_authorization_header(token);
202    if let Some(range) = range {
203        request = request.header(RANGE, range);
204    }
205    let mut response = request.send().await?;
206
207    let status = response.status();
208    if !(status == StatusCode::OK || status == StatusCode::PARTIAL_CONTENT) {
209        return Err(CacheError::UnexpectedStatusCode(status));
210    }
211
212    let mut current_progress = start;
213
214    while let Some(chunk) = response.chunk().await? {
215        output_file.write_all(&chunk).await?;
216        tracing::trace!("wrote chunk of size {}", chunk.len());
217        current_progress += chunk.len() as u64;
218        if let Some(length) = length {
219            progress(FileLoadingProgress {
220                progress: current_progress,
221                cached_size: start,
222                size: length,
223                start_time: std::time::Instant::now(),
224            });
225        }
226    }
227
228    tracing::trace!("Download of {} complete", file.display());
229
230    Ok(())
231}
232
233trait RequestBuilderExt {
234    fn with_authorization_header(self, token: Option<String>) -> Self;
235}
236
237impl RequestBuilderExt for reqwest::RequestBuilder {
238    fn with_authorization_header(self, token: Option<String>) -> Self {
239        if let Some(token) = token {
240            self.header(reqwest::header::AUTHORIZATION, format!("Bearer {token}"))
241        } else {
242            self
243        }
244    }
245}
246
247#[cfg(test)]
248#[tokio::test]
249async fn downloads_work() {
250    let url = "https://httpbin.org/range/102400?duration=2";
251    let file = PathBuf::from("download.bin");
252    let progress = |p| {
253        println!("Progress: {:?}", p);
254    };
255    let client = reqwest::Client::new();
256    let response = client.head(url).send().await.unwrap();
257    download_into(url, &file, response, client, None, progress)
258        .await
259        .unwrap();
260    assert!(file.exists());
261    tokio::fs::remove_file(file).await.unwrap();
262}
263
264fn huggingface_token() -> Option<String> {
265    let cache = hf_hub::Cache::default();
266    cache.token().or_else(|| std::env::var("HF_TOKEN").ok())
267}