use crate::error::{Error, Result};
use std::ffi::CString;
use std::path::Path;
#[derive(Debug)]
pub struct Source {
inner: hmll_sys::hmll_source,
path: Option<String>,
}
impl Source {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_ref = path.as_ref();
let path_str = path_ref
.to_str()
.ok_or_else(|| Error::FileNotFound("Invalid UTF-8 in path".to_string()))?;
let c_path = CString::new(path_str)
.map_err(|_| Error::FileNotFound("Path contains null byte".to_string()))?;
let mut source = hmll_sys::hmll_source { fd: -1, size: 0 };
unsafe {
let err = hmll_sys::hmll_source_open(c_path.as_ptr(), &mut source);
Error::check_hmll_error(err)?;
}
Ok(Self {
inner: source,
path: Some(path_str.to_string()),
})
}
#[inline(always)]
pub const fn size(&self) -> usize {
self.inner.size
}
#[cfg(target_family = "unix")]
#[inline(always)]
pub const fn fd(&self) -> i32 {
self.inner.fd
}
#[inline]
pub fn path(&self) -> Option<&str> {
self.path.as_deref()
}
#[inline(always)]
pub const fn as_raw(&self) -> &hmll_sys::hmll_source {
&self.inner
}
#[allow(dead_code)]
#[inline(always)]
pub(crate) unsafe fn into_raw(mut self) -> hmll_sys::hmll_source {
let source = self.inner;
self.inner.fd = -1;
source
}
}
impl Drop for Source {
fn drop(&mut self) {
if self.inner.fd >= 0 {
unsafe {
hmll_sys::hmll_source_close(&self.inner);
}
}
}
}
unsafe impl Send for Source {}
unsafe impl Sync for Source {}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_file(content: &[u8]) -> NamedTempFile {
let mut file = NamedTempFile::new().expect("Failed to create temp file");
file.write_all(content)
.expect("Failed to write test content");
file.flush().expect("Failed to flush");
file
}
#[test]
fn test_source_invalid_path() {
let result = Source::open("/nonexistent/file.safetensors");
assert!(result.is_err());
}
#[test]
fn test_source_null_byte() {
let result = Source::open("file\0name.safetensors");
assert!(result.is_err());
}
#[test]
fn test_source_open_and_size() {
let content = b"Hello, HMLL! This is test data for the integration test.";
let temp_file = create_test_file(content);
let source = Source::open(temp_file.path()).expect("Failed to open source");
assert_eq!(source.size(), content.len());
assert!(source.fd() >= 0);
assert!(source.path().is_some());
}
}