use crate::common::RusTorchResult;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock, RwLock};
use std::thread;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct GCConfig {
pub gc_interval: Duration,
pub memory_threshold: usize,
pub usage_threshold: f32,
pub forced_gc_cooldown: Duration,
pub enable_auto_gc: bool,
pub enable_generational: bool,
pub young_generation_limit: u32,
}
impl Default for GCConfig {
fn default() -> Self {
Self {
gc_interval: Duration::from_secs(30),
memory_threshold: 1024 * 1024 * 1024, usage_threshold: 0.8,
forced_gc_cooldown: Duration::from_secs(5),
enable_auto_gc: true,
enable_generational: true,
young_generation_limit: 3,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct GCStats {
pub total_collections: u64,
pub last_collection_time: Option<Instant>,
pub last_collection_duration: Option<Duration>,
pub total_gc_time: Duration,
pub memory_reclaimed: usize,
pub objects_reclaimed: u64,
pub young_gc_count: u64,
pub old_gc_count: u64,
pub average_gc_time: Duration,
}
impl GCStats {
pub fn update_collection(
&mut self,
duration: Duration,
memory_freed: usize,
objects_freed: u64,
is_young: bool,
) {
self.total_collections += 1;
self.last_collection_time = Some(Instant::now());
self.last_collection_duration = Some(duration);
self.total_gc_time += duration;
self.memory_reclaimed += memory_freed;
self.objects_reclaimed += objects_freed;
if is_young {
self.young_gc_count += 1;
} else {
self.old_gc_count += 1;
}
if self.total_collections > 0 {
self.average_gc_time = self.total_gc_time / self.total_collections as u32;
}
}
pub fn generate_report(&self) -> String {
format!(
"GC Statistics Report:\n\
- Total Collections: {}\n\
- Young Generation GCs: {}\n\
- Old Generation GCs: {}\n\
- Total Memory Reclaimed: {:.2} MB\n\
- Total Objects Reclaimed: {}\n\
- Total GC Time: {:.2}s\n\
- Average GC Time: {:.2}ms\n\
- Last GC Duration: {}\n",
self.total_collections,
self.young_gc_count,
self.old_gc_count,
self.memory_reclaimed as f64 / 1024.0 / 1024.0,
self.objects_reclaimed,
self.total_gc_time.as_secs_f64(),
self.average_gc_time.as_millis(),
self.last_collection_duration
.map(|d| format!("{:.2}ms", d.as_millis()))
.unwrap_or_else(|| "N/A".to_string())
)
}
}
pub trait GCObject: Send + Sync + std::fmt::Debug {
fn size(&self) -> usize;
fn last_accessed(&self) -> Instant;
fn survival_count(&self) -> u32;
fn increment_survival(&mut self);
fn is_marked(&self) -> bool;
fn set_marked(&mut self, marked: bool);
fn references(&self) -> Vec<u64>;
}
#[derive(Debug)]
struct GCEntry {
object: Box<dyn GCObject>,
id: u64,
generation: u8, created_at: Instant,
}
pub struct GarbageCollector {
config: GCConfig,
objects: Arc<Mutex<HashMap<u64, GCEntry>>>,
stats: Arc<RwLock<GCStats>>,
next_id: Arc<Mutex<u64>>,
last_gc: Arc<Mutex<Instant>>,
gc_thread: Option<thread::JoinHandle<()>>,
shutdown_signal: Arc<Mutex<bool>>,
}
impl GarbageCollector {
pub fn new(config: GCConfig) -> Self {
let collector = Self {
config: config.clone(),
objects: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(RwLock::new(GCStats::default())),
next_id: Arc::new(Mutex::new(1)),
last_gc: Arc::new(Mutex::new(Instant::now())),
gc_thread: None,
shutdown_signal: Arc::new(Mutex::new(false)),
};
collector
}
pub fn start_auto_gc(&mut self) -> RusTorchResult<()> {
if !self.config.enable_auto_gc {
return Ok(());
}
let config = self.config.clone();
let objects = Arc::clone(&self.objects);
let stats = Arc::clone(&self.stats);
let last_gc = Arc::clone(&self.last_gc);
let shutdown = Arc::clone(&self.shutdown_signal);
let handle = thread::spawn(move || {
loop {
{
let shutdown_guard = shutdown.lock().unwrap();
if *shutdown_guard {
break;
}
}
let should_gc = {
let last = last_gc.lock().unwrap();
last.elapsed() >= config.gc_interval
};
if should_gc {
let _ = Self::perform_gc_internal(&config, &objects, &stats, &last_gc);
}
thread::sleep(Duration::from_millis(100));
}
});
self.gc_thread = Some(handle);
Ok(())
}
pub fn register_object(&self, object: Box<dyn GCObject>) -> u64 {
let id = {
let mut next_id = self.next_id.lock().unwrap();
let id = *next_id;
*next_id += 1;
id
};
let entry = GCEntry {
object,
id,
generation: 0, created_at: Instant::now(),
};
let mut objects = self.objects.lock().unwrap();
objects.insert(id, entry);
id
}
pub fn unregister_object(&self, id: u64) -> bool {
let mut objects = self.objects.lock().unwrap();
objects.remove(&id).is_some()
}
pub fn collect(&self) -> RusTorchResult<()> {
Self::perform_gc_internal(&self.config, &self.objects, &self.stats, &self.last_gc)
}
pub fn force_collect(&self) -> RusTorchResult<()> {
self.collect()?;
thread::sleep(self.config.forced_gc_cooldown);
Ok(())
}
pub fn get_stats(&self) -> GCStats {
self.stats.read().unwrap().clone()
}
pub fn get_memory_usage(&self) -> usize {
let objects = self.objects.lock().unwrap();
objects.values().map(|entry| entry.object.size()).sum()
}
pub fn get_object_count(&self) -> usize {
let objects = self.objects.lock().unwrap();
objects.len()
}
pub fn get_generational_stats(&self) -> (usize, usize) {
let objects = self.objects.lock().unwrap();
let mut young_count = 0;
let mut old_count = 0;
for entry in objects.values() {
if entry.generation == 0 {
young_count += 1;
} else {
old_count += 1;
}
}
(young_count, old_count)
}
fn perform_gc_internal(
config: &GCConfig,
objects: &Arc<Mutex<HashMap<u64, GCEntry>>>,
stats: &Arc<RwLock<GCStats>>,
last_gc: &Arc<Mutex<Instant>>,
) -> RusTorchResult<()> {
let start_time = Instant::now();
let mut memory_freed = 0;
let mut objects_freed = 0;
{
let mut objects_guard = objects.lock().unwrap();
let current_memory = objects_guard
.values()
.map(|e| e.object.size())
.sum::<usize>();
if current_memory < config.memory_threshold {
return Ok(());
}
if config.enable_generational {
Self::perform_generational_gc(
&mut objects_guard,
config,
&mut memory_freed,
&mut objects_freed,
)?;
} else {
Self::perform_full_gc(&mut objects_guard, &mut memory_freed, &mut objects_freed)?;
}
}
let duration = start_time.elapsed();
{
let mut stats_guard = stats.write().unwrap();
stats_guard.update_collection(
duration,
memory_freed,
objects_freed,
config.enable_generational,
);
}
{
let mut last_gc_guard = last_gc.lock().unwrap();
*last_gc_guard = Instant::now();
}
Ok(())
}
fn perform_generational_gc(
objects: &mut HashMap<u64, GCEntry>,
config: &GCConfig,
memory_freed: &mut usize,
objects_freed: &mut u64,
) -> RusTorchResult<()> {
let now = Instant::now();
let mut to_remove = Vec::new();
let mut to_promote = Vec::new();
for (id, entry) in objects.iter_mut() {
if entry.generation == 0 {
if entry.object.last_accessed().elapsed() > Duration::from_secs(60) {
*memory_freed += entry.object.size();
*objects_freed += 1;
to_remove.push(*id);
} else if entry.object.survival_count() >= config.young_generation_limit {
to_promote.push(*id);
}
entry.object.increment_survival();
}
}
for id in to_promote {
if let Some(entry) = objects.get_mut(&id) {
entry.generation = 1;
}
}
for id in to_remove {
objects.remove(&id);
}
let old_memory: usize = objects
.values()
.filter(|e| e.generation == 1)
.map(|e| e.object.size())
.sum();
if old_memory > config.memory_threshold / 2 {
Self::perform_old_generation_gc(objects, memory_freed, objects_freed)?;
}
Ok(())
}
fn perform_old_generation_gc(
objects: &mut HashMap<u64, GCEntry>,
memory_freed: &mut usize,
objects_freed: &mut u64,
) -> RusTorchResult<()> {
let mut to_remove = Vec::new();
for (id, entry) in objects.iter() {
if entry.generation == 1 {
if entry.object.last_accessed().elapsed() > Duration::from_secs(300) {
*memory_freed += entry.object.size();
*objects_freed += 1;
to_remove.push(*id);
}
}
}
for id in to_remove {
objects.remove(&id);
}
Ok(())
}
fn perform_full_gc(
objects: &mut HashMap<u64, GCEntry>,
memory_freed: &mut usize,
objects_freed: &mut u64,
) -> RusTorchResult<()> {
let mut to_remove = Vec::new();
for (id, entry) in objects.iter() {
if entry.object.last_accessed().elapsed() > Duration::from_secs(120) {
*memory_freed += entry.object.size();
*objects_freed += 1;
to_remove.push(*id);
}
}
for id in to_remove {
objects.remove(&id);
}
Ok(())
}
}
impl Drop for GarbageCollector {
fn drop(&mut self) {
{
let mut shutdown = self.shutdown_signal.lock().unwrap();
*shutdown = true;
}
if let Some(handle) = self.gc_thread.take() {
let _ = handle.join();
}
}
}
static GLOBAL_GC: OnceLock<Arc<Mutex<GarbageCollector>>> = OnceLock::new();
pub fn init_global_gc(config: GCConfig) -> RusTorchResult<()> {
let mut gc = GarbageCollector::new(config);
let _ = gc.start_auto_gc();
let gc_arc = Arc::new(Mutex::new(gc));
let _ = GLOBAL_GC.set(gc_arc);
Ok(())
}
pub fn get_global_gc() -> Option<Arc<Mutex<GarbageCollector>>> {
GLOBAL_GC.get().cloned()
}
pub fn gc_register(object: Box<dyn GCObject>) -> Option<u64> {
get_global_gc()?.lock().ok()?.register_object(object).into()
}
pub fn gc_unregister(id: u64) -> bool {
if let Some(gc_arc) = get_global_gc() {
if let Ok(gc) = gc_arc.lock() {
return gc.unregister_object(id);
}
}
false
}
pub fn gc_collect() -> RusTorchResult<()> {
if let Some(gc) = get_global_gc() {
if let Ok(gc) = gc.lock() {
return gc.collect();
}
}
Ok(())
}
pub fn gc_stats() -> Option<GCStats> {
if let Some(gc_arc) = get_global_gc() {
if let Ok(gc) = gc_arc.lock() {
return Some(gc.get_stats());
}
}
None
}