use super::pickle::PValue;
use super::types::{PtContents, PtTensorMeta};
use crate::error::{Error, Result};
use numr::dtype::DType;
use std::collections::HashMap;
pub(super) fn apply_reduce(func: PValue, args: Vec<PValue>) -> Result<PValue> {
let (module, name) = match &func {
PValue::Global(m, n) => (m.as_str(), n.as_str()),
_ => {
return Err(Error::ModelError {
reason: format!("REDUCE called on non-Global: {func:?}"),
});
}
};
match (module, name) {
("torch._utils", "_rebuild_tensor_v2") | ("torch._utils", "_rebuild_tensor") => {
if args.len() < 3 {
return Err(Error::ModelError {
reason: format!("_rebuild_tensor_v2 expects ≥3 args, got {}", args.len()),
});
}
let storage_ref = args[0].clone();
let storage_offset = as_i64(&args[1])? as usize;
let size = as_usize_tuple(&args[2])?;
let (dtype, storage_id, numel) = match storage_ref {
PValue::PersistentRef {
dtype,
storage_id,
numel,
} => (dtype, storage_id, numel),
other => {
return Err(Error::ModelError {
reason: format!("storage ref is not a persistent id: {other:?}"),
});
}
};
Ok(PValue::Tensor(PtTensorMeta {
dtype,
shape: size,
storage_id,
storage_offset,
storage_numel: numel,
storage_elem_size: dtype.size_in_bytes(),
}))
}
("collections", "OrderedDict") => Ok(PValue::Dict(Vec::new())),
_ => Ok(PValue::Opaque(module.to_string(), name.to_string())),
}
}
pub(super) fn resolve_persistent_id(pid: PValue) -> Result<PValue> {
let items = match pid {
PValue::Tuple(t) => t,
other => {
return Err(Error::ModelError {
reason: format!("persistent_load(pid) expects tuple, got {other:?}"),
});
}
};
if items.len() < 5 {
return Err(Error::ModelError {
reason: format!("storage persistent_id expects 5 items, got {}", items.len()),
});
}
let kind = match &items[0] {
PValue::Str(s) => s.clone(),
_ => {
return Err(Error::ModelError {
reason: "persistent_id kind must be a string".into(),
});
}
};
if kind != "storage" {
return Err(Error::ModelError {
reason: format!("unsupported persistent_id kind: {kind}"),
});
}
let dtype_global = &items[1];
let storage_id = match &items[2] {
PValue::Str(s) => s.clone(),
_ => {
return Err(Error::ModelError {
reason: "storage_id must be a string".into(),
});
}
};
let numel = as_i64(&items[4])? as usize;
let dtype = dtype_from_global(dtype_global)?;
Ok(PValue::PersistentRef {
dtype,
storage_id,
numel,
})
}
pub(super) fn dtype_from_global(v: &PValue) -> Result<DType> {
let (module, name) = match v {
PValue::Global(m, n) => (m.as_str(), n.as_str()),
_ => {
return Err(Error::ModelError {
reason: format!("expected dtype global, got {v:?}"),
});
}
};
if module != "torch" {
return Err(Error::ModelError {
reason: format!("dtype module must be torch, got {module}"),
});
}
match name {
"FloatStorage" | "float32" => Ok(DType::F32),
"DoubleStorage" | "float64" => Ok(DType::F64),
"HalfStorage" | "float16" => Ok(DType::F16),
"BFloat16Storage" | "bfloat16" => Ok(DType::BF16),
"LongStorage" | "int64" => Ok(DType::I64),
"IntStorage" | "int32" => Ok(DType::I32),
"BoolStorage" | "bool" => Ok(DType::Bool),
other => Err(Error::ModelError {
reason: format!("unsupported torch dtype: torch.{other}"),
}),
}
}
pub(super) fn as_i64(v: &PValue) -> Result<i64> {
match v {
PValue::Int(i) => Ok(*i),
PValue::Bool(b) => Ok(*b as i64),
other => Err(Error::ModelError {
reason: format!("expected int, got {other:?}"),
}),
}
}
pub(super) fn as_usize_tuple(v: &PValue) -> Result<Vec<usize>> {
match v {
PValue::Tuple(items) | PValue::List(items) => items
.iter()
.map(|it| as_i64(it).map(|i| i as usize))
.collect(),
other => Err(Error::ModelError {
reason: format!("expected shape tuple, got {other:?}"),
}),
}
}
pub(super) fn finalize_root(root: PValue) -> Result<PtContents> {
let mut raw = HashMap::new();
match root {
PValue::Tensor(meta) => {
raw.insert(String::new(), meta);
}
PValue::Dict(_) => {
flatten_into("", &root, &mut raw);
if raw.is_empty() {
return Err(Error::ModelError {
reason: "dict contained no recognizable tensors".into(),
});
}
}
other => {
return Err(Error::ModelError {
reason: format!("unexpected pickle root: {other:?}"),
});
}
}
let tensors: HashMap<String, PtTensorMeta> = raw
.into_iter()
.map(|(k, v)| (strip_module_segments(&k), v))
.collect();
Ok(PtContents { tensors })
}
fn strip_module_segments(key: &str) -> String {
key.split('.')
.filter(|seg| *seg != "module")
.collect::<Vec<_>>()
.join(".")
}
fn flatten_into(prefix: &str, value: &PValue, out: &mut HashMap<String, PtTensorMeta>) {
match value {
PValue::Tensor(meta) => {
out.insert(prefix.to_string(), meta.clone());
}
PValue::Dict(entries) => {
for (k, v) in entries {
let key = match k {
PValue::Str(s) => s.as_str(),
_ => continue,
};
let next_prefix = match prefix {
"" => key.to_string(),
p => format!("{p}.{key}"),
};
flatten_into(&next_prefix, v, out);
}
}
_ => {
}
}
}