use super::types::*;
use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use tokio::fs;
#[async_trait::async_trait]
pub trait CheckpointStorage: Send + Sync {
async fn save_checkpoint(&self, checkpoint: &MapReduceCheckpoint) -> Result<()>;
async fn load_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<MapReduceCheckpoint>;
async fn list_checkpoints(&self, job_id: &str) -> Result<Vec<CheckpointInfo>>;
async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()>;
async fn checkpoint_exists(&self, checkpoint_id: &CheckpointId) -> Result<bool>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompressionAlgorithm {
None,
Gzip,
Zstd,
Lz4,
}
impl CompressionAlgorithm {
pub fn extension(&self) -> &str {
match self {
Self::None => "json",
Self::Gzip => "gz",
Self::Zstd => "zst",
Self::Lz4 => "lz4",
}
}
pub async fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
match self {
Self::None => Ok(data.to_vec()),
Self::Gzip => {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(data)?;
Ok(encoder.finish()?)
}
Self::Zstd => {
let level = 3; zstd::encode_all(data, level).context("Failed to compress with zstd")
}
Self::Lz4 => {
lz4::block::compress(data, None, true).context("Failed to compress with lz4")
}
}
}
pub async fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
match self {
Self::None => Ok(data.to_vec()),
Self::Gzip => {
use flate2::read::GzDecoder;
use std::io::Read;
let mut decoder = GzDecoder::new(data);
let mut result = Vec::new();
decoder.read_to_end(&mut result)?;
Ok(result)
}
Self::Zstd => zstd::decode_all(data).context("Failed to decompress with zstd"),
Self::Lz4 => {
lz4::block::decompress(data, None).context("Failed to decompress with lz4")
}
}
}
}
impl Default for CompressionAlgorithm {
fn default() -> Self {
Self::Gzip
}
}
pub struct FileCheckpointStorage {
base_path: PathBuf,
compression_algorithm: CompressionAlgorithm,
}
impl FileCheckpointStorage {
pub fn new(base_path: PathBuf, compression_enabled: bool) -> Self {
Self {
base_path,
compression_algorithm: if compression_enabled {
CompressionAlgorithm::default()
} else {
CompressionAlgorithm::None
},
}
}
pub fn with_compression(base_path: PathBuf, algorithm: CompressionAlgorithm) -> Self {
Self {
base_path,
compression_algorithm: algorithm,
}
}
fn checkpoint_path(&self, checkpoint_id: &CheckpointId) -> PathBuf {
let extension = format!("checkpoint.{}", self.compression_algorithm.extension());
self.base_path
.join(format!("{}.{}", checkpoint_id, extension))
}
async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
self.compression_algorithm.compress(data).await
}
async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
self.compression_algorithm.decompress(data).await
}
}
#[async_trait::async_trait]
impl CheckpointStorage for FileCheckpointStorage {
async fn save_checkpoint(&self, checkpoint: &MapReduceCheckpoint) -> Result<()> {
let checkpoint_id = CheckpointId::from_string(checkpoint.metadata.checkpoint_id.clone());
let path = self.checkpoint_path(&checkpoint_id);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let json = serde_json::to_vec_pretty(checkpoint)?;
let data = self.compress_data(&json).await?;
let temp_path = path.with_extension("tmp");
fs::write(&temp_path, &data).await?;
fs::rename(&temp_path, &path).await?;
Ok(())
}
async fn load_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<MapReduceCheckpoint> {
let path = self.checkpoint_path(checkpoint_id);
if !path.exists() {
return Err(anyhow!("Checkpoint {} not found", checkpoint_id));
}
let data = fs::read(&path).await?;
let json = self.decompress_data(&data).await?;
let checkpoint: MapReduceCheckpoint = serde_json::from_slice(&json)?;
Ok(checkpoint)
}
async fn list_checkpoints(&self, job_id: &str) -> Result<Vec<CheckpointInfo>> {
let mut checkpoints = Vec::new();
if !self.base_path.exists() {
return Ok(checkpoints);
}
let mut entries = fs::read_dir(&self.base_path).await?;
while let Some(entry) = entries.next_entry().await? {
if let Some(name) = entry.file_name().to_str() {
if name.contains(".checkpoint") {
if let Some(checkpoint_id_str) = name.split('.').next() {
let checkpoint_id =
CheckpointId::from_string(checkpoint_id_str.to_string());
if let Ok(checkpoint) = self.load_checkpoint(&checkpoint_id).await {
if checkpoint.metadata.job_id == job_id {
checkpoints.push(CheckpointInfo {
id: checkpoint.metadata.checkpoint_id,
job_id: checkpoint.metadata.job_id,
created_at: checkpoint.metadata.created_at,
phase: checkpoint.metadata.phase,
completed_items: checkpoint.metadata.completed_items,
total_items: checkpoint.metadata.total_work_items,
is_final: checkpoint.metadata.phase == PhaseType::Complete,
});
}
}
}
}
}
}
Ok(checkpoints)
}
async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()> {
let path = self.checkpoint_path(checkpoint_id);
if path.exists() {
fs::remove_file(path).await?;
}
Ok(())
}
async fn checkpoint_exists(&self, checkpoint_id: &CheckpointId) -> Result<bool> {
Ok(self.checkpoint_path(checkpoint_id).exists())
}
}