use crate::error::{Error, Result};
use crate::{DType, Source};
use std::ffi::CStr;
pub struct Registry {
inner: hmll_sys::hmll_registry,
}
impl Registry {
pub fn from_safetensors(source: &Source) -> Result<Self> {
let mut ctx: hmll_sys::hmll = unsafe { std::mem::zeroed() };
let mut inner: hmll_sys::hmll_registry = unsafe { std::mem::zeroed() };
unsafe {
hmll_sys::hmll_safetensors_populate_registry(
&mut ctx,
&mut inner,
*source.as_raw(),
0,
0,
);
}
Error::check_hmll_error(ctx.error)?;
if inner.num_tensors == 0 {
return Err(Error::TableEmpty);
}
Ok(Self { inner })
}
pub fn from_sharded_safetensors(index: &Source, shards: &[&Source]) -> Result<Self> {
let mut ctx: hmll_sys::hmll = unsafe { std::mem::zeroed() };
let mut inner: hmll_sys::hmll_registry = unsafe { std::mem::zeroed() };
let num_files =
unsafe { hmll_sys::hmll_safetensors_index(&mut ctx, &mut inner, *index.as_raw()) };
Error::check_hmll_error(ctx.error)?;
if num_files == 0 {
return Err(Error::TableEmpty);
}
if shards.len() != num_files {
unsafe { hmll_sys::hmll_free_registry(&mut inner) };
return Err(Error::InvalidRange);
}
let mut offset = 0;
for (fid, shard) in shards.iter().enumerate() {
let count = unsafe {
hmll_sys::hmll_safetensors_populate_registry(
&mut ctx,
&mut inner,
*shard.as_raw(),
fid,
offset,
)
};
Error::check_hmll_error(ctx.error)?;
offset += count;
}
if inner.num_tensors == 0 {
return Err(Error::TableEmpty);
}
Ok(Self { inner })
}
#[inline]
pub fn len(&self) -> usize {
self.inner.num_tensors
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.num_tensors == 0
}
pub fn get(&self, index: usize) -> Option<TensorInfo<'_>> {
if index >= self.inner.num_tensors {
return None;
}
unsafe {
let specs = &*self.inner.tensors.add(index);
let name_ptr = *self.inner.names.add(index);
let name = if name_ptr.is_null() {
""
} else {
CStr::from_ptr(name_ptr).to_str().unwrap_or("")
};
let source_index = if self.inner.indexes.is_null() {
0
} else {
*self.inner.indexes.add(index) as usize
};
Some(TensorInfo {
name,
dtype: DType::from_raw(specs.dtype),
shape: &specs.shape[..specs.rank as usize],
start: specs.start,
end: specs.end,
source_index,
})
}
}
pub fn iter(&self) -> impl Iterator<Item = TensorInfo<'_>> {
(0..self.len()).filter_map(|i| self.get(i))
}
}
impl Drop for Registry {
fn drop(&mut self) {
unsafe {
hmll_sys::hmll_free_registry(&mut self.inner);
}
}
}
#[derive(Debug, Clone)]
pub struct TensorInfo<'a> {
pub name: &'a str,
pub dtype: DType,
pub shape: &'a [usize],
pub start: usize,
pub end: usize,
pub source_index: usize,
}
impl TensorInfo<'_> {
#[inline]
pub fn size_bytes(&self) -> usize {
self.end.saturating_sub(self.start)
}
#[inline]
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
}