use anyhow::{Context, Result};
use candle_core::{DType, Device, Tensor};
use std::collections::HashMap;
use std::path::Path;
pub(crate) fn load_tensors_to_cpu(paths: &[impl AsRef<Path>]) -> Result<HashMap<String, Tensor>> {
let mut combined: HashMap<String, Tensor> = HashMap::new();
for path in paths {
let map = candle_core::safetensors::load(path.as_ref(), &Device::Cpu)
.with_context(|| format!("failed to park-load {}", path.as_ref().display()))?;
combined.extend(map);
}
Ok(combined)
}
pub(crate) fn varbuilder_from_parked<'a>(
parked: &HashMap<String, Tensor>,
dtype: DType,
target_device: &Device,
) -> candle_nn::VarBuilder<'a> {
candle_nn::VarBuilder::from_tensors(parked.clone(), dtype, target_device)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_nn::{Linear, Module, VarBuilder};
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap as StdHashMap;
fn temp_safetensors(name: &str, kvs: &[(&str, Vec<f32>, Vec<usize>)]) -> std::path::PathBuf {
let mut path = std::env::temp_dir();
path.push(format!(
"mold-park-{}-{}-{}.safetensors",
name,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let mut bufs: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
for (k, v, shape) in kvs {
let mut bytes = Vec::with_capacity(v.len() * 4);
for f in v {
bytes.extend_from_slice(&f.to_le_bytes());
}
bufs.push(((*k).to_string(), bytes, shape.clone()));
}
let mut tensors: StdHashMap<String, TensorView> = StdHashMap::new();
for (k, b, shape) in &bufs {
tensors.insert(
k.clone(),
TensorView::new(SafeDtype::F32, shape.clone(), b).unwrap(),
);
}
serialize_to_file(&tensors, &None, &path).unwrap();
path
}
#[test]
fn load_tensors_to_cpu_returns_owned_cpu_tensors() {
let path = temp_safetensors(
"load",
&[
("weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]),
("bias", vec![5.0, 6.0], vec![2]),
],
);
let map = load_tensors_to_cpu(std::slice::from_ref(&path)).unwrap();
assert_eq!(map.len(), 2);
let w = map.get("weight").unwrap();
assert_eq!(w.shape().dims(), &[2, 2]);
assert!(w.device().is_cpu());
assert_eq!(w.dtype(), DType::F32);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_park_unpark_roundtrip_linear() {
let path = temp_safetensors(
"linear",
&[
("weight", vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]),
("bias", vec![0.5, -0.5], vec![2]),
],
);
let vb_orig = unsafe {
VarBuilder::from_mmaped_safetensors(&[&path], DType::F32, &Device::Cpu).unwrap()
};
let lin_orig = Linear::new(
vb_orig.get((2, 2), "weight").unwrap(),
Some(vb_orig.get(2, "bias").unwrap()),
);
let parked = load_tensors_to_cpu(std::slice::from_ref(&path)).unwrap();
let vb_unpark = varbuilder_from_parked(&parked, DType::F32, &Device::Cpu);
let lin_new = Linear::new(
vb_unpark.get((2, 2), "weight").unwrap(),
Some(vb_unpark.get(2, "bias").unwrap()),
);
let x = Tensor::from_slice(&[1.0f32, 2.0], (1, 2), &Device::Cpu).unwrap();
let y_orig = lin_orig.forward(&x).unwrap();
let y_new = lin_new.forward(&x).unwrap();
let v_orig: Vec<f32> = y_orig.flatten_all().unwrap().to_vec1().unwrap();
let v_new: Vec<f32> = y_new.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(
v_orig, v_new,
"park→unpark must be bit-identical (same dtype, same device, no lossy ops)"
);
let vb_again = varbuilder_from_parked(&parked, DType::F32, &Device::Cpu);
let _ = vb_again.get((2, 2), "weight").unwrap();
let _ = std::fs::remove_file(&path);
}
#[test]
fn varbuilder_from_parked_errors_on_missing_tensor() {
let parked: HashMap<String, Tensor> = HashMap::new();
let vb = varbuilder_from_parked(&parked, DType::F32, &Device::Cpu);
let err = vb.get((2, 2), "weight").unwrap_err();
assert!(
err.to_string().contains("weight") || err.to_string().contains("Cannot find"),
"expected missing-tensor error, got: {err}"
);
}
}