use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CriticResult {
pub approved: bool,
pub score: f64,
pub dimension_scores: std::collections::HashMap<String, f64>,
pub feedback: String,
pub reviewer: ReviewerIdentity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ReviewerIdentity {
#[serde(rename = "llm")]
Llm { model_id: String },
#[serde(rename = "human")]
Human { user_id: String, name: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReviewRequest {
pub review_id: String,
pub content: String,
pub context: String,
pub rubric_dimensions: Vec<String>,
pub requested_at: chrono::DateTime<chrono::Utc>,
pub deadline: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReviewResponse {
pub review_id: String,
pub result: CriticResult,
}
#[async_trait::async_trait]
pub trait Critic: Send + Sync {
async fn evaluate(&self, content: &str, context: &str) -> Result<CriticResult, CriticError>;
fn critic_type(&self) -> &str;
}
pub struct HumanCritic {
review_sender: mpsc::Sender<(ReviewRequest, oneshot::Sender<ReviewResponse>)>,
timeout: Duration,
default_reviewer: String,
}
impl HumanCritic {
pub fn new(
timeout: Duration,
) -> (
Self,
mpsc::Receiver<(ReviewRequest, oneshot::Sender<ReviewResponse>)>,
) {
let (tx, rx) = mpsc::channel(32);
let critic = Self {
review_sender: tx,
timeout,
default_reviewer: "human".into(),
};
(critic, rx)
}
pub fn with_reviewer_name(mut self, name: impl Into<String>) -> Self {
self.default_reviewer = name.into();
self
}
}
#[async_trait::async_trait]
impl Critic for HumanCritic {
async fn evaluate(&self, content: &str, context: &str) -> Result<CriticResult, CriticError> {
let review_id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let request = ReviewRequest {
review_id: review_id.clone(),
content: content.to_string(),
context: context.to_string(),
rubric_dimensions: Vec::new(),
requested_at: now,
deadline: now + chrono::Duration::from_std(self.timeout).unwrap_or_default(),
};
let (response_tx, response_rx) = oneshot::channel();
self.review_sender
.send((request, response_tx))
.await
.map_err(|_| CriticError::ChannelClosed)?;
match tokio::time::timeout(self.timeout, response_rx).await {
Ok(Ok(response)) => Ok(response.result),
Ok(Err(_)) => Err(CriticError::ChannelClosed),
Err(_) => Err(CriticError::Timeout {
review_id,
timeout: self.timeout,
}),
}
}
fn critic_type(&self) -> &str {
"human"
}
}
pub struct LlmCritic {
provider: Arc<dyn crate::reasoning::inference::InferenceProvider>,
system_prompt: String,
model_id: String,
}
impl LlmCritic {
pub fn new(
provider: Arc<dyn crate::reasoning::inference::InferenceProvider>,
system_prompt: impl Into<String>,
) -> Self {
let model_id = provider.default_model().to_string();
Self {
provider,
system_prompt: system_prompt.into(),
model_id,
}
}
}
#[async_trait::async_trait]
impl Critic for LlmCritic {
async fn evaluate(&self, content: &str, context: &str) -> Result<CriticResult, CriticError> {
use crate::reasoning::conversation::{Conversation, ConversationMessage};
use crate::reasoning::inference::{InferenceOptions, ResponseFormat};
let mut conv = Conversation::with_system(&self.system_prompt);
conv.push(ConversationMessage::user(format!(
"Context: {}\n\nContent to review:\n{}",
context, content
)));
let schema = serde_json::json!({
"type": "object",
"properties": {
"approved": {"type": "boolean"},
"score": {"type": "number", "minimum": 0.0, "maximum": 1.0},
"feedback": {"type": "string"}
},
"required": ["approved", "score", "feedback"]
});
let options = InferenceOptions {
response_format: ResponseFormat::JsonSchema {
schema,
name: Some("critic_evaluation".into()),
},
..InferenceOptions::default()
};
let response = self.provider.complete(&conv, &options).await.map_err(|e| {
CriticError::InferenceError {
message: e.to_string(),
}
})?;
let parsed: serde_json::Value =
serde_json::from_str(&response.content).map_err(|e| CriticError::ParseError {
message: e.to_string(),
})?;
Ok(CriticResult {
approved: parsed["approved"].as_bool().unwrap_or(false),
score: parsed["score"].as_f64().unwrap_or(0.0),
dimension_scores: std::collections::HashMap::new(),
feedback: parsed["feedback"].as_str().unwrap_or("").to_string(),
reviewer: ReviewerIdentity::Llm {
model_id: self.model_id.clone(),
},
})
}
fn critic_type(&self) -> &str {
"llm"
}
}
#[derive(Debug, thiserror::Error)]
pub enum CriticError {
#[error("Review timed out (review_id={review_id}, timeout={timeout:?})")]
Timeout {
review_id: String,
timeout: Duration,
},
#[error("Review channel closed")]
ChannelClosed,
#[error("Inference error: {message}")]
InferenceError { message: String },
#[error("Failed to parse critic response: {message}")]
ParseError { message: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_critic_result_serde() {
let result = CriticResult {
approved: true,
score: 0.85,
dimension_scores: {
let mut m = std::collections::HashMap::new();
m.insert("accuracy".into(), 0.9);
m.insert("clarity".into(), 0.8);
m
},
feedback: "Good analysis.".into(),
reviewer: ReviewerIdentity::Llm {
model_id: "claude-sonnet".into(),
},
};
let json = serde_json::to_string(&result).unwrap();
let restored: CriticResult = serde_json::from_str(&json).unwrap();
assert!(restored.approved);
assert!((restored.score - 0.85).abs() < f64::EPSILON);
assert_eq!(restored.dimension_scores.len(), 2);
}
#[test]
fn test_review_request_serde() {
let request = ReviewRequest {
review_id: "test-123".into(),
content: "Content to review".into(),
context: "Generated by agent X".into(),
rubric_dimensions: vec!["accuracy".into(), "completeness".into()],
requested_at: chrono::Utc::now(),
deadline: chrono::Utc::now() + chrono::Duration::minutes(5),
};
let json = serde_json::to_string(&request).unwrap();
let restored: ReviewRequest = serde_json::from_str(&json).unwrap();
assert_eq!(restored.review_id, "test-123");
assert_eq!(restored.rubric_dimensions.len(), 2);
}
#[test]
fn test_reviewer_identity_serde() {
let llm = ReviewerIdentity::Llm {
model_id: "gpt-4".into(),
};
let json = serde_json::to_string(&llm).unwrap();
assert!(json.contains("\"type\":\"llm\""));
let human = ReviewerIdentity::Human {
user_id: "user-1".into(),
name: "Alice".into(),
};
let json = serde_json::to_string(&human).unwrap();
assert!(json.contains("\"type\":\"human\""));
}
#[tokio::test]
async fn test_human_critic_timeout() {
let (critic, _rx) = HumanCritic::new(Duration::from_millis(50));
let result = critic.evaluate("test content", "test context").await;
assert!(result.is_err());
match result.unwrap_err() {
CriticError::Timeout { .. } => {}
other => panic!("Expected Timeout, got {:?}", other),
}
}
#[tokio::test]
async fn test_human_critic_response() {
let (critic, mut rx) = HumanCritic::new(Duration::from_secs(5));
tokio::spawn(async move {
if let Some((request, response_tx)) = rx.recv().await {
let _ = response_tx.send(ReviewResponse {
review_id: request.review_id,
result: CriticResult {
approved: true,
score: 0.9,
dimension_scores: std::collections::HashMap::new(),
feedback: "Looks good!".into(),
reviewer: ReviewerIdentity::Human {
user_id: "tester".into(),
name: "Test User".into(),
},
},
});
}
});
let result = critic
.evaluate("test content", "test context")
.await
.unwrap();
assert!(result.approved);
assert!((result.score - 0.9).abs() < f64::EPSILON);
assert_eq!(result.feedback, "Looks good!");
}
#[tokio::test]
async fn test_human_critic_channel_closed() {
let (critic, rx) = HumanCritic::new(Duration::from_secs(5));
drop(rx);
let result = critic.evaluate("test", "context").await;
assert!(matches!(result.unwrap_err(), CriticError::ChannelClosed));
}
}