use super::checkpoint::{CheckpointStore, SessionCheckpoint};
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::fs;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForkInfo {
pub fork_id: String,
pub parent_session_id: String,
pub checkpoint_id: String,
pub branch_name: Option<String>,
pub created_at: DateTime<Utc>,
pub status: ForkStatus,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ForkStatus {
Creating,
Active,
Merged,
Abandoned,
Archived,
}
impl std::fmt::Display for ForkStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ForkStatus::Creating => write!(f, "creating"),
ForkStatus::Active => write!(f, "active"),
ForkStatus::Merged => write!(f, "merged"),
ForkStatus::Abandoned => write!(f, "abandoned"),
ForkStatus::Archived => write!(f, "archived"),
}
}
}
impl ForkInfo {
pub fn new(
parent_session_id: impl Into<String>,
checkpoint_id: impl Into<String>,
branch_name: Option<String>,
) -> Self {
Self {
fork_id: uuid::Uuid::new_v4().to_string(),
parent_session_id: parent_session_id.into(),
checkpoint_id: checkpoint_id.into(),
branch_name,
created_at: Utc::now(),
status: ForkStatus::Creating,
metadata: HashMap::new(),
}
}
pub fn is_active(&self) -> bool {
self.status == ForkStatus::Active
}
pub fn is_usable(&self) -> bool {
matches!(self.status, ForkStatus::Active | ForkStatus::Creating)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForkRegistry {
pub forks: HashMap<String, ForkInfo>,
pub children: HashMap<String, Vec<String>>,
}
impl ForkRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, fork: ForkInfo) {
let fork_id = fork.fork_id.clone();
let parent_id = fork.parent_session_id.clone();
self.forks.insert(fork_id.clone(), fork);
self.children.entry(parent_id).or_default().push(fork_id);
}
pub fn get(&self, fork_id: &str) -> Option<&ForkInfo> {
self.forks.get(fork_id)
}
pub fn get_mut(&mut self, fork_id: &str) -> Option<&mut ForkInfo> {
self.forks.get_mut(fork_id)
}
pub fn get_children(&self, parent_session_id: &str) -> Vec<&ForkInfo> {
self.children
.get(parent_session_id)
.map(|ids| ids.iter().filter_map(|id| self.forks.get(id)).collect())
.unwrap_or_default()
}
pub fn get_active_children(&self, parent_session_id: &str) -> Vec<&ForkInfo> {
self.get_children(parent_session_id)
.into_iter()
.filter(|f| f.is_active())
.collect()
}
pub fn update_status(&mut self, fork_id: &str, status: ForkStatus) -> Option<()> {
self.forks.get_mut(fork_id).map(|f| {
f.status = status;
})
}
pub fn remove(&mut self, fork_id: &str) -> Option<ForkInfo> {
if let Some(fork) = self.forks.remove(fork_id) {
if let Some(children) = self.children.get_mut(&fork.parent_session_id) {
children.retain(|id| id != fork_id);
}
Some(fork)
} else {
None
}
}
pub fn list_all(&self) -> Vec<&ForkInfo> {
self.forks.values().collect()
}
}
pub struct ForkManager {
base_path: PathBuf,
registry: RwLock<ForkRegistry>,
checkpoint_store: Option<CheckpointStore>,
}
impl ForkManager {
pub async fn new(base_path: impl AsRef<Path>) -> Result<Self> {
let base = base_path.as_ref().to_path_buf();
fs::create_dir_all(&base).await?;
let manager = Self {
base_path: base.clone(),
registry: RwLock::new(ForkRegistry::new()),
checkpoint_store: None,
};
manager.load_registry().await?;
Ok(manager)
}
pub async fn with_checkpoint_store(
base_path: impl AsRef<Path>,
checkpoint_store: CheckpointStore,
) -> Result<Self> {
let mut manager = Self::new(&base_path).await?;
manager.checkpoint_store = Some(checkpoint_store);
Ok(manager)
}
pub async fn create_fork(
&self,
parent_session_id: &str,
checkpoint_id: &str,
branch_name: Option<String>,
) -> Result<ForkInfo> {
if let Some(ref store) = self.checkpoint_store {
store.load(checkpoint_id).await?;
}
let mut fork = ForkInfo::new(parent_session_id, checkpoint_id, branch_name);
fork.status = ForkStatus::Active;
{
let mut registry = self.registry.write().await;
registry.register(fork.clone());
}
self.save_registry().await?;
tracing::info!(
fork_id = %fork.fork_id,
parent = %parent_session_id,
checkpoint = %checkpoint_id,
"Fork created"
);
Ok(fork)
}
pub async fn get_fork(&self, fork_id: &str) -> Option<ForkInfo> {
let registry = self.registry.read().await;
registry.get(fork_id).cloned()
}
pub async fn list_forks(&self, session_id: &str) -> Vec<ForkInfo> {
let registry = self.registry.read().await;
registry
.get_children(session_id)
.into_iter()
.cloned()
.collect()
}
pub async fn update_fork_status(&self, fork_id: &str, status: ForkStatus) -> Result<()> {
{
let mut registry = self.registry.write().await;
registry.update_status(fork_id, status);
}
self.save_registry().await?;
Ok(())
}
pub async fn abandon_fork(&self, fork_id: &str) -> Result<()> {
self.update_fork_status(fork_id, ForkStatus::Abandoned)
.await
}
pub async fn archive_fork(&self, fork_id: &str) -> Result<()> {
self.update_fork_status(fork_id, ForkStatus::Archived).await
}
pub async fn delete_fork(&self, fork_id: &str) -> Result<()> {
{
let mut registry = self.registry.write().await;
registry.remove(fork_id);
}
self.save_registry().await?;
tracing::info!(fork_id = %fork_id, "Fork deleted");
Ok(())
}
pub async fn get_fork_checkpoint(&self, fork_id: &str) -> Result<Option<SessionCheckpoint>> {
let registry = self.registry.read().await;
if let Some(fork) = registry.get(fork_id) {
if let Some(ref store) = self.checkpoint_store {
return Ok(Some(store.load(&fork.checkpoint_id).await?));
}
}
Ok(None)
}
async fn save_registry(&self) -> Result<()> {
let registry = self.registry.read().await;
let data = serde_json::to_vec_pretty(&*registry)?;
fs::write(self.base_path.join("forks.json"), data).await?;
Ok(())
}
async fn load_registry(&self) -> Result<()> {
let path = self.base_path.join("forks.json");
if path.exists() {
let data = fs::read(&path).await?;
let loaded: ForkRegistry = serde_json::from_slice(&data)?;
let mut registry = self.registry.write().await;
*registry = loaded;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_fork_info_creation() {
let fork = ForkInfo::new(
"session-1",
"checkpoint-1",
Some("feature-branch".to_string()),
);
assert_eq!(fork.parent_session_id, "session-1");
assert_eq!(fork.checkpoint_id, "checkpoint-1");
assert_eq!(fork.branch_name, Some("feature-branch".to_string()));
assert_eq!(fork.status, ForkStatus::Creating);
}
#[test]
fn test_fork_registry() {
let mut registry = ForkRegistry::new();
let mut fork1 = ForkInfo::new("parent-1", "cp-1", None);
fork1.status = ForkStatus::Active;
let fork1_id = fork1.fork_id.clone();
let mut fork2 = ForkInfo::new("parent-1", "cp-2", None);
fork2.status = ForkStatus::Active;
registry.register(fork1);
registry.register(fork2);
let children = registry.get_children("parent-1");
assert_eq!(children.len(), 2);
registry.update_status(&fork1_id, ForkStatus::Merged);
let active = registry.get_active_children("parent-1");
assert_eq!(active.len(), 1);
}
#[tokio::test]
async fn test_fork_manager() {
let temp = tempdir().expect("Failed to create temp dir");
let manager = ForkManager::new(temp.path()).await.unwrap();
let fork = manager
.create_fork("session-1", "checkpoint-1", Some("test-branch".to_string()))
.await
.unwrap();
assert!(fork.is_active());
let forks = manager.list_forks("session-1").await;
assert_eq!(forks.len(), 1);
manager.abandon_fork(&fork.fork_id).await.unwrap();
let updated = manager.get_fork(&fork.fork_id).await.unwrap();
assert_eq!(updated.status, ForkStatus::Abandoned);
}
}