use crate::engine::SynaDB;
use crate::error::{Result, SynaError};
use crate::types::Atom;
use serde::{Deserialize, Serialize};
pub const DEFAULT_CHUNK_SIZE: usize = 1024 * 1024;
pub const CHUNK_SIZE_SMALL: usize = 1024 * 1024;
pub const CHUNK_SIZE_MEDIUM: usize = 4 * 1024 * 1024;
pub const CHUNK_SIZE_LARGE: usize = 16 * 1024 * 1024;
pub fn optimal_chunk_size(tensor_bytes: usize) -> usize {
match tensor_bytes {
0..=10_000_000 => CHUNK_SIZE_SMALL, 10_000_001..=100_000_000 => CHUNK_SIZE_MEDIUM, _ => CHUNK_SIZE_LARGE, }
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TensorMeta {
pub shape: Vec<usize>,
pub dtype: String,
pub total_bytes: usize,
pub chunk_count: usize,
pub chunk_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
Float32,
Float64,
Int32,
Int64,
}
impl DType {
pub fn size(&self) -> usize {
match self {
DType::Float32 | DType::Int32 => 4,
DType::Float64 | DType::Int64 => 8,
}
}
pub fn name(&self) -> &'static str {
match self {
DType::Float32 => "float32",
DType::Float64 => "float64",
DType::Int32 => "int32",
DType::Int64 => "int64",
}
}
}
pub struct TensorEngine {
db: SynaDB,
}
impl TensorEngine {
pub fn new(db: SynaDB) -> Self {
Self { db }
}
pub fn db(&self) -> &SynaDB {
&self.db
}
pub fn db_mut(&mut self) -> &mut SynaDB {
&mut self.db
}
pub fn into_db(self) -> SynaDB {
self.db
}
pub fn get_tensor(&mut self, pattern: &str, dtype: DType) -> Result<(Vec<u8>, Vec<usize>)> {
let keys = self.match_keys(pattern);
let mut data = Vec::new();
let mut count = 0;
for key in &keys {
if let Some(atom) = self.db.get(key)? {
if self.append_atom_as_dtype(&mut data, &atom, dtype)? {
count += 1;
}
}
}
let shape = vec![count];
Ok((data, shape))
}
pub fn put_tensor(
&mut self,
key_prefix: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let mut count = 0;
for i in 0..n_elements {
let key = format!("{}{:08}", key_prefix, i);
let start = i * element_size;
let end = start + element_size;
let atom = self.bytes_to_atom(&data[start..end], dtype)?;
self.db.append(&key, atom)?;
count += 1;
}
Ok(count)
}
fn match_keys(&self, pattern: &str) -> Vec<String> {
let keys = self.db.keys();
if let Some(prefix_without_star) = pattern.strip_suffix("/*") {
let prefix = format!("{}/", prefix_without_star);
let mut matched: Vec<String> = keys
.into_iter()
.filter(|k| k.starts_with(&prefix))
.collect();
matched.sort(); matched
} else if let Some(prefix) = pattern.strip_suffix('*') {
let mut matched: Vec<String> =
keys.into_iter().filter(|k| k.starts_with(prefix)).collect();
matched.sort();
matched
} else {
keys.into_iter().filter(|k| k == pattern).collect()
}
}
fn append_atom_as_dtype(&self, data: &mut Vec<u8>, atom: &Atom, dtype: DType) -> Result<bool> {
match (atom, dtype) {
(Atom::Float(f), DType::Float64) => {
data.extend_from_slice(&f.to_le_bytes());
Ok(true)
}
(Atom::Float(f), DType::Float32) => {
data.extend_from_slice(&(*f as f32).to_le_bytes());
Ok(true)
}
(Atom::Float(f), DType::Int64) => {
data.extend_from_slice(&(*f as i64).to_le_bytes());
Ok(true)
}
(Atom::Float(f), DType::Int32) => {
data.extend_from_slice(&(*f as i32).to_le_bytes());
Ok(true)
}
(Atom::Int(i), DType::Int64) => {
data.extend_from_slice(&i.to_le_bytes());
Ok(true)
}
(Atom::Int(i), DType::Int32) => {
data.extend_from_slice(&(*i as i32).to_le_bytes());
Ok(true)
}
(Atom::Int(i), DType::Float64) => {
data.extend_from_slice(&(*i as f64).to_le_bytes());
Ok(true)
}
(Atom::Int(i), DType::Float32) => {
data.extend_from_slice(&(*i as f32).to_le_bytes());
Ok(true)
}
(Atom::Vector(vec, _), DType::Float32) if !vec.is_empty() => {
data.extend_from_slice(&vec[0].to_le_bytes());
Ok(true)
}
(Atom::Vector(vec, _), DType::Float64) if !vec.is_empty() => {
data.extend_from_slice(&(vec[0] as f64).to_le_bytes());
Ok(true)
}
(Atom::Null, _) | (Atom::Text(_), _) | (Atom::Bytes(_), _) => Ok(false),
(Atom::Vector(vec, _), _) if vec.is_empty() => Ok(false),
(Atom::Vector(vec, _), DType::Int32) if !vec.is_empty() => {
data.extend_from_slice(&(vec[0] as i32).to_le_bytes());
Ok(true)
}
(Atom::Vector(vec, _), DType::Int64) if !vec.is_empty() => {
data.extend_from_slice(&(vec[0] as i64).to_le_bytes());
Ok(true)
}
_ => Ok(false),
}
}
fn bytes_to_atom(&self, bytes: &[u8], dtype: DType) -> Result<Atom> {
match dtype {
DType::Float32 => {
let arr: [u8; 4] = bytes.try_into().map_err(|_| SynaError::ShapeMismatch {
data_size: bytes.len(),
expected_size: 4,
})?;
Ok(Atom::Float(f32::from_le_bytes(arr) as f64))
}
DType::Float64 => {
let arr: [u8; 8] = bytes.try_into().map_err(|_| SynaError::ShapeMismatch {
data_size: bytes.len(),
expected_size: 8,
})?;
Ok(Atom::Float(f64::from_le_bytes(arr)))
}
DType::Int32 => {
let arr: [u8; 4] = bytes.try_into().map_err(|_| SynaError::ShapeMismatch {
data_size: bytes.len(),
expected_size: 4,
})?;
Ok(Atom::Int(i32::from_le_bytes(arr) as i64))
}
DType::Int64 => {
let arr: [u8; 8] = bytes.try_into().map_err(|_| SynaError::ShapeMismatch {
data_size: bytes.len(),
expected_size: 8,
})?;
Ok(Atom::Int(i64::from_le_bytes(arr)))
}
}
}
pub fn put_tensor_chunked(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
self.put_tensor_chunked_with_size(name, data, shape, dtype, DEFAULT_CHUNK_SIZE)
}
pub fn put_tensor_chunked_with_size(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
chunk_size: usize,
) -> Result<usize> {
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let chunk_count = data.len().div_ceil(chunk_size);
let meta = TensorMeta {
shape: shape.to_vec(),
dtype: dtype.name().to_string(),
total_bytes: data.len(),
chunk_count,
chunk_size,
};
let meta_json = serde_json::to_string(&meta)
.map_err(|e| SynaError::InvalidPath(format!("Failed to serialize metadata: {}", e)))?;
let meta_key = format!("{}/meta", name);
self.db.append(&meta_key, Atom::Text(meta_json))?;
for (i, chunk_data) in data.chunks(chunk_size).enumerate() {
let chunk_key = format!("{}/chunk/{}", name, i);
self.db
.append(&chunk_key, Atom::Bytes(chunk_data.to_vec()))?;
}
Ok(chunk_count)
}
pub fn put_tensor_optimized(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
let chunk_size = crate::tensor::optimal_chunk_size(data.len());
self.put_tensor_chunked_with_size(name, data, shape, dtype, chunk_size)
}
pub fn get_tensor_chunked(&mut self, name: &str) -> Result<(Vec<u8>, Vec<usize>)> {
let meta_key = format!("{}/meta", name);
let meta_atom = self
.db
.get(&meta_key)?
.ok_or_else(|| SynaError::KeyNotFound(meta_key.clone()))?;
let meta_json = match meta_atom {
Atom::Text(s) => s,
_ => {
return Err(SynaError::TypeConversion {
from_type: meta_atom.type_name(),
to_type: "Text",
})
}
};
let meta: TensorMeta = serde_json::from_str(&meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
let mut data = Vec::with_capacity(meta.total_bytes);
for i in 0..meta.chunk_count {
let chunk_key = format!("{}/chunk/{}", name, i);
let chunk_atom = self
.db
.get(&chunk_key)?
.ok_or_else(|| SynaError::KeyNotFound(chunk_key.clone()))?;
match chunk_atom {
Atom::Bytes(bytes) => data.extend_from_slice(&bytes),
_ => {
return Err(SynaError::TypeConversion {
from_type: chunk_atom.type_name(),
to_type: "Bytes",
})
}
}
}
if data.len() != meta.total_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: meta.total_bytes,
});
}
Ok((data, meta.shape))
}
pub fn get_tensor_meta(&mut self, name: &str) -> Result<TensorMeta> {
let meta_key = format!("{}/meta", name);
let meta_atom = self
.db
.get(&meta_key)?
.ok_or_else(|| SynaError::KeyNotFound(meta_key.clone()))?;
let meta_json = match meta_atom {
Atom::Text(s) => s,
_ => {
return Err(SynaError::TypeConversion {
from_type: meta_atom.type_name(),
to_type: "Text",
})
}
};
let meta: TensorMeta = serde_json::from_str(&meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
Ok(meta)
}
pub fn put_tensor_batched(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let meta = TensorMeta {
shape: shape.to_vec(),
dtype: dtype.name().to_string(),
total_bytes: data.len(),
chunk_count: 1, chunk_size: data.len(), };
let meta_json = serde_json::to_vec(&meta)
.map_err(|e| SynaError::InvalidPath(format!("Failed to serialize metadata: {}", e)))?;
let mut buffer = Vec::with_capacity(4 + meta_json.len() + data.len());
buffer.extend_from_slice(&(meta_json.len() as u32).to_le_bytes());
buffer.extend_from_slice(&meta_json);
buffer.extend_from_slice(data);
self.db.append(name, Atom::Bytes(buffer))?;
Ok(1)
}
pub fn get_tensor_batched(&mut self, name: &str) -> Result<(Vec<u8>, Vec<usize>)> {
let blob_atom = self
.db
.get(name)?
.ok_or_else(|| SynaError::KeyNotFound(name.to_string()))?;
let blob = match blob_atom {
Atom::Bytes(b) => b,
_ => {
return Err(SynaError::TypeConversion {
from_type: blob_atom.type_name(),
to_type: "Bytes",
})
}
};
if blob.len() < 4 {
return Err(SynaError::ShapeMismatch {
data_size: blob.len(),
expected_size: 4,
});
}
let meta_len = u32::from_le_bytes([blob[0], blob[1], blob[2], blob[3]]) as usize;
if blob.len() < 4 + meta_len {
return Err(SynaError::ShapeMismatch {
data_size: blob.len(),
expected_size: 4 + meta_len,
});
}
let meta_json = &blob[4..4 + meta_len];
let meta: TensorMeta = serde_json::from_slice(meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
let data_start = 4 + meta_len;
let data = blob[data_start..].to_vec();
if data.len() != meta.total_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: meta.total_bytes,
});
}
Ok((data, meta.shape))
}
pub fn get_tensor_batched_meta(&mut self, name: &str) -> Result<TensorMeta> {
let blob_atom = self
.db
.get(name)?
.ok_or_else(|| SynaError::KeyNotFound(name.to_string()))?;
let blob = match blob_atom {
Atom::Bytes(b) => b,
_ => {
return Err(SynaError::TypeConversion {
from_type: blob_atom.type_name(),
to_type: "Bytes",
})
}
};
if blob.len() < 4 {
return Err(SynaError::ShapeMismatch {
data_size: blob.len(),
expected_size: 4,
});
}
let meta_len = u32::from_le_bytes([blob[0], blob[1], blob[2], blob[3]]) as usize;
if blob.len() < 4 + meta_len {
return Err(SynaError::ShapeMismatch {
data_size: blob.len(),
expected_size: 4 + meta_len,
});
}
let meta_json = &blob[4..4 + meta_len];
let meta: TensorMeta = serde_json::from_slice(meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
Ok(meta)
}
pub fn delete_tensor_chunked(&mut self, name: &str) -> Result<usize> {
let meta_key = format!("{}/meta", name);
let chunk_count = match self.db.get(&meta_key)? {
Some(Atom::Text(json)) => {
if let Ok(meta) = serde_json::from_str::<TensorMeta>(&json) {
meta.chunk_count
} else {
0
}
}
_ => 0,
};
let mut deleted = 0;
if self.db.delete(&meta_key).is_ok() {
deleted += 1;
}
for i in 0..chunk_count {
let chunk_key = format!("{}/chunk/{}", name, i);
if self.db.delete(&chunk_key).is_ok() {
deleted += 1;
}
}
Ok(deleted)
}
pub fn put_tensor_mmap(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
use memmap2::MmapMut;
use std::fs::OpenOptions;
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let safe_name = name.replace('/', "_");
let mmap_path = self.db.path.with_extension(format!("{}.mmap", safe_name));
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&mmap_path)?;
file.set_len(data.len() as u64)?;
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
mmap.copy_from_slice(data);
mmap.flush()?;
let meta = MmapTensorMeta {
shape: shape.to_vec(),
dtype: dtype.name().to_string(),
total_bytes: data.len(),
mmap_path: mmap_path.to_string_lossy().to_string(),
};
let meta_json = serde_json::to_string(&meta)
.map_err(|e| SynaError::InvalidPath(format!("Failed to serialize metadata: {}", e)))?;
let meta_key = format!("{}/mmap_meta", name);
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(1)
}
pub fn put_tensor_mmap_fast(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
use std::fs::OpenOptions;
use std::io::Write;
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let safe_name = name.replace('/', "_");
let mmap_path = self.db.path.with_extension(format!("{}.mmap", safe_name));
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&mmap_path)?;
file.set_len(data.len() as u64)?;
file.write_all(data)?;
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let _ =
unsafe { libc::posix_fadvise(file.as_raw_fd(), 0, 0, libc::POSIX_FADV_DONTNEED) };
}
#[cfg(any(target_os = "macos", target_os = "windows"))]
{
}
let meta = MmapTensorMeta {
shape: shape.to_vec(),
dtype: dtype.name().to_string(),
total_bytes: data.len(),
mmap_path: mmap_path.to_string_lossy().to_string(),
};
let meta_json = serde_json::to_string(&meta)
.map_err(|e| SynaError::InvalidPath(format!("Failed to serialize metadata: {}", e)))?;
let meta_key = format!("{}/mmap_meta", name);
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(1)
}
pub fn get_tensor_mmap(&mut self, name: &str) -> Result<(Vec<u8>, Vec<usize>)> {
use memmap2::Mmap;
use std::fs::File;
let meta_key = format!("{}/mmap_meta", name);
let meta_atom = self
.db
.get(&meta_key)?
.ok_or_else(|| SynaError::KeyNotFound(meta_key.clone()))?;
let meta_json = match meta_atom {
Atom::Text(s) => s,
_ => {
return Err(SynaError::TypeConversion {
from_type: meta_atom.type_name(),
to_type: "Text",
})
}
};
let meta: MmapTensorMeta = serde_json::from_str(&meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
let file = File::open(&meta.mmap_path)?;
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() != meta.total_bytes {
return Err(SynaError::ShapeMismatch {
data_size: mmap.len(),
expected_size: meta.total_bytes,
});
}
let data = mmap.to_vec();
Ok((data, meta.shape))
}
pub fn get_tensor_mmap_ref(&mut self, name: &str) -> Result<MmapTensorRef> {
use memmap2::Mmap;
use std::fs::File;
let meta_key = format!("{}/mmap_meta", name);
let meta_atom = self
.db
.get(&meta_key)?
.ok_or_else(|| SynaError::KeyNotFound(meta_key.clone()))?;
let meta_json = match meta_atom {
Atom::Text(s) => s,
_ => {
return Err(SynaError::TypeConversion {
from_type: meta_atom.type_name(),
to_type: "Text",
})
}
};
let meta: MmapTensorMeta = serde_json::from_str(&meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
let file = File::open(&meta.mmap_path)?;
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() != meta.total_bytes {
return Err(SynaError::ShapeMismatch {
data_size: mmap.len(),
expected_size: meta.total_bytes,
});
}
Ok(MmapTensorRef {
shape: meta.shape,
dtype: meta.dtype,
mmap,
})
}
pub fn get_tensor_mmap_meta(&mut self, name: &str) -> Result<MmapTensorMeta> {
let meta_key = format!("{}/mmap_meta", name);
let meta_atom = self
.db
.get(&meta_key)?
.ok_or_else(|| SynaError::KeyNotFound(meta_key.clone()))?;
let meta_json = match meta_atom {
Atom::Text(s) => s,
_ => {
return Err(SynaError::TypeConversion {
from_type: meta_atom.type_name(),
to_type: "Text",
})
}
};
let meta: MmapTensorMeta = serde_json::from_str(&meta_json)
.map_err(|e| SynaError::InvalidPath(format!("Failed to parse metadata: {}", e)))?;
Ok(meta)
}
pub fn delete_tensor_mmap(&mut self, name: &str) -> Result<()> {
let meta_key = format!("{}/mmap_meta", name);
if let Some(Atom::Text(json)) = self.db.get(&meta_key)? {
if let Ok(meta) = serde_json::from_str::<MmapTensorMeta>(&json) {
let _ = std::fs::remove_file(&meta.mmap_path);
}
}
self.db.delete(&meta_key)?;
Ok(())
}
#[cfg(feature = "async")]
pub async fn put_tensor_async(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
chunk_size: usize,
) -> Result<usize> {
use tokio::fs::File as AsyncFile;
use tokio::io::AsyncWriteExt;
use tokio::task::JoinSet;
let element_size = dtype.size();
let n_elements = shape.iter().product::<usize>();
let expected_bytes = n_elements * element_size;
if data.len() != expected_bytes {
return Err(SynaError::ShapeMismatch {
data_size: data.len(),
expected_size: expected_bytes,
});
}
let chunk_count = data.len().div_ceil(chunk_size);
let meta = TensorMeta {
shape: shape.to_vec(),
dtype: dtype.name().to_string(),
total_bytes: data.len(),
chunk_count,
chunk_size,
};
let meta_json = serde_json::to_string(&meta)
.map_err(|e| SynaError::InvalidPath(format!("Failed to serialize metadata: {}", e)))?;
let meta_key = format!("{}/meta", name);
self.db.append(&meta_key, Atom::Text(meta_json))?;
let db_dir = self.db.path.parent().unwrap_or(std::path::Path::new("."));
let safe_name = name.replace('/', "_");
let chunks_data: Vec<(usize, Vec<u8>)> = data
.chunks(chunk_size)
.enumerate()
.map(|(i, chunk_data)| (i, chunk_data.to_vec()))
.collect();
let mut tasks: JoinSet<std::result::Result<(usize, std::path::PathBuf), String>> =
JoinSet::new();
for (i, chunk_data) in chunks_data {
let temp_path = db_dir.join(format!(".{}_chunk_{}.tmp", safe_name, i));
tasks.spawn(async move {
let mut file = AsyncFile::create(&temp_path)
.await
.map_err(|e| format!("Failed to create temp file: {}", e))?;
file.write_all(&chunk_data)
.await
.map_err(|e| format!("Failed to write chunk: {}", e))?;
file.flush()
.await
.map_err(|e| format!("Failed to flush chunk: {}", e))?;
Ok((i, temp_path))
});
}
let mut temp_files: Vec<(usize, std::path::PathBuf)> = Vec::with_capacity(chunk_count);
while let Some(result) = tasks.join_next().await {
match result {
Ok(Ok((i, path))) => temp_files.push((i, path)),
Ok(Err(e)) => {
for (_, path) in &temp_files {
let _ = std::fs::remove_file(path);
}
return Err(SynaError::InvalidPath(e));
}
Err(join_error) => {
for (_, path) in &temp_files {
let _ = std::fs::remove_file(path);
}
return Err(SynaError::InvalidPath(format!(
"Async task failed: {}",
join_error
)));
}
}
}
temp_files.sort_by_key(|(i, _)| *i);
for (i, temp_path) in &temp_files {
let chunk_data = std::fs::read(temp_path)?;
let chunk_key = format!("{}/chunk/{}", name, i);
self.db.append(&chunk_key, Atom::Bytes(chunk_data))?;
let _ = std::fs::remove_file(temp_path);
}
Ok(chunk_count)
}
#[cfg(feature = "async")]
pub async fn put_tensor_async_default(
&mut self,
name: &str,
data: &[u8],
shape: &[usize],
dtype: DType,
) -> Result<usize> {
self.put_tensor_async(name, data, shape, dtype, DEFAULT_CHUNK_SIZE)
.await
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MmapTensorMeta {
pub shape: Vec<usize>,
pub dtype: String,
pub total_bytes: usize,
pub mmap_path: String,
}
pub struct MmapTensorRef {
pub shape: Vec<usize>,
pub dtype: String,
mmap: memmap2::Mmap,
}
impl MmapTensorRef {
#[inline]
pub fn data(&self) -> &[u8] {
&self.mmap
}
#[inline]
pub fn as_f32_slice(&self) -> &[f32] {
let (prefix, floats, suffix) = unsafe { self.mmap[..].align_to::<f32>() };
debug_assert!(
prefix.is_empty() && suffix.is_empty(),
"mmap tensor data misaligned for f32"
);
floats
}
#[inline]
pub fn as_f64_slice(&self) -> &[f64] {
let (prefix, doubles, suffix) = unsafe { self.mmap[..].align_to::<f64>() };
debug_assert!(
prefix.is_empty() && suffix.is_empty(),
"mmap tensor data misaligned for f64"
);
doubles
}
#[inline]
pub fn as_i32_slice(&self) -> &[i32] {
let (prefix, ints, suffix) = unsafe { self.mmap[..].align_to::<i32>() };
debug_assert!(
prefix.is_empty() && suffix.is_empty(),
"mmap tensor data misaligned for i32"
);
ints
}
#[inline]
pub fn as_i64_slice(&self) -> &[i64] {
let (prefix, longs, suffix) = unsafe { self.mmap[..].align_to::<i64>() };
debug_assert!(
prefix.is_empty() && suffix.is_empty(),
"mmap tensor data misaligned for i64"
);
longs
}
#[inline]
pub fn len(&self) -> usize {
self.mmap.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.mmap.is_empty()
}
}
pub mod direct_io {
use std::fs::{File, OpenOptions};
use std::io::{self, Write};
use std::path::Path;
pub const DIRECT_IO_ALIGNMENT: usize = 4096;
pub const DIRECT_IO_MIN_SIZE: usize = 1024 * 1024;
#[inline]
pub fn is_direct_io_available() -> bool {
cfg!(target_os = "linux")
}
#[cfg(target_os = "linux")]
pub fn open_direct(path: &Path) -> io::Result<File> {
use std::os::unix::fs::OpenOptionsExt;
const O_DIRECT: i32 = 0o40000;
OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.custom_flags(O_DIRECT)
.open(path)
}
#[cfg(not(target_os = "linux"))]
pub fn open_direct(path: &Path) -> io::Result<File> {
OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(path)
}
#[cfg(target_os = "linux")]
pub fn open_direct_read(path: &Path) -> io::Result<File> {
use std::os::unix::fs::OpenOptionsExt;
const O_DIRECT: i32 = 0o40000;
OpenOptions::new()
.read(true)
.custom_flags(O_DIRECT)
.open(path)
}
#[cfg(not(target_os = "linux"))]
pub fn open_direct_read(path: &Path) -> io::Result<File> {
OpenOptions::new().read(true).open(path)
}
#[inline]
pub fn align_size(size: usize) -> usize {
(size + DIRECT_IO_ALIGNMENT - 1) & !(DIRECT_IO_ALIGNMENT - 1)
}
pub fn create_aligned_buffer(size: usize) -> Vec<u8> {
let aligned_size = align_size(size);
vec![0; aligned_size]
}
pub fn write_aligned(file: &mut File, data: &[u8]) -> io::Result<usize> {
let aligned_size = align_size(data.len());
if aligned_size == data.len() {
file.write_all(data)?;
Ok(data.len())
} else {
let mut aligned_buffer = create_aligned_buffer(data.len());
aligned_buffer[..data.len()].copy_from_slice(data);
file.write_all(&aligned_buffer)?;
Ok(aligned_size)
}
}
#[inline]
pub fn should_use_direct_io(size: usize) -> bool {
is_direct_io_available() && size >= DIRECT_IO_MIN_SIZE
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_dtype_size() {
assert_eq!(DType::Float32.size(), 4);
assert_eq!(DType::Float64.size(), 8);
assert_eq!(DType::Int32.size(), 4);
assert_eq!(DType::Int64.size(), 8);
}
#[test]
fn test_get_tensor_empty() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let (data, shape) = engine.get_tensor("nonexistent/*", DType::Float64).unwrap();
assert!(data.is_empty());
assert_eq!(shape, vec![0]);
}
#[test]
fn test_get_tensor_floats() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let mut db = SynaDB::new(&db_path).unwrap();
db.append("data/00", Atom::Float(1.0)).unwrap();
db.append("data/01", Atom::Float(2.0)).unwrap();
db.append("data/02", Atom::Float(3.0)).unwrap();
let mut engine = TensorEngine::new(db);
let (data, shape) = engine.get_tensor("data/*", DType::Float64).unwrap();
assert_eq!(shape, vec![3]);
assert_eq!(data.len(), 3 * 8);
let values: Vec<f64> = data
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(values, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_put_tensor() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let values = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = values.iter().flat_map(|f| f.to_le_bytes()).collect();
let count = engine
.put_tensor("tensor/", &data, &[4], DType::Float64)
.unwrap();
assert_eq!(count, 4);
let db = engine.db_mut();
assert_eq!(db.get("tensor/00000000").unwrap(), Some(Atom::Float(1.0)));
assert_eq!(db.get("tensor/00000001").unwrap(), Some(Atom::Float(2.0)));
assert_eq!(db.get("tensor/00000002").unwrap(), Some(Atom::Float(3.0)));
assert_eq!(db.get("tensor/00000003").unwrap(), Some(Atom::Float(4.0)));
}
#[test]
fn test_put_tensor_shape_mismatch() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data = vec![0u8; 16]; let result = engine.put_tensor("tensor/", &data, &[4], DType::Float64);
assert!(matches!(result, Err(SynaError::ShapeMismatch { .. })));
}
#[test]
fn test_roundtrip_float64() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.5f64, 2.5, 3.5];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
engine
.put_tensor("rt/", &data, &[3], DType::Float64)
.unwrap();
let (loaded_data, shape) = engine.get_tensor("rt/*", DType::Float64).unwrap();
assert_eq!(shape, vec![3]);
let loaded: Vec<f64> = loaded_data
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(loaded, original);
}
#[test]
fn test_int_to_float_conversion() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let mut db = SynaDB::new(&db_path).unwrap();
db.append("int/0", Atom::Int(10)).unwrap();
db.append("int/1", Atom::Int(20)).unwrap();
let mut engine = TensorEngine::new(db);
let (data, shape) = engine.get_tensor("int/*", DType::Float64).unwrap();
assert_eq!(shape, vec![2]);
let values: Vec<f64> = data
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(values, vec![10.0, 20.0]);
}
#[test]
fn test_skip_non_numeric() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let mut db = SynaDB::new(&db_path).unwrap();
db.append("mix/0", Atom::Float(1.0)).unwrap();
db.append("mix/1", Atom::Text("skip me".to_string()))
.unwrap();
db.append("mix/2", Atom::Float(2.0)).unwrap();
db.append("mix/3", Atom::Null).unwrap();
db.append("mix/4", Atom::Float(3.0)).unwrap();
let mut engine = TensorEngine::new(db);
let (data, shape) = engine.get_tensor("mix/*", DType::Float64).unwrap();
assert_eq!(shape, vec![3]);
let values: Vec<f64> = data
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(values, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_put_get_tensor_chunked_small() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let chunks = engine
.put_tensor_chunked("small", &data, &[4], DType::Float64)
.unwrap();
assert_eq!(chunks, 1);
let (loaded, shape) = engine.get_tensor_chunked("small").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
}
#[test]
fn test_put_get_tensor_chunked_large() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 10_000;
let original: Vec<f64> = (0..num_elements).map(|i| i as f64 * 0.1).collect();
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let chunk_size = 1024;
let chunks = engine
.put_tensor_chunked_with_size(
"large",
&data,
&[num_elements],
DType::Float64,
chunk_size,
)
.unwrap();
let expected_chunks = (data.len() + chunk_size - 1) / chunk_size;
assert_eq!(chunks, expected_chunks);
assert!(chunks > 1);
let (loaded, shape) = engine.get_tensor_chunked("large").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![num_elements]);
let loaded_values: Vec<f64> = loaded
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(loaded_values, original);
}
#[test]
fn test_get_tensor_meta() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_chunked("tensor", &data, &[2, 3], DType::Float64)
.unwrap();
let meta = engine.get_tensor_meta("tensor").unwrap();
assert_eq!(meta.shape, vec![2, 3]);
assert_eq!(meta.dtype, "float64");
assert_eq!(meta.total_bytes, 48); assert_eq!(meta.chunk_count, 1);
}
#[test]
fn test_delete_tensor_chunked() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 5000;
let data: Vec<u8> = (0..num_elements)
.map(|i| i as f64)
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_chunked_with_size("to_delete", &data, &[num_elements], DType::Float64, 1024)
.unwrap();
assert!(engine.get_tensor_meta("to_delete").is_ok());
let deleted = engine.delete_tensor_chunked("to_delete").unwrap();
assert!(deleted > 0);
assert!(engine.get_tensor_meta("to_delete").is_err());
}
#[test]
fn test_chunked_shape_mismatch() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data = vec![0u8; 16]; let result = engine.put_tensor_chunked("bad", &data, &[4], DType::Float64);
assert!(matches!(result, Err(SynaError::ShapeMismatch { .. })));
}
#[test]
fn test_chunked_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let result = engine.get_tensor_chunked("nonexistent");
assert!(matches!(result, Err(SynaError::KeyNotFound(_))));
}
#[test]
fn test_chunked_roundtrip_all_dtypes() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let f32_data: Vec<u8> = vec![1.0f32, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_chunked("f32", &f32_data, &[4], DType::Float32)
.unwrap();
let (loaded, _) = engine.get_tensor_chunked("f32").unwrap();
assert_eq!(loaded, f32_data);
let i32_data: Vec<u8> = vec![1i32, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_chunked("i32", &i32_data, &[4], DType::Int32)
.unwrap();
let (loaded, _) = engine.get_tensor_chunked("i32").unwrap();
assert_eq!(loaded, i32_data);
let i64_data: Vec<u8> = vec![1i64, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_chunked("i64", &i64_data, &[4], DType::Int64)
.unwrap();
let (loaded, _) = engine.get_tensor_chunked("i64").unwrap();
assert_eq!(loaded, i64_data);
}
#[test]
fn test_put_get_tensor_batched_small() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let count = engine
.put_tensor_batched("small", &data, &[4], DType::Float64)
.unwrap();
assert_eq!(count, 1);
let (loaded, shape) = engine.get_tensor_batched("small").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
}
#[test]
fn test_put_get_tensor_batched_large() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 10_000;
let original: Vec<f64> = (0..num_elements).map(|i| i as f64 * 0.1).collect();
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let count = engine
.put_tensor_batched("large", &data, &[num_elements], DType::Float64)
.unwrap();
assert_eq!(count, 1);
let (loaded, shape) = engine.get_tensor_batched("large").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![num_elements]);
let loaded_values: Vec<f64> = loaded
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(loaded_values, original);
}
#[test]
fn test_get_tensor_batched_meta() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_batched("tensor", &data, &[2, 3], DType::Float64)
.unwrap();
let meta = engine.get_tensor_batched_meta("tensor").unwrap();
assert_eq!(meta.shape, vec![2, 3]);
assert_eq!(meta.dtype, "float64");
assert_eq!(meta.total_bytes, 48); }
#[test]
fn test_batched_shape_mismatch() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data = vec![0u8; 16]; let result = engine.put_tensor_batched("bad", &data, &[4], DType::Float64);
assert!(matches!(result, Err(SynaError::ShapeMismatch { .. })));
}
#[test]
fn test_batched_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let result = engine.get_tensor_batched("nonexistent");
assert!(matches!(result, Err(SynaError::KeyNotFound(_))));
}
#[test]
fn test_batched_roundtrip_all_dtypes() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let f32_data: Vec<u8> = vec![1.0f32, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_batched("f32", &f32_data, &[4], DType::Float32)
.unwrap();
let (loaded, _) = engine.get_tensor_batched("f32").unwrap();
assert_eq!(loaded, f32_data);
let f64_data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_batched("f64", &f64_data, &[4], DType::Float64)
.unwrap();
let (loaded, _) = engine.get_tensor_batched("f64").unwrap();
assert_eq!(loaded, f64_data);
let i32_data: Vec<u8> = vec![1i32, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_batched("i32", &i32_data, &[4], DType::Int32)
.unwrap();
let (loaded, _) = engine.get_tensor_batched("i32").unwrap();
assert_eq!(loaded, i32_data);
let i64_data: Vec<u8> = vec![1i64, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_batched("i64", &i64_data, &[4], DType::Int64)
.unwrap();
let (loaded, _) = engine.get_tensor_batched("i64").unwrap();
assert_eq!(loaded, i64_data);
}
#[test]
fn test_batched_multidimensional_shape() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 2 * 3 * 4;
let original: Vec<f64> = (0..num_elements).map(|i| i as f64).collect();
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
engine
.put_tensor_batched("3d", &data, &[2, 3, 4], DType::Float64)
.unwrap();
let (loaded, shape) = engine.get_tensor_batched("3d").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![2, 3, 4]);
}
#[test]
fn test_put_get_tensor_mmap_small() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let count = engine
.put_tensor_mmap("small", &data, &[4], DType::Float64)
.unwrap();
assert_eq!(count, 1);
let (loaded, shape) = engine.get_tensor_mmap("small").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
}
#[test]
fn test_put_get_tensor_mmap_large() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 128 * 1024; let original: Vec<f64> = (0..num_elements).map(|i| i as f64 * 0.1).collect();
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let count = engine
.put_tensor_mmap("large", &data, &[num_elements], DType::Float64)
.unwrap();
assert_eq!(count, 1);
let (loaded, shape) = engine.get_tensor_mmap("large").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![num_elements]);
let loaded_values: Vec<f64> = loaded
.chunks(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
assert_eq!(loaded_values, original);
}
#[test]
fn test_get_tensor_mmap_ref() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
engine
.put_tensor_mmap("ref_test", &data, &[4], DType::Float64)
.unwrap();
let tensor_ref = engine.get_tensor_mmap_ref("ref_test").unwrap();
assert_eq!(tensor_ref.shape, vec![4]);
assert_eq!(tensor_ref.dtype, "float64");
assert_eq!(tensor_ref.len(), 32);
let floats = tensor_ref.as_f64_slice();
assert_eq!(floats, &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_get_tensor_mmap_meta() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_mmap("meta_test", &data, &[2, 3], DType::Float64)
.unwrap();
let meta = engine.get_tensor_mmap_meta("meta_test").unwrap();
assert_eq!(meta.shape, vec![2, 3]);
assert_eq!(meta.dtype, "float64");
assert_eq!(meta.total_bytes, 48); assert!(meta.mmap_path.contains("meta_test"));
}
#[test]
fn test_mmap_shape_mismatch() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data = vec![0u8; 16]; let result = engine.put_tensor_mmap("bad", &data, &[4], DType::Float64);
assert!(matches!(result, Err(SynaError::ShapeMismatch { .. })));
}
#[test]
fn test_mmap_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let result = engine.get_tensor_mmap("nonexistent");
assert!(matches!(result, Err(SynaError::KeyNotFound(_))));
}
#[test]
fn test_delete_tensor_mmap() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_mmap("to_delete", &data, &[4], DType::Float64)
.unwrap();
assert!(engine.get_tensor_mmap_meta("to_delete").is_ok());
engine.delete_tensor_mmap("to_delete").unwrap();
assert!(engine.get_tensor_mmap_meta("to_delete").is_err());
}
#[test]
fn test_mmap_roundtrip_all_dtypes() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let f32_data: Vec<u8> = vec![1.0f32, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_mmap("f32", &f32_data, &[4], DType::Float32)
.unwrap();
let (loaded, _) = engine.get_tensor_mmap("f32").unwrap();
assert_eq!(loaded, f32_data);
let f64_data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_mmap("f64", &f64_data, &[4], DType::Float64)
.unwrap();
let (loaded, _) = engine.get_tensor_mmap("f64").unwrap();
assert_eq!(loaded, f64_data);
let i32_data: Vec<u8> = vec![1i32, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_mmap("i32", &i32_data, &[4], DType::Int32)
.unwrap();
let (loaded, _) = engine.get_tensor_mmap("i32").unwrap();
assert_eq!(loaded, i32_data);
let i64_data: Vec<u8> = vec![1i64, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_mmap("i64", &i64_data, &[4], DType::Int64)
.unwrap();
let (loaded, _) = engine.get_tensor_mmap("i64").unwrap();
assert_eq!(loaded, i64_data);
}
#[test]
fn test_mmap_tensor_ref_slices() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let f32_values = vec![1.0f32, 2.0, 3.0, 4.0];
let f32_data: Vec<u8> = f32_values.iter().flat_map(|f| f.to_le_bytes()).collect();
engine
.put_tensor_mmap("f32_ref", &f32_data, &[4], DType::Float32)
.unwrap();
let tensor_ref = engine.get_tensor_mmap_ref("f32_ref").unwrap();
assert_eq!(tensor_ref.as_f32_slice(), &[1.0f32, 2.0, 3.0, 4.0]);
let i32_values = vec![10i32, 20, 30, 40];
let i32_data: Vec<u8> = i32_values.iter().flat_map(|i| i.to_le_bytes()).collect();
engine
.put_tensor_mmap("i32_ref", &i32_data, &[4], DType::Int32)
.unwrap();
let tensor_ref = engine.get_tensor_mmap_ref("i32_ref").unwrap();
assert_eq!(tensor_ref.as_i32_slice(), &[10i32, 20, 30, 40]);
let i64_values = vec![100i64, 200, 300, 400];
let i64_data: Vec<u8> = i64_values.iter().flat_map(|i| i.to_le_bytes()).collect();
engine
.put_tensor_mmap("i64_ref", &i64_data, &[4], DType::Int64)
.unwrap();
let tensor_ref = engine.get_tensor_mmap_ref("i64_ref").unwrap();
assert_eq!(tensor_ref.as_i64_slice(), &[100i64, 200, 300, 400]);
}
#[test]
fn test_mmap_with_slashes_in_name() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_mmap("model/layer1/weights", &data, &[4], DType::Float64)
.unwrap();
let (loaded, shape) = engine.get_tensor_mmap("model/layer1/weights").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
let meta = engine.get_tensor_mmap_meta("model/layer1/weights").unwrap();
assert!(meta.mmap_path.contains("model_layer1_weights"));
}
#[test]
fn test_optimal_chunk_size() {
assert_eq!(optimal_chunk_size(0), CHUNK_SIZE_SMALL);
assert_eq!(optimal_chunk_size(1_000_000), CHUNK_SIZE_SMALL);
assert_eq!(optimal_chunk_size(5_000_000), CHUNK_SIZE_SMALL);
assert_eq!(optimal_chunk_size(10_000_000), CHUNK_SIZE_SMALL);
assert_eq!(optimal_chunk_size(10_000_001), CHUNK_SIZE_MEDIUM);
assert_eq!(optimal_chunk_size(50_000_000), CHUNK_SIZE_MEDIUM);
assert_eq!(optimal_chunk_size(100_000_000), CHUNK_SIZE_MEDIUM);
assert_eq!(optimal_chunk_size(100_000_001), CHUNK_SIZE_LARGE);
assert_eq!(optimal_chunk_size(200_000_000), CHUNK_SIZE_LARGE);
assert_eq!(optimal_chunk_size(1_000_000_000), CHUNK_SIZE_LARGE);
}
#[test]
fn test_chunk_size_constants() {
assert_eq!(CHUNK_SIZE_SMALL, 1 * 1024 * 1024); assert_eq!(CHUNK_SIZE_MEDIUM, 4 * 1024 * 1024); assert_eq!(CHUNK_SIZE_LARGE, 16 * 1024 * 1024); }
#[test]
fn test_put_tensor_optimized_small() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let chunks = engine
.put_tensor_optimized("small_tensor", &data, &[4], DType::Float64)
.unwrap();
assert_eq!(chunks, 1);
let (loaded, shape) = engine.get_tensor_chunked("small_tensor").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
let meta = engine.get_tensor_meta("small_tensor").unwrap();
assert_eq!(meta.chunk_size, CHUNK_SIZE_SMALL);
}
#[test]
fn test_put_tensor_optimized_roundtrip() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let f64_data: Vec<u8> = vec![1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
engine
.put_tensor_optimized("f64_opt", &f64_data, &[4], DType::Float64)
.unwrap();
let (loaded, _) = engine.get_tensor_chunked("f64_opt").unwrap();
assert_eq!(loaded, f64_data);
let i32_data: Vec<u8> = vec![1i32, 2, 3, 4]
.iter()
.flat_map(|i| i.to_le_bytes())
.collect();
engine
.put_tensor_optimized("i32_opt", &i32_data, &[4], DType::Int32)
.unwrap();
let (loaded, _) = engine.get_tensor_chunked("i32_opt").unwrap();
assert_eq!(loaded, i32_data);
}
#[cfg(feature = "async")]
mod async_tests {
use super::*;
#[tokio::test]
async fn test_put_tensor_async_small() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let chunks = engine
.put_tensor_async("async_small", &data, &[4], DType::Float64, 1024)
.await
.unwrap();
assert_eq!(chunks, 1);
let (loaded, shape) = engine.get_tensor_chunked("async_small").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
}
#[tokio::test]
async fn test_put_tensor_async_large() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let num_elements = 10_000;
let original: Vec<f64> = (0..num_elements).map(|i| i as f64 * 0.1).collect();
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let chunk_size = 1024;
let chunks = engine
.put_tensor_async(
"async_large",
&data,
&[num_elements],
DType::Float64,
chunk_size,
)
.await
.unwrap();
let expected_chunks = data.len().div_ceil(chunk_size);
assert_eq!(chunks, expected_chunks);
let (loaded, shape) = engine.get_tensor_chunked("async_large").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![num_elements]);
}
#[tokio::test]
async fn test_put_tensor_async_default() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let original = vec![1.0f64, 2.0, 3.0, 4.0];
let data: Vec<u8> = original.iter().flat_map(|f| f.to_le_bytes()).collect();
let chunks = engine
.put_tensor_async_default("async_default", &data, &[4], DType::Float64)
.await
.unwrap();
assert_eq!(chunks, 1);
let (loaded, shape) = engine.get_tensor_chunked("async_default").unwrap();
assert_eq!(loaded, data);
assert_eq!(shape, vec![4]);
}
#[tokio::test]
async fn test_put_tensor_async_shape_mismatch() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let db = SynaDB::new(&db_path).unwrap();
let mut engine = TensorEngine::new(db);
let data = vec![0u8; 16]; let result = engine
.put_tensor_async("mismatch", &data, &[4], DType::Float64, 1024) .await;
assert!(matches!(result, Err(SynaError::ShapeMismatch { .. })));
}
}
mod direct_io_tests {
use super::super::direct_io::*;
use tempfile::tempdir;
#[test]
fn test_is_direct_io_available() {
let available = is_direct_io_available();
#[cfg(target_os = "linux")]
assert!(available);
#[cfg(not(target_os = "linux"))]
assert!(!available);
}
#[test]
fn test_align_size() {
assert_eq!(align_size(0), 0);
assert_eq!(align_size(1), DIRECT_IO_ALIGNMENT);
assert_eq!(align_size(100), DIRECT_IO_ALIGNMENT);
assert_eq!(align_size(4096), 4096);
assert_eq!(align_size(4097), 8192);
assert_eq!(align_size(8192), 8192);
assert_eq!(align_size(10000), 12288); }
#[test]
fn test_create_aligned_buffer() {
let buffer = create_aligned_buffer(100);
assert!(buffer.capacity() >= DIRECT_IO_ALIGNMENT);
assert_eq!(buffer.len(), DIRECT_IO_ALIGNMENT);
let buffer = create_aligned_buffer(5000);
assert!(buffer.capacity() >= 8192);
assert_eq!(buffer.len(), 8192); }
#[test]
fn test_should_use_direct_io() {
assert!(!should_use_direct_io(100));
assert!(!should_use_direct_io(1000));
assert!(!should_use_direct_io(DIRECT_IO_MIN_SIZE - 1));
#[cfg(target_os = "linux")]
{
assert!(should_use_direct_io(DIRECT_IO_MIN_SIZE));
assert!(should_use_direct_io(10_000_000));
}
#[cfg(not(target_os = "linux"))]
{
assert!(!should_use_direct_io(DIRECT_IO_MIN_SIZE));
assert!(!should_use_direct_io(10_000_000));
}
}
#[test]
fn test_open_direct_write() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("direct_test.bin");
let file = open_direct(&file_path);
assert!(file.is_ok());
assert!(file_path.exists());
}
#[test]
fn test_write_aligned() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("aligned_write.bin");
let mut file = open_direct(&file_path).unwrap();
let data = vec![42u8; 1000];
let written = write_aligned(&mut file, &data).unwrap();
assert_eq!(written, DIRECT_IO_ALIGNMENT);
drop(file);
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(metadata.len() as usize, DIRECT_IO_ALIGNMENT);
}
#[test]
fn test_write_aligned_exact() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("aligned_exact.bin");
let mut file = open_direct(&file_path).unwrap();
let data = vec![42u8; DIRECT_IO_ALIGNMENT];
let written = write_aligned(&mut file, &data).unwrap();
assert_eq!(written, DIRECT_IO_ALIGNMENT);
}
#[test]
fn test_open_direct_read() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("read_test.bin");
std::fs::write(&file_path, vec![0u8; 4096]).unwrap();
let file = open_direct_read(&file_path);
assert!(file.is_ok());
}
}
}