use crate::error::MullamaError;
use crate::sys;
#[allow(dead_code)]
pub struct MemoryManager {
memory_ptr: sys::llama_memory_t,
owned: bool,
stats: MemoryStats,
}
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub clears: u64,
pub seq_removals: u64,
pub seq_copies: u64,
pub pos_shifts: u64,
pub active_sequences: usize,
}
#[derive(Debug, Clone)]
pub struct SequenceInfo {
pub seq_id: i32,
pub pos_min: i32,
pub pos_max: i32,
pub token_count: usize,
}
impl MemoryManager {
pub fn new() -> Self {
Self::default()
}
pub(crate) fn from_ptr(ptr: sys::llama_memory_t, owned: bool) -> Self {
Self {
memory_ptr: ptr,
owned,
stats: MemoryStats::default(),
}
}
pub unsafe fn from_context(ctx_ptr: *mut sys::llama_context) -> Option<Self> {
let memory_ptr = sys::llama_get_memory(ctx_ptr);
if memory_ptr.is_null() {
None
} else {
Some(Self::from_ptr(memory_ptr, false))
}
}
pub fn is_valid(&self) -> bool {
!self.memory_ptr.is_null()
}
pub fn clear(&mut self, clear_data: bool) -> Result<(), MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
unsafe {
sys::llama_memory_clear(self.memory_ptr, clear_data);
}
self.stats.clears += 1;
self.stats.active_sequences = 0;
Ok(())
}
pub fn remove_sequence_tokens(
&mut self,
seq_id: i32,
pos_start: i32,
pos_end: i32,
) -> Result<bool, MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
let result =
unsafe { sys::llama_memory_seq_rm(self.memory_ptr, seq_id, pos_start, pos_end) };
self.stats.seq_removals += 1;
Ok(result)
}
pub fn copy_sequence_tokens(
&mut self,
src_seq_id: i32,
dst_seq_id: i32,
pos_start: i32,
pos_end: i32,
) -> Result<(), MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
unsafe {
sys::llama_memory_seq_cp(self.memory_ptr, src_seq_id, dst_seq_id, pos_start, pos_end);
}
self.stats.seq_copies += 1;
Ok(())
}
pub fn keep_sequence(&mut self, seq_id: i32) -> Result<(), MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
unsafe {
sys::llama_memory_seq_keep(self.memory_ptr, seq_id);
}
self.stats.active_sequences = 1;
Ok(())
}
pub fn shift_positions(
&mut self,
seq_id: i32,
pos_start: i32,
pos_end: i32,
delta: i32,
) -> Result<(), MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
unsafe {
sys::llama_memory_seq_add(self.memory_ptr, seq_id, pos_start, pos_end, delta);
}
self.stats.pos_shifts += 1;
Ok(())
}
pub fn divide_positions(
&mut self,
seq_id: i32,
pos_start: i32,
pos_end: i32,
divisor: i32,
) -> Result<(), MullamaError> {
if self.memory_ptr.is_null() {
return Err(MullamaError::MemoryError(
"Invalid memory handle".to_string(),
));
}
if divisor <= 1 {
return Err(MullamaError::InvalidInput(
"Divisor must be greater than 1".to_string(),
));
}
unsafe {
sys::llama_memory_seq_div(self.memory_ptr, seq_id, pos_start, pos_end, divisor);
}
Ok(())
}
pub fn get_min_position(&self, seq_id: i32) -> i32 {
if self.memory_ptr.is_null() {
return -1;
}
unsafe { sys::llama_memory_seq_pos_min(self.memory_ptr, seq_id) }
}
pub fn get_max_position(&self, seq_id: i32) -> i32 {
if self.memory_ptr.is_null() {
return -1;
}
unsafe { sys::llama_memory_seq_pos_max(self.memory_ptr, seq_id) }
}
pub fn can_shift(&self) -> bool {
if self.memory_ptr.is_null() {
return false;
}
unsafe { sys::llama_memory_can_shift(self.memory_ptr) }
}
pub fn get_sequence_info(&self, seq_id: i32) -> Option<SequenceInfo> {
if self.memory_ptr.is_null() {
return None;
}
let pos_min = self.get_min_position(seq_id);
let pos_max = self.get_max_position(seq_id);
if pos_min < 0 || pos_max < 0 {
return None;
}
Some(SequenceInfo {
seq_id,
pos_min,
pos_max,
token_count: (pos_max - pos_min + 1) as usize,
})
}
pub fn stats(&self) -> &MemoryStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = MemoryStats::default();
}
pub fn as_ptr(&self) -> sys::llama_memory_t {
self.memory_ptr
}
pub fn context_shift(&mut self, seq_id: i32, keep_count: i32) -> Result<(), MullamaError> {
let pos_max = self.get_max_position(seq_id);
if pos_max < 0 {
return Ok(()); }
let pos_min = self.get_min_position(seq_id);
let total_tokens = pos_max - pos_min + 1;
if total_tokens <= keep_count {
return Ok(()); }
let remove_count = total_tokens - keep_count;
let remove_end = pos_min + remove_count;
self.remove_sequence_tokens(seq_id, pos_min, remove_end)?;
self.shift_positions(seq_id, remove_end, -1, -remove_count)?;
Ok(())
}
pub fn fork_sequence(
&mut self,
src_seq_id: i32,
dst_seq_id: i32,
) -> Result<SequenceInfo, MullamaError> {
let pos_min = self.get_min_position(src_seq_id);
let pos_max = self.get_max_position(src_seq_id);
if pos_min < 0 || pos_max < 0 {
return Err(MullamaError::MemoryError(format!(
"Source sequence {} is empty",
src_seq_id
)));
}
self.copy_sequence_tokens(src_seq_id, dst_seq_id, pos_min, pos_max + 1)?;
Ok(SequenceInfo {
seq_id: dst_seq_id,
pos_min,
pos_max,
token_count: (pos_max - pos_min + 1) as usize,
})
}
pub fn truncate_sequence(&mut self, seq_id: i32, max_length: i32) -> Result<(), MullamaError> {
let pos_min = self.get_min_position(seq_id);
let pos_max = self.get_max_position(seq_id);
if pos_min < 0 || pos_max < 0 {
return Ok(()); }
let current_length = pos_max - pos_min + 1;
if current_length <= max_length {
return Ok(()); }
let truncate_start = pos_min + max_length;
self.remove_sequence_tokens(seq_id, truncate_start, -1)?;
Ok(())
}
}
impl Default for MemoryManager {
fn default() -> Self {
Self {
memory_ptr: std::ptr::null_mut(),
owned: false,
stats: MemoryStats::default(),
}
}
}
pub struct KVCacheManager {
ctx_ptr: *mut sys::llama_context,
}
impl KVCacheManager {
pub fn new(ctx_ptr: *mut sys::llama_context) -> Self {
Self { ctx_ptr }
}
fn get_memory(&self) -> sys::llama_memory_t {
unsafe { sys::llama_get_memory(self.ctx_ptr) }
}
pub fn clear(&mut self) {
unsafe {
let mem = self.get_memory();
sys::llama_memory_clear(mem, false);
}
}
pub fn seq_rm(&mut self, seq_id: i32, p0: i32, p1: i32) -> bool {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_rm(mem, seq_id, p0, p1)
}
}
pub fn seq_cp(&mut self, seq_id_src: i32, seq_id_dst: i32, p0: i32, p1: i32) {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1);
}
}
pub fn seq_keep(&mut self, seq_id: i32) {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_keep(mem, seq_id);
}
}
pub fn seq_add(&mut self, seq_id: i32, p0: i32, p1: i32, delta: i32) {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_add(mem, seq_id, p0, p1, delta);
}
}
pub fn seq_div(&mut self, seq_id: i32, p0: i32, p1: i32, d: i32) {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_div(mem, seq_id, p0, p1, d);
}
}
pub fn seq_pos_min(&self, seq_id: i32) -> i32 {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_pos_min(mem, seq_id)
}
}
pub fn seq_pos_max(&self, seq_id: i32) -> i32 {
unsafe {
let mem = self.get_memory();
sys::llama_memory_seq_pos_max(mem, seq_id)
}
}
pub fn can_shift(&self) -> bool {
unsafe {
let mem = self.get_memory();
sys::llama_memory_can_shift(mem)
}
}
}
#[derive(Debug, Clone)]
pub struct ConstrainedMemoryConfig {
pub use_mmap: bool,
pub use_mlock: bool,
pub cache_type_k: crate::context::KvCacheType,
pub cache_type_v: crate::context::KvCacheType,
pub context_size: u32,
pub batch_size: u32,
pub gpu_layers: i32,
pub model_larger_than_ram: bool,
pub estimated_memory_bytes: u64,
pub available_memory_bytes: u64,
}
pub fn recommend_constrained_config(
model_size_bytes: u64,
context_size: u32,
n_layers: i32,
n_embd: i32,
available_ram_bytes: u64,
gpu_layers_requested: i32,
) -> ConstrainedMemoryConfig {
let monitor = crate::memory_monitor::MemoryMonitor::with_defaults();
monitor.update_stats();
let (mem_used, mem_total) = monitor.system_memory();
let available = if available_ram_bytes > 0 {
available_ram_bytes
} else {
mem_total.saturating_sub(mem_used)
};
let model_larger_than_ram = model_size_bytes > available;
let kv_cache_f16 = 2u64 * n_layers as u64 * n_embd as u64 * context_size as u64 * 2;
let kv_cache_q8 = kv_cache_f16 / 2;
let kv_cache_q4 = kv_cache_f16 / 4;
let overhead = 512u64 * 1024 * 1024; let total_f16 = model_size_bytes + kv_cache_f16 + overhead;
let total_q8 = model_size_bytes + kv_cache_q8 + overhead;
let total_q4 = model_size_bytes + kv_cache_q4 + overhead;
let (cache_type_k, cache_type_v, estimated_memory) =
if total_q4 <= available || model_larger_than_ram {
(
crate::context::KvCacheType::Q4_0,
crate::context::KvCacheType::Q4_0,
total_q4,
)
} else if total_q8 <= available {
(
crate::context::KvCacheType::Q8_0,
crate::context::KvCacheType::Q8_0,
total_q8,
)
} else {
(
crate::context::KvCacheType::F16,
crate::context::KvCacheType::F16,
total_f16,
)
};
let adjusted_context = if !model_larger_than_ram && estimated_memory > available {
let ratio = available as f64 / estimated_memory as f64;
((context_size as f64 * ratio * 0.8) as u32).max(512)
} else {
context_size
};
let gpu_layers = if model_larger_than_ram {
if gpu_layers_requested < 0 {
(n_layers / 2).min(20) } else {
gpu_layers_requested.min(n_layers / 2)
}
} else {
gpu_layers_requested
};
let use_mmap = true;
let use_mlock = !model_larger_than_ram && total_f16 <= available;
ConstrainedMemoryConfig {
use_mmap,
use_mlock,
cache_type_k,
cache_type_v,
context_size: adjusted_context,
batch_size: if model_larger_than_ram { 256 } else { 512 },
gpu_layers,
model_larger_than_ram,
estimated_memory_bytes: estimated_memory,
available_memory_bytes: available,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_stats_default() {
let stats = MemoryStats::default();
assert_eq!(stats.clears, 0);
assert_eq!(stats.seq_removals, 0);
assert_eq!(stats.seq_copies, 0);
assert_eq!(stats.pos_shifts, 0);
assert_eq!(stats.active_sequences, 0);
}
#[test]
fn test_sequence_info() {
let info = SequenceInfo {
seq_id: 0,
pos_min: 0,
pos_max: 99,
token_count: 100,
};
assert_eq!(info.seq_id, 0);
assert_eq!(info.token_count, 100);
}
#[test]
fn test_memory_manager_default() {
let manager = MemoryManager::default();
assert!(!manager.is_valid());
assert_eq!(manager.get_min_position(0), -1);
assert_eq!(manager.get_max_position(0), -1);
assert!(!manager.can_shift());
}
}