use super::async_file::AsyncSpillFile;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::Mutex;
pub struct AsyncSpillManager {
spill_dir: PathBuf,
next_file_id: AtomicU64,
active_files: Mutex<Vec<PathBuf>>,
total_spilled_bytes: AtomicU64,
}
impl AsyncSpillManager {
pub async fn new(spill_dir: impl Into<PathBuf>) -> std::io::Result<Self> {
let spill_dir = spill_dir.into();
tokio::fs::create_dir_all(&spill_dir).await?;
Ok(Self {
spill_dir,
next_file_id: AtomicU64::new(0),
active_files: Mutex::new(Vec::new()),
total_spilled_bytes: AtomicU64::new(0),
})
}
pub async fn with_temp_dir() -> std::io::Result<Self> {
let temp_dir = std::env::temp_dir().join("grafeo_spill");
Self::new(temp_dir).await
}
#[must_use]
pub fn spill_dir(&self) -> &Path {
&self.spill_dir
}
pub async fn create_file(&self, prefix: &str) -> std::io::Result<AsyncSpillFile> {
let file_id = self.next_file_id.fetch_add(1, Ordering::Relaxed);
let file_name = format!("{prefix}_{file_id}.spill");
let file_path = self.spill_dir.join(file_name);
self.active_files.lock().await.push(file_path.clone());
AsyncSpillFile::new(file_path).await
}
pub fn register_spilled_bytes(&self, bytes: u64) {
self.total_spilled_bytes.fetch_add(bytes, Ordering::Relaxed);
}
pub fn unregister_spilled_bytes(&self, bytes: u64) {
self.total_spilled_bytes.fetch_sub(bytes, Ordering::Relaxed);
}
pub async fn unregister_file(&self, path: &Path) {
let mut files = self.active_files.lock().await;
files.retain(|p| p != path);
}
#[must_use]
pub fn spilled_bytes(&self) -> u64 {
self.total_spilled_bytes.load(Ordering::Relaxed)
}
pub async fn active_file_count(&self) -> usize {
self.active_files.lock().await.len()
}
pub async fn cleanup(&self) -> std::io::Result<()> {
let files = std::mem::take(&mut *self.active_files.lock().await);
let mut last_error = None;
for path in files {
if let Err(e) = tokio::fs::remove_file(&path).await {
if e.kind() != std::io::ErrorKind::NotFound {
last_error = Some(e);
}
}
}
self.total_spilled_bytes.store(0, Ordering::Relaxed);
match last_error {
Some(e) => Err(e),
None => Ok(()),
}
}
fn cleanup_sync(&self) {
if let Ok(mut files) = self.active_files.try_lock() {
let files_to_remove = std::mem::take(&mut *files);
for path in files_to_remove {
let _ = std::fs::remove_file(&path);
}
}
}
}
impl Drop for AsyncSpillManager {
fn drop(&mut self) {
self.cleanup_sync();
}
}
impl std::fmt::Debug for AsyncSpillManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncSpillManager")
.field("spill_dir", &self.spill_dir)
.field("spilled_bytes", &self.spilled_bytes())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_async_manager_creation() {
let temp_dir = TempDir::new().unwrap();
let manager = AsyncSpillManager::new(temp_dir.path()).await.unwrap();
assert_eq!(manager.spilled_bytes(), 0);
assert_eq!(manager.active_file_count().await, 0);
assert_eq!(manager.spill_dir(), temp_dir.path());
}
#[tokio::test]
async fn test_async_create_spill_file() {
let temp_dir = TempDir::new().unwrap();
let manager = AsyncSpillManager::new(temp_dir.path()).await.unwrap();
let file1 = manager.create_file("sort").await.unwrap();
let file2 = manager.create_file("sort").await.unwrap();
let file3 = manager.create_file("agg").await.unwrap();
assert_eq!(manager.active_file_count().await, 3);
assert_ne!(file1.path(), file2.path());
assert!(file1.path().to_str().unwrap().contains("sort_0"));
assert!(file2.path().to_str().unwrap().contains("sort_1"));
assert!(file3.path().to_str().unwrap().contains("agg_2"));
}
#[tokio::test]
async fn test_async_cleanup() {
let temp_dir = TempDir::new().unwrap();
let manager = AsyncSpillManager::new(temp_dir.path()).await.unwrap();
let _file1 = manager.create_file("test").await.unwrap();
let _file2 = manager.create_file("test").await.unwrap();
assert_eq!(manager.active_file_count().await, 2);
manager.cleanup().await.unwrap();
assert_eq!(manager.active_file_count().await, 0);
}
#[tokio::test]
async fn test_async_spilled_bytes_tracking() {
let temp_dir = TempDir::new().unwrap();
let manager = AsyncSpillManager::new(temp_dir.path()).await.unwrap();
manager.register_spilled_bytes(1000);
manager.register_spilled_bytes(500);
assert_eq!(manager.spilled_bytes(), 1500);
manager.unregister_spilled_bytes(300);
assert_eq!(manager.spilled_bytes(), 1200);
}
#[tokio::test]
async fn test_async_cleanup_on_drop() {
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().to_path_buf();
let file_path = {
let manager = AsyncSpillManager::new(&temp_path).await.unwrap();
let file = manager.create_file("test").await.unwrap();
file.path().to_path_buf()
};
assert!(!file_path.exists());
}
}