use anyhow::Result;
use oxidized_state::{BranchRecord, CommitId, CommitRecord, SurrealHandle};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::{debug, info, instrument, warn};
use crate::metrics::METRICS;
#[derive(Debug, Clone)]
pub struct ForkResult {
#[allow(dead_code)]
pub parent_commit: String,
pub branches: Vec<String>,
pub commit_ids: Vec<CommitId>,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchStatus {
pub name: String,
pub commit_id: String,
pub score: f32,
pub active: bool,
pub step: usize,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub score_threshold: f32,
#[allow(dead_code)]
pub max_branches: usize,
#[allow(dead_code)]
pub auto_prune: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
score_threshold: 0.3,
max_branches: 10,
auto_prune: true,
}
}
}
#[instrument(skip(handle), fields(parent = %&parent_commit[..8.min(parent_commit.len())]))]
pub async fn fork_agent_parallel(
handle: Arc<SurrealHandle>,
parent_commit: &str,
count: u8,
prefix: &str,
) -> Result<ForkResult> {
METRICS.inc_forks();
info!(
"Forking {} parallel branches from {}",
count,
&parent_commit[..8.min(parent_commit.len())]
);
let parent_snapshot = handle.load_snapshot(parent_commit).await?;
let mut tasks: Vec<JoinHandle<Result<(String, CommitId)>>> = Vec::new();
for i in 0..count {
let handle_clone = Arc::clone(&handle);
let parent_id = parent_commit.to_string();
let branch_name = format!("{}-{}", prefix, i);
let state = parent_snapshot.state.clone();
let task = tokio::spawn(async move {
let fork_data = format!("fork:{}:{}", parent_id, branch_name);
let commit_id = CommitId::from_state(fork_data.as_bytes());
handle_clone.save_snapshot(&commit_id, state).await?;
let commit = CommitRecord::new(
commit_id.clone(),
vec![parent_id.clone()],
&format!("Fork branch {}", branch_name),
"parallel-fork",
);
handle_clone.save_commit(&commit).await?;
handle_clone
.save_commit_graph_edge(&commit_id.hash, &parent_id)
.await?;
let branch = BranchRecord::new(&branch_name, &commit_id.hash, false);
handle_clone.save_branch(&branch).await?;
debug!(
"Created fork branch: {} at {}",
branch_name,
commit_id.short()
);
Ok((branch_name, commit_id))
});
tasks.push(task);
}
let mut branches = Vec::new();
let mut commit_ids = Vec::new();
for task in tasks {
let (name, id) = task.await??;
branches.push(name);
commit_ids.push(id);
}
info!("Created {} parallel branches", branches.len());
Ok(ForkResult {
parent_commit: parent_commit.to_string(),
branches,
commit_ids,
})
}
#[allow(dead_code)]
pub struct ParallelManager {
#[allow(dead_code)]
handle: Arc<SurrealHandle>,
config: ParallelConfig,
branch_status: Arc<Mutex<Vec<BranchStatus>>>,
}
#[allow(dead_code)]
impl ParallelManager {
pub fn new(handle: Arc<SurrealHandle>, config: ParallelConfig) -> Self {
Self {
handle,
config,
branch_status: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn register_branch(&self, name: &str, commit_id: &str) {
let mut status = self.branch_status.lock().await;
status.push(BranchStatus {
name: name.to_string(),
commit_id: commit_id.to_string(),
score: 1.0, active: true,
step: 0,
});
}
pub async fn update_score(&self, branch_name: &str, score: f32) {
let mut status = self.branch_status.lock().await;
if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
branch.score = score;
}
}
pub async fn update_step(&self, branch_name: &str, step: usize) {
let mut status = self.branch_status.lock().await;
if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
branch.step = step;
}
}
pub async fn get_statuses(&self) -> Vec<BranchStatus> {
self.branch_status.lock().await.clone()
}
pub async fn prune_low_performing_branches(&self) -> Result<Vec<String>> {
let mut status = self.branch_status.lock().await;
let threshold = self.config.score_threshold;
let mut pruned = Vec::new();
for branch in status.iter_mut() {
if branch.active && branch.score < threshold {
warn!(
"Pruning branch '{}' - score {} below threshold {}",
branch.name, branch.score, threshold
);
branch.active = false;
pruned.push(branch.name.clone());
}
}
if !pruned.is_empty() {
info!("Pruned {} low-performing branches", pruned.len());
}
Ok(pruned)
}
pub async fn active_count(&self) -> usize {
let status = self.branch_status.lock().await;
status.iter().filter(|b| b.active).count()
}
pub async fn is_active(&self, branch_name: &str) -> bool {
let status = self.branch_status.lock().await;
status
.iter()
.find(|b| b.name == branch_name)
.map(|b| b.active)
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_five_branches_are_forked_and_run_concurrently_via_tokio() {
let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
let parent_state = serde_json::json!({
"agent": "optimizer",
"strategy": "baseline",
"step": 0
});
let parent_id = CommitId::from_state(b"parent-state");
handle
.save_snapshot(&parent_id, parent_state)
.await
.unwrap();
let parent_commit = CommitRecord::new(parent_id.clone(), vec![], "Parent commit", "test");
handle.save_commit(&parent_commit).await.unwrap();
let result = fork_agent_parallel(Arc::clone(&handle), &parent_id.hash, 5, "experiment")
.await
.unwrap();
assert_eq!(result.branches.len(), 5, "Should create 5 branches");
assert_eq!(result.commit_ids.len(), 5, "Should have 5 commit IDs");
let unique_names: std::collections::HashSet<_> = result.branches.iter().collect();
assert_eq!(unique_names.len(), 5, "Branch names should be unique");
for (i, branch_name) in result.branches.iter().enumerate() {
let branch = handle.get_branch(branch_name).await.unwrap();
assert!(branch.is_some(), "Branch {} should exist", branch_name);
assert_eq!(
branch.unwrap().head_commit_id,
result.commit_ids[i].hash,
"Branch head should match commit ID"
);
}
for commit_id in &result.commit_ids {
let parent = handle.get_parent(&commit_id.hash).await.unwrap();
assert_eq!(
parent,
Some(parent_id.hash.clone()),
"Fork should have parent edge"
);
}
}
#[tokio::test]
async fn test_optimizer_kills_branch_when_score_threshold_is_missed() {
let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
let config = ParallelConfig {
score_threshold: 0.5,
max_branches: 10,
auto_prune: true,
};
let manager = ParallelManager::new(handle, config);
manager.register_branch("high-performer", "commit-1").await;
manager
.register_branch("medium-performer", "commit-2")
.await;
manager.register_branch("low-performer", "commit-3").await;
manager
.register_branch("very-low-performer", "commit-4")
.await;
manager.update_score("high-performer", 0.9).await;
manager.update_score("medium-performer", 0.6).await;
manager.update_score("low-performer", 0.3).await; manager.update_score("very-low-performer", 0.1).await;
let pruned = manager.prune_low_performing_branches().await.unwrap();
assert_eq!(pruned.len(), 2, "Should prune 2 branches");
assert!(pruned.contains(&"low-performer".to_string()));
assert!(pruned.contains(&"very-low-performer".to_string()));
assert!(manager.is_active("high-performer").await);
assert!(manager.is_active("medium-performer").await);
assert!(!manager.is_active("low-performer").await);
assert!(!manager.is_active("very-low-performer").await);
assert_eq!(manager.active_count().await, 2);
}
#[tokio::test]
async fn test_parallel_manager_tracks_branch_progress() {
let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
let manager = ParallelManager::new(handle, ParallelConfig::default());
manager.register_branch("branch-1", "commit-abc").await;
manager.update_step("branch-1", 5).await;
manager.update_score("branch-1", 0.75).await;
let statuses = manager.get_statuses().await;
assert_eq!(statuses.len(), 1);
assert_eq!(statuses[0].step, 5);
assert_eq!(statuses[0].score, 0.75);
assert!(statuses[0].active);
}
}