kalosm_model_types/
lib.rs1use std::{fmt::Display, path::PathBuf};
4
5#[derive(Clone, Debug)]
7pub enum ModelLoadingProgress {
8 Downloading {
10 source: String,
12 progress: FileLoadingProgress,
13 },
14 Loading {
16 progress: f32,
18 },
19}
20
21#[derive(Clone, Debug)]
23pub struct FileLoadingProgress {
24 pub start_time: std::time::Instant,
26 pub cached_size: u64,
28 pub size: u64,
30 pub progress: u64,
32}
33
34impl ModelLoadingProgress {
35 pub fn downloading(source: String, file_loading_progress: FileLoadingProgress) -> Self {
37 Self::Downloading {
38 source,
39 progress: file_loading_progress,
40 }
41 }
42
43 pub fn downloading_progress(
45 source: String,
46 ) -> impl FnMut(FileLoadingProgress) -> Self + Send + Sync {
47 move |progress| ModelLoadingProgress::downloading(source.clone(), progress)
48 }
49
50 pub fn loading(progress: f32) -> Self {
52 Self::Loading { progress }
53 }
54
55 pub fn progress(&self) -> f32 {
57 match self {
58 Self::Downloading {
59 progress:
60 FileLoadingProgress {
61 progress,
62 size,
63 cached_size,
64 ..
65 },
66 ..
67 } => (*progress - *cached_size) as f32 / *size as f32,
68 Self::Loading { progress } => *progress,
69 }
70 }
71
72 pub fn estimate_time_remaining(&self) -> Option<std::time::Duration> {
74 match self {
75 Self::Downloading {
76 progress: FileLoadingProgress { start_time, .. },
77 ..
78 } => {
79 let elapsed = start_time.elapsed();
80 let progress = self.progress();
81 let remaining = (1. - progress) * elapsed.as_secs_f32();
82 Some(std::time::Duration::from_secs_f32(remaining))
83 }
84 _ => None,
85 }
86 }
87
88 #[cfg(feature = "loading-progress-bar")]
89 pub fn multi_bar_loading_indicator() -> impl FnMut(ModelLoadingProgress) + Send + Sync + 'static
91 {
92 use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
93 use std::collections::HashMap;
94 let m = MultiProgress::new();
95 let sty = ProgressStyle::with_template(
96 "{spinner:.green} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] ({decimal_bytes_per_sec}, ETA {eta})",
97 )
98 .unwrap();
99 let mut progress_bars = HashMap::new();
100
101 move |progress| match progress {
102 Self::Downloading {
103 source,
104 progress:
105 FileLoadingProgress {
106 progress,
107 size,
108 cached_size,
109 ..
110 },
111 ..
112 } => {
113 let progress_bar = progress_bars.entry(source.clone()).or_insert_with(|| {
114 let pb = m.add(ProgressBar::new(size));
115 pb.set_message(format!("Downloading {source}"));
116 pb.set_style(sty.clone());
117 pb.set_position(cached_size);
118 pb
119 });
120
121 progress_bar.set_position(progress);
122 }
123 ModelLoadingProgress::Loading { progress } => {
124 for pb in progress_bars.values_mut() {
125 pb.finish();
126 }
127 let progress = progress * 100.;
128 m.println(format!("Loading {progress:.2}%")).unwrap();
129 }
130 }
131 }
132}
133
134#[derive(Clone, Debug)]
136pub enum FileSource {
137 HuggingFace {
139 model_id: String,
141 revision: String,
143 file: String,
145 },
146 Local(PathBuf),
148}
149
150impl Display for FileSource {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 FileSource::HuggingFace {
154 model_id,
155 revision,
156 file,
157 } => write!(f, "hf://{}/{}/{}", model_id, revision, file),
158 FileSource::Local(path) => write!(f, "{}", path.display()),
159 }
160 }
161}
162
163impl FileSource {
164 pub fn huggingface(
166 model_id: impl ToString,
167 revision: impl ToString,
168 file: impl ToString,
169 ) -> Self {
170 Self::HuggingFace {
171 model_id: model_id.to_string(),
172 revision: revision.to_string(),
173 file: file.to_string(),
174 }
175 }
176
177 pub fn local(path: PathBuf) -> Self {
179 Self::Local(path)
180 }
181}