use crate::recursive::best_of::{self, FnScorer, Scorer};
use crate::recursive::ensemble::{self, Aggregate};
use crate::recursive::executor::{CodeExecutor, DynCodeExecutor};
use crate::recursive::llm::Llm;
use crate::recursive::program;
use crate::recursive::reason;
use crate::recursive::refine;
use crate::recursive::rewrite::extract_code;
use crate::recursive::shared;
use crate::recursive::validate::Validate;
use futures::stream::{FuturesUnordered, StreamExt};
use std::borrow::Cow;
use std::future::Future;
use std::time::{Duration, Instant};
pub fn pipeline<'a, L: Llm>(llm: &'a L, prompt: &'a str) -> Pipeline<'a, L> {
Pipeline::new(llm, prompt)
}
struct FanOutBranch {
name: String,
steps: Vec<PipelineStep>,
}
enum PipelineStep {
Refine {
validator: Box<dyn Validate>,
max_iter: u32,
target: f64,
},
Extract { lang: String },
BestOf {
n: usize,
scorer: Option<Box<dyn Scorer>>,
},
Ensemble { n: usize, aggregate: Aggregate },
Reason,
Program { executor: Box<dyn DynCodeExecutor> },
Map {
f: Box<dyn Fn(String) -> String + Send + Sync>,
},
FanOut {
branches: Vec<FanOutBranch>,
merge: MergeStrategy,
},
Nested { steps: Vec<PipelineStep> },
}
#[derive(Debug, Clone)]
pub struct StepResult {
pub name: String,
pub input: String,
pub output: String,
pub score: Option<f64>,
pub tokens: u32,
pub elapsed: Duration,
}
#[derive(Debug, Clone)]
pub struct PipelineResult {
pub output: String,
pub steps: Vec<StepResult>,
pub total_tokens: u32,
pub elapsed: Duration,
}
#[derive(Debug, Clone)]
pub enum PipelineEvent {
StepStart {
index: usize,
name: String,
},
StepComplete {
index: usize,
result: StepResult,
},
FanOutBranchStart {
step_index: usize,
branch_index: usize,
branch_name: String,
},
FanOutBranchComplete {
step_index: usize,
branch_result: FanOutBranchResult,
},
Complete(PipelineResult),
}
pub enum MergeStrategy {
FirstSuccess,
BestScore,
Concat {
separator: String,
},
Custom(Box<dyn Fn(&[FanOutBranchResult]) -> String + Send + Sync>),
}
impl Clone for MergeStrategy {
fn clone(&self) -> Self {
match self {
Self::FirstSuccess => Self::FirstSuccess,
Self::BestScore => Self::BestScore,
Self::Concat { separator } => Self::Concat {
separator: separator.clone(),
},
Self::Custom(_) => {
panic!("MergeStrategy::Custom cannot be cloned; use a named variant or build a new Custom")
}
}
}
}
impl std::fmt::Debug for MergeStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FirstSuccess => write!(f, "FirstSuccess"),
Self::BestScore => write!(f, "BestScore"),
Self::Concat { separator } => write!(f, "Concat({:?})", separator),
Self::Custom(_) => write!(f, "Custom(...)"),
}
}
}
#[derive(Debug, Clone)]
pub struct FanOutBranchResult {
pub index: usize,
pub name: String,
pub output: String,
pub score: Option<f64>,
pub tokens: u32,
pub elapsed: Duration,
}
pub struct BranchBuilder {
name: String,
steps: Vec<PipelineStep>,
}
impl BranchBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
steps: Vec::new(),
}
}
pub fn refine(mut self, validator: impl Validate + 'static) -> Self {
self.steps.push(PipelineStep::Refine {
validator: Box::new(validator),
max_iter: 5,
target: 1.0,
});
self
}
pub fn refine_with(
mut self,
validator: impl Validate + 'static,
max_iter: u32,
target: f64,
) -> Self {
self.steps.push(PipelineStep::Refine {
validator: Box::new(validator),
max_iter,
target,
});
self
}
pub fn extract(mut self, lang: &str) -> Self {
self.steps.push(PipelineStep::Extract {
lang: lang.to_string(),
});
self
}
pub fn best_of(mut self, n: usize) -> Self {
self.steps.push(PipelineStep::BestOf { n, scorer: None });
self
}
pub fn ensemble(mut self, n: usize) -> Self {
self.steps.push(PipelineStep::Ensemble {
n,
aggregate: Aggregate::MajorityVote,
});
self
}
pub fn reason(mut self) -> Self {
self.steps.push(PipelineStep::Reason);
self
}
pub fn map(mut self, f: impl Fn(String) -> String + Send + Sync + 'static) -> Self {
self.steps.push(PipelineStep::Map { f: Box::new(f) });
self
}
fn build(self) -> FanOutBranch {
FanOutBranch {
name: self.name,
steps: self.steps,
}
}
}
pub struct FanOutCollector {
branches: Vec<FanOutBranch>,
}
impl FanOutCollector {
fn new() -> Self {
Self {
branches: Vec::new(),
}
}
pub fn branch<F>(mut self, name: &str, f: F) -> Self
where
F: FnOnce(BranchBuilder) -> BranchBuilder,
{
let builder = f(BranchBuilder::new(name));
self.branches.push(builder.build());
self
}
}
pub struct Pipeline<'a, L: Llm> {
llm: &'a L,
prompt: Cow<'a, str>,
steps: Vec<PipelineStep>,
}
impl<'a, L: Llm> Pipeline<'a, L> {
pub fn new(llm: &'a L, prompt: &'a str) -> Self {
Self {
llm,
prompt: Cow::Borrowed(prompt),
steps: Vec::new(),
}
}
pub fn new_owned(llm: &'a L, prompt: String) -> Self {
Self {
llm,
prompt: Cow::Owned(prompt),
steps: Vec::new(),
}
}
pub fn refine(mut self, validator: impl Validate + 'static) -> Self {
self.steps.push(PipelineStep::Refine {
validator: Box::new(validator),
max_iter: 5,
target: 1.0,
});
self
}
pub fn refine_with(
mut self,
validator: impl Validate + 'static,
max_iter: u32,
target: f64,
) -> Self {
self.steps.push(PipelineStep::Refine {
validator: Box::new(validator),
max_iter,
target,
});
self
}
pub fn extract(mut self, lang: &str) -> Self {
self.steps.push(PipelineStep::Extract {
lang: lang.to_string(),
});
self
}
pub fn best_of(mut self, n: usize) -> Self {
self.steps.push(PipelineStep::BestOf { n, scorer: None });
self
}
pub fn best_of_scored(
mut self,
n: usize,
scorer: impl Fn(&str) -> f64 + Send + Sync + 'static,
) -> Self {
self.steps.push(PipelineStep::BestOf {
n,
scorer: Some(Box::new(FnScorer(scorer))),
});
self
}
pub fn ensemble(mut self, n: usize) -> Self {
self.steps.push(PipelineStep::Ensemble {
n,
aggregate: Aggregate::MajorityVote,
});
self
}
pub fn ensemble_with(mut self, n: usize, aggregate: Aggregate) -> Self {
self.steps.push(PipelineStep::Ensemble { n, aggregate });
self
}
pub fn reason(mut self) -> Self {
self.steps.push(PipelineStep::Reason);
self
}
pub fn program(mut self, executor: impl CodeExecutor + 'static) -> Self {
self.steps.push(PipelineStep::Program {
executor: Box::new(executor),
});
self
}
pub fn map(mut self, f: impl Fn(String) -> String + Send + Sync + 'static) -> Self {
self.steps.push(PipelineStep::Map { f: Box::new(f) });
self
}
pub fn fan_out(mut self, branches: Vec<BranchBuilder>, merge: MergeStrategy) -> Self {
self.steps.push(PipelineStep::FanOut {
branches: branches.into_iter().map(|b| b.build()).collect(),
merge,
});
self
}
pub fn fan_out_with<F>(mut self, merge: MergeStrategy, f: F) -> Self
where
F: FnOnce(FanOutCollector) -> FanOutCollector,
{
let collector = f(FanOutCollector::new());
self.steps.push(PipelineStep::FanOut {
branches: collector.branches,
merge,
});
self
}
pub fn go(self) -> PipelineResult {
shared::block_on(self.run())
}
pub fn run_stream(self) -> impl futures::stream::Stream<Item = PipelineEvent> + 'a {
async_stream::stream! {
let start = Instant::now();
let mut current_output = String::new();
let mut step_results: Vec<StepResult> = Vec::with_capacity(self.steps.len());
let mut total_tokens: u32 = 0;
let llm = self.llm;
let prompt = self.prompt;
let pipeline_steps = self.steps;
if pipeline_steps.is_empty() {
let step_start = Instant::now();
if let Ok(output) = llm.generate(&prompt, "", None).await {
let tokens = output.total_tokens();
total_tokens += tokens;
current_output = output.text.to_string();
let sr = StepResult {
name: "generate".to_string(),
input: prompt.to_string(),
output: current_output.clone(),
score: None,
tokens,
elapsed: step_start.elapsed(),
};
step_results.push(sr.clone());
yield PipelineEvent::StepComplete { index: 0, result: sr };
}
}
for (i, step) in pipeline_steps.into_iter().enumerate() {
let step_name = match &step {
PipelineStep::Refine { .. } => "refine",
PipelineStep::Extract { .. } => "extract",
PipelineStep::BestOf { .. } => "best_of",
PipelineStep::Ensemble { .. } => "ensemble",
PipelineStep::Reason => "reason",
PipelineStep::Program { .. } => "program",
PipelineStep::Map { .. } => "map",
PipelineStep::FanOut { .. } => "fan_out",
PipelineStep::Nested { .. } => "nested",
};
yield PipelineEvent::StepStart { index: i, name: step_name.to_string() };
let step_start = Instant::now();
let input = if i == 0 {
prompt.to_string()
} else {
current_output.clone()
};
let (name, step_output) = match step {
PipelineStep::Refine { validator, max_iter, target } => (
"refine",
run_refine(llm, &input, validator, max_iter, target).await,
),
PipelineStep::Extract { lang } => {
("extract", run_extract(&input, &lang))
}
PipelineStep::BestOf { n, scorer } => {
("best_of", run_best_of(llm, &input, n, scorer).await)
}
PipelineStep::Ensemble { n, aggregate } => {
("ensemble", run_ensemble(llm, &input, n, aggregate).await)
}
PipelineStep::Reason => ("reason", run_reason(llm, &input).await),
PipelineStep::Program { executor } => {
("program", run_program(llm, &input, executor).await)
}
PipelineStep::Map { f } => {
let output = f(input.clone());
("map", StepOutput { output, score: None, tokens: 0 })
}
PipelineStep::FanOut { branches, merge } => {
("fan_out", run_fan_out(llm, &input, branches, &merge).await)
}
PipelineStep::Nested { steps } => {
("nested", {
let br = run_branch_steps(llm, input.clone(), steps).await;
StepOutput { output: br.output, score: br.best_score, tokens: br.tokens }
})
}
};
total_tokens += step_output.tokens;
current_output = step_output.output.clone();
let sr = StepResult {
name: name.to_string(),
input,
output: step_output.output,
score: step_output.score,
tokens: step_output.tokens,
elapsed: step_start.elapsed(),
};
step_results.push(sr.clone());
yield PipelineEvent::StepComplete { index: i, result: sr };
}
yield PipelineEvent::Complete(PipelineResult {
output: current_output,
steps: step_results,
total_tokens,
elapsed: start.elapsed(),
});
}
}
pub fn nest(mut self, _inner: Pipeline<'a, L>) -> Self {
self.steps.push(PipelineStep::Nested {
steps: _inner.steps,
});
self
}
pub async fn run(self) -> PipelineResult {
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("pipeline", steps = self.steps.len()).entered();
let start = Instant::now();
let mut current_output = String::new();
let mut step_results: Vec<StepResult> = Vec::with_capacity(self.steps.len());
let mut total_tokens: u32 = 0;
let llm = self.llm;
let prompt = self.prompt;
let pipeline_steps = self.steps;
if pipeline_steps.is_empty() {
let step_start = Instant::now();
if let Ok(output) = llm.generate(&prompt, "", None).await {
let tokens = output.total_tokens();
total_tokens += tokens;
current_output = output.text.to_string();
step_results.push(StepResult {
name: "generate".to_string(),
input: prompt.to_string(),
output: current_output.clone(),
score: None,
tokens,
elapsed: step_start.elapsed(),
});
}
}
for (i, step) in pipeline_steps.into_iter().enumerate() {
let step_start = Instant::now();
let input = if i == 0 {
prompt.to_string()
} else {
current_output.clone()
};
let (name, step_output) = match step {
PipelineStep::Refine {
validator,
max_iter,
target,
} => (
"refine",
run_refine(llm, &input, validator, max_iter, target).await,
),
PipelineStep::Extract { lang } => ("extract", run_extract(&input, &lang)),
PipelineStep::BestOf { n, scorer } => {
("best_of", run_best_of(llm, &input, n, scorer).await)
}
PipelineStep::Ensemble { n, aggregate } => {
("ensemble", run_ensemble(llm, &input, n, aggregate).await)
}
PipelineStep::Reason => ("reason", run_reason(llm, &input).await),
PipelineStep::Program { executor } => {
("program", run_program(llm, &input, executor).await)
}
PipelineStep::Map { f } => {
let output = f(input.clone());
(
"map",
StepOutput {
output,
score: None,
tokens: 0,
},
)
}
PipelineStep::FanOut { branches, merge } => {
("fan_out", run_fan_out(llm, &input, branches, &merge).await)
}
PipelineStep::Nested { steps } => ("nested", {
let br = run_branch_steps(llm, input.clone(), steps).await;
StepOutput {
output: br.output,
score: br.best_score,
tokens: br.tokens,
}
}),
};
total_tokens += step_output.tokens;
current_output = step_output.output.clone();
let step_elapsed = step_start.elapsed();
#[cfg(feature = "tracing")]
tracing::debug!(
step = name,
step_index = i,
score = ?step_output.score,
tokens = step_output.tokens,
elapsed_ms = step_elapsed.as_millis() as u64,
"pipeline step complete"
);
step_results.push(StepResult {
name: name.to_string(),
input,
output: step_output.output,
score: step_output.score,
tokens: step_output.tokens,
elapsed: step_elapsed,
});
}
PipelineResult {
output: current_output,
steps: step_results,
total_tokens,
elapsed: start.elapsed(),
}
}
}
pub struct PipelineAsStep<'a, L: Llm> {
llm: &'a L,
steps: std::sync::Mutex<Option<Vec<PipelineStep>>>,
}
struct AssertSend<F>(F);
unsafe impl<F: Future> Send for AssertSend<F> {}
impl<F: Future> Future for AssertSend<F> {
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.0) };
inner.poll(cx)
}
}
impl<L: Llm + 'static> crate::recursive::step::DynStep for PipelineAsStep<'_, L> {
fn run_dyn<'b>(
&'b self,
input: &'b str,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = crate::error::Result<crate::recursive::step::StepOutput>,
> + Send
+ 'b,
>,
> {
let steps = self
.steps
.lock()
.expect("pipeline steps lock poisoned")
.take()
.unwrap_or_default();
let llm = self.llm;
let input_owned = input.to_string();
Box::pin(AssertSend(async move {
let start = std::time::Instant::now();
let result = run_branch_steps(llm, input_owned, steps).await;
Ok(crate::recursive::step::StepOutput::new(
result.output,
result.best_score.unwrap_or(1.0),
result.tokens,
)
.with_meta("elapsed_ms", start.elapsed().as_millis().to_string()))
}))
}
fn dyn_name(&self) -> &'static str {
"pipeline"
}
}
impl<'a, L: Llm> Pipeline<'a, L> {
pub fn as_step(self) -> PipelineAsStep<'a, L> {
PipelineAsStep {
llm: self.llm,
steps: std::sync::Mutex::new(Some(self.steps)),
}
}
}
async fn run_refine<L: Llm>(
llm: &L,
input: &str,
validator: Box<dyn Validate>,
max_iter: u32,
target: f64,
) -> StepOutput {
let result = refine::refine(llm, input)
.validate(validator)
.max_iter(max_iter)
.target(target)
.run()
.await;
match result {
Ok(r) => StepOutput {
output: r.output,
score: Some(r.score),
tokens: 0,
},
Err(_) => StepOutput {
output: input.to_string(),
score: Some(0.0),
tokens: 0,
},
}
}
fn run_extract(input: &str, lang: &str) -> StepOutput {
let extracted = extract_code(input, lang)
.map(|s| s.to_string())
.unwrap_or_else(|| input.to_string());
StepOutput {
output: extracted,
score: None,
tokens: 0,
}
}
async fn run_best_of<L: Llm>(
llm: &L,
input: &str,
n: usize,
scorer: Option<Box<dyn Scorer>>,
) -> StepOutput {
if let Some(s) = scorer {
let result = best_of::best_of(llm, input)
.n(n)
.metric(move |text| s.score(text))
.run()
.await;
StepOutput {
output: result.output,
score: Some(result.score),
tokens: result.tokens,
}
} else {
let result = best_of::best_of(llm, input).n(n).run().await;
StepOutput {
output: result.output,
score: Some(result.score),
tokens: result.tokens,
}
}
}
async fn run_ensemble<L: Llm>(llm: &L, input: &str, n: usize, aggregate: Aggregate) -> StepOutput {
let result = ensemble::ensemble(llm, input)
.n(n)
.aggregate(aggregate)
.run()
.await;
StepOutput {
output: result.output,
score: Some(result.agreement_ratio),
tokens: result.tokens,
}
}
async fn run_reason<L: Llm>(llm: &L, input: &str) -> StepOutput {
let result = reason::reason(llm, input).run().await;
StepOutput {
output: result.output,
score: Some(result.score),
tokens: result.tokens,
}
}
async fn run_program<L: Llm>(
llm: &L,
input: &str,
executor: Box<dyn DynCodeExecutor>,
) -> StepOutput {
let result = program::program(llm, input)
.executor_dyn(executor)
.run()
.await;
StepOutput {
output: result.output,
score: if result.success { Some(1.0) } else { Some(0.0) },
tokens: result.tokens,
}
}
async fn run_fan_out<L: Llm>(
llm: &L,
input: &str,
branches: Vec<FanOutBranch>,
merge: &MergeStrategy,
) -> StepOutput {
let mut futs = FuturesUnordered::new();
for (idx, branch) in branches.into_iter().enumerate() {
let input_owned = input.to_string();
let name = branch.name;
let steps = branch.steps;
futs.push(async move {
let result = run_branch_steps(llm, input_owned, steps).await;
(idx, name, result)
});
}
let mut branch_results: Vec<FanOutBranchResult> = Vec::new();
while let Some((idx, name, sub_output)) = futs.next().await {
branch_results.push(FanOutBranchResult {
index: idx,
name,
output: sub_output.output,
score: sub_output.best_score,
tokens: sub_output.tokens,
elapsed: sub_output.elapsed,
});
}
branch_results.sort_by_key(|b| b.index);
let total_tokens: u32 = branch_results.iter().map(|b| b.tokens).sum();
let best_score = branch_results
.iter()
.filter_map(|b| b.score)
.fold(None, |acc: Option<f64>, s| {
Some(acc.map_or(s, |a| a.max(s)))
});
let merged_output = apply_merge(&branch_results, merge);
StepOutput {
output: merged_output,
score: best_score,
tokens: total_tokens,
}
}
struct BranchOutput {
output: String,
best_score: Option<f64>,
tokens: u32,
elapsed: Duration,
}
async fn run_branch_steps<L: Llm>(
llm: &L,
input: String,
steps: Vec<PipelineStep>,
) -> BranchOutput {
let start = Instant::now();
let mut current_output = input.clone();
let mut total_tokens: u32 = 0;
let mut best_score: Option<f64> = None;
if steps.is_empty() {
if let Ok(output) = llm.generate(&input, "", None).await {
total_tokens += output.total_tokens();
current_output = output.text.to_string();
}
}
for step in steps {
let step_input = current_output.clone();
let step_output = match step {
PipelineStep::Refine {
validator,
max_iter,
target,
} => run_refine(llm, &step_input, validator, max_iter, target).await,
PipelineStep::Extract { lang } => run_extract(&step_input, &lang),
PipelineStep::BestOf { n, scorer } => run_best_of(llm, &step_input, n, scorer).await,
PipelineStep::Ensemble { n, aggregate } => {
run_ensemble(llm, &step_input, n, aggregate).await
}
PipelineStep::Reason => run_reason(llm, &step_input).await,
PipelineStep::Program { executor } => run_program(llm, &step_input, executor).await,
PipelineStep::Map { f } => StepOutput {
output: f(step_input.clone()),
score: None,
tokens: 0,
},
PipelineStep::FanOut { branches, merge } => {
run_fan_out(llm, &step_input, branches, &merge).await
}
PipelineStep::Nested { steps } => {
let br = Box::pin(run_branch_steps(llm, step_input.clone(), steps)).await;
StepOutput {
output: br.output,
score: br.best_score,
tokens: br.tokens,
}
}
};
total_tokens += step_output.tokens;
current_output = step_output.output;
if let Some(s) = step_output.score {
best_score = Some(best_score.map_or(s, |prev: f64| prev.max(s)));
}
}
BranchOutput {
output: current_output,
best_score,
tokens: total_tokens,
elapsed: start.elapsed(),
}
}
fn apply_merge(results: &[FanOutBranchResult], strategy: &MergeStrategy) -> String {
match strategy {
MergeStrategy::FirstSuccess => results
.iter()
.find(|r| !r.output.is_empty())
.map(|r| r.output.clone())
.unwrap_or_default(),
MergeStrategy::BestScore => results
.iter()
.filter(|r| r.score.is_some())
.max_by(|a, b| {
a.score
.unwrap()
.partial_cmp(&b.score.unwrap())
.unwrap_or(std::cmp::Ordering::Equal)
})
.or_else(|| results.first())
.map(|r| r.output.clone())
.unwrap_or_default(),
MergeStrategy::Concat { separator } => results
.iter()
.map(|r| r.output.as_str())
.collect::<Vec<_>>()
.join(separator),
MergeStrategy::Custom(f) => f(results),
}
}
struct StepOutput {
output: String,
score: Option<f64>,
tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::recursive::checks::checks;
use crate::recursive::llm::MockLlm;
#[test]
fn test_pipeline_single_refine() {
let llm = MockLlm::new(|_, feedback| {
if feedback.is_some() {
"fn add(a: i32, b: i32) -> i32 { a + b }".to_string()
} else {
"fn add(a, b) { a + b }".to_string()
}
});
let result = pipeline(&llm, "Write an add function")
.refine(checks().require("->"))
.go();
assert!(result.output.contains("->"));
assert_eq!(result.steps.len(), 1);
assert!(result.steps[0].score.unwrap() > 0.0);
}
#[test]
fn test_pipeline_refine_then_extract() {
let llm = MockLlm::new(|_, _| {
"Here's the code:\n```rust\nfn main() { println!(\"hello\"); }\n```".to_string()
});
let result = pipeline(&llm, "Write a hello world")
.refine(checks().require("```"))
.extract("rust")
.go();
assert!(result.output.contains("fn main"));
assert!(!result.output.contains("```"));
assert_eq!(result.steps.len(), 2);
}
#[test]
fn test_pipeline_map() {
let llm = MockLlm::new(|_, _| " hello world ".to_string());
let result = pipeline(&llm, "greet")
.refine(checks().min_len(1))
.map(|s| s.trim().to_uppercase())
.go();
assert_eq!(result.output, "HELLO WORLD");
}
#[test]
fn test_pipeline_best_of() {
let llm = MockLlm::new(|_, _| "candidate output".to_string());
let result = pipeline(&llm, "generate something").best_of(3).go();
assert_eq!(result.output, "candidate output");
assert_eq!(result.steps.len(), 1);
}
#[test]
fn test_pipeline_ensemble() {
let llm = MockLlm::new(|_, _| "Paris".to_string());
let result = pipeline(&llm, "Capital of France?").ensemble(3).go();
assert!(result.output.contains("Paris"));
}
#[test]
fn test_pipeline_reason() {
let llm = MockLlm::new(|_, _| {
"Let me think step by step.\n1. 2^10 = 1024\n\nTherefore: 1024".to_string()
});
let result = pipeline(&llm, "What is 2^10?").reason().go();
assert!(result.output.contains("1024"));
}
#[test]
fn test_pipeline_empty() {
let llm = MockLlm::new(|_, _| "direct output".to_string());
let result = pipeline(&llm, "just generate").go();
assert_eq!(result.output, "direct output");
assert_eq!(result.steps.len(), 1);
}
#[test]
fn test_pipeline_multi_step() {
let call_count = std::sync::atomic::AtomicU32::new(0);
let llm = MockLlm::new(|_prompt, _| {
let n = call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if n == 0 {
"```python\ndef solve(): return 42\n```".to_string()
} else {
"42".to_string()
}
});
let result = pipeline(&llm, "Write a solver")
.refine(checks().require("```"))
.extract("python")
.go();
assert!(result.output.contains("def solve"));
assert!(!result.output.contains("```"));
assert_eq!(result.steps.len(), 2);
}
#[test]
fn test_pipeline_result_tokens() {
let llm = MockLlm::new(|_, _| "output".to_string());
let result = pipeline(&llm, "prompt")
.refine(checks().require("output"))
.go();
assert_eq!(result.total_tokens, 0);
}
#[test]
fn test_pipeline_elapsed() {
let llm = MockLlm::new(|_, _| "fast".to_string());
let result = pipeline(&llm, "prompt")
.refine(checks().require("fast"))
.go();
assert!(result.elapsed.as_nanos() > 0);
}
#[test]
fn test_pipeline_best_of_scored() {
let llm = MockLlm::new(|_, _| "a]b long enough text here".to_string());
let result = pipeline(&llm, "generate")
.best_of_scored(3, |text: &str| text.len() as f64 / 100.0)
.go();
assert!(!result.output.is_empty());
assert!(result.steps[0].score.unwrap() > 0.0);
}
#[test]
fn test_pipeline_fan_out_first_success() {
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("A:") {
"result_a".to_string()
} else {
"result_b".to_string()
}
});
let result = pipeline(&llm, "base prompt")
.fan_out(
vec![
BranchBuilder::new("branch_a").map(|s| format!("A: {}", s)),
BranchBuilder::new("branch_b").map(|s| format!("B: {}", s)),
],
MergeStrategy::FirstSuccess,
)
.go();
assert_eq!(result.output, "A: base prompt");
assert_eq!(result.steps.len(), 1);
assert_eq!(result.steps[0].name, "fan_out");
}
#[test]
fn test_pipeline_fan_out_concat() {
let llm = MockLlm::new(|_, _| "output".to_string());
let result = pipeline(&llm, "input")
.fan_out(
vec![
BranchBuilder::new("x").map(|s| format!("X:{}", s)),
BranchBuilder::new("y").map(|s| format!("Y:{}", s)),
],
MergeStrategy::Concat {
separator: "|".to_string(),
},
)
.go();
assert_eq!(result.output, "X:input|Y:input");
}
#[test]
fn test_pipeline_fan_out_best_score() {
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("good") {
"fn add(a: i32, b: i32) -> i32 { a + b }".to_string()
} else {
"bad output".to_string()
}
});
let result = pipeline(&llm, "Write code")
.fan_out(
vec![
BranchBuilder::new("good_branch")
.map(|s| format!("good {}", s))
.refine(checks().require("fn ").require("->")),
BranchBuilder::new("bad_branch").refine(checks().require("impossible_string")),
],
MergeStrategy::BestScore,
)
.go();
assert!(result.output.contains("fn add"));
}
#[test]
fn test_pipeline_fan_out_with_closure() {
let llm = MockLlm::new(|_, _| "output".to_string());
let result = pipeline(&llm, "input")
.fan_out_with(
MergeStrategy::Concat {
separator: " + ".to_string(),
},
|fan| {
fan.branch("first", |b| b.map(|s| format!("1:{}", s)))
.branch("second", |b| b.map(|s| format!("2:{}", s)))
},
)
.go();
assert_eq!(result.output, "1:input + 2:input");
}
#[test]
fn test_pipeline_fan_out_custom_merge() {
let llm = MockLlm::new(|_, _| "out".to_string());
let result = pipeline(&llm, "in")
.fan_out(
vec![
BranchBuilder::new("a").map(|_| "alpha".to_string()),
BranchBuilder::new("b").map(|_| "beta".to_string()),
],
MergeStrategy::Custom(Box::new(|results| {
results
.iter()
.map(|r| format!("[{}]", r.output))
.collect::<Vec<_>>()
.join(",")
})),
)
.go();
assert_eq!(result.output, "[alpha],[beta]");
}
#[test]
fn test_pipeline_fan_out_with_refine_branches() {
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("rust") {
"fn solve() -> i32 { 42 }".to_string()
} else {
"def solve(): return 42".to_string()
}
});
let result = pipeline(&llm, "Write a solve function")
.fan_out(
vec![
BranchBuilder::new("rust")
.map(|s| format!("{} in rust", s))
.refine(checks().require("fn ")),
BranchBuilder::new("python")
.map(|s| format!("{} in python", s))
.refine(checks().require("def ")),
],
MergeStrategy::Concat {
separator: "\n---\n".to_string(),
},
)
.go();
assert!(result.output.contains("fn solve"));
assert!(result.output.contains("def solve"));
assert!(result.output.contains("---"));
}
#[test]
fn test_pipeline_fan_out_empty_branches() {
let llm = MockLlm::new(|_, _| "generated".to_string());
let result = pipeline(&llm, "input")
.fan_out(vec![], MergeStrategy::FirstSuccess)
.go();
assert_eq!(result.output, "");
}
#[test]
fn test_pipeline_fan_out_single_branch() {
let llm = MockLlm::new(|_, _| "single".to_string());
let result = pipeline(&llm, "input")
.fan_out(
vec![BranchBuilder::new("only").map(|s| format!("mapped:{}", s))],
MergeStrategy::FirstSuccess,
)
.go();
assert_eq!(result.output, "mapped:input");
}
#[test]
fn test_pipeline_nest() {
let llm = MockLlm::new(|_, _| "inner output".to_string());
let inner = pipeline(&llm, "ignored inner prompt");
let result = pipeline(&llm, "outer prompt")
.refine(checks().min_len(1))
.nest(inner)
.go();
assert!(!result.output.is_empty());
assert_eq!(result.steps.len(), 2);
assert_eq!(result.steps[1].name, "nested");
}
#[test]
fn test_pipeline_as_step() {
use crate::recursive::step::DynStep;
let llm = MockLlm::new(|_, _| "step output".to_string());
let step = pipeline(&llm, "ignored")
.map(|s| format!("mapped: {}", s))
.as_step();
let result = shared::block_on(step.run_dyn("my input"));
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.text, "mapped: my input");
}
}