use std::io::Write;
use std::path::PathBuf;
use anyhow::Result;
use tokio::sync::mpsc;
use crate::display::renderer::Renderer;
use crate::event::AppEvent;
use crate::protocol::types::{AssistantContentBlock, InboundEvent};
use crate::session::runner::{SessionConfig, SessionRunner};
pub struct ForkConfig {
pub extra_args: Vec<String>,
pub working_dir: Option<PathBuf>,
}
impl ForkConfig {
pub fn if_enabled(
enabled: bool,
extra_args: &[String],
working_dir: &Option<PathBuf>,
) -> Option<Self> {
enabled.then(|| Self {
extra_args: extra_args.to_vec(),
working_dir: working_dir.clone(),
})
}
}
pub async fn run_fork<W: Write>(
parent_session_id: &str,
tasks: Vec<String>,
config: &ForkConfig,
renderer: &mut Renderer<W>,
) -> Result<String> {
renderer.render_fork_start(&tasks);
let num_tasks = tasks.len();
let (merged_tx, mut merged_rx) = mpsc::unbounded_channel::<(usize, AppEvent)>();
let mut runners: Vec<SessionRunner> = Vec::new();
for (i, task) in tasks.iter().enumerate() {
let (child_tx, mut child_rx) = mpsc::unbounded_channel();
let mut extra_args = config.extra_args.clone();
extra_args.push("--fork-session".to_string());
let child_config = SessionConfig {
prompt: Some(format!("You were assigned '{task}'")),
resume: Some(parent_session_id.to_string()),
extra_args,
working_dir: config.working_dir.clone(),
..Default::default()
};
let runner = SessionRunner::spawn(child_config, child_tx).await?;
runners.push(runner);
let merged_tx = merged_tx.clone();
tokio::spawn(async move {
while let Some(event) = child_rx.recv().await {
if merged_tx.send((i, event)).is_err() {
break;
}
}
});
}
drop(merged_tx);
let mut results: Vec<Option<std::result::Result<String, String>>> = vec![None; num_tasks];
let mut completed = 0;
while let Some((idx, event)) = merged_rx.recv().await {
match event {
AppEvent::Claude(inbound) => match &*inbound {
InboundEvent::Assistant(msg) if msg.parent_tool_use_id.is_none() => {
for block in &msg.message.content {
if let AssistantContentBlock::ToolUse { name, input, .. } = block {
renderer.render_fork_child_tool_call(name, input);
}
}
}
InboundEvent::Result(result) => {
renderer.render_fork_child_done(&tasks[idx]);
results[idx] = Some(Ok(result.result.clone()));
completed += 1;
if completed == num_tasks {
break;
}
}
_ => {}
},
AppEvent::ParseWarning(w) => {
renderer.render_warning(&w);
}
AppEvent::ProcessExit(_) => {
if results[idx].is_none() {
results[idx] = Some(Err("Child process exited unexpectedly".to_string()));
completed += 1;
if completed == num_tasks {
break;
}
}
}
}
}
for runner in &mut runners {
runner.close_input();
let _ = runner.wait().await;
}
renderer.render_fork_complete();
let result_tuples: Vec<(String, std::result::Result<String, String>)> = tasks
.into_iter()
.zip(results)
.map(|(label, result)| {
let outcome = result.unwrap_or_else(|| Err("No result received".to_string()));
(label, outcome)
})
.collect();
Ok(compose_reintegration_message(&result_tuples))
}
pub fn parse_fork_tag(text: &str) -> Option<Vec<String>> {
let inner = crate::protocol::parse::extract_tag_inner(text, "fork")?;
let tasks: Vec<String> = inner
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.map(|line| line.strip_prefix("- ").unwrap_or(line).trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if tasks.is_empty() { None } else { Some(tasks) }
}
pub fn compose_reintegration_message(results: &[(String, Result<String, String>)]) -> String {
use std::fmt::Write;
let mut xml = String::from("<fork-results>\n");
for (label, outcome) in results {
let safe_label = label.replace('"', """);
match outcome {
Ok(text) => {
let _ = write!(
xml,
"<task label=\"{safe_label}\">\n<![CDATA[{text}]]>\n</task>\n"
);
}
Err(err) => {
let _ = write!(
xml,
"<task label=\"{safe_label}\" error=\"true\">\n<![CDATA[{err}]]>\n</task>\n"
);
}
}
}
xml.push_str("</fork-results>");
xml
}
pub fn fork_system_prompt() -> &'static str {
"To parallelize work, emit a <fork> tag containing a YAML list of short task labels:\n\
<fork>\n\
- Refactor auth module\n\
- Add tests for user API\n\
</fork>\n\
Each fork inherits your full context and runs in parallel. You'll receive the results \
in a <fork-results> message when all children complete."
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_fork_tag_basic() {
let text = "Let me split this up.\n<fork>\n- Refactor auth\n- Add tests\n</fork>\nDone.";
assert_eq!(
parse_fork_tag(text),
Some(vec!["Refactor auth".to_string(), "Add tests".to_string()])
);
}
#[test]
fn parse_fork_tag_single_task() {
let text = "<fork>\n- Just one thing\n</fork>";
assert_eq!(
parse_fork_tag(text),
Some(vec!["Just one thing".to_string()])
);
}
#[test]
fn parse_fork_tag_no_tag() {
assert_eq!(parse_fork_tag("no fork here"), None);
}
#[test]
fn parse_fork_tag_empty_list() {
let text = "<fork>\n\n</fork>";
assert_eq!(parse_fork_tag(text), None);
}
#[test]
fn parse_fork_tag_partial() {
let text = "<fork>\n- item\n but no closing tag";
assert_eq!(parse_fork_tag(text), None);
}
#[test]
fn parse_fork_tag_extra_whitespace() {
let text = "<fork>\n - spaced out \n - another \n</fork>";
assert_eq!(
parse_fork_tag(text),
Some(vec!["spaced out".to_string(), "another".to_string()])
);
}
#[test]
fn compose_reintegration_message_success() {
let results = vec![
("Task A".to_string(), Ok("Result A".to_string())),
("Task B".to_string(), Ok("Result B".to_string())),
];
let msg = compose_reintegration_message(&results);
assert!(msg.starts_with("<fork-results>"));
assert!(msg.ends_with("</fork-results>"));
assert!(msg.contains("<task label=\"Task A\">"));
assert!(msg.contains("<![CDATA[Result A]]>"));
assert!(msg.contains("<task label=\"Task B\">"));
assert!(msg.contains("<![CDATA[Result B]]>"));
}
#[test]
fn compose_reintegration_message_with_error() {
let results = vec![
("Good".to_string(), Ok("worked".to_string())),
("Bad".to_string(), Err("process crashed".to_string())),
];
let msg = compose_reintegration_message(&results);
assert!(msg.contains("<task label=\"Good\">"));
assert!(msg.contains("<task label=\"Bad\" error=\"true\">"));
assert!(msg.contains("<![CDATA[process crashed]]>"));
}
#[test]
fn compose_reintegration_message_handles_angle_brackets() {
let results = vec![(
"Fix code".to_string(),
Ok("Changed Vec<String> to Vec<&str>".to_string()),
)];
let msg = compose_reintegration_message(&results);
assert!(msg.contains("<![CDATA[Changed Vec<String> to Vec<&str>]]>"));
}
#[test]
fn compose_reintegration_message_escapes_label() {
let results = vec![("Fix \"quotes\"".to_string(), Ok("done".to_string()))];
let msg = compose_reintegration_message(&results);
assert!(msg.contains("label=\"Fix "quotes"\""));
}
#[test]
fn fork_system_prompt_contains_tag() {
let prompt = fork_system_prompt();
assert!(prompt.contains("<fork>"));
assert!(prompt.contains("</fork>"));
assert!(prompt.contains("<fork-results>"));
}
}