use crate::agent::self_reflection::types::Critique;
use crate::error::Result;
use futures::future::BoxFuture;
pub trait Critic: Send + Sync {
fn critique<'a>(
&'a self,
task: &'a str,
answer: &'a str,
context: &'a str,
) -> BoxFuture<'a, Result<Critique>>;
fn name(&self) -> &str {
"anonymous"
}
}
pub struct StaticCritic {
score: f64,
passed: bool,
feedback: String,
suggestions: Vec<String>,
}
impl StaticCritic {
pub fn new(score: f64, passed: bool, feedback: impl Into<String>) -> Self {
Self {
score,
passed,
feedback: feedback.into(),
suggestions: Vec::new(),
}
}
pub fn with_suggestions(mut self, suggestions: Vec<String>) -> Self {
self.suggestions = suggestions;
self
}
pub fn always_pass() -> Self {
Self::new(9.0, true, "Response quality is excellent")
}
pub fn always_fail() -> Self {
Self::new(3.0, false, "Response does not meet requirements")
}
}
impl Critic for StaticCritic {
fn critique<'a>(
&'a self,
_task: &'a str,
_answer: &'a str,
_context: &'a str,
) -> BoxFuture<'a, Result<Critique>> {
Box::pin(async move {
Ok(Critique {
score: self.score,
passed: self.passed,
feedback: self.feedback.clone(),
suggestions: self.suggestions.clone(),
})
})
}
fn name(&self) -> &str {
"static"
}
}
pub struct ThresholdCritic<C: Critic> {
inner: C,
threshold: f64,
}
impl<C: Critic> ThresholdCritic<C> {
pub fn new(inner: C, threshold: f64) -> Self {
Self { inner, threshold }
}
}
impl<C: Critic> Critic for ThresholdCritic<C> {
fn critique<'a>(
&'a self,
task: &'a str,
answer: &'a str,
context: &'a str,
) -> BoxFuture<'a, Result<Critique>> {
Box::pin(async move {
let mut critique = self.inner.critique(task, answer, context).await?;
critique.passed = critique.score >= self.threshold;
Ok(critique)
})
}
fn name(&self) -> &str {
self.inner.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_critic_always_pass() {
let critic = StaticCritic::always_pass();
let critique = critic.critique("task", "answer", "").await.unwrap();
assert!(critique.passed);
assert!(critique.score >= 8.0);
}
#[tokio::test]
async fn test_static_critic_always_fail() {
let critic = StaticCritic::always_fail();
let critique = critic.critique("task", "answer", "").await.unwrap();
assert!(!critique.passed);
}
#[tokio::test]
async fn test_threshold_critic() {
let inner = StaticCritic::new(6.0, true, "OK");
let critic = ThresholdCritic::new(inner, 8.0);
let critique = critic.critique("task", "answer", "").await.unwrap();
assert!(!critique.passed); assert_eq!(critique.score, 6.0);
}
#[tokio::test]
async fn test_threshold_critic_passes() {
let inner = StaticCritic::new(9.0, false, "Very good");
let critic = ThresholdCritic::new(inner, 8.0);
let critique = critic.critique("task", "answer", "").await.unwrap();
assert!(critique.passed); }
}