1use anyhow::Result;
13use oxidized_state::{BranchRecord, CommitId, CommitRecord, SurrealHandle};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17use tokio::task::JoinHandle;
18use tracing::{debug, info, instrument, warn};
19
20use crate::metrics::METRICS;
21
22#[derive(Debug, Clone)]
24pub struct ForkResult {
25 #[allow(dead_code)]
27 pub parent_commit: String,
28 pub branches: Vec<String>,
30 pub commit_ids: Vec<CommitId>,
32}
33
34#[allow(dead_code)]
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct BranchStatus {
38 pub name: String,
40 pub commit_id: String,
42 pub score: f32,
44 pub active: bool,
46 pub step: usize,
48}
49
50#[allow(dead_code)]
52#[derive(Debug, Clone)]
53pub struct ParallelConfig {
54 pub score_threshold: f32,
56 #[allow(dead_code)]
58 pub max_branches: usize,
59 #[allow(dead_code)]
61 pub auto_prune: bool,
62}
63
64impl Default for ParallelConfig {
65 fn default() -> Self {
66 Self {
67 score_threshold: 0.3,
68 max_branches: 10,
69 auto_prune: true,
70 }
71 }
72}
73
74#[instrument(skip(handle), fields(parent = %&parent_commit[..8.min(parent_commit.len())]))]
90pub async fn fork_agent_parallel(
91 handle: Arc<SurrealHandle>,
92 parent_commit: &str,
93 count: u8,
94 prefix: &str,
95) -> Result<ForkResult> {
96 METRICS.inc_forks();
97 info!(
98 "Forking {} parallel branches from {}",
99 count,
100 &parent_commit[..8.min(parent_commit.len())]
101 );
102
103 let parent_snapshot = handle.load_snapshot(parent_commit).await?;
105
106 let mut tasks: Vec<JoinHandle<Result<(String, CommitId)>>> = Vec::new();
108
109 for i in 0..count {
110 let handle_clone = Arc::clone(&handle);
111 let parent_id = parent_commit.to_string();
112 let branch_name = format!("{}-{}", prefix, i);
113 let state = parent_snapshot.state.clone();
114
115 let task = tokio::spawn(async move {
116 let fork_data = format!("fork:{}:{}", parent_id, branch_name);
118 let commit_id = CommitId::from_state(fork_data.as_bytes());
119
120 handle_clone.save_snapshot(&commit_id, state).await?;
122
123 let commit = CommitRecord::new(
125 commit_id.clone(),
126 vec![parent_id.clone()],
127 &format!("Fork branch {}", branch_name),
128 "parallel-fork",
129 );
130 handle_clone.save_commit(&commit).await?;
131
132 handle_clone
134 .save_commit_graph_edge(&commit_id.hash, &parent_id)
135 .await?;
136
137 let branch = BranchRecord::new(&branch_name, &commit_id.hash, false);
139 handle_clone.save_branch(&branch).await?;
140
141 debug!(
142 "Created fork branch: {} at {}",
143 branch_name,
144 commit_id.short()
145 );
146 Ok((branch_name, commit_id))
147 });
148
149 tasks.push(task);
150 }
151
152 let mut branches = Vec::new();
154 let mut commit_ids = Vec::new();
155
156 for task in tasks {
157 let (name, id) = task.await??;
158 branches.push(name);
159 commit_ids.push(id);
160 }
161
162 info!("Created {} parallel branches", branches.len());
163
164 Ok(ForkResult {
165 parent_commit: parent_commit.to_string(),
166 branches,
167 commit_ids,
168 })
169}
170
171#[allow(dead_code)]
173pub struct ParallelManager {
174 #[allow(dead_code)]
175 handle: Arc<SurrealHandle>,
176 config: ParallelConfig,
177 branch_status: Arc<Mutex<Vec<BranchStatus>>>,
178}
179
180#[allow(dead_code)]
181impl ParallelManager {
182 pub fn new(handle: Arc<SurrealHandle>, config: ParallelConfig) -> Self {
184 Self {
185 handle,
186 config,
187 branch_status: Arc::new(Mutex::new(Vec::new())),
188 }
189 }
190
191 pub async fn register_branch(&self, name: &str, commit_id: &str) {
193 let mut status = self.branch_status.lock().await;
194 status.push(BranchStatus {
195 name: name.to_string(),
196 commit_id: commit_id.to_string(),
197 score: 1.0, active: true,
199 step: 0,
200 });
201 }
202
203 pub async fn update_score(&self, branch_name: &str, score: f32) {
205 let mut status = self.branch_status.lock().await;
206 if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
207 branch.score = score;
208 }
209 }
210
211 pub async fn update_step(&self, branch_name: &str, step: usize) {
213 let mut status = self.branch_status.lock().await;
214 if let Some(branch) = status.iter_mut().find(|b| b.name == branch_name) {
215 branch.step = step;
216 }
217 }
218
219 pub async fn get_statuses(&self) -> Vec<BranchStatus> {
221 self.branch_status.lock().await.clone()
222 }
223
224 pub async fn prune_low_performing_branches(&self) -> Result<Vec<String>> {
230 let mut status = self.branch_status.lock().await;
231 let threshold = self.config.score_threshold;
232
233 let mut pruned = Vec::new();
234
235 for branch in status.iter_mut() {
236 if branch.active && branch.score < threshold {
237 warn!(
238 "Pruning branch '{}' - score {} below threshold {}",
239 branch.name, branch.score, threshold
240 );
241 branch.active = false;
242 pruned.push(branch.name.clone());
243 }
244 }
245
246 if !pruned.is_empty() {
247 info!("Pruned {} low-performing branches", pruned.len());
248 }
249
250 Ok(pruned)
251 }
252
253 pub async fn active_count(&self) -> usize {
255 let status = self.branch_status.lock().await;
256 status.iter().filter(|b| b.active).count()
257 }
258
259 pub async fn is_active(&self, branch_name: &str) -> bool {
261 let status = self.branch_status.lock().await;
262 status
263 .iter()
264 .find(|b| b.name == branch_name)
265 .map(|b| b.active)
266 .unwrap_or(false)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[tokio::test]
275 async fn test_five_branches_are_forked_and_run_concurrently_via_tokio() {
276 let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
277
278 let parent_state = serde_json::json!({
280 "agent": "optimizer",
281 "strategy": "baseline",
282 "step": 0
283 });
284 let parent_id = CommitId::from_state(b"parent-state");
285 handle
286 .save_snapshot(&parent_id, parent_state)
287 .await
288 .unwrap();
289
290 let parent_commit = CommitRecord::new(parent_id.clone(), vec![], "Parent commit", "test");
291 handle.save_commit(&parent_commit).await.unwrap();
292
293 let result = fork_agent_parallel(Arc::clone(&handle), &parent_id.hash, 5, "experiment")
295 .await
296 .unwrap();
297
298 assert_eq!(result.branches.len(), 5, "Should create 5 branches");
300 assert_eq!(result.commit_ids.len(), 5, "Should have 5 commit IDs");
301
302 let unique_names: std::collections::HashSet<_> = result.branches.iter().collect();
304 assert_eq!(unique_names.len(), 5, "Branch names should be unique");
305
306 for (i, branch_name) in result.branches.iter().enumerate() {
308 let branch = handle.get_branch(branch_name).await.unwrap();
309 assert!(branch.is_some(), "Branch {} should exist", branch_name);
310 assert_eq!(
311 branch.unwrap().head_commit_id,
312 result.commit_ids[i].hash,
313 "Branch head should match commit ID"
314 );
315 }
316
317 for commit_id in &result.commit_ids {
319 let parent = handle.get_parent(&commit_id.hash).await.unwrap();
320 assert_eq!(
321 parent,
322 Some(parent_id.hash.clone()),
323 "Fork should have parent edge"
324 );
325 }
326 }
327
328 #[tokio::test]
329 async fn test_optimizer_kills_branch_when_score_threshold_is_missed() {
330 let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
331
332 let config = ParallelConfig {
333 score_threshold: 0.5,
334 max_branches: 10,
335 auto_prune: true,
336 };
337
338 let manager = ParallelManager::new(handle, config);
339
340 manager.register_branch("high-performer", "commit-1").await;
342 manager
343 .register_branch("medium-performer", "commit-2")
344 .await;
345 manager.register_branch("low-performer", "commit-3").await;
346 manager
347 .register_branch("very-low-performer", "commit-4")
348 .await;
349
350 manager.update_score("high-performer", 0.9).await;
352 manager.update_score("medium-performer", 0.6).await;
353 manager.update_score("low-performer", 0.3).await; manager.update_score("very-low-performer", 0.1).await; let pruned = manager.prune_low_performing_branches().await.unwrap();
358
359 assert_eq!(pruned.len(), 2, "Should prune 2 branches");
361 assert!(pruned.contains(&"low-performer".to_string()));
362 assert!(pruned.contains(&"very-low-performer".to_string()));
363
364 assert!(manager.is_active("high-performer").await);
366 assert!(manager.is_active("medium-performer").await);
367 assert!(!manager.is_active("low-performer").await);
368 assert!(!manager.is_active("very-low-performer").await);
369
370 assert_eq!(manager.active_count().await, 2);
372 }
373
374 #[tokio::test]
375 async fn test_parallel_manager_tracks_branch_progress() {
376 let handle = Arc::new(SurrealHandle::setup_db().await.unwrap());
377 let manager = ParallelManager::new(handle, ParallelConfig::default());
378
379 manager.register_branch("branch-1", "commit-abc").await;
380 manager.update_step("branch-1", 5).await;
381 manager.update_score("branch-1", 0.75).await;
382
383 let statuses = manager.get_statuses().await;
384 assert_eq!(statuses.len(), 1);
385 assert_eq!(statuses[0].step, 5);
386 assert_eq!(statuses[0].score, 0.75);
387 assert!(statuses[0].active);
388 }
389}