use futures::StreamExt;
use crate::provider::{Brain, BrainEvent, BrainRequest, ContentBlock, Msg, PromptCacheConfig};
use crate::router::TaskTier;
const CRITIC_SYSTEM: &str = "You are a ruthless senior reviewer. Find the single most important flaw in the answer below: a wrong claim, a missing case, an unproven assertion, a violated constraint, or an unclear step. Be concrete and specific. If — and only if — the answer is genuinely correct, complete, and well-justified, reply with exactly the word APPROVED and nothing else.";
const JUDGE_SYSTEM: &str = "You are an impartial judge. You will see a task and several candidate answers. Pick the single best one: most correct, complete, and well-justified. Reply with ONLY its index number, nothing else.";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReasoningBudget {
pub samples: usize,
pub refine_rounds: usize,
}
impl ReasoningBudget {
pub fn for_tier(tier: &TaskTier) -> Self {
match tier {
TaskTier::Trivial | TaskTier::Small => Self {
samples: 1,
refine_rounds: 0,
},
TaskTier::Medium | TaskTier::Vision => Self {
samples: 1,
refine_rounds: 1,
},
TaskTier::Hard => Self {
samples: 3,
refine_rounds: 2,
},
}
}
pub fn scaled_for_capability(self, band: u8) -> Self {
match band {
0..=2 => Self {
samples: (self.samples + 2).min(6),
refine_rounds: (self.refine_rounds + 1).min(4),
},
3 => Self {
samples: (self.samples + 1).min(6),
refine_rounds: self.refine_rounds,
},
_ => self,
}
}
pub fn max_calls(&self) -> usize {
let draft_calls = self.samples + if self.samples > 1 { 1 } else { 0 };
draft_calls + self.refine_rounds * 2
}
}
fn user_req(system: &str, user: &str, temperature: f32, max_tokens: u32) -> BrainRequest {
BrainRequest {
system: Some(system.to_string()),
messages: vec![Msg {
role: "user".into(),
content: vec![ContentBlock::Text {
text: user.to_string(),
}],
}],
tools: vec![],
max_tokens,
temperature,
stop: vec![],
cache: PromptCacheConfig::disabled(),
}
}
pub async fn collect_text(brain: &dyn Brain, req: BrainRequest) -> Option<String> {
let mut stream = brain.complete(req).await.ok()?;
let mut out = String::new();
while let Some(ev) = stream.next().await {
match ev {
BrainEvent::TextDelta(t) => out.push_str(&t),
BrainEvent::Done(_) => break,
BrainEvent::Error(_) => return None,
_ => {}
}
}
Some(out)
}
fn is_approved(critique: &str) -> bool {
let t = critique.trim().to_ascii_uppercase();
t == "APPROVED" || t.starts_with("APPROVED")
}
fn parse_index(s: &str, n: usize) -> usize {
let digits: String = s
.chars()
.skip_while(|c| !c.is_ascii_digit())
.take_while(|c| c.is_ascii_digit())
.collect();
digits.parse::<usize>().ok().filter(|&i| i < n).unwrap_or(0)
}
pub async fn select_best(brain: &dyn Brain, task: &str, candidates: &[String]) -> usize {
if candidates.len() <= 1 {
return 0;
}
let mut listing = String::new();
for (i, c) in candidates.iter().enumerate() {
listing.push_str(&format!("[{i}]\n{}\n---\n", c.trim()));
}
let prompt = format!(
"Task:\n{task}\n\nCandidate answers:\n{listing}\nReply with ONLY the index number of the single best candidate."
);
match collect_text(brain, user_req(JUDGE_SYSTEM, &prompt, 0.0, 8)).await {
Some(out) => parse_index(&out, candidates.len()),
None => 0,
}
}
pub async fn best_of_n(brain: &dyn Brain, system: &str, task: &str, n: usize) -> Option<String> {
let n = n.max(1);
let mut candidates = Vec::with_capacity(n);
for i in 0..n {
let temperature = if n == 1 { 0.0 } else { 0.2 + 0.2 * i as f32 };
if let Some(c) = collect_text(brain, user_req(system, task, temperature, 2048)).await {
if !c.trim().is_empty() {
candidates.push(c);
}
}
}
match candidates.len() {
0 => None,
1 => candidates.pop(),
_ => {
let idx = select_best(brain, task, &candidates).await;
candidates.into_iter().nth(idx)
}
}
}
pub async fn self_refine_from(
brain: &dyn Brain,
system: &str,
task: &str,
mut answer: String,
rounds: usize,
) -> Option<String> {
for _ in 0..rounds {
let critique_prompt =
format!("Task:\n{task}\n\nAnswer to review:\n{answer}\n\nYour critique:");
let critique =
collect_text(brain, user_req(CRITIC_SYSTEM, &critique_prompt, 0.0, 512)).await?;
if is_approved(&critique) {
break;
}
let revise_prompt = format!(
"Task:\n{task}\n\nYour previous answer:\n{answer}\n\nA reviewer raised this issue:\n{critique}\n\nProduce an improved answer that fully fixes it. Output only the improved answer.",
);
answer = collect_text(brain, user_req(system, &revise_prompt, 0.2, 2048)).await?;
}
Some(answer)
}
pub async fn reason_max(
brain: &dyn Brain,
system: &str,
task: &str,
budget: ReasoningBudget,
) -> Option<String> {
let draft = if budget.samples > 1 {
best_of_n(brain, system, task, budget.samples).await?
} else {
collect_text(brain, user_req(system, task, 0.2, 2048)).await?
};
if budget.refine_rounds == 0 {
return Some(draft);
}
self_refine_from(brain, system, task, draft, budget.refine_rounds).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::StopReason;
use crate::provider::{BrainStream, ModelCaps};
use std::collections::VecDeque;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
struct ScriptBrain {
responses: Mutex<VecDeque<String>>,
calls: AtomicUsize,
}
impl ScriptBrain {
fn new(responses: &[&str]) -> Self {
Self {
responses: Mutex::new(responses.iter().map(|s| s.to_string()).collect()),
calls: AtomicUsize::new(0),
}
}
fn calls(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl Brain for ScriptBrain {
fn id(&self) -> &str {
"mock:script"
}
fn caps(&self) -> ModelCaps {
ModelCaps::default()
}
async fn complete(&self, _req: BrainRequest) -> anyhow::Result<BrainStream> {
self.calls.fetch_add(1, Ordering::SeqCst);
let resp = self
.responses
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| "DEFAULT".into());
let events = vec![
BrainEvent::TextDelta(resp),
BrainEvent::Done(StopReason::EndTurn),
];
Ok(Box::pin(futures::stream::iter(events)))
}
}
#[test]
fn budget_scales_with_tier_and_capability() {
assert_eq!(
ReasoningBudget::for_tier(&TaskTier::Trivial),
ReasoningBudget {
samples: 1,
refine_rounds: 0
}
);
let hard = ReasoningBudget::for_tier(&TaskTier::Hard);
assert_eq!(
hard,
ReasoningBudget {
samples: 3,
refine_rounds: 2
}
);
assert!(hard.scaled_for_capability(1).max_calls() > hard.max_calls());
assert_eq!(hard.scaled_for_capability(5), hard);
}
#[test]
fn parse_index_is_robust() {
assert_eq!(parse_index("2", 3), 2);
assert_eq!(parse_index("The best is [1].", 3), 1);
assert_eq!(parse_index("99", 3), 0); assert_eq!(parse_index("none", 3), 0);
}
#[tokio::test]
async fn self_refine_revises_then_stops_on_approval() {
let b = ScriptBrain::new(&["ISSUE: missing empty-input case", "answer v2", "APPROVED"]);
let out = self_refine_from(&b, "sys", "task", "answer v1".into(), 3)
.await
.unwrap();
assert_eq!(out, "answer v2");
assert_eq!(b.calls(), 3, "critique + revise + critique");
}
#[tokio::test]
async fn self_refine_stops_immediately_when_approved() {
let b = ScriptBrain::new(&["APPROVED"]);
let out = self_refine_from(&b, "sys", "task", "good answer".into(), 3)
.await
.unwrap();
assert_eq!(out, "good answer", "unchanged");
assert_eq!(b.calls(), 1, "only the critique");
}
#[tokio::test]
async fn best_of_n_samples_then_judges() {
let b = ScriptBrain::new(&["cand A", "cand B", "cand C", "1"]);
let out = best_of_n(&b, "sys", "task", 3).await.unwrap();
assert_eq!(out, "cand B");
assert_eq!(b.calls(), 4, "3 drafts + 1 judge");
}
#[tokio::test]
async fn best_of_one_skips_the_judge() {
let b = ScriptBrain::new(&["only candidate"]);
let out = best_of_n(&b, "sys", "task", 1).await.unwrap();
assert_eq!(out, "only candidate");
assert_eq!(b.calls(), 1, "no judge call for n=1");
}
#[tokio::test]
async fn reason_max_runs_the_full_pipeline() {
let b = ScriptBrain::new(&["d0", "d1", "1", "ISSUE: x", "refined answer"]);
let budget = ReasoningBudget {
samples: 2,
refine_rounds: 1,
};
let out = reason_max(&b, "sys", "task", budget).await.unwrap();
assert_eq!(out, "refined answer");
assert_eq!(b.calls(), 5, "2 drafts + judge + critique + revise");
}
}