Skip to main content

aivcs_core/
parallel.rs

1//! Parallel Simulation Module
2//!
3//! Provides concurrent agent exploration capabilities:
4//! - Fork multiple branches from a parent commit
5//! - Run agent variants concurrently using Tokio
6//! - Prune low-performing branches based on score threshold
7//!
8//! # TDD Tests:
9//! - test_five_branches_are_forked_and_run_concurrently_via_tokio
10//! - test_optimizer_kills_branch_when_score_threshold_is_missed
11
12use anyhow::Result;
13use oxidized_state::{BranchRecord, CommitId, CommitRecord, SurrealHandle};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17use tokio::task::JoinHandle;
18use tracing::{debug, info, instrument, warn};
19
20use crate::metrics::METRICS;
21
22/// Result of forking multiple branches
23#[derive(Debug, Clone)]
24pub struct ForkResult {
25    /// Parent commit that was forked from
26    #[allow(dead_code)]
27    pub parent_commit: String,
28    /// Branch names created
29    pub branches: Vec<String>,
30    /// Commit IDs for each branch
31    pub commit_ids: Vec<CommitId>,
32}
33
34/// Status of a running parallel branch
35#[allow(dead_code)]
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct BranchStatus {
38    /// Branch name
39    pub name: String,
40    /// Current commit ID
41    pub commit_id: String,
42    /// Performance score (0.0 - 1.0)
43    pub score: f32,
44    /// Whether the branch is still active
45    pub active: bool,
46    /// Step count in this branch
47    pub step: usize,
48}
49
50/// Configuration for parallel exploration
51#[allow(dead_code)]
52#[derive(Debug, Clone)]
53pub struct ParallelConfig {
54    /// Minimum score threshold (branches below this get pruned)
55    pub score_threshold: f32,
56    /// Maximum concurrent branches
57    #[allow(dead_code)]
58    pub max_branches: usize,
59    /// Auto-prune low performers
60    #[allow(dead_code)]
61    pub auto_prune: bool,
62}
63
64impl Default for ParallelConfig {
65    fn default() -> Self {
66        Self {
67            score_threshold: 0.3,
68            max_branches: 10,
69            auto_prune: true,
70        }
71    }
72}
73
74/// Fork multiple agent branches from a parent commit
75///
76/// Creates `count` new branches, each starting from the same parent commit.
77/// The branches are created concurrently using Tokio.
78///
79/// # TDD: test_five_branches_are_forked_and_run_concurrently_via_tokio
80///
81/// # Arguments
82/// * `handle` - SurrealDB handle
83/// * `parent_commit` - Commit ID to fork from
84/// * `count` - Number of branches to create
85/// * `prefix` - Branch name prefix (branches named "{prefix}-0", "{prefix}-1", etc.)
86///
87/// # Returns
88/// * `ForkResult` containing the created branch names and commit IDs
89#[instrument(skip(handle), fields(parent = %&parent_commit[..8.min(parent_commit.len())]))]
90pub async fn fork_agent_parallel(
91    handle: Arc<SurrealHandle>,
92    parent_commit: &str,
93    count: u8,
94    prefix: &str,
95) -> Result<ForkResult> {
96    METRICS.inc_forks();
97    info!(
98        "Forking {} parallel branches from {}",
99        count,
100        &parent_commit[..8.min(parent_commit.len())]
101    );
102
103    // Get parent snapshot to clone state
104    let parent_snapshot = handle.load_snapshot(parent_commit).await?;
105
106    // Spawn concurrent tasks to create branches
107    let mut tasks: Vec<JoinHandle<Result<(String, CommitId)>>> = Vec::new();
108
109    for i in 0..count {
110        let handle_clone = Arc::clone(&handle);
111        let parent_id = parent_commit.to_string();
112        let branch_name = format!("{}-{}", prefix, i);
113        let state = parent_snapshot.state.clone();
114
115        let task = tokio::spawn(async move {
116            // Create commit ID for this branch
117            let fork_data = format!("fork:{}:{}", parent_id, branch_name);
118            let commit_id = CommitId::from_state(fork_data.as_bytes());
119
120            // Save forked snapshot
121            handle_clone.save_snapshot(&commit_id, state).await?;
122
123            // Create commit record
124            let commit = CommitRecord::new(
125                commit_id.clone(),
126                vec![parent_id.clone()],
127                &format!("Fork branch {}", branch_name),
128                "parallel-fork",
129            );
130            handle_clone.save_commit(&commit).await?;
131
132            // Create graph edge
133            handle_clone
134                .save_commit_graph_edge(&commit_id.hash, &parent_id)
135                .await?;
136
137            // Create branch pointer
138            let branch = BranchRecord::new(&branch_name, &commit_id.hash, false);
139            handle_clone.save_branch(&branch).await?;
140
141            debug!(
142                "Created fork branch: {} at {}",
143                branch_name,
144                commit_id.short()
145            );
146            Ok((branch_name, commit_id))
147        });
148
149        tasks.push(task);
150    }
151
152    // Wait for all forks to complete
153    let mut branches = Vec::new();
154    let mut commit_ids = Vec::new();
155
156    for task in tasks {
157        let (name, id) = task.await??;
158        branches.push(name);
159        commit_ids.push(id);
160    }
161
162    info!("Created {} parallel branches", branches.len());
163
164    Ok(ForkResult {
165        parent_commit: parent_commit.to_string(),
166        branches,
167        commit_ids,
168    })
169}
170
171/// Parallel branch manager for tracking and pruning branches
172#[allow(dead_code)]
173pub struct ParallelManager {
174    #[allow(dead_code)]
175    handle: Arc<SurrealHandle>,
176    config: ParallelConfig,
177    branch_status: Arc<Mutex<Vec<BranchStatus>>>,
178}
179
180#[allow(dead_code)]
181impl ParallelManager {
182    /// Create a new parallel manager
183    pub fn new(handle: Arc<SurrealHandle>, config: ParallelConfig) -> Self {
184        Self {
185            handle,
186            config,
187            branch_status: Arc::new(Mutex::new(Vec::new())),
188        }
189    }
190
191    /// Register a branch for tracking
192    pub async fn register_branch(&self, name: &str, commit_id: &str) {
193        let mut status = self.branch_status.lock().await;
194        status.push(BranchStatus {
195            name: name.to_string(),
196            commit_id: commit_id.to_string(),
197            score: 1.0, // Start with perfect score
198            active: true,
199            step: 0,
200        });
201    }
202
203    /// Update branch score
204    pub async fn update_score(&self, branch_name: &str, score: f32) {
205        let mut status = self.branch_status.lock().await;
206        if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
207            branch.score = score;
208        }
209    }
210
211    /// Update branch step count
212    pub async fn update_step(&self, branch_name: &str, step: usize) {
213        let mut status = self.branch_status.lock().await;
214        if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
215            branch.step = step;
216        }
217    }
218
219    /// Get all branch statuses
220    pub async fn get_statuses(&self) -> Vec<BranchStatus> {
221        self.branch_status.lock().await.clone()
222    }
223
224    /// Prune branches that fall below the score threshold
225    ///
226    /// # TDD: test_optimizer_kills_branch_when_score_threshold_is_missed
227    ///
228    /// Returns the names of pruned branches
229    pub async fn prune_low_performing_branches(&self) -> Result<Vec<String>> {
230        let mut status = self.branch_status.lock().await;
231        let threshold = self.config.score_threshold;
232
233        let mut pruned = Vec::new();
234
235        for branch in status.iter_mut() {
236            if branch.active && branch.score < threshold {
237                warn!(
238                    "Pruning branch '{}' - score {} below threshold {}",
239                    branch.name, branch.score, threshold
240                );
241                branch.active = false;
242                pruned.push(branch.name.clone());
243            }
244        }
245
246        if !pruned.is_empty() {
247            info!("Pruned {} low-performing branches", pruned.len());
248        }
249
250        Ok(pruned)
251    }
252
253    /// Get active branch count
254    pub async fn active_count(&self) -> usize {
255        let status = self.branch_status.lock().await;
256        status.iter().filter(|b| b.active).count()
257    }
258
259    /// Check if a specific branch is still active
260    pub async fn is_active(&self, branch_name: &str) -> bool {
261        let status = self.branch_status.lock().await;
262        status
263            .iter()
264            .find(|b| b.name == branch_name)
265            .map(|b| b.active)
266            .unwrap_or(false)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[tokio::test]
275    async fn test_five_branches_are_forked_and_run_concurrently_via_tokio() {
276        let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
277
278        // Create a parent commit to fork from
279        let parent_state = serde_json::json!({
280            "agent": "optimizer",
281            "strategy": "baseline",
282            "step": 0
283        });
284        let parent_id = CommitId::from_state(b"parent-state");
285        handle
286            .save_snapshot(&parent_id, parent_state)
287            .await
288            .unwrap();
289
290        let parent_commit = CommitRecord::new(parent_id.clone(), vec![], "Parent commit", "test");
291        handle.save_commit(&parent_commit).await.unwrap();
292
293        // Fork 5 branches concurrently
294        let result = fork_agent_parallel(Arc::clone(&handle), &parent_id.hash, 5, "experiment")
295            .await
296            .unwrap();
297
298        // Verify all 5 branches were created
299        assert_eq!(result.branches.len(), 5, "Should create 5 branches");
300        assert_eq!(result.commit_ids.len(), 5, "Should have 5 commit IDs");
301
302        // Verify branches have unique names
303        let unique_names: std::collections::HashSet<_> = result.branches.iter().collect();
304        assert_eq!(unique_names.len(), 5, "Branch names should be unique");
305
306        // Verify each branch exists in the database
307        for (i, branch_name) in result.branches.iter().enumerate() {
308            let branch = handle.get_branch(branch_name).await.unwrap();
309            assert!(branch.is_some(), "Branch {} should exist", branch_name);
310            assert_eq!(
311                branch.unwrap().head_commit_id,
312                result.commit_ids[i].hash,
313                "Branch head should match commit ID"
314            );
315        }
316
317        // Verify graph edges point to parent
318        for commit_id in &result.commit_ids {
319            let parent = handle.get_parent(&commit_id.hash).await.unwrap();
320            assert_eq!(
321                parent,
322                Some(parent_id.hash.clone()),
323                "Fork should have parent edge"
324            );
325        }
326    }
327
328    #[tokio::test]
329    async fn test_optimizer_kills_branch_when_score_threshold_is_missed() {
330        let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
331
332        let config = ParallelConfig {
333            score_threshold: 0.5,
334            max_branches: 10,
335            auto_prune: true,
336        };
337
338        let manager = ParallelManager::new(handle, config);
339
340        // Register branches with different scores
341        manager.register_branch("high-performer", "commit-1").await;
342        manager
343            .register_branch("medium-performer", "commit-2")
344            .await;
345        manager.register_branch("low-performer", "commit-3").await;
346        manager
347            .register_branch("very-low-performer", "commit-4")
348            .await;
349
350        // Set scores
351        manager.update_score("high-performer", 0.9).await;
352        manager.update_score("medium-performer", 0.6).await;
353        manager.update_score("low-performer", 0.3).await; // Below threshold
354        manager.update_score("very-low-performer", 0.1).await; // Below threshold
355
356        // Prune low performers
357        let pruned = manager.prune_low_performing_branches().await.unwrap();
358
359        // Verify correct branches were pruned
360        assert_eq!(pruned.len(), 2, "Should prune 2 branches");
361        assert!(pruned.contains(&"low-performer".to_string()));
362        assert!(pruned.contains(&"very-low-performer".to_string()));
363
364        // Verify high performers are still active
365        assert!(manager.is_active("high-performer").await);
366        assert!(manager.is_active("medium-performer").await);
367        assert!(!manager.is_active("low-performer").await);
368        assert!(!manager.is_active("very-low-performer").await);
369
370        // Verify active count
371        assert_eq!(manager.active_count().await, 2);
372    }
373
374    #[tokio::test]
375    async fn test_parallel_manager_tracks_branch_progress() {
376        let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
377        let manager = ParallelManager::new(handle, ParallelConfig::default());
378
379        manager.register_branch("branch-1", "commit-abc").await;
380        manager.update_step("branch-1", 5).await;
381        manager.update_score("branch-1", 0.75).await;
382
383        let statuses = manager.get_statuses().await;
384        assert_eq!(statuses.len(), 1);
385        assert_eq!(statuses[0].step, 5);
386        assert_eq!(statuses[0].score, 0.75);
387        assert!(statuses[0].active);
388    }
389}