use std::{fmt::Display, path::PathBuf};
#[derive(Clone, Debug)]
pub enum ModelLoadingProgress {
Downloading {
source: String,
progress: FileLoadingProgress,
},
Loading {
progress: f32,
},
}
#[derive(Clone, Debug)]
pub struct FileLoadingProgress {
pub start_time: std::time::Instant,
pub cached_size: u64,
pub size: u64,
pub progress: u64,
}
impl ModelLoadingProgress {
pub fn downloading(source: String, file_loading_progress: FileLoadingProgress) -> Self {
Self::Downloading {
source,
progress: file_loading_progress,
}
}
pub fn downloading_progress(
source: String,
) -> impl FnMut(FileLoadingProgress) -> Self + Send + Sync {
move |progress| ModelLoadingProgress::downloading(source.clone(), progress)
}
pub fn loading(progress: f32) -> Self {
Self::Loading { progress }
}
pub fn progress(&self) -> f32 {
match self {
Self::Downloading {
progress:
FileLoadingProgress {
progress,
size,
cached_size,
..
},
..
} => (*progress - *cached_size) as f32 / *size as f32,
Self::Loading { progress } => *progress,
}
}
pub fn estimate_time_remaining(&self) -> Option<std::time::Duration> {
match self {
Self::Downloading {
progress: FileLoadingProgress { start_time, .. },
..
} => {
let elapsed = start_time.elapsed();
let progress = self.progress();
let remaining = (1. - progress) * elapsed.as_secs_f32();
Some(std::time::Duration::from_secs_f32(remaining))
}
_ => None,
}
}
#[cfg(feature = "loading-progress-bar")]
pub fn multi_bar_loading_indicator() -> impl FnMut(ModelLoadingProgress) + Send + Sync + 'static
{
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::collections::HashMap;
let m = MultiProgress::new();
let sty = ProgressStyle::with_template(
"{spinner:.green} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] ({decimal_bytes_per_sec}, ETA {eta})",
)
.unwrap();
let mut progress_bars = HashMap::new();
move |progress| match progress {
Self::Downloading {
source,
progress:
FileLoadingProgress {
progress,
size,
cached_size,
..
},
..
} => {
let progress_bar = progress_bars.entry(source.clone()).or_insert_with(|| {
let pb = m.add(ProgressBar::new(size));
pb.set_message(format!("Downloading {source}"));
pb.set_style(sty.clone());
pb.set_position(cached_size);
pb
});
progress_bar.set_position(progress);
}
ModelLoadingProgress::Loading { progress } => {
for pb in progress_bars.values_mut() {
pb.finish();
}
let progress = progress * 100.;
m.println(format!("Loading {progress:.2}%")).unwrap();
}
}
}
}
#[derive(Clone, Debug)]
pub enum FileSource {
HuggingFace {
model_id: String,
revision: String,
file: String,
},
Local(PathBuf),
}
impl Display for FileSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FileSource::HuggingFace {
model_id,
revision,
file,
} => write!(f, "hf://{}/{}/{}", model_id, revision, file),
FileSource::Local(path) => write!(f, "{}", path.display()),
}
}
}
impl FileSource {
pub fn huggingface(
model_id: impl ToString,
revision: impl ToString,
file: impl ToString,
) -> Self {
Self::HuggingFace {
model_id: model_id.to_string(),
revision: revision.to_string(),
file: file.to_string(),
}
}
pub fn local(path: PathBuf) -> Self {
Self::Local(path)
}
}