kalosm_model_types/
lib.rs

1//! Common types for Kalosm models
2
3use std::{fmt::Display, path::PathBuf};
4
5/// The progress starting a model
6#[derive(Clone, Debug)]
7pub enum ModelLoadingProgress {
8    /// The model is downloading
9    Downloading {
10        /// The source of the download. This is not a path or URL, but a description of the source
11        source: String,
12        progress: FileLoadingProgress,
13    },
14    /// The model is loading
15    Loading {
16        /// The progress of the loading, from 0 to 1
17        progress: f32,
18    },
19}
20
21/// The progress of a file download
22#[derive(Clone, Debug)]
23pub struct FileLoadingProgress {
24    /// The time stamp the download started
25    pub start_time: std::time::Instant,
26    /// The size of the cached part of the download in bytes
27    pub cached_size: u64,
28    /// The size of the download in bytes
29    pub size: u64,
30    /// The progress of the download in bytes, from 0 to size
31    pub progress: u64,
32}
33
34impl ModelLoadingProgress {
35    /// Create a new downloading progress
36    pub fn downloading(source: String, file_loading_progress: FileLoadingProgress) -> Self {
37        Self::Downloading {
38            source,
39            progress: file_loading_progress,
40        }
41    }
42
43    /// Create a new downloading progress
44    pub fn downloading_progress(
45        source: String,
46    ) -> impl FnMut(FileLoadingProgress) -> Self + Send + Sync {
47        move |progress| ModelLoadingProgress::downloading(source.clone(), progress)
48    }
49
50    /// Create a new loading progress
51    pub fn loading(progress: f32) -> Self {
52        Self::Loading { progress }
53    }
54
55    /// Return the percent complete
56    pub fn progress(&self) -> f32 {
57        match self {
58            Self::Downloading {
59                progress:
60                    FileLoadingProgress {
61                        progress,
62                        size,
63                        cached_size,
64                        ..
65                    },
66                ..
67            } => (*progress - *cached_size) as f32 / *size as f32,
68            Self::Loading { progress } => *progress,
69        }
70    }
71
72    /// Try to estimate the time remaining for a download
73    pub fn estimate_time_remaining(&self) -> Option<std::time::Duration> {
74        match self {
75            Self::Downloading {
76                progress: FileLoadingProgress { start_time, .. },
77                ..
78            } => {
79                let elapsed = start_time.elapsed();
80                let progress = self.progress();
81                let remaining = (1. - progress) * elapsed.as_secs_f32();
82                Some(std::time::Duration::from_secs_f32(remaining))
83            }
84            _ => None,
85        }
86    }
87
88    #[cfg(feature = "loading-progress-bar")]
89    /// A default loading progress bar
90    pub fn multi_bar_loading_indicator() -> impl FnMut(ModelLoadingProgress) + Send + Sync + 'static
91    {
92        use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
93        use std::collections::HashMap;
94        let m = MultiProgress::new();
95        let sty = ProgressStyle::with_template(
96            "{spinner:.green} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] ({decimal_bytes_per_sec}, ETA {eta})",
97        )
98        .unwrap();
99        let mut progress_bars = HashMap::new();
100
101        move |progress| match progress {
102            Self::Downloading {
103                source,
104                progress:
105                    FileLoadingProgress {
106                        progress,
107                        size,
108                        cached_size,
109                        ..
110                    },
111                ..
112            } => {
113                let progress_bar = progress_bars.entry(source.clone()).or_insert_with(|| {
114                    let pb = m.add(ProgressBar::new(size));
115                    pb.set_message(format!("Downloading {source}"));
116                    pb.set_style(sty.clone());
117                    pb.set_position(cached_size);
118                    pb
119                });
120
121                progress_bar.set_position(progress);
122            }
123            ModelLoadingProgress::Loading { progress } => {
124                for pb in progress_bars.values_mut() {
125                    pb.finish();
126                }
127                let progress = progress * 100.;
128                m.println(format!("Loading {progress:.2}%")).unwrap();
129            }
130        }
131    }
132}
133
134/// A source for a file, either from Hugging Face or a local path
135#[derive(Clone, Debug)]
136pub enum FileSource {
137    /// A file from Hugging Face
138    HuggingFace {
139        /// The model id to use
140        model_id: String,
141        /// The revision to use
142        revision: String,
143        /// The file to use
144        file: String,
145    },
146    /// A local file
147    Local(PathBuf),
148}
149
150impl Display for FileSource {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            FileSource::HuggingFace {
154                model_id,
155                revision,
156                file,
157            } => write!(f, "hf://{}/{}/{}", model_id, revision, file),
158            FileSource::Local(path) => write!(f, "{}", path.display()),
159        }
160    }
161}
162
163impl FileSource {
164    /// Create a new source for a file from Hugging Face
165    pub fn huggingface(
166        model_id: impl ToString,
167        revision: impl ToString,
168        file: impl ToString,
169    ) -> Self {
170        Self::HuggingFace {
171            model_id: model_id.to_string(),
172            revision: revision.to_string(),
173            file: file.to_string(),
174        }
175    }
176
177    /// Create a new source for a local file
178    pub fn local(path: PathBuf) -> Self {
179        Self::Local(path)
180    }
181}