pub use crate::types::{EvaluationDecision, EvaluationStrategy};
use super::config::AgentLoopConfig;
use super::core::agent_loop;
use crate::types::*;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub struct TransparentEvaluation;
#[async_trait::async_trait]
impl EvaluationStrategy for TransparentEvaluation {
async fn evaluate(
&self,
_prompts: &[AgentMessage],
outcomes: &[ParallelLoopOutcome],
_tx: &mpsc::UnboundedSender<AgentEvent>,
_cancel: CancellationToken,
) -> (EvaluationDecision, Usage) {
assert_eq!(
outcomes.len(),
1,
"TransparentEvaluation requires exactly one branch, got {}",
outcomes.len()
);
(EvaluationDecision::Select(0), Usage::default())
}
}
pub struct PickFirstEvaluation;
#[async_trait::async_trait]
impl EvaluationStrategy for PickFirstEvaluation {
async fn evaluate(
&self,
_prompts: &[AgentMessage],
_outcomes: &[ParallelLoopOutcome],
_tx: &mpsc::UnboundedSender<AgentEvent>,
_cancel: CancellationToken,
) -> (EvaluationDecision, Usage) {
(EvaluationDecision::Select(0), Usage::default())
}
}
pub struct TokenEfficientEvaluation;
#[async_trait::async_trait]
impl EvaluationStrategy for TokenEfficientEvaluation {
async fn evaluate(
&self,
_prompts: &[AgentMessage],
outcomes: &[ParallelLoopOutcome],
_tx: &mpsc::UnboundedSender<AgentEvent>,
_cancel: CancellationToken,
) -> (EvaluationDecision, Usage) {
let idx = outcomes
.iter()
.enumerate()
.min_by_key(|(_, o)| o.usage.total_tokens)
.map(|(i, _)| i)
.unwrap_or(0);
(EvaluationDecision::Select(idx), Usage::default())
}
}
pub struct ElaborateEvaluation;
#[async_trait::async_trait]
impl EvaluationStrategy for ElaborateEvaluation {
async fn evaluate(
&self,
_prompts: &[AgentMessage],
outcomes: &[ParallelLoopOutcome],
_tx: &mpsc::UnboundedSender<AgentEvent>,
_cancel: CancellationToken,
) -> (EvaluationDecision, Usage) {
let idx = outcomes
.iter()
.enumerate()
.max_by_key(|(_, o)| o.usage.total_tokens)
.map(|(i, _)| i)
.unwrap_or(0);
(EvaluationDecision::Select(idx), Usage::default())
}
}
pub struct LlmJudgeEvaluation {
pub judge_config: AgentLoopConfig,
pub system_prompt: Option<String>,
}
fn extract_text_only(content: &[Content]) -> String {
content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
fn extract_query_text(prompts: &[AgentMessage]) -> String {
prompts
.iter()
.filter_map(|m| match m {
AgentMessage::Llm(LlmMessage {
message: Message::User { content, .. },
..
}) => Some(content),
_ => None,
})
.flat_map(|content| {
content.iter().filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
})
.collect::<Vec<_>>()
.join("\n")
}
fn extract_last_user_text(messages: &[AgentMessage]) -> Option<String> {
messages.iter().rev().find_map(|m| match m {
AgentMessage::Llm(LlmMessage {
message: Message::User { content, .. },
..
}) => {
let text = extract_text_only(content);
if text.is_empty() {
None
} else {
Some(text)
}
}
_ => None,
})
}
fn format_prior_context(messages: &[AgentMessage]) -> String {
let mut parts: Vec<String> = Vec::new();
for m in messages {
match m {
AgentMessage::Llm(LlmMessage {
message: Message::User { content, .. },
..
}) => {
let text = extract_text_only(content);
if !text.is_empty() {
parts.push(format!("User: {}", text));
}
}
AgentMessage::Llm(LlmMessage {
message: Message::Assistant { content, .. },
..
}) => {
let text = extract_text_only(content);
if !text.is_empty() {
parts.push(format!("Assistant: {}", text));
}
}
AgentMessage::Llm(LlmMessage {
message:
Message::ToolResult {
tool_name, content, ..
},
..
}) => {
let text = extract_text_only(content);
if !text.is_empty() {
parts.push(format!("Tool [{}]: {}", tool_name, text));
}
}
_ => {}
}
}
parts.join("\n")
}
fn extract_final_assistant_text(messages: &[AgentMessage]) -> String {
messages
.iter()
.rev()
.find_map(|m| match m {
AgentMessage::Llm(LlmMessage {
message: Message::Assistant { content, .. },
..
}) => {
let text = extract_text_only(content);
if text.is_empty() {
None
} else {
Some(text)
}
}
_ => None,
})
.unwrap_or_default()
}
fn compact_tier1(text: &str, max_lines: usize) -> String {
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= max_lines {
text.to_string()
} else {
lines[lines.len() - max_lines..].join("\n")
}
}
fn compact_tier2(text: &str) -> String {
let paragraphs: Vec<&str> = text
.split("\n\n")
.map(str::trim)
.filter(|p| !p.is_empty())
.collect();
match paragraphs.len() {
0 => text.to_string(),
1 => paragraphs[0].to_string(),
_ => format!(
"{}\n\n...\n\n{}",
paragraphs[0],
paragraphs[paragraphs.len() - 1]
),
}
}
fn compact_tier3(text: &str, max_chars: usize) -> String {
if text.len() <= max_chars {
text.to_string()
} else {
let cut = max_chars.saturating_sub(3);
format!("{}...", &text[..cut])
}
}
fn estimate_tokens(s: &str) -> usize {
s.len().div_ceil(4)
}
fn compact_responses(responses: Vec<String>, token_budget: usize) -> (Vec<String>, bool) {
let mut current = responses;
if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
return (current, true);
}
current = current.into_iter().map(|r| compact_tier1(&r, 80)).collect();
if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
return (current, true);
}
current = current.into_iter().map(|r| compact_tier2(&r)).collect();
if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
return (current, true);
}
let n = current.len().max(1);
let max_chars = std::cmp::max(200, (token_budget * 4) / n);
current = current
.into_iter()
.map(|r| compact_tier3(&r, max_chars))
.collect();
let satisfied = current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget;
(current, satisfied)
}
fn compact_for_judge(
prior_context: String,
outputs: Vec<String>,
token_budget: usize,
) -> (String, Vec<String>, bool) {
let out_tokens = || outputs.iter().map(|o| estimate_tokens(o)).sum::<usize>();
if estimate_tokens(&prior_context) + out_tokens() <= token_budget {
return (prior_context, outputs, true);
}
let ctx1 = compact_tier1(&prior_context, 80);
if estimate_tokens(&ctx1) + out_tokens() <= token_budget {
return (ctx1, outputs, true);
}
let ctx2 = compact_tier2(&ctx1);
if estimate_tokens(&ctx2) + out_tokens() <= token_budget {
return (ctx2, outputs, true);
}
let n_out = outputs.len().max(1);
let ctx_budget_chars = (token_budget.saturating_sub(out_tokens()) * 4).max(200);
let ctx3 = compact_tier3(&ctx2, ctx_budget_chars);
if estimate_tokens(&ctx3) + out_tokens() <= token_budget {
return (ctx3, outputs, true);
}
let out_budget = token_budget
.saturating_sub(estimate_tokens(&ctx3))
.max(200 * n_out);
let (compacted_outputs, satisfied) = compact_responses(outputs, out_budget);
(ctx3, compacted_outputs, satisfied)
}
fn build_judge_user_message(
prior_context: Option<&str>,
query: &str,
responses: &[String],
) -> String {
let mut msg = String::new();
if let Some(ctx) = prior_context.filter(|s| !s.trim().is_empty()) {
msg.push_str("Prior conversation context:\n");
msg.push_str(ctx);
msg.push_str("\n\n");
}
msg.push_str(&format!("Original query:\n{}\n\n", query));
for (i, resp) in responses.iter().enumerate() {
msg.push_str(&format!("Response {}:\n{}\n\n", i + 1, resp));
}
msg.push_str(
"Which response is best? Reply with ONLY the response number (e.g., \"1\" or \"2\").",
);
msg
}
fn parse_judge_selection(text: &str, max_index: usize) -> usize {
for word in text.split_whitespace() {
let digits: String = word.chars().filter(|c| c.is_ascii_digit()).collect();
if let Ok(n) = digits.parse::<usize>() {
if n >= 1 && n <= max_index + 1 {
return n - 1;
}
}
}
0
}
#[async_trait::async_trait]
impl EvaluationStrategy for LlmJudgeEvaluation {
async fn evaluate(
&self,
prompts: &[AgentMessage],
outcomes: &[ParallelLoopOutcome],
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: CancellationToken,
) -> (EvaluationDecision, Usage) {
let orig_len = outcomes
.first()
.map(|o| o.original_context_len)
.unwrap_or(0);
let orig_ctx_msgs: &[AgentMessage] = outcomes
.first()
.map(|o| &o.context.messages[..orig_len])
.unwrap_or(&[]);
let (query, prior_context_msgs): (String, &[AgentMessage]) = if !prompts.is_empty() {
(extract_query_text(prompts), orig_ctx_msgs)
} else {
let last_user_pos = orig_ctx_msgs.iter().rposition(|m| {
matches!(
m,
AgentMessage::Llm(LlmMessage {
message: Message::User { .. },
..
})
)
});
match last_user_pos {
Some(pos) => (
extract_last_user_text(&orig_ctx_msgs[pos..pos + 1]).unwrap_or_default(),
&orig_ctx_msgs[..pos],
),
None => (String::new(), orig_ctx_msgs),
}
};
let prior_context_text = format_prior_context(prior_context_msgs);
let raw_responses: Vec<String> = outcomes
.iter()
.map(|o| extract_final_assistant_text(&o.new_messages))
.collect();
let token_budget = self
.judge_config
.context_config
.as_ref()
.map(|c| c.max_context_tokens);
let (prior_ctx_for_judge, responses) = if let Some(budget) = token_budget {
let content_budget = (budget * 4) / 5;
let (pc, resp, satisfied) =
compact_for_judge(prior_context_text, raw_responses, content_budget);
if !satisfied {
tx.send(AgentEvent::ProgressMessage {
loop_id: String::new(),
tool_call_id: "judge-compaction".into(),
tool_name: "LlmJudgeEvaluation".into(),
text: format!(
"LlmJudgeEvaluation: could not fit prior context + {} branch \
responses within the judge's context budget ({} tokens) after \
2-iteration compaction. Proceeding best-effort — judge comparison \
may be incomplete.",
outcomes.len(),
budget
),
})
.ok();
}
(pc, resp)
} else {
(prior_context_text, raw_responses)
};
let default_system = "You are an impartial judge evaluating AI assistant responses. \
Select the response that best answers the user's query. \
Reply with ONLY the response number (e.g., \"1\" or \"2\").";
let system_prompt = self
.system_prompt
.as_deref()
.unwrap_or(default_system)
.to_string();
let judge_user_text =
build_judge_user_message(Some(&prior_ctx_for_judge), &query, &responses);
let session_id = outcomes.first().and_then(|o| o.context.session_id.clone());
let mut judge_context = AgentContext {
system_prompt,
messages: vec![],
tools: vec![],
agent_id: None,
session_id,
loop_id: None,
parent_loop_id: None,
continuation_kind: None,
session: None,
user_context: Vec::new(),
inrun_context: Vec::new(),
active_node_id: None,
next_node_id: 0,
};
let judge_prompts = vec![AgentMessage::Llm(LlmMessage::new(Message::user(
judge_user_text,
)))];
let (judge_tx, judge_rx) = mpsc::unbounded_channel::<AgentEvent>();
let (usage_tx, usage_rx) = tokio::sync::oneshot::channel::<Usage>();
let main_tx = tx.clone();
tokio::spawn(async move {
let mut judge_rx = judge_rx;
let mut last_usage = Usage::default();
while let Some(event) = judge_rx.recv().await {
if let AgentEvent::AgentEnd { ref usage, .. } = event {
last_usage = usage.clone();
}
main_tx.send(event).ok();
}
usage_tx.send(last_usage).ok();
});
let judge_messages = agent_loop(
judge_prompts,
&mut judge_context,
&self.judge_config,
judge_tx,
cancel,
)
.await;
let judge_usage = usage_rx.await.unwrap_or_default();
let judge_text = extract_final_assistant_text(&judge_messages);
let selected = parse_judge_selection(&judge_text, outcomes.len() - 1);
(EvaluationDecision::Select(selected), judge_usage)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_outcome(loop_id: &str, total_tokens: u64, final_text: &str) -> ParallelLoopOutcome {
let msg = AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: final_text.to_string(),
}],
stop_reason: StopReason::Stop,
model: "test".into(),
provider: "test".into(),
usage: Usage {
total_tokens,
..Default::default()
},
timestamp: 0,
error_message: None,
}));
ParallelLoopOutcome {
config_index: 0,
loop_id: loop_id.to_string(),
context: AgentContext {
system_prompt: String::new(),
messages: vec![],
tools: vec![],
agent_id: None,
session_id: None,
loop_id: None,
parent_loop_id: None,
continuation_kind: None,
session: None,
user_context: Vec::new(),
inrun_context: Vec::new(),
active_node_id: None,
next_node_id: 0,
},
new_messages: vec![msg],
usage: Usage {
total_tokens,
..Default::default()
},
original_context_len: 0,
}
}
fn dummy_tx() -> mpsc::UnboundedSender<AgentEvent> {
let (tx, _rx) = mpsc::unbounded_channel();
tx
}
#[tokio::test]
async fn test_transparent_single_branch() {
let outcomes = vec![make_outcome("loop1", 100, "hello")];
let (decision, usage) = TransparentEvaluation
.evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
.await;
assert!(matches!(decision, EvaluationDecision::Select(0)));
assert_eq!(usage.total_tokens, 0);
}
#[tokio::test]
#[should_panic(expected = "TransparentEvaluation requires exactly one branch")]
async fn test_transparent_panics_on_multiple() {
let outcomes = vec![
make_outcome("loop1", 100, "a"),
make_outcome("loop2", 200, "b"),
];
TransparentEvaluation
.evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
.await;
}
#[tokio::test]
async fn test_pick_first() {
let outcomes = vec![
make_outcome("loop1", 300, "verbose"),
make_outcome("loop2", 50, "concise"),
];
let (decision, _) = PickFirstEvaluation
.evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
.await;
assert!(matches!(decision, EvaluationDecision::Select(0)));
}
#[tokio::test]
async fn test_token_efficient() {
let outcomes = vec![
make_outcome("loop1", 500, "long verbose response"),
make_outcome("loop2", 50, "short"),
make_outcome("loop3", 200, "medium"),
];
let (decision, _) = TokenEfficientEvaluation
.evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
.await;
assert!(matches!(decision, EvaluationDecision::Select(1)));
}
#[tokio::test]
async fn test_elaborate() {
let outcomes = vec![
make_outcome("loop1", 500, "long verbose response"),
make_outcome("loop2", 50, "short"),
make_outcome("loop3", 200, "medium"),
];
let (decision, _) = ElaborateEvaluation
.evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
.await;
assert!(matches!(decision, EvaluationDecision::Select(0)));
}
#[test]
fn test_parse_judge_selection() {
assert_eq!(parse_judge_selection("2", 2), 1);
assert_eq!(parse_judge_selection("Response 1 is best.", 2), 0);
assert_eq!(parse_judge_selection("I pick 3.", 3), 2);
assert_eq!(parse_judge_selection("unclear", 2), 0); assert_eq!(parse_judge_selection("5", 2), 0); }
#[test]
fn test_compact_tier1() {
let text = (0..100)
.map(|i| format!("line {}", i))
.collect::<Vec<_>>()
.join("\n");
let compacted = compact_tier1(&text, 80);
assert_eq!(compacted.lines().count(), 80);
}
#[test]
fn test_compact_tier2() {
let text = "First paragraph.\n\nMiddle paragraph.\n\nLast paragraph.";
let compacted = compact_tier2(text);
assert!(compacted.contains("First paragraph."));
assert!(compacted.contains("Last paragraph."));
assert!(!compacted.contains("Middle paragraph."));
}
#[test]
fn test_extract_query_text() {
let prompts = vec![
AgentMessage::Llm(LlmMessage::new(Message::User {
content: vec![Content::Text {
text: "Hello".into(),
}],
timestamp: 0,
})),
AgentMessage::Llm(LlmMessage::new(Message::User {
content: vec![Content::Text {
text: "World".into(),
}],
timestamp: 0,
})),
];
let query = extract_query_text(&prompts);
assert_eq!(query, "Hello\nWorld");
}
#[test]
fn test_extract_final_assistant_text() {
let messages = vec![
AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: "first".into(),
}],
stop_reason: StopReason::Stop,
model: "m".into(),
provider: "p".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
})),
AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: "final".into(),
}],
stop_reason: StopReason::Stop,
model: "m".into(),
provider: "p".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
})),
];
assert_eq!(extract_final_assistant_text(&messages), "final");
}
#[test]
fn test_extract_last_user_text() {
let messages = vec![
AgentMessage::Llm(LlmMessage::new(Message::User {
content: vec![Content::Text {
text: "first query".into(),
}],
timestamp: 0,
})),
AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: "answer".into(),
}],
stop_reason: StopReason::Stop,
model: "m".into(),
provider: "p".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
})),
AgentMessage::Llm(LlmMessage::new(Message::User {
content: vec![Content::Text {
text: "follow-up".into(),
}],
timestamp: 0,
})),
];
assert_eq!(
extract_last_user_text(&messages),
Some("follow-up".to_string())
);
}
#[test]
fn test_extract_last_user_text_none() {
let messages: Vec<AgentMessage> = vec![];
assert_eq!(extract_last_user_text(&messages), None);
}
#[test]
fn test_format_prior_context() {
let messages = vec![
AgentMessage::Llm(LlmMessage::new(Message::User {
content: vec![Content::Text {
text: "Hello".into(),
}],
timestamp: 0,
})),
AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: "Hi there!".into(),
}],
stop_reason: StopReason::Stop,
model: "m".into(),
provider: "p".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
})),
];
let transcript = format_prior_context(&messages);
assert!(transcript.contains("User: Hello"));
assert!(transcript.contains("Assistant: Hi there!"));
}
#[test]
fn test_compact_for_judge_no_compaction_needed() {
let ctx = "short context".to_string();
let outputs = vec!["short response".to_string()];
let (c, o, satisfied) = compact_for_judge(ctx.clone(), outputs.clone(), 10_000);
assert!(satisfied);
assert_eq!(c, ctx);
assert_eq!(o, outputs);
}
#[test]
fn test_compact_for_judge_iter1_compacts_context_only() {
let many_lines: String = (0..200).map(|i| format!("line {}\n", i)).collect();
let outputs = vec!["tiny".to_string()];
let budget = 100;
let (c, o, satisfied) = compact_for_judge(many_lines, outputs.clone(), budget);
assert_eq!(o, outputs);
assert!(estimate_tokens(&c) < 1000);
let _ = satisfied;
}
}