pub mod optimizer;
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use scirs2_core::ndarray::{s, IxDyn};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryEvictionPolicy {
LRU,
LFU,
SizeBased,
ARC,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AdaptiveStrategy {
Fixed,
MemoryPressure,
HitRate,
Predictive,
}
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub enable_memory_pool: bool,
pub max_pool_size: usize,
pub min_pool_size: usize,
pub enable_zero_copy: bool,
pub enable_mmap: bool,
pub mmap_threshold: usize,
pub cleanup_interval: Duration,
pub eviction_policy: MemoryEvictionPolicy,
pub adaptive_strategy: AdaptiveStrategy,
pub target_hit_rate: f64,
pub enable_prefetching: bool,
pub enable_defragmentation: bool,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
enable_memory_pool: true,
max_pool_size: 1024 * 1024 * 1024, min_pool_size: 64 * 1024 * 1024, enable_zero_copy: true,
enable_mmap: true,
mmap_threshold: 100 * 1024 * 1024, cleanup_interval: Duration::from_secs(60),
eviction_policy: MemoryEvictionPolicy::Hybrid,
adaptive_strategy: AdaptiveStrategy::HitRate,
target_hit_rate: 0.85, enable_prefetching: true,
enable_defragmentation: true,
}
}
}
#[derive(Debug, Clone)]
struct PoolEntry {
tensor: Tensor,
last_used: Instant,
ref_count: usize,
access_count: usize,
#[allow(dead_code)]
created_at: Instant,
#[allow(dead_code)]
pool_time: Duration,
size_bytes: usize,
}
impl PoolEntry {
fn new(tensor: Tensor, size_bytes: usize) -> Self {
let now = Instant::now();
Self {
tensor,
last_used: now,
ref_count: 0,
access_count: 0,
created_at: now,
pool_time: Duration::ZERO,
size_bytes,
}
}
fn mark_accessed(&mut self) {
self.last_used = Instant::now();
self.access_count += 1;
}
fn eviction_priority(&self, policy: MemoryEvictionPolicy) -> f64 {
match policy {
MemoryEvictionPolicy::LRU => {
-(self.last_used.elapsed().as_secs_f64())
},
MemoryEvictionPolicy::LFU => {
-(self.access_count as f64)
},
MemoryEvictionPolicy::SizeBased => {
-(self.size_bytes as f64)
},
MemoryEvictionPolicy::ARC => {
let recency_score = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
let frequency_score = self.access_count as f64;
-(recency_score + frequency_score)
},
MemoryEvictionPolicy::Hybrid => {
let recency = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
let frequency = self.access_count as f64;
let size_factor = 1.0 / (1.0 + (self.size_bytes as f64 / 1_000_000.0));
-(recency * 0.4 + frequency * 0.4 + size_factor * 0.2)
},
}
}
}
#[derive(Debug)]
pub struct TensorView {
original: Arc<Tensor>,
offset: usize,
shape: Vec<usize>,
#[allow(dead_code)]
strides: Vec<usize>,
}
impl TensorView {
pub fn slice(tensor: Arc<Tensor>, start: usize, end: usize) -> Result<Self> {
let original_shape = tensor.shape();
if start >= end || end > original_shape.iter().product::<usize>() {
return Err(TrustformersError::invalid_input(
"Invalid slice bounds".to_string(),
));
}
let slice_len = end - start;
Ok(Self {
original: tensor,
offset: start,
shape: vec![slice_len],
strides: vec![1],
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn as_tensor(&self) -> Result<Tensor> {
match &*self.original {
Tensor::F32(arr) => {
let flat = arr
.view()
.into_shape_with_order(arr.len())
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
let slice = flat.slice(s![
self.offset..self.offset + self.shape.iter().product::<usize>()
]);
let sliced_arr = slice
.to_owned()
.into_shape_with_order(IxDyn(&self.shape))
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
Ok(Tensor::F32(sliced_arr))
},
_ => Err(TrustformersError::tensor_op_error(
"Zero-copy slicing not implemented for this tensor type",
"zero_copy_slice",
)),
}
}
}
#[derive(Debug, Clone)]
struct PoolStatistics {
total_requests: usize,
cache_hits: usize,
cache_misses: usize,
total_evictions: usize,
evictions_by_policy: HashMap<String, usize>,
total_allocated_bytes: usize,
peak_memory_usage: usize,
#[allow(dead_code)]
average_tensor_lifetime: Duration,
#[allow(dead_code)]
last_reset: Instant,
}
impl Default for PoolStatistics {
fn default() -> Self {
Self {
total_requests: 0,
cache_hits: 0,
cache_misses: 0,
total_evictions: 0,
evictions_by_policy: HashMap::new(),
total_allocated_bytes: 0,
peak_memory_usage: 0,
average_tensor_lifetime: Duration::ZERO,
last_reset: Instant::now(),
}
}
}
impl PoolStatistics {
fn hit_rate(&self) -> f64 {
if self.total_requests == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_requests as f64
}
}
fn miss_rate(&self) -> f64 {
if self.total_requests == 0 {
0.0
} else {
self.cache_misses as f64 / self.total_requests as f64
}
}
}
pub struct TensorMemoryPool {
config: MemoryConfig,
pool: Arc<RwLock<HashMap<Vec<usize>, Vec<PoolEntry>>>>,
current_size: Arc<Mutex<usize>>,
last_cleanup: Arc<Mutex<Instant>>,
statistics: Arc<Mutex<PoolStatistics>>,
access_patterns: Arc<Mutex<HashMap<Vec<usize>, Vec<Instant>>>>,
dynamic_max_size: Arc<Mutex<usize>>,
}
impl TensorMemoryPool {
pub fn new(config: MemoryConfig) -> Self {
let dynamic_max_size = config.max_pool_size;
Self {
config,
pool: Arc::new(RwLock::new(HashMap::new())),
current_size: Arc::new(Mutex::new(0)),
last_cleanup: Arc::new(Mutex::new(Instant::now())),
statistics: Arc::new(Mutex::new(PoolStatistics::default())),
access_patterns: Arc::new(Mutex::new(HashMap::new())),
dynamic_max_size: Arc::new(Mutex::new(dynamic_max_size)),
}
}
pub fn get_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
if self.config.enable_prefetching {
let mut patterns = self.access_patterns.lock().expect("lock should not be poisoned");
patterns.entry(shape.to_vec()).or_default().push(Instant::now());
}
{
let mut stats = self.statistics.lock().expect("lock should not be poisoned");
stats.total_requests += 1;
}
if !self.config.enable_memory_pool {
return self.create_tensor(shape, dtype);
}
if let Some(tensor) = self.try_get_from_pool(shape)? {
let mut stats = self.statistics.lock().expect("lock should not be poisoned");
stats.cache_hits += 1;
return Ok(tensor);
}
{
let mut stats = self.statistics.lock().expect("lock should not be poisoned");
stats.cache_misses += 1;
}
self.apply_adaptive_sizing()?;
self.create_tensor(shape, dtype)
}
pub fn return_tensor(&self, tensor: Tensor) -> Result<()> {
if !self.config.enable_memory_pool {
return Ok(()); }
let shape = tensor.shape().to_vec();
let tensor_size = self.estimate_tensor_size(&tensor);
let entry = PoolEntry::new(tensor, tensor_size);
let mut pool = self.pool.write().expect("lock should not be poisoned");
pool.entry(shape).or_default().push(entry);
{
let mut current = self.current_size.lock().expect("lock should not be poisoned");
*current += tensor_size;
let mut stats = self.statistics.lock().expect("lock should not be poisoned");
if *current > stats.peak_memory_usage {
stats.peak_memory_usage = *current;
}
stats.total_allocated_bytes += tensor_size;
}
self.cleanup_if_needed()?;
Ok(())
}
fn try_get_from_pool(&self, shape: &[usize]) -> Result<Option<Tensor>> {
let mut pool = self.pool.write().expect("lock should not be poisoned");
if let Some(entries) = pool.get_mut(shape) {
if let Some(mut entry) = entries.pop() {
entry.mark_accessed();
let tensor_size = entry.size_bytes;
*self.current_size.lock().expect("lock should not be poisoned") -= tensor_size;
return Ok(Some(entry.tensor));
}
}
Ok(None)
}
fn create_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
match dtype {
crate::tensor::DType::F32 => Tensor::zeros(shape),
crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
crate::tensor::DType::F16 => Tensor::zeros_f16(shape),
crate::tensor::DType::BF16 => Tensor::zeros_bf16(shape),
crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
crate::tensor::DType::C32 => Tensor::zeros_c32(shape),
crate::tensor::DType::C64 => Tensor::zeros_c64(shape),
crate::tensor::DType::CF16 => Tensor::zeros_cf16(shape),
crate::tensor::DType::CBF16 => Tensor::zeros_cbf16(shape),
_ => Err(TrustformersError::tensor_op_error(
&format!("Tensor creation not implemented for dtype: {:?} - only supported types are F32, F64, F16, BF16, I64, C32, C64, CF16, CBF16", dtype),
"create_tensor"
)),
}
}
fn estimate_tensor_size(&self, tensor: &Tensor) -> usize {
let elements = tensor.shape().iter().product::<usize>();
match tensor {
Tensor::F32(_) => elements * 4, Tensor::F64(_) => elements * 8, Tensor::F16(_) => elements * 2, Tensor::BF16(_) => elements * 2, Tensor::I64(_) => elements * 8, Tensor::C32(_) => elements * 8, Tensor::C64(_) => elements * 16, Tensor::CF16(_) => elements * 4, Tensor::CBF16(_) => elements * 4, #[cfg(feature = "torch")]
Tensor::Torch(_) => elements * 4, #[cfg(feature = "candle")]
Tensor::Candle(_) => elements * 4, #[cfg(all(target_os = "macos", feature = "metal"))]
Tensor::Metal(data) => elements * data.dtype.size_in_bytes(),
#[cfg(feature = "cuda")]
Tensor::CUDA(data) => elements * data.dtype.size_in_bytes(),
Tensor::Sparse(sparse) => {
let nnz = sparse.nnz();
nnz * 4 + nnz * std::mem::size_of::<usize>() },
}
}
fn cleanup_if_needed(&self) -> Result<()> {
let mut last_cleanup = self.last_cleanup.lock().expect("lock should not be poisoned");
let should_cleanup_time = last_cleanup.elapsed() >= self.config.cleanup_interval;
let current_size = *self.current_size.lock().expect("lock should not be poisoned");
let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
let should_cleanup_size = current_size > dynamic_max;
if !should_cleanup_time && !should_cleanup_size {
return Ok(());
}
let mut pool = self.pool.write().expect("lock should not be poisoned");
let mut total_freed = 0;
let mut eviction_count = 0;
let policy = self.config.eviction_policy;
let target_size = (dynamic_max as f64 * 0.85) as usize; let need_to_free = current_size.saturating_sub(target_size);
let mut all_entries: Vec<(Vec<usize>, usize, f64)> = Vec::new();
for (shape, entries) in pool.iter() {
for (idx, entry) in entries.iter().enumerate() {
if entry.ref_count == 0 {
let priority = entry.eviction_priority(policy);
all_entries.push((shape.clone(), idx, priority));
}
}
}
all_entries.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let mut freed_so_far = 0;
let mut shapes_to_remove: Vec<Vec<usize>> = Vec::new();
for (shape, _, _) in all_entries.iter() {
if freed_so_far >= need_to_free {
break;
}
if let Some(entries) = pool.get_mut(shape) {
if let Some(entry) = entries.first() {
if entry.ref_count == 0 {
let size = entry.size_bytes;
freed_so_far += size;
total_freed += size;
eviction_count += 1;
shapes_to_remove.push(shape.clone());
}
}
}
}
for shape in shapes_to_remove {
if let Some(entries) = pool.get_mut(&shape) {
if !entries.is_empty() {
entries.remove(0);
}
}
}
pool.retain(|_, entries| !entries.is_empty());
drop(pool);
{
let mut stats = self.statistics.lock().expect("lock should not be poisoned");
stats.total_evictions += eviction_count;
*stats.evictions_by_policy.entry(format!("{:?}", policy)).or_insert(0) +=
eviction_count;
}
*self.current_size.lock().expect("lock should not be poisoned") -= total_freed;
*last_cleanup = Instant::now();
if self.config.enable_defragmentation {
self.defragment_pool()?;
}
Ok(())
}
fn apply_adaptive_sizing(&self) -> Result<()> {
match self.config.adaptive_strategy {
AdaptiveStrategy::Fixed => Ok(()), AdaptiveStrategy::HitRate => self.adapt_by_hit_rate(),
AdaptiveStrategy::MemoryPressure => self.adapt_by_memory_pressure(),
AdaptiveStrategy::Predictive => self.adapt_by_prediction(),
}
}
fn adapt_by_hit_rate(&self) -> Result<()> {
let stats = self.statistics.lock().expect("lock should not be poisoned");
let hit_rate = stats.hit_rate();
drop(stats);
let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
let target_rate = self.config.target_hit_rate;
if hit_rate < target_rate {
let increase = (*dynamic_max as f64 * 0.1) as usize;
let new_size = (*dynamic_max + increase).min(self.config.max_pool_size);
if new_size > *dynamic_max {
*dynamic_max = new_size;
}
} else if hit_rate > target_rate + 0.1 {
let decrease = (*dynamic_max as f64 * 0.05) as usize;
let new_size = (*dynamic_max - decrease).max(self.config.min_pool_size);
if new_size < *dynamic_max {
*dynamic_max = new_size;
}
}
Ok(())
}
fn adapt_by_memory_pressure(&self) -> Result<()> {
let current_size = *self.current_size.lock().expect("lock should not be poisoned");
let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
let utilization = current_size as f64 / *dynamic_max as f64;
if utilization > 0.9 {
let new_size = (*dynamic_max as f64 * 0.9) as usize;
*dynamic_max = new_size.max(self.config.min_pool_size);
} else if utilization < 0.5 {
let new_size = (*dynamic_max as f64 * 1.1) as usize;
*dynamic_max = new_size.min(self.config.max_pool_size);
}
Ok(())
}
fn adapt_by_prediction(&self) -> Result<()> {
let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
let mut total_recent_accesses = 0;
let recent_window = Duration::from_secs(60);
let now = Instant::now();
for timestamps in patterns.values() {
total_recent_accesses +=
timestamps.iter().filter(|t| now.duration_since(**t) < recent_window).count();
}
drop(patterns);
let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
if total_recent_accesses > 1000 {
let new_size = (*dynamic_max as f64 * 1.15) as usize;
*dynamic_max = new_size.min(self.config.max_pool_size);
} else if total_recent_accesses < 100 {
let new_size = (*dynamic_max as f64 * 0.9) as usize;
*dynamic_max = new_size.max(self.config.min_pool_size);
}
Ok(())
}
fn defragment_pool(&self) -> Result<()> {
let mut pool = self.pool.write().expect("lock should not be poisoned");
for entries in pool.values_mut() {
entries.sort_by_key(|entry| std::cmp::Reverse(entry.access_count));
}
Ok(())
}
pub fn get_stats(&self) -> MemoryPoolStats {
let pool = self.pool.read().expect("lock should not be poisoned");
let current_size = *self.current_size.lock().expect("lock should not be poisoned");
let stats = self.statistics.lock().expect("lock should not be poisoned");
let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
let total_tensors = pool.values().map(|v| v.len()).sum();
let total_shapes = pool.len();
MemoryPoolStats {
total_tensors,
total_shapes,
current_size_bytes: current_size,
max_size_bytes: self.config.max_pool_size,
dynamic_max_size_bytes: dynamic_max,
utilization: current_size as f64 / dynamic_max as f64,
hit_rate: stats.hit_rate(),
miss_rate: stats.miss_rate(),
total_requests: stats.total_requests,
cache_hits: stats.cache_hits,
cache_misses: stats.cache_misses,
total_evictions: stats.total_evictions,
peak_memory_usage_bytes: stats.peak_memory_usage,
eviction_policy: self.config.eviction_policy,
adaptive_strategy: self.config.adaptive_strategy,
}
}
pub fn reset_statistics(&self) {
let mut stats = self.statistics.lock().expect("Lock poisoned");
*stats = PoolStatistics::default();
}
pub fn hit_rate(&self) -> f64 {
let stats = self.statistics.lock().expect("lock should not be poisoned");
stats.hit_rate()
}
pub fn eviction_policy(&self) -> MemoryEvictionPolicy {
self.config.eviction_policy
}
pub fn adaptive_strategy(&self) -> AdaptiveStrategy {
self.config.adaptive_strategy
}
pub fn get_predicted_shapes(&self, window: Duration) -> Vec<Vec<usize>> {
let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
let now = Instant::now();
let mut frequent_shapes: Vec<(Vec<usize>, usize)> = patterns
.iter()
.map(|(shape, timestamps)| {
let count = timestamps.iter().filter(|t| now.duration_since(**t) < window).count();
(shape.clone(), count)
})
.filter(|(_, count)| *count > 0)
.collect();
frequent_shapes.sort_by_key(|item| std::cmp::Reverse(item.1));
frequent_shapes.into_iter().map(|(shape, _)| shape).collect()
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub total_tensors: usize,
pub total_shapes: usize,
pub current_size_bytes: usize,
pub max_size_bytes: usize,
pub dynamic_max_size_bytes: usize,
pub utilization: f64,
pub hit_rate: f64,
pub miss_rate: f64,
pub total_requests: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub total_evictions: usize,
pub peak_memory_usage_bytes: usize,
pub eviction_policy: MemoryEvictionPolicy,
pub adaptive_strategy: AdaptiveStrategy,
}
pub struct MemoryMappedTensor {
file_path: String,
shape: Vec<usize>,
dtype: crate::tensor::DType,
_file: Option<File>,
file_size: u64,
}
impl MemoryMappedTensor {
pub fn new(file_path: String, shape: Vec<usize>, dtype: crate::tensor::DType) -> Result<Self> {
let mut file = File::open(&file_path).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to open file for memory mapping: {}", e),
"mmap_new",
)
})?;
let file_size = file.seek(SeekFrom::End(0)).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to get file size: {}", e),
"mmap_new",
)
})?;
let element_size = dtype.size_in_bytes();
let total_elements: usize = shape.iter().product();
let expected_size = total_elements * element_size;
if file_size != expected_size as u64 {
return Err(TrustformersError::tensor_op_error(
&format!(
"File size {} doesn't match expected tensor size {}",
file_size, expected_size
),
"mmap_new",
));
}
Ok(Self {
file_path,
shape,
dtype,
_file: Some(file),
file_size,
})
}
pub fn load(&self) -> Result<Tensor> {
let mut file = File::open(&self.file_path).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to open file for reading: {}", e),
"mmap_load",
)
})?;
let mut buffer = vec![0u8; self.file_size as usize];
file.read_exact(&mut buffer).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to read file data: {}", e),
"mmap_load",
)
})?;
match self.dtype {
crate::tensor::DType::F32 => {
let float_data = buffer
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<f32>>();
Tensor::from_slice(&float_data, &self.shape)
},
crate::tensor::DType::F64 => {
let float_data = buffer
.chunks_exact(8)
.map(|chunk| {
f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect::<Vec<f64>>();
Tensor::from_slice_f64(&float_data, &self.shape)
},
crate::tensor::DType::I64 => {
let int_data = buffer
.chunks_exact(8)
.map(|chunk| {
i64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect::<Vec<i64>>();
Tensor::from_slice_i64(&int_data, &self.shape)
},
crate::tensor::DType::I32 => {
let int_data = buffer
.chunks_exact(4)
.map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<i32>>();
Tensor::from_slice_i32(&int_data, &self.shape)
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported dtype for memory mapped tensor",
"mmap_load",
)),
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn file_path(&self) -> &str {
&self.file_path
}
}
static MEMORY_MANAGER: std::sync::OnceLock<TensorMemoryPool> = std::sync::OnceLock::new();
pub fn init_memory_manager(config: MemoryConfig) -> Result<()> {
let pool = TensorMemoryPool::new(config);
MEMORY_MANAGER.set(pool).map_err(|_| {
TrustformersError::invalid_input("Memory manager already initialized".to_string())
})?;
Ok(())
}
pub fn get_memory_manager() -> Option<&'static TensorMemoryPool> {
MEMORY_MANAGER.get()
}
pub fn get_tensor(shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
if let Some(manager) = get_memory_manager() {
manager.get_tensor(shape, dtype)
} else {
match dtype {
crate::tensor::DType::F32 => Tensor::zeros(shape),
crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
_ => Err(TrustformersError::tensor_op_error(
"Unsupported dtype",
"get_tensor",
)),
}
}
}
pub fn return_tensor(tensor: Tensor) -> Result<()> {
if let Some(manager) = get_memory_manager() {
manager.return_tensor(tensor)
} else {
Ok(()) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_config_default() {
let config = MemoryConfig::default();
assert!(config.enable_memory_pool);
assert!(config.enable_zero_copy);
assert!(config.enable_mmap);
assert_eq!(config.max_pool_size, 1024 * 1024 * 1024);
}
#[test]
fn test_tensor_pool_creation() {
let config = MemoryConfig::default();
let pool = TensorMemoryPool::new(config);
let stats = pool.get_stats();
assert_eq!(stats.total_tensors, 0);
assert_eq!(stats.current_size_bytes, 0);
}
#[test]
fn test_tensor_pool_get_and_return() -> Result<()> {
let config = MemoryConfig::default();
let pool = TensorMemoryPool::new(config);
let shape = vec![2, 3];
let tensor = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
assert_eq!(tensor.shape(), shape.as_slice());
pool.return_tensor(tensor)?;
let tensor2 = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
assert_eq!(tensor2.shape(), shape.as_slice());
Ok(())
}
#[test]
fn test_zero_copy_tensor_view() -> Result<()> {
let tensor = Arc::new(Tensor::ones(&[10])?);
let view = TensorView::slice(tensor, 2, 8)?;
assert_eq!(view.shape(), &[6]);
let viewed_tensor = view.as_tensor()?;
assert_eq!(viewed_tensor.shape(), &[6]);
Ok(())
}
#[test]
fn test_memory_mapped_tensor() -> Result<()> {
use std::fs::File;
use std::io::Write;
let temp_file = "test_temp.bin";
let data_size = 100 * 100 * std::mem::size_of::<f32>();
let data: Vec<u8> = vec![0; data_size];
{
let mut file = File::create(temp_file).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to create test file: {}", e),
"test_setup",
)
})?;
file.write_all(&data).map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to write test data: {}", e),
"test_setup",
)
})?;
}
let mmap_tensor = MemoryMappedTensor::new(
temp_file.to_string(),
vec![100, 100],
crate::tensor::DType::F32,
)?;
assert_eq!(mmap_tensor.shape(), &[100, 100]);
assert_eq!(mmap_tensor.file_path(), temp_file);
let loaded = mmap_tensor.load()?;
assert_eq!(loaded.shape(), &[100, 100]);
std::fs::remove_file(temp_file).ok();
Ok(())
}
#[test]
fn test_global_memory_manager() -> Result<()> {
let config = MemoryConfig::default();
init_memory_manager(config)?;
let tensor = get_tensor(&[5, 5], crate::tensor::DType::F32)?;
assert_eq!(tensor.shape(), [5, 5].as_slice());
return_tensor(tensor)?;
Ok(())
}
#[test]
fn test_memory_config_custom_values() {
let config = MemoryConfig {
enable_memory_pool: false,
max_pool_size: 512 * 1024 * 1024,
min_pool_size: 32 * 1024 * 1024,
enable_zero_copy: false,
enable_mmap: false,
mmap_threshold: 50 * 1024 * 1024,
cleanup_interval: Duration::from_secs(30),
eviction_policy: MemoryEvictionPolicy::LRU,
adaptive_strategy: AdaptiveStrategy::Fixed,
target_hit_rate: 0.9,
enable_prefetching: false,
enable_defragmentation: false,
};
assert!(!config.enable_memory_pool);
assert_eq!(config.max_pool_size, 512 * 1024 * 1024);
assert_eq!(config.eviction_policy, MemoryEvictionPolicy::LRU);
assert_eq!(config.adaptive_strategy, AdaptiveStrategy::Fixed);
}
#[test]
fn test_memory_eviction_policy_lru() {
let config = MemoryConfig {
eviction_policy: MemoryEvictionPolicy::LRU,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.eviction_policy(), MemoryEvictionPolicy::LRU);
}
#[test]
fn test_memory_eviction_policy_lfu() {
let config = MemoryConfig {
eviction_policy: MemoryEvictionPolicy::LFU,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.eviction_policy(), MemoryEvictionPolicy::LFU);
}
#[test]
fn test_memory_eviction_policy_size_based() {
let config = MemoryConfig {
eviction_policy: MemoryEvictionPolicy::SizeBased,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.eviction_policy(), MemoryEvictionPolicy::SizeBased);
}
#[test]
fn test_memory_eviction_policy_arc() {
let config = MemoryConfig {
eviction_policy: MemoryEvictionPolicy::ARC,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.eviction_policy(), MemoryEvictionPolicy::ARC);
}
#[test]
fn test_adaptive_strategy_fixed() {
let config = MemoryConfig {
adaptive_strategy: AdaptiveStrategy::Fixed,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.adaptive_strategy(), AdaptiveStrategy::Fixed);
}
#[test]
fn test_adaptive_strategy_memory_pressure() {
let config = MemoryConfig {
adaptive_strategy: AdaptiveStrategy::MemoryPressure,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.adaptive_strategy(), AdaptiveStrategy::MemoryPressure);
}
#[test]
fn test_adaptive_strategy_predictive() {
let config = MemoryConfig {
adaptive_strategy: AdaptiveStrategy::Predictive,
..Default::default()
};
let pool = TensorMemoryPool::new(config);
assert_eq!(pool.adaptive_strategy(), AdaptiveStrategy::Predictive);
}
#[test]
fn test_pool_stats_initial_zero() {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let stats = pool.get_stats();
assert_eq!(stats.total_tensors, 0);
assert_eq!(stats.current_size_bytes, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
}
#[test]
fn test_pool_initial_hit_rate() {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let hr = pool.hit_rate();
assert!(
hr == 0.0 || hr.is_nan(),
"initial hit rate should be 0.0 or NaN, got {hr}"
);
}
#[test]
fn test_pool_multiple_shapes() -> Result<()> {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let t1 = pool.get_tensor(&[2, 3], crate::tensor::DType::F32)?;
let t2 = pool.get_tensor(&[4, 5], crate::tensor::DType::F32)?;
assert_eq!(t1.shape(), &[2, 3]);
assert_eq!(t2.shape(), &[4, 5]);
Ok(())
}
#[test]
fn test_pool_f64_dtype() -> Result<()> {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let t = pool.get_tensor(&[3, 3], crate::tensor::DType::F64)?;
assert_eq!(t.shape(), &[3, 3]);
Ok(())
}
#[test]
fn test_pool_i64_dtype() -> Result<()> {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let t = pool.get_tensor(&[5], crate::tensor::DType::I64)?;
assert_eq!(t.shape(), &[5]);
Ok(())
}
#[test]
fn test_pool_reset_statistics() -> Result<()> {
let pool = TensorMemoryPool::new(MemoryConfig::default());
let t1 = pool.get_tensor(&[2, 2], crate::tensor::DType::F32)?;
pool.return_tensor(t1)?;
let _t2 = pool.get_tensor(&[2, 2], crate::tensor::DType::F32)?;
pool.reset_statistics();
let stats = pool.get_stats();
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
Ok(())
}
#[test]
fn test_tensor_view_slice_middle() -> Result<()> {
let tensor = Arc::new(Tensor::ones(&[10])?);
let view = TensorView::slice(tensor, 3, 7)?;
assert_eq!(view.shape(), &[4]);
Ok(())
}
#[test]
fn test_tensor_view_as_tensor_values() -> Result<()> {
let tensor = Arc::new(Tensor::ones(&[10])?);
let view = TensorView::slice(tensor, 0, 5)?;
let viewed = view.as_tensor()?;
assert_eq!(viewed.shape(), &[5]);
if let Tensor::F32(arr) = &viewed {
for v in arr.iter() {
assert!((*v - 1.0_f32).abs() < 1e-6, "expected 1.0, got {v}");
}
}
Ok(())
}
#[test]
fn test_tensor_view_full_range() -> Result<()> {
let tensor = Arc::new(Tensor::ones(&[8])?);
let view = TensorView::slice(tensor, 0, 8)?;
assert_eq!(view.shape(), &[8]);
Ok(())
}
#[test]
fn test_mmap_shape_stored() -> Result<()> {
use std::io::Write;
let tmp_dir = std::env::temp_dir();
let path = tmp_dir.join("trustformers_mmap_shape_test.bin");
let path_str = path.to_string_lossy().to_string();
let data = vec![0u8; 10 * 20 * std::mem::size_of::<f32>()];
{
let mut f = std::fs::File::create(&path).map_err(|e| {
TrustformersError::tensor_op_error(&e.to_string(), "test_mmap_shape_stored")
})?;
f.write_all(&data).map_err(|e| {
TrustformersError::tensor_op_error(&e.to_string(), "test_mmap_shape_stored")
})?;
}
let mmap =
MemoryMappedTensor::new(path_str.clone(), vec![10, 20], crate::tensor::DType::F32)?;
assert_eq!(mmap.shape(), &[10, 20]);
std::fs::remove_file(&path).ok();
Ok(())
}
#[test]
fn test_mmap_file_path_stored() -> Result<()> {
use std::io::Write;
let tmp_dir = std::env::temp_dir();
let path = tmp_dir.join("trustformers_mmap_path_test.bin");
let path_str = path.to_string_lossy().to_string();
let data = vec![0u8; 4 * std::mem::size_of::<f32>()];
{
let mut f = std::fs::File::create(&path).map_err(|e| {
TrustformersError::tensor_op_error(&e.to_string(), "test_mmap_file_path_stored")
})?;
f.write_all(&data).map_err(|e| {
TrustformersError::tensor_op_error(&e.to_string(), "test_mmap_file_path_stored")
})?;
}
let mmap = MemoryMappedTensor::new(path_str.clone(), vec![4], crate::tensor::DType::F32)?;
assert_eq!(mmap.file_path(), path_str);
std::fs::remove_file(&path).ok();
Ok(())
}
#[test]
fn test_global_get_tensor_without_explicit_init() -> Result<()> {
let tensor = get_tensor(&[3, 3], crate::tensor::DType::F32)?;
assert_eq!(tensor.shape(), &[3, 3]);
Ok(())
}
}