1pub use crate::types::{EvaluationDecision, EvaluationStrategy};
23
24use super::config::AgentLoopConfig;
25use super::core::agent_loop;
26use crate::types::*;
27use tokio::sync::mpsc;
28use tokio_util::sync::CancellationToken;
29
30pub struct TransparentEvaluation;
39
40#[async_trait::async_trait]
41impl EvaluationStrategy for TransparentEvaluation {
42 async fn evaluate(
43 &self,
44 _prompts: &[AgentMessage],
45 outcomes: &[ParallelLoopOutcome],
46 _tx: &mpsc::UnboundedSender<AgentEvent>,
47 _cancel: CancellationToken,
48 ) -> (EvaluationDecision, Usage) {
49 assert_eq!(
50 outcomes.len(),
51 1,
52 "TransparentEvaluation requires exactly one branch, got {}",
53 outcomes.len()
54 );
55 (EvaluationDecision::Select(0), Usage::default())
56 }
57}
58
59pub struct PickFirstEvaluation;
66
67#[async_trait::async_trait]
68impl EvaluationStrategy for PickFirstEvaluation {
69 async fn evaluate(
70 &self,
71 _prompts: &[AgentMessage],
72 _outcomes: &[ParallelLoopOutcome],
73 _tx: &mpsc::UnboundedSender<AgentEvent>,
74 _cancel: CancellationToken,
75 ) -> (EvaluationDecision, Usage) {
76 (EvaluationDecision::Select(0), Usage::default())
77 }
78}
79
80pub struct TokenEfficientEvaluation;
87
88#[async_trait::async_trait]
89impl EvaluationStrategy for TokenEfficientEvaluation {
90 async fn evaluate(
91 &self,
92 _prompts: &[AgentMessage],
93 outcomes: &[ParallelLoopOutcome],
94 _tx: &mpsc::UnboundedSender<AgentEvent>,
95 _cancel: CancellationToken,
96 ) -> (EvaluationDecision, Usage) {
97 let idx = outcomes
98 .iter()
99 .enumerate()
100 .min_by_key(|(_, o)| o.usage.total_tokens)
101 .map(|(i, _)| i)
102 .unwrap_or(0);
103 (EvaluationDecision::Select(idx), Usage::default())
104 }
105}
106
107pub struct ElaborateEvaluation;
114
115#[async_trait::async_trait]
116impl EvaluationStrategy for ElaborateEvaluation {
117 async fn evaluate(
118 &self,
119 _prompts: &[AgentMessage],
120 outcomes: &[ParallelLoopOutcome],
121 _tx: &mpsc::UnboundedSender<AgentEvent>,
122 _cancel: CancellationToken,
123 ) -> (EvaluationDecision, Usage) {
124 let idx = outcomes
125 .iter()
126 .enumerate()
127 .max_by_key(|(_, o)| o.usage.total_tokens)
128 .map(|(i, _)| i)
129 .unwrap_or(0);
130 (EvaluationDecision::Select(idx), Usage::default())
131 }
132}
133
134pub struct LlmJudgeEvaluation {
198 pub judge_config: AgentLoopConfig,
201 pub system_prompt: Option<String>,
203}
204
205fn extract_text_only(content: &[Content]) -> String {
209 content
210 .iter()
211 .filter_map(|c| match c {
212 Content::Text { text } => Some(text.as_str()),
213 _ => None,
214 })
215 .collect::<Vec<_>>()
216 .join("\n")
217}
218
219fn extract_query_text(prompts: &[AgentMessage]) -> String {
222 prompts
223 .iter()
224 .filter_map(|m| match m {
225 AgentMessage::Llm(LlmMessage {
226 message: Message::User { content, .. },
227 ..
228 }) => Some(content),
229 _ => None,
230 })
231 .flat_map(|content| {
232 content.iter().filter_map(|c| match c {
233 Content::Text { text } => Some(text.as_str()),
234 _ => None,
235 })
236 })
237 .collect::<Vec<_>>()
238 .join("\n")
239}
240
241fn extract_last_user_text(messages: &[AgentMessage]) -> Option<String> {
244 messages.iter().rev().find_map(|m| match m {
245 AgentMessage::Llm(LlmMessage {
246 message: Message::User { content, .. },
247 ..
248 }) => {
249 let text = extract_text_only(content);
250 if text.is_empty() {
251 None
252 } else {
253 Some(text)
254 }
255 }
256 _ => None,
257 })
258}
259
260fn format_prior_context(messages: &[AgentMessage]) -> String {
265 let mut parts: Vec<String> = Vec::new();
266 for m in messages {
267 match m {
268 AgentMessage::Llm(LlmMessage {
269 message: Message::User { content, .. },
270 ..
271 }) => {
272 let text = extract_text_only(content);
273 if !text.is_empty() {
274 parts.push(format!("User: {}", text));
275 }
276 }
277 AgentMessage::Llm(LlmMessage {
278 message: Message::Assistant { content, .. },
279 ..
280 }) => {
281 let text = extract_text_only(content);
282 if !text.is_empty() {
283 parts.push(format!("Assistant: {}", text));
284 }
285 }
286 AgentMessage::Llm(LlmMessage {
287 message:
288 Message::ToolResult {
289 tool_name, content, ..
290 },
291 ..
292 }) => {
293 let text = extract_text_only(content);
294 if !text.is_empty() {
295 parts.push(format!("Tool [{}]: {}", tool_name, text));
296 }
297 }
298 _ => {}
299 }
300 }
301 parts.join("\n")
302}
303
304fn extract_final_assistant_text(messages: &[AgentMessage]) -> String {
308 messages
309 .iter()
310 .rev()
311 .find_map(|m| match m {
312 AgentMessage::Llm(LlmMessage {
313 message: Message::Assistant { content, .. },
314 ..
315 }) => {
316 let text = extract_text_only(content);
317 if text.is_empty() {
318 None
319 } else {
320 Some(text)
321 }
322 }
323 _ => None,
324 })
325 .unwrap_or_default()
326}
327
328fn compact_tier1(text: &str, max_lines: usize) -> String {
330 let lines: Vec<&str> = text.lines().collect();
331 if lines.len() <= max_lines {
332 text.to_string()
333 } else {
334 lines[lines.len() - max_lines..].join("\n")
335 }
336}
337
338fn compact_tier2(text: &str) -> String {
340 let paragraphs: Vec<&str> = text
341 .split("\n\n")
342 .map(str::trim)
343 .filter(|p| !p.is_empty())
344 .collect();
345 match paragraphs.len() {
346 0 => text.to_string(),
347 1 => paragraphs[0].to_string(),
348 _ => format!(
349 "{}\n\n...\n\n{}",
350 paragraphs[0],
351 paragraphs[paragraphs.len() - 1]
352 ),
353 }
354}
355
356fn compact_tier3(text: &str, max_chars: usize) -> String {
358 if text.len() <= max_chars {
359 text.to_string()
360 } else {
361 let cut = max_chars.saturating_sub(3);
363 format!("{}...", &text[..cut])
364 }
365}
366
367fn estimate_tokens(s: &str) -> usize {
369 s.len().div_ceil(4)
370}
371
372fn compact_responses(responses: Vec<String>, token_budget: usize) -> (Vec<String>, bool) {
377 let mut current = responses;
379 if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
380 return (current, true);
381 }
382
383 current = current.into_iter().map(|r| compact_tier1(&r, 80)).collect();
385 if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
386 return (current, true);
387 }
388
389 current = current.into_iter().map(|r| compact_tier2(&r)).collect();
391 if current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget {
392 return (current, true);
393 }
394
395 let n = current.len().max(1);
397 let max_chars = std::cmp::max(200, (token_budget * 4) / n);
398 current = current
399 .into_iter()
400 .map(|r| compact_tier3(&r, max_chars))
401 .collect();
402
403 let satisfied = current.iter().map(|r| estimate_tokens(r)).sum::<usize>() <= token_budget;
404 (current, satisfied)
405}
406
407fn compact_for_judge(
421 prior_context: String,
422 outputs: Vec<String>,
423 token_budget: usize,
424) -> (String, Vec<String>, bool) {
425 let out_tokens = || outputs.iter().map(|o| estimate_tokens(o)).sum::<usize>();
426
427 if estimate_tokens(&prior_context) + out_tokens() <= token_budget {
429 return (prior_context, outputs, true);
430 }
431
432 let ctx1 = compact_tier1(&prior_context, 80);
434 if estimate_tokens(&ctx1) + out_tokens() <= token_budget {
435 return (ctx1, outputs, true);
436 }
437
438 let ctx2 = compact_tier2(&ctx1);
439 if estimate_tokens(&ctx2) + out_tokens() <= token_budget {
440 return (ctx2, outputs, true);
441 }
442
443 let n_out = outputs.len().max(1);
444 let ctx_budget_chars = (token_budget.saturating_sub(out_tokens()) * 4).max(200);
445 let ctx3 = compact_tier3(&ctx2, ctx_budget_chars);
446 if estimate_tokens(&ctx3) + out_tokens() <= token_budget {
447 return (ctx3, outputs, true);
448 }
449
450 let out_budget = token_budget
452 .saturating_sub(estimate_tokens(&ctx3))
453 .max(200 * n_out);
454 let (compacted_outputs, satisfied) = compact_responses(outputs, out_budget);
455 (ctx3, compacted_outputs, satisfied)
456}
457
458fn build_judge_user_message(
461 prior_context: Option<&str>,
462 query: &str,
463 responses: &[String],
464) -> String {
465 let mut msg = String::new();
466 if let Some(ctx) = prior_context.filter(|s| !s.trim().is_empty()) {
467 msg.push_str("Prior conversation context:\n");
468 msg.push_str(ctx);
469 msg.push_str("\n\n");
470 }
471 msg.push_str(&format!("Original query:\n{}\n\n", query));
472 for (i, resp) in responses.iter().enumerate() {
473 msg.push_str(&format!("Response {}:\n{}\n\n", i + 1, resp));
474 }
475 msg.push_str(
476 "Which response is best? Reply with ONLY the response number (e.g., \"1\" or \"2\").",
477 );
478 msg
479}
480
481fn parse_judge_selection(text: &str, max_index: usize) -> usize {
484 for word in text.split_whitespace() {
485 let digits: String = word.chars().filter(|c| c.is_ascii_digit()).collect();
486 if let Ok(n) = digits.parse::<usize>() {
487 if n >= 1 && n <= max_index + 1 {
488 return n - 1;
489 }
490 }
491 }
492 0
493}
494
495#[async_trait::async_trait]
498impl EvaluationStrategy for LlmJudgeEvaluation {
499 async fn evaluate(
500 &self,
501 prompts: &[AgentMessage],
502 outcomes: &[ParallelLoopOutcome],
503 tx: &mpsc::UnboundedSender<AgentEvent>,
504 cancel: CancellationToken,
505 ) -> (EvaluationDecision, Usage) {
506 let orig_len = outcomes
512 .first()
513 .map(|o| o.original_context_len)
514 .unwrap_or(0);
515 let orig_ctx_msgs: &[AgentMessage] = outcomes
516 .first()
517 .map(|o| &o.context.messages[..orig_len])
518 .unwrap_or(&[]);
519
520 let (query, prior_context_msgs): (String, &[AgentMessage]) = if !prompts.is_empty() {
521 (extract_query_text(prompts), orig_ctx_msgs)
524 } else {
525 let last_user_pos = orig_ctx_msgs.iter().rposition(|m| {
528 matches!(
529 m,
530 AgentMessage::Llm(LlmMessage {
531 message: Message::User { .. },
532 ..
533 })
534 )
535 });
536 match last_user_pos {
537 Some(pos) => (
538 extract_last_user_text(&orig_ctx_msgs[pos..pos + 1]).unwrap_or_default(),
539 &orig_ctx_msgs[..pos],
540 ),
541 None => (String::new(), orig_ctx_msgs),
542 }
543 };
544
545 let prior_context_text = format_prior_context(prior_context_msgs);
546
547 let raw_responses: Vec<String> = outcomes
549 .iter()
550 .map(|o| extract_final_assistant_text(&o.new_messages))
551 .collect();
552
553 let token_budget = self
558 .judge_config
559 .context_config
560 .as_ref()
561 .map(|c| c.max_context_tokens);
562
563 let (prior_ctx_for_judge, responses) = if let Some(budget) = token_budget {
564 let content_budget = (budget * 4) / 5;
566 let (pc, resp, satisfied) =
567 compact_for_judge(prior_context_text, raw_responses, content_budget);
568 if !satisfied {
569 tx.send(AgentEvent::ProgressMessage {
570 loop_id: String::new(),
571 tool_call_id: "judge-compaction".into(),
572 tool_name: "LlmJudgeEvaluation".into(),
573 text: format!(
574 "LlmJudgeEvaluation: could not fit prior context + {} branch \
575 responses within the judge's context budget ({} tokens) after \
576 2-iteration compaction. Proceeding best-effort — judge comparison \
577 may be incomplete.",
578 outcomes.len(),
579 budget
580 ),
581 })
582 .ok();
583 }
584 (pc, resp)
585 } else {
586 (prior_context_text, raw_responses)
587 };
588
589 let default_system = "You are an impartial judge evaluating AI assistant responses. \
591 Select the response that best answers the user's query. \
592 Reply with ONLY the response number (e.g., \"1\" or \"2\").";
593 let system_prompt = self
594 .system_prompt
595 .as_deref()
596 .unwrap_or(default_system)
597 .to_string();
598
599 let judge_user_text =
600 build_judge_user_message(Some(&prior_ctx_for_judge), &query, &responses);
601
602 let session_id = outcomes.first().and_then(|o| o.context.session_id.clone());
604
605 let mut judge_context = AgentContext {
606 system_prompt,
607 messages: vec![],
608 tools: vec![],
609 agent_id: None,
610 session_id,
611 loop_id: None,
612 parent_loop_id: None,
613 continuation_kind: None,
614 session: None,
615 user_context: Vec::new(),
616 inrun_context: Vec::new(),
617 };
618
619 let judge_prompts = vec![AgentMessage::Llm(LlmMessage::new(Message::user(
620 judge_user_text,
621 )))];
622
623 let (judge_tx, judge_rx) = mpsc::unbounded_channel::<AgentEvent>();
628 let (usage_tx, usage_rx) = tokio::sync::oneshot::channel::<Usage>();
629
630 let main_tx = tx.clone();
631 tokio::spawn(async move {
632 let mut judge_rx = judge_rx;
633 let mut last_usage = Usage::default();
634 while let Some(event) = judge_rx.recv().await {
635 if let AgentEvent::AgentEnd { ref usage, .. } = event {
636 last_usage = usage.clone();
637 }
638 main_tx.send(event).ok();
639 }
640 usage_tx.send(last_usage).ok();
643 });
644
645 let judge_messages = agent_loop(
646 judge_prompts,
647 &mut judge_context,
648 &self.judge_config,
649 judge_tx,
650 cancel,
651 )
652 .await;
653
654 let judge_usage = usage_rx.await.unwrap_or_default();
655
656 let judge_text = extract_final_assistant_text(&judge_messages);
658 let selected = parse_judge_selection(&judge_text, outcomes.len() - 1);
659
660 (EvaluationDecision::Select(selected), judge_usage)
661 }
662}
663
664#[cfg(test)]
667mod tests {
668 use super::*;
669
670 fn make_outcome(loop_id: &str, total_tokens: u64, final_text: &str) -> ParallelLoopOutcome {
671 let msg = AgentMessage::Llm(LlmMessage::new(Message::Assistant {
672 content: vec![Content::Text {
673 text: final_text.to_string(),
674 }],
675 stop_reason: StopReason::Stop,
676 model: "test".into(),
677 provider: "test".into(),
678 usage: Usage {
679 total_tokens,
680 ..Default::default()
681 },
682 timestamp: 0,
683 error_message: None,
684 }));
685 ParallelLoopOutcome {
686 config_index: 0,
687 loop_id: loop_id.to_string(),
688 context: AgentContext {
689 system_prompt: String::new(),
690 messages: vec![],
691 tools: vec![],
692 agent_id: None,
693 session_id: None,
694 loop_id: None,
695 parent_loop_id: None,
696 continuation_kind: None,
697 session: None,
698 user_context: Vec::new(),
699 inrun_context: Vec::new(),
700 },
701 new_messages: vec![msg],
702 usage: Usage {
703 total_tokens,
704 ..Default::default()
705 },
706 original_context_len: 0,
707 }
708 }
709
710 fn dummy_tx() -> mpsc::UnboundedSender<AgentEvent> {
711 let (tx, _rx) = mpsc::unbounded_channel();
712 tx
713 }
714
715 #[tokio::test]
716 async fn test_transparent_single_branch() {
717 let outcomes = vec![make_outcome("loop1", 100, "hello")];
718 let (decision, usage) = TransparentEvaluation
719 .evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
720 .await;
721 assert!(matches!(decision, EvaluationDecision::Select(0)));
722 assert_eq!(usage.total_tokens, 0);
723 }
724
725 #[tokio::test]
726 #[should_panic(expected = "TransparentEvaluation requires exactly one branch")]
727 async fn test_transparent_panics_on_multiple() {
728 let outcomes = vec![
729 make_outcome("loop1", 100, "a"),
730 make_outcome("loop2", 200, "b"),
731 ];
732 TransparentEvaluation
733 .evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
734 .await;
735 }
736
737 #[tokio::test]
738 async fn test_pick_first() {
739 let outcomes = vec![
740 make_outcome("loop1", 300, "verbose"),
741 make_outcome("loop2", 50, "concise"),
742 ];
743 let (decision, _) = PickFirstEvaluation
744 .evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
745 .await;
746 assert!(matches!(decision, EvaluationDecision::Select(0)));
747 }
748
749 #[tokio::test]
750 async fn test_token_efficient() {
751 let outcomes = vec![
752 make_outcome("loop1", 500, "long verbose response"),
753 make_outcome("loop2", 50, "short"),
754 make_outcome("loop3", 200, "medium"),
755 ];
756 let (decision, _) = TokenEfficientEvaluation
757 .evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
758 .await;
759 assert!(matches!(decision, EvaluationDecision::Select(1)));
760 }
761
762 #[tokio::test]
763 async fn test_elaborate() {
764 let outcomes = vec![
765 make_outcome("loop1", 500, "long verbose response"),
766 make_outcome("loop2", 50, "short"),
767 make_outcome("loop3", 200, "medium"),
768 ];
769 let (decision, _) = ElaborateEvaluation
770 .evaluate(&[], &outcomes, &dummy_tx(), CancellationToken::new())
771 .await;
772 assert!(matches!(decision, EvaluationDecision::Select(0)));
773 }
774
775 #[test]
776 fn test_parse_judge_selection() {
777 assert_eq!(parse_judge_selection("2", 2), 1);
778 assert_eq!(parse_judge_selection("Response 1 is best.", 2), 0);
779 assert_eq!(parse_judge_selection("I pick 3.", 3), 2);
780 assert_eq!(parse_judge_selection("unclear", 2), 0); assert_eq!(parse_judge_selection("5", 2), 0); }
783
784 #[test]
785 fn test_compact_tier1() {
786 let text = (0..100)
787 .map(|i| format!("line {}", i))
788 .collect::<Vec<_>>()
789 .join("\n");
790 let compacted = compact_tier1(&text, 80);
791 assert_eq!(compacted.lines().count(), 80);
792 }
793
794 #[test]
795 fn test_compact_tier2() {
796 let text = "First paragraph.\n\nMiddle paragraph.\n\nLast paragraph.";
797 let compacted = compact_tier2(text);
798 assert!(compacted.contains("First paragraph."));
799 assert!(compacted.contains("Last paragraph."));
800 assert!(!compacted.contains("Middle paragraph."));
801 }
802
803 #[test]
804 fn test_extract_query_text() {
805 let prompts = vec![
806 AgentMessage::Llm(LlmMessage::new(Message::User {
807 content: vec![Content::Text {
808 text: "Hello".into(),
809 }],
810 timestamp: 0,
811 })),
812 AgentMessage::Llm(LlmMessage::new(Message::User {
813 content: vec![Content::Text {
814 text: "World".into(),
815 }],
816 timestamp: 0,
817 })),
818 ];
819 let query = extract_query_text(&prompts);
820 assert_eq!(query, "Hello\nWorld");
821 }
822
823 #[test]
824 fn test_extract_final_assistant_text() {
825 let messages = vec![
826 AgentMessage::Llm(LlmMessage::new(Message::Assistant {
827 content: vec![Content::Text {
828 text: "first".into(),
829 }],
830 stop_reason: StopReason::Stop,
831 model: "m".into(),
832 provider: "p".into(),
833 usage: Usage::default(),
834 timestamp: 0,
835 error_message: None,
836 })),
837 AgentMessage::Llm(LlmMessage::new(Message::Assistant {
838 content: vec![Content::Text {
839 text: "final".into(),
840 }],
841 stop_reason: StopReason::Stop,
842 model: "m".into(),
843 provider: "p".into(),
844 usage: Usage::default(),
845 timestamp: 0,
846 error_message: None,
847 })),
848 ];
849 assert_eq!(extract_final_assistant_text(&messages), "final");
850 }
851
852 #[test]
853 fn test_extract_last_user_text() {
854 let messages = vec![
855 AgentMessage::Llm(LlmMessage::new(Message::User {
856 content: vec![Content::Text {
857 text: "first query".into(),
858 }],
859 timestamp: 0,
860 })),
861 AgentMessage::Llm(LlmMessage::new(Message::Assistant {
862 content: vec![Content::Text {
863 text: "answer".into(),
864 }],
865 stop_reason: StopReason::Stop,
866 model: "m".into(),
867 provider: "p".into(),
868 usage: Usage::default(),
869 timestamp: 0,
870 error_message: None,
871 })),
872 AgentMessage::Llm(LlmMessage::new(Message::User {
873 content: vec![Content::Text {
874 text: "follow-up".into(),
875 }],
876 timestamp: 0,
877 })),
878 ];
879 assert_eq!(
881 extract_last_user_text(&messages),
882 Some("follow-up".to_string())
883 );
884 }
885
886 #[test]
887 fn test_extract_last_user_text_none() {
888 let messages: Vec<AgentMessage> = vec![];
889 assert_eq!(extract_last_user_text(&messages), None);
890 }
891
892 #[test]
893 fn test_format_prior_context() {
894 let messages = vec![
895 AgentMessage::Llm(LlmMessage::new(Message::User {
896 content: vec![Content::Text {
897 text: "Hello".into(),
898 }],
899 timestamp: 0,
900 })),
901 AgentMessage::Llm(LlmMessage::new(Message::Assistant {
902 content: vec![Content::Text {
903 text: "Hi there!".into(),
904 }],
905 stop_reason: StopReason::Stop,
906 model: "m".into(),
907 provider: "p".into(),
908 usage: Usage::default(),
909 timestamp: 0,
910 error_message: None,
911 })),
912 ];
913 let transcript = format_prior_context(&messages);
914 assert!(transcript.contains("User: Hello"));
915 assert!(transcript.contains("Assistant: Hi there!"));
916 }
917
918 #[test]
919 fn test_compact_for_judge_no_compaction_needed() {
920 let ctx = "short context".to_string();
921 let outputs = vec!["short response".to_string()];
922 let (c, o, satisfied) = compact_for_judge(ctx.clone(), outputs.clone(), 10_000);
923 assert!(satisfied);
924 assert_eq!(c, ctx);
925 assert_eq!(o, outputs);
926 }
927
928 #[test]
929 fn test_compact_for_judge_iter1_compacts_context_only() {
930 let many_lines: String = (0..200).map(|i| format!("line {}\n", i)).collect();
932 let outputs = vec!["tiny".to_string()];
933 let budget = 100;
935 let (c, o, satisfied) = compact_for_judge(many_lines, outputs.clone(), budget);
936 assert_eq!(o, outputs);
938 assert!(estimate_tokens(&c) < 1000);
940 let _ = satisfied;
942 }
943}