Skip to main content

models_cat/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3
4mod fslock;
5
6pub mod hub;
7pub mod repo;
8pub mod utils;
9
10pub use hub::{ModelsCat, MultiProgressWrapper, Progress, ProgressBarWrapper, ProgressUnit};
11pub use repo::{Repo, RepoType};
12pub use utils::OpsError;
13
14/// Shortcut for downloading a model
15pub fn download_model(repo_id: &str, filename: &str) -> Result<(), OpsError> {
16    ModelsCat::new(Repo::new_model(repo_id)).download(filename)
17}
18
19/// Shortcut for downloading a model with progress
20/// The filename including extension and parent directory, such as `models.gguf` or `gguf/models.gguf`.
21pub fn download_model_with_progress(
22    repo_id: &str,
23    filename: &str,
24    progress: impl Progress,
25) -> Result<(), OpsError> {
26    ModelsCat::new(Repo::new_model(repo_id)).download_with_progress(filename, progress)
27}
28
29/// Shortcut for downloading a dataset
30pub fn download_dataset(repo_id: &str, filename: &str) -> Result<(), OpsError> {
31    ModelsCat::new(Repo::new_dataset(repo_id)).download(filename)
32}
33
34/// Shortcut for downloading a dataset with progress
35pub fn download_dataset_with_progress(
36    repo_id: &str,
37    filename: &str,
38    progress: impl Progress,
39) -> Result<(), OpsError> {
40    ModelsCat::new(Repo::new_dataset(repo_id)).download_with_progress(filename, progress)
41}
42
43/// Shortcut pulling a model repo
44pub fn pull_model(repo_id: &str) -> Result<(), OpsError> {
45    ModelsCat::new(Repo::new_model(repo_id)).pull()
46}
47
48/// Shortcut pulling a dataset repo
49pub fn pull_dataset(repo_id: &str) -> Result<(), OpsError> {
50    ModelsCat::new(Repo::new_dataset(repo_id)).pull()
51}
52
53/// Shortcut removing a local model repo
54pub fn remove_model_repo(repo_id: &str) -> Result<(), OpsError> {
55    ModelsCat::new(Repo::new_model(repo_id)).remove_all()
56}
57
58/// Shortcut removing a local dataset repo
59pub fn remove_dataset_repo(repo_id: &str) -> Result<(), OpsError> {
60    ModelsCat::new(Repo::new_dataset(repo_id)).remove_all()
61}
62
63/// Shortcut removing a local model file
64pub fn remove_model_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
65    ModelsCat::new(Repo::new_model(repo_id)).remove(filname)
66}
67
68/// Shortcut removing a local dataset file
69pub fn remove_dataset_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
70    ModelsCat::new(Repo::new_dataset(repo_id)).remove(filname)
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn test_download_model() {
79        download_model_with_progress(
80            "BAAI/bge-small-zh-v1.5",
81            "model.safetensors",
82            ProgressBarWrapper::default(),
83        )
84        .unwrap();
85    }
86
87    #[test]
88    fn test_cache_dir_env() {
89        unsafe {
90            std::env::set_var("MODELS_CAT_CACHE_DIR", "./test_cache");
91        }
92        download_model_with_progress(
93            "BAAI/bge-small-zh-v1.5",
94            "model.safetensors",
95            ProgressBarWrapper::default(),
96        )
97        .unwrap();
98
99        std::fs::remove_dir_all(std::path::Path::new("./test_cache")).unwrap();
100    }
101}
102
103/// The asynchronous module provides a set of asynchronous functions for interacting with model and dataset repositories.
104#[cfg(feature = "tokio")]
105pub mod asynchronous {
106    pub use crate::hub::async_hub::{
107        ModelsCat, MultiProgressWrapper, Progress, ProgressBarWrapper, ProgressUnit,
108    };
109    pub use crate::repo::{Repo, RepoType};
110    pub use crate::utils::OpsError;
111
112    /// Shortcut for downloading a model
113    pub async fn download_model(repo_id: &str, filename: &str) -> Result<(), OpsError> {
114        ModelsCat::new(Repo::new_model(repo_id))
115            .download(filename)
116            .await
117    }
118
119    /// Shortcut for downloading a model with progress
120    pub async fn download_model_with_progress(
121        repo_id: &str,
122        filename: &str,
123        progress: impl Progress,
124    ) -> Result<(), OpsError> {
125        ModelsCat::new(Repo::new_model(repo_id))
126            .download_with_progress(filename, progress)
127            .await
128    }
129
130    /// Shortcut for downloading a dataset
131    pub async fn download_dataset(repo_id: &str, filename: &str) -> Result<(), OpsError> {
132        ModelsCat::new(Repo::new_dataset(repo_id))
133            .download(filename)
134            .await
135    }
136
137    /// Shortcut for downloading a dataset with progress
138    pub async fn download_dataset_with_progress(
139        repo_id: &str,
140        filename: &str,
141        progress: impl Progress,
142    ) -> Result<(), OpsError> {
143        ModelsCat::new(Repo::new_dataset(repo_id))
144            .download_with_progress(filename, progress)
145            .await
146    }
147
148    /// Shortcut pulling a model repo
149    pub async fn pull_model(repo_id: &str) -> Result<(), OpsError> {
150        ModelsCat::new(Repo::new_model(repo_id)).pull().await
151    }
152
153    /// Shortcut pulling a dataset repo
154    pub async fn pull_dataset(repo_id: &str) -> Result<(), OpsError> {
155        ModelsCat::new(Repo::new_dataset(repo_id)).pull().await
156    }
157
158    /// Shortcut removing a local model repo
159    pub async fn remove_model_repo(repo_id: &str) -> Result<(), OpsError> {
160        ModelsCat::new(Repo::new_model(repo_id)).remove_all().await
161    }
162
163    /// Shortcut removing a local dataset repo
164    pub async fn remove_dataset_repo(repo_id: &str) -> Result<(), OpsError> {
165        ModelsCat::new(Repo::new_dataset(repo_id))
166            .remove_all()
167            .await
168    }
169
170    /// Shortcut removing a local model file
171    pub async fn remove_model_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
172        ModelsCat::new(Repo::new_model(repo_id))
173            .remove(filname)
174            .await
175    }
176
177    /// Shortcut removing a local dataset file
178    pub async fn remove_dataset_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
179        ModelsCat::new(Repo::new_dataset(repo_id))
180            .remove(filname)
181            .await
182    }
183
184    #[cfg(test)]
185    mod tests {
186        use super::*;
187        use tokio::test;
188
189        #[test]
190        async fn test_download_model() {
191            download_model_with_progress(
192                "BAAI/bge-small-zh-v1.5",
193                "model.safetensors",
194                ProgressBarWrapper::default(),
195            )
196            .await
197            .unwrap();
198        }
199    }
200}