use crate::fsutil;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
pub const PARALLEL_STATE_SCHEMA_VERSION: u32 = 3;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
#[serde(rename_all = "snake_case")]
pub enum WorkerLifecycle {
#[default]
Running,
Integrating,
Completed,
Failed,
BlockedPush,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerRecord {
pub task_id: String,
pub workspace_path: PathBuf,
#[serde(default)]
pub lifecycle: WorkerLifecycle,
pub started_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<String>,
#[serde(default)]
pub push_attempts: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_error: Option<String>,
}
impl WorkerRecord {
pub fn new(task_id: impl Into<String>, workspace_path: PathBuf, started_at: String) -> Self {
Self {
task_id: task_id.into(),
workspace_path,
lifecycle: WorkerLifecycle::Running,
started_at,
completed_at: None,
push_attempts: 0,
last_error: None,
}
}
pub fn start_integration(&mut self) {
self.lifecycle = WorkerLifecycle::Integrating;
}
pub fn mark_completed(&mut self, timestamp: String) {
self.lifecycle = WorkerLifecycle::Completed;
self.completed_at = Some(timestamp);
}
pub fn mark_failed(&mut self, timestamp: String, error: impl Into<String>) {
self.lifecycle = WorkerLifecycle::Failed;
self.completed_at = Some(timestamp);
self.last_error = Some(error.into());
}
pub fn mark_blocked(&mut self, timestamp: String, error: impl Into<String>) {
self.lifecycle = WorkerLifecycle::BlockedPush;
self.completed_at = Some(timestamp);
self.last_error = Some(error.into());
}
pub fn increment_push_attempt(&mut self) {
self.push_attempts += 1;
}
pub fn is_terminal(&self) -> bool {
matches!(
self.lifecycle,
WorkerLifecycle::Completed | WorkerLifecycle::Failed | WorkerLifecycle::BlockedPush
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelStateFile {
#[serde(default = "default_schema_version")]
pub schema_version: u32,
#[serde(default)]
pub started_at: String,
#[serde(default)]
pub target_branch: String,
#[serde(default)]
pub workers: Vec<WorkerRecord>,
}
fn default_schema_version() -> u32 {
1
}
impl ParallelStateFile {
pub fn new(started_at: impl Into<String>, target_branch: impl Into<String>) -> Self {
Self {
schema_version: PARALLEL_STATE_SCHEMA_VERSION,
started_at: started_at.into(),
target_branch: target_branch.into(),
workers: Vec::new(),
}
}
pub fn upsert_worker(&mut self, record: WorkerRecord) {
if let Some(existing) = self
.workers
.iter_mut()
.find(|w| w.task_id == record.task_id)
{
*existing = record;
} else {
self.workers.push(record);
}
}
pub fn remove_worker(&mut self, task_id: &str) {
self.workers.retain(|w| w.task_id != task_id);
}
pub fn get_worker(&self, task_id: &str) -> Option<&WorkerRecord> {
self.workers.iter().find(|w| w.task_id == task_id)
}
pub fn get_worker_mut(&mut self, task_id: &str) -> Option<&mut WorkerRecord> {
self.workers.iter_mut().find(|w| w.task_id == task_id)
}
pub fn has_worker(&self, task_id: &str) -> bool {
self.workers.iter().any(|w| w.task_id == task_id)
}
pub fn workers_by_lifecycle(
&self,
lifecycle: WorkerLifecycle,
) -> impl Iterator<Item = &WorkerRecord> {
self.workers
.iter()
.filter(move |w| w.lifecycle == lifecycle)
}
pub fn active_worker_count(&self) -> usize {
self.workers.iter().filter(|w| !w.is_terminal()).count()
}
pub fn blocked_worker_count(&self) -> usize {
self.workers_by_lifecycle(WorkerLifecycle::BlockedPush)
.count()
}
}
pub fn state_file_path(repo_root: &Path) -> PathBuf {
repo_root.join(".ralph/cache/parallel/state.json")
}
fn migrate_state(mut state: ParallelStateFile) -> ParallelStateFile {
if state.schema_version < PARALLEL_STATE_SCHEMA_VERSION {
log::info!(
"Migrating parallel state from schema v{} to v{}",
state.schema_version,
PARALLEL_STATE_SCHEMA_VERSION
);
state.schema_version = PARALLEL_STATE_SCHEMA_VERSION;
state.workers.clear();
}
state
}
pub fn load_state(path: &Path) -> Result<Option<ParallelStateFile>> {
if !path.exists() {
return Ok(None);
}
let raw = std::fs::read_to_string(path)
.with_context(|| format!("read parallel state {}", path.display()))?;
let state: ParallelStateFile =
crate::jsonc::parse_jsonc::<ParallelStateFile>(&raw, "parallel state")?;
let state = migrate_state(state);
Ok(Some(state))
}
pub fn save_state(path: &Path, state: &ParallelStateFile) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create parallel state dir {}", parent.display()))?;
}
let rendered = serde_json::to_string_pretty(state).context("serialize parallel state")?;
fsutil::write_atomic(path, rendered.as_bytes())
.with_context(|| format!("write parallel state {}", path.display()))?;
Ok(())
}
#[cfg(test)]
#[path = "state/tests.rs"]
mod tests;