1#![deny(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3
4mod fslock;
5
6pub mod hub;
7pub mod repo;
8pub mod utils;
9
10pub use hub::{ModelsCat, MultiProgressWrapper, Progress, ProgressBarWrapper, ProgressUnit};
11pub use repo::{Repo, RepoType};
12pub use utils::OpsError;
13
14pub fn download_model(repo_id: &str, filename: &str) -> Result<(), OpsError> {
16 ModelsCat::new(Repo::new_model(repo_id)).download(filename)
17}
18
19pub fn download_model_with_progress(
22 repo_id: &str,
23 filename: &str,
24 progress: impl Progress,
25) -> Result<(), OpsError> {
26 ModelsCat::new(Repo::new_model(repo_id)).download_with_progress(filename, progress)
27}
28
29pub fn download_dataset(repo_id: &str, filename: &str) -> Result<(), OpsError> {
31 ModelsCat::new(Repo::new_dataset(repo_id)).download(filename)
32}
33
34pub fn download_dataset_with_progress(
36 repo_id: &str,
37 filename: &str,
38 progress: impl Progress,
39) -> Result<(), OpsError> {
40 ModelsCat::new(Repo::new_dataset(repo_id)).download_with_progress(filename, progress)
41}
42
43pub fn pull_model(repo_id: &str) -> Result<(), OpsError> {
45 ModelsCat::new(Repo::new_model(repo_id)).pull()
46}
47
48pub fn pull_dataset(repo_id: &str) -> Result<(), OpsError> {
50 ModelsCat::new(Repo::new_dataset(repo_id)).pull()
51}
52
53pub fn remove_model_repo(repo_id: &str) -> Result<(), OpsError> {
55 ModelsCat::new(Repo::new_model(repo_id)).remove_all()
56}
57
58pub fn remove_dataset_repo(repo_id: &str) -> Result<(), OpsError> {
60 ModelsCat::new(Repo::new_dataset(repo_id)).remove_all()
61}
62
63pub fn remove_model_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
65 ModelsCat::new(Repo::new_model(repo_id)).remove(filname)
66}
67
68pub fn remove_dataset_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
70 ModelsCat::new(Repo::new_dataset(repo_id)).remove(filname)
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn test_download_model() {
79 download_model_with_progress(
80 "BAAI/bge-small-zh-v1.5",
81 "model.safetensors",
82 ProgressBarWrapper::default(),
83 )
84 .unwrap();
85 }
86
87 #[test]
88 fn test_cache_dir_env() {
89 unsafe {
90 std::env::set_var("MODELS_CAT_CACHE_DIR", "./test_cache");
91 }
92 download_model_with_progress(
93 "BAAI/bge-small-zh-v1.5",
94 "model.safetensors",
95 ProgressBarWrapper::default(),
96 )
97 .unwrap();
98
99 std::fs::remove_dir_all(std::path::Path::new("./test_cache")).unwrap();
100 }
101}
102
103#[cfg(feature = "tokio")]
105pub mod asynchronous {
106 pub use crate::hub::async_hub::{
107 ModelsCat, MultiProgressWrapper, Progress, ProgressBarWrapper, ProgressUnit,
108 };
109 pub use crate::repo::{Repo, RepoType};
110 pub use crate::utils::OpsError;
111
112 pub async fn download_model(repo_id: &str, filename: &str) -> Result<(), OpsError> {
114 ModelsCat::new(Repo::new_model(repo_id))
115 .download(filename)
116 .await
117 }
118
119 pub async fn download_model_with_progress(
121 repo_id: &str,
122 filename: &str,
123 progress: impl Progress,
124 ) -> Result<(), OpsError> {
125 ModelsCat::new(Repo::new_model(repo_id))
126 .download_with_progress(filename, progress)
127 .await
128 }
129
130 pub async fn download_dataset(repo_id: &str, filename: &str) -> Result<(), OpsError> {
132 ModelsCat::new(Repo::new_dataset(repo_id))
133 .download(filename)
134 .await
135 }
136
137 pub async fn download_dataset_with_progress(
139 repo_id: &str,
140 filename: &str,
141 progress: impl Progress,
142 ) -> Result<(), OpsError> {
143 ModelsCat::new(Repo::new_dataset(repo_id))
144 .download_with_progress(filename, progress)
145 .await
146 }
147
148 pub async fn pull_model(repo_id: &str) -> Result<(), OpsError> {
150 ModelsCat::new(Repo::new_model(repo_id)).pull().await
151 }
152
153 pub async fn pull_dataset(repo_id: &str) -> Result<(), OpsError> {
155 ModelsCat::new(Repo::new_dataset(repo_id)).pull().await
156 }
157
158 pub async fn remove_model_repo(repo_id: &str) -> Result<(), OpsError> {
160 ModelsCat::new(Repo::new_model(repo_id)).remove_all().await
161 }
162
163 pub async fn remove_dataset_repo(repo_id: &str) -> Result<(), OpsError> {
165 ModelsCat::new(Repo::new_dataset(repo_id))
166 .remove_all()
167 .await
168 }
169
170 pub async fn remove_model_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
172 ModelsCat::new(Repo::new_model(repo_id))
173 .remove(filname)
174 .await
175 }
176
177 pub async fn remove_dataset_file(repo_id: &str, filname: &str) -> Result<(), OpsError> {
179 ModelsCat::new(Repo::new_dataset(repo_id))
180 .remove(filname)
181 .await
182 }
183
184 #[cfg(test)]
185 mod tests {
186 use super::*;
187 use tokio::test;
188
189 #[test]
190 async fn test_download_model() {
191 download_model_with_progress(
192 "BAAI/bge-small-zh-v1.5",
193 "model.safetensors",
194 ProgressBarWrapper::default(),
195 )
196 .await
197 .unwrap();
198 }
199 }
200}