use std::future::Future;
use std::pin::Pin;
use chrono::Utc;
use crate::ctx::CruxCtx;
use crate::types::error::CruxErr;
use crate::types::step::{Step, StepKind, StepStatus};
pub struct SpecArm<T> {
pub name: String,
pub fut: Pin<Box<dyn Future<Output = Result<T, CruxErr>> + Send>>,
}
pub struct SpeculationBuilder<'a, T> {
ctx: &'a mut CruxCtx,
name: String,
arms: Vec<SpecArm<T>>,
}
impl<'a, T> SpeculationBuilder<'a, T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
{
pub(crate) fn new(ctx: &'a mut CruxCtx, name: &str, arms: Vec<SpecArm<T>>) -> Self {
Self {
ctx,
name: name.to_string(),
arms,
}
}
pub async fn pick_best_by<F>(self, f: F) -> Result<T, CruxErr>
where
F: Fn(&T) -> f32,
{
trace_speculate!(&self.name, self.arms.len());
let (_ordinal, input_hash) = self.ctx.recorder_mut().next_ordinal(&self.name);
let mut completed: Vec<(String, Result<T, CruxErr>)> = Vec::new();
for arm in self.arms {
let result = arm.fut.await;
completed.push((arm.name, result));
}
let mut best_idx: Option<usize> = None;
let mut best_score: f32 = f32::NEG_INFINITY;
for (i, (_, result)) in completed.iter().enumerate() {
if let Ok(val) = result {
let score = f(val);
if score > best_score {
best_score = score;
best_idx = Some(i);
}
}
}
let Some(winner_idx) = best_idx else {
for (arm_name, result) in &completed {
let error = match result {
Err(e) => e.to_string(),
Ok(_) => unreachable!(),
};
self.ctx.push_step(Step {
name: format!("{}::{}", self.name, arm_name),
kind: StepKind::Speculation,
status: StepStatus::Err,
confidence: 0.0,
started_at: Utc::now(),
duration_ms: 0,
input_hash,
content_hash: None,
output: None,
error: Some(error),
attempt: 1,
events: vec![],
});
}
return Err(CruxErr::step_failed(
&self.name,
"all speculation arms failed",
));
};
let mut winner_val: Option<T> = None;
for (i, (arm_name, result)) in completed.into_iter().enumerate() {
if i == winner_idx {
let val = result.unwrap();
self.ctx.push_step(Step {
name: format!("{}::{}", self.name, arm_name),
kind: StepKind::Speculation,
status: StepStatus::Ok,
confidence: best_score,
started_at: Utc::now(),
duration_ms: 0,
input_hash,
content_hash: None,
output: serde_json::to_value(&val).ok(),
error: None,
attempt: 1,
events: vec![],
});
winner_val = Some(val);
} else {
let (status, output, error) = match result {
Ok(val) => (StepStatus::Rejected, serde_json::to_value(&val).ok(), None),
Err(e) => (StepStatus::Err, None, Some(e.to_string())),
};
self.ctx.push_step(Step {
name: format!("{}::{}", self.name, arm_name),
kind: StepKind::Speculation,
status,
confidence: 0.0,
started_at: Utc::now(),
duration_ms: 0,
input_hash,
content_hash: None,
output,
error,
attempt: 1,
events: vec![],
});
}
}
Ok(winner_val.unwrap())
}
pub async fn first_ok(self) -> Result<T, CruxErr> {
trace_speculate!(&self.name, self.arms.len());
let (_ordinal, input_hash) = self.ctx.recorder_mut().next_ordinal(&self.name);
let mut last_err = None;
for arm in self.arms {
match arm.fut.await {
Ok(val) => {
self.ctx.push_step(Step {
name: format!("{}::{}", self.name, arm.name),
kind: StepKind::Speculation,
status: StepStatus::Ok,
confidence: 1.0,
started_at: Utc::now(),
duration_ms: 0,
input_hash,
content_hash: None,
output: serde_json::to_value(&val).ok(),
error: None,
attempt: 1,
events: vec![],
});
return Ok(val);
}
Err(e) => {
self.ctx.push_step(Step {
name: format!("{}::{}", self.name, arm.name),
kind: StepKind::Speculation,
status: StepStatus::Rejected,
confidence: 0.0,
started_at: Utc::now(),
duration_ms: 0,
input_hash,
content_hash: None,
output: None,
error: Some(e.to_string()),
attempt: 1,
events: vec![],
});
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| CruxErr::step_failed(&self.name, "no speculation arms")))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ctx::CruxCtx;
use crate::types::error::CruxErr;
use crate::types::step::StepStatus;
fn ok_arm<T: Send + 'static>(name: &str, val: T) -> SpecArm<T> {
SpecArm {
name: name.to_string(),
fut: Box::pin(async move { Ok(val) }),
}
}
fn err_arm<T: Send + 'static>(name: &str) -> SpecArm<T> {
let name = name.to_string();
SpecArm {
name: name.clone(),
fut: Box::pin(async move { Err(CruxErr::step_failed(&name, "forced failure")) }),
}
}
#[tokio::test]
async fn pick_best_selects_highest_score() {
let mut ctx = CruxCtx::new("test");
let arms = vec![
ok_arm("low", serde_json::json!(1)),
ok_arm("high", serde_json::json!(2)),
ok_arm("mid", serde_json::json!(3)),
];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder
.pick_best_by(|v| v.as_i64().unwrap_or(0) as f32)
.await
.unwrap();
assert_eq!(result.as_i64().unwrap(), 3);
}
#[tokio::test]
async fn pick_best_all_fail_returns_err() {
let mut ctx = CruxCtx::new("test");
let arms: Vec<SpecArm<serde_json::Value>> = vec![err_arm("a"), err_arm("b"), err_arm("c")];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder.pick_best_by(|_| 1.0).await;
assert!(result.is_err(), "expected Err when all arms fail");
}
#[tokio::test]
async fn pick_best_all_fail_records_err_steps() {
let mut ctx = CruxCtx::new("test");
let arms: Vec<SpecArm<serde_json::Value>> = vec![err_arm("x"), err_arm("y")];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let _ = builder.pick_best_by(|_| 1.0).await;
let cruxx = ctx.finalize::<()>(Ok(()));
assert_eq!(cruxx.steps.len(), 2);
assert!(cruxx.steps.iter().all(|s| s.status == StepStatus::Err));
}
#[tokio::test]
async fn pick_best_winner_ok_losers_rejected() {
let mut ctx = CruxCtx::new("test");
let arms = vec![
ok_arm("winner", serde_json::json!(10)),
ok_arm("loser", serde_json::json!(1)),
];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let _ = builder
.pick_best_by(|v| v.as_i64().unwrap_or(0) as f32)
.await
.unwrap();
let cruxx = ctx.finalize::<()>(Ok(()));
assert_eq!(cruxx.steps.len(), 2);
let winner = cruxx.steps.iter().find(|s| s.status == StepStatus::Ok);
let loser = cruxx
.steps
.iter()
.find(|s| s.status == StepStatus::Rejected);
assert!(winner.is_some(), "expected one Ok step");
assert!(loser.is_some(), "expected one Rejected step");
}
#[tokio::test]
async fn pick_best_tie_break_favors_first_arm() {
let mut ctx = CruxCtx::new("test");
let arms = vec![
ok_arm("first", serde_json::json!("a")),
ok_arm("second", serde_json::json!("b")),
];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder.pick_best_by(|_| 5.0).await.unwrap();
assert_eq!(result.as_str().unwrap(), "a");
}
#[tokio::test]
async fn first_ok_returns_first_success() {
let mut ctx = CruxCtx::new("test");
let arms = vec![
err_arm("fail1"),
ok_arm("pass", serde_json::json!("found")),
ok_arm("never", serde_json::json!("skipped")),
];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder.first_ok().await.unwrap();
assert_eq!(result.as_str().unwrap(), "found");
}
#[tokio::test]
async fn first_ok_all_fail_returns_last_err() {
let mut ctx = CruxCtx::new("test");
let arms: Vec<SpecArm<serde_json::Value>> = vec![err_arm("a"), err_arm("b")];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder.first_ok().await;
assert!(result.is_err());
}
#[tokio::test]
async fn first_ok_empty_arms_returns_err() {
let mut ctx = CruxCtx::new("test");
let arms: Vec<SpecArm<serde_json::Value>> = vec![];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let result = builder.first_ok().await;
assert!(result.is_err());
}
#[tokio::test]
async fn first_ok_records_failed_as_rejected() {
let mut ctx = CruxCtx::new("test");
let arms = vec![err_arm("fail1"), ok_arm("pass", serde_json::json!(true))];
let builder = SpeculationBuilder::new(&mut ctx, "spec", arms);
let _ = builder.first_ok().await.unwrap();
let cruxx = ctx.finalize::<()>(Ok(()));
assert_eq!(cruxx.steps.len(), 2);
let rejected = cruxx
.steps
.iter()
.find(|s| s.status == StepStatus::Rejected);
assert!(rejected.is_some());
let ok = cruxx.steps.iter().find(|s| s.status == StepStatus::Ok);
assert!(ok.is_some());
}
}