use async_trait::async_trait;
use oharness_core::{AssistantTurn, ConversationView, Task, TrajectoryView};
use serde::{Deserialize, Serialize};
#[async_trait]
pub trait Critic: Send + Sync {
fn name(&self) -> &str;
async fn assess(&self, ctx: &AssessmentContext<'_>) -> CriticVerdict;
}
#[async_trait]
impl<T: Critic + ?Sized> Critic for std::sync::Arc<T> {
fn name(&self) -> &str {
(**self).name()
}
async fn assess(&self, ctx: &AssessmentContext<'_>) -> CriticVerdict {
(**self).assess(ctx).await
}
}
pub struct AssessmentContext<'a> {
pub task: &'a Task,
pub conversation: ConversationView<'a>,
pub latest_turn: &'a AssistantTurn,
pub trajectory: TrajectoryView<'a>,
}
impl<'a> AssessmentContext<'a> {
pub fn new(
task: &'a Task,
conversation: ConversationView<'a>,
latest_turn: &'a AssistantTurn,
trajectory: TrajectoryView<'a>,
) -> Self {
Self {
task,
conversation,
latest_turn,
trajectory,
}
}
}
#[derive(Debug, Clone)]
pub enum CriticVerdict {
Accept,
AcceptWithNote(String),
Reject {
reason: String,
},
Revise {
replacement: oharness_core::AssistantTurn,
reason: String,
},
Abort {
reason: String,
},
}
impl CriticVerdict {
pub fn is_accepting(&self) -> bool {
matches!(
self,
CriticVerdict::Accept | CriticVerdict::AcceptWithNote(_)
)
}
pub fn is_rejecting(&self) -> bool {
matches!(self, CriticVerdict::Reject { .. })
}
pub fn is_terminal(&self) -> bool {
matches!(self, CriticVerdict::Abort { .. })
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
pub enum CriticTrigger {
#[default]
AfterAssistant,
AfterToolResult,
AfterEveryNTurns(u32),
OnDemand,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verdict_classification_helpers() {
assert!(CriticVerdict::Accept.is_accepting());
assert!(CriticVerdict::AcceptWithNote("n".into()).is_accepting());
assert!(CriticVerdict::Reject {
reason: "no".into()
}
.is_rejecting());
assert!(!CriticVerdict::Reject {
reason: "no".into()
}
.is_accepting());
assert!(CriticVerdict::Abort { reason: "x".into() }.is_terminal());
assert!(!CriticVerdict::Accept.is_terminal());
}
#[test]
fn trigger_round_trips_through_serde() {
let t = CriticTrigger::AfterEveryNTurns(5);
let s = serde_json::to_string(&t).unwrap();
let back: CriticTrigger = serde_json::from_str(&s).unwrap();
assert_eq!(t, back);
}
#[test]
fn trigger_default_is_after_assistant() {
assert_eq!(CriticTrigger::default(), CriticTrigger::AfterAssistant);
}
}