use crate::{Buffer, Device, Error, Range, Result, Source};
use std::marker::PhantomData;
use std::ptr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LoaderKind {
Auto,
#[cfg(target_os = "linux")]
IoUring,
Mmap,
}
impl LoaderKind {
#[inline(always)]
pub(crate) const fn to_raw(self) -> hmll_sys::hmll_loader_kind {
match self {
LoaderKind::Auto => hmll_sys::HMLL_FETCHER_AUTO,
#[cfg(target_os = "linux")]
LoaderKind::IoUring => hmll_sys::HMLL_FETCHER_IO_URING,
LoaderKind::Mmap => hmll_sys::HMLL_FETCHER_MMAP,
}
}
}
impl Default for LoaderKind {
#[inline(always)]
fn default() -> Self {
LoaderKind::Auto
}
}
pub struct WeightLoader<'a> {
context: Box<hmll_sys::hmll>,
sources: Vec<hmll_sys::hmll_source>,
device: Device,
_marker: PhantomData<&'a ()>,
}
impl<'a> WeightLoader<'a> {
pub fn new(sources: &'a [Source], device: Device, kind: LoaderKind) -> Result<Self> {
if sources.is_empty() {
return Err(Error::InvalidRange);
}
let sources_vec: Vec<hmll_sys::hmll_source> = sources.iter().map(|s| *s.as_raw()).collect();
let mut context = Box::new(hmll_sys::hmll {
fetcher: ptr::null_mut(),
sources: ptr::null(),
num_sources: 0,
error: hmll_sys::hmll_error {
code: hmll_sys::HMLL_ERR_SUCCESS,
sys_err: 0,
},
});
unsafe {
let err = hmll_sys::hmll_loader_init(
context.as_mut(),
sources_vec.as_ptr(),
sources_vec.len(),
device.to_raw(),
kind.to_raw(),
);
Error::check_hmll_error(err)?;
}
Ok(Self {
context,
sources: sources_vec,
device,
_marker: PhantomData,
})
}
pub fn fetch<R: Into<Range>>(&mut self, range: R, file_index: i32) -> Result<Buffer> {
let range = range.into();
if file_index >= self.sources.len() as i32 {
return Err(Error::InvalidRange);
}
if range.is_empty() {
return Ok(unsafe { Buffer::from_raw_parts(ptr::null_mut(), 0, self.device, false) });
}
let iobuf = unsafe {
hmll_sys::hmll_get_buffer_for_range(
self.context.as_mut(),
self.device.to_raw(),
range.to_raw(),
)
};
if iobuf.ptr.is_null() {
return Err(Error::AllocationFailed);
}
let res = unsafe {
hmll_sys::hmll_fetch(self.context.as_mut(), file_index, &iobuf, range.to_raw())
};
if res < 0 {
let err = self.context.error;
self.context.error = hmll_sys::hmll_error {
code: hmll_sys::HMLL_ERR_SUCCESS,
sys_err: 0,
};
return Err(Error::from_hmll_error(err));
}
Ok(unsafe { Buffer::from_raw_parts(iobuf.ptr as *mut u8, iobuf.size, self.device, false) })
}
#[inline(always)]
pub const fn device(&self) -> Device {
self.device
}
#[inline(always)]
pub fn num_sources(&self) -> usize {
self.sources.len()
}
#[inline]
pub fn source_info(&self, index: usize) -> Option<SourceInfo> {
if index < self.sources.len() {
Some(SourceInfo {
size: self.sources[index].size,
#[cfg(target_family = "unix")]
fd: self.sources[index].fd,
})
} else {
None
}
}
}
impl<'a> Drop for WeightLoader<'a> {
fn drop(&mut self) {
unsafe {
hmll_sys::hmll_destroy(self.context.as_mut());
}
}
}
unsafe impl<'a> Send for WeightLoader<'a> {}
#[derive(Debug, Clone, Copy)]
pub struct SourceInfo {
pub size: usize,
#[cfg(target_family = "unix")]
pub fd: i32,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_file(content: &[u8]) -> NamedTempFile {
let mut file = NamedTempFile::new().expect("Failed to create temp file");
file.write_all(content)
.expect("Failed to write test content");
file.flush().expect("Failed to flush");
file
}
#[test]
fn test_empty_sources() {
let result = WeightLoader::new(&[], Device::Cpu, LoaderKind::Auto);
assert!(result.is_err());
}
#[test]
fn test_loader_kind_default() {
assert_eq!(LoaderKind::default(), LoaderKind::Auto);
}
#[test]
fn test_device_default() {
assert_eq!(Device::default(), Device::Cpu);
}
#[test]
fn test_loader_creation() {
let content = b"Test file content for loader creation test.";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
assert_eq!(loader.device(), Device::Cpu);
assert_eq!(loader.num_sources(), 1);
let info = loader.source_info(0).expect("Failed to get source info");
assert_eq!(info.size, content.len());
}
#[test]
fn test_fetch_full_file() {
let content = b"This is the complete file content that we want to fetch entirely.";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let buffer = loader
.fetch(0..content.len(), 0)
.expect("Failed to fetch data");
assert_eq!(buffer.len(), content.len());
assert_eq!(buffer.device(), Device::Cpu);
let slice = buffer.as_slice().expect("Failed to get slice");
assert_eq!(slice, content);
}
#[test]
fn test_fetch_partial_range() {
let content = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let buffer = loader
.fetch(10..20, 0)
.expect("Failed to fetch partial data");
assert_eq!(buffer.len(), 10);
let slice = buffer.as_slice().expect("Failed to get slice");
assert_eq!(slice, b"ABCDEFGHIJ");
}
#[test]
fn test_fetch_empty_range() {
let content = b"Some content";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let buffer = loader.fetch(5..5, 0).expect("Failed to fetch empty range");
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_fetch_invalid_file_index() {
let content = b"Test content";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let result = loader.fetch(0..10, 99);
assert!(result.is_err());
}
#[test]
fn test_multiple_sources() {
let content1 = b"First file content here.";
let content2 = b"Second file with different data.";
let content3 = b"Third file completes the set.";
let temp1 = create_test_file(content1);
let temp2 = create_test_file(content2);
let temp3 = create_test_file(content3);
let source1 = Source::open(temp1.path()).expect("Failed to open source 1");
let source2 = Source::open(temp2.path()).expect("Failed to open source 2");
let source3 = Source::open(temp3.path()).expect("Failed to open source 3");
let sources = [source1, source2, source3];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
assert_eq!(loader.num_sources(), 3);
let buf1 = loader
.fetch(0..content1.len(), 0)
.expect("Failed to fetch file 1");
let buf2 = loader
.fetch(0..content2.len(), 1)
.expect("Failed to fetch file 2");
let buf3 = loader
.fetch(0..content3.len(), 2)
.expect("Failed to fetch file 3");
assert_eq!(buf1.as_slice().unwrap(), content1);
assert_eq!(buf2.as_slice().unwrap(), content2);
assert_eq!(buf3.as_slice().unwrap(), content3);
}
#[test]
fn test_large_file() {
let size = 1024 * 1024; let content: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let temp_file = create_test_file(&content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let buffer = loader
.fetch(0..size, 0)
.expect("Failed to fetch large file");
assert_eq!(buffer.len(), size);
let slice = buffer.as_slice().expect("Failed to get slice");
assert_eq!(slice, content.as_slice());
}
#[test]
fn test_source_info() {
let content = b"Source info test content";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let info = loader.source_info(0);
assert!(info.is_some());
assert_eq!(info.unwrap().size, content.len());
let info = loader.source_info(100);
assert!(info.is_none());
}
#[test]
fn test_buffer_to_vec() {
let content = b"Convert me to a Vec!";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)
.expect("Failed to create loader");
let buffer = loader.fetch(0..content.len(), 0).expect("Failed to fetch");
let vec = buffer.to_vec();
assert_eq!(vec, content.to_vec());
}
#[test]
fn test_mmap_loader_kind() {
let content = b"Testing mmap loader backend explicitly.";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
let sources = [source];
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap)
.expect("Failed to create mmap loader");
let buffer = loader
.fetch(0..content.len(), 0)
.expect("Failed to fetch with mmap");
assert_eq!(buffer.as_slice().unwrap(), content);
}
}