1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use std::collections::HashMap;
use std::path::Path;
use candle_core::{Device, Tensor};

pub fn load_model(path: &Path) -> candle_core::Result<HashMap<String, Tensor>> {
    let tensors = match path.extension() {
        Some(s) if s.eq("pt") => {
            candle_core::pickle::read_all_with_key(path, None)?.into_iter().collect()
        }
        Some(s) if s.eq("ckpt") => {
            unimplemented!("loading `ckpt` checkpoints is not yet supported")
        }
        _ => {
            candle_core::safetensors::load(path, &Device::Cpu)?
        }
    };
    Ok(tensors)
}