use crate::agent::self_reflection::critic::Critic;
use crate::agent::self_reflection::types::{Critique, CritiqueOutput, critique_output_schema};
use crate::error::{ReactError, Result};
use crate::llm;
use crate::llm::ResponseFormat;
use crate::llm::types::Message;
use futures::future::BoxFuture;
use reqwest::Client;
use std::sync::Arc;
use tracing::{debug, info, warn};
pub struct LlmCritic {
model: String,
client: Arc<Client>,
system_prompt: String,
pass_threshold: f64,
}
impl LlmCritic {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
client: Arc::new(
Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap_or_default(),
),
system_prompt: Self::default_system_prompt().to_string(),
pass_threshold: 7.0,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_pass_threshold(mut self, threshold: f64) -> Self {
self.pass_threshold = threshold;
self
}
fn default_system_prompt() -> &'static str {
"你是一个严格的质量评估专家。你需要评估给定回答的质量。\n\n\
评估维度:\n\
1. 准确性:事实是否正确\n\
2. 完整性:是否覆盖了所有要点\n\
3. 清晰度:表述是否清楚易懂\n\
4. 实用性:是否提供了有价值的信息\n\n\
评分标准:\n\
- 9.0-10.0: 优秀,几乎无可挑剔\n\
- 7.0-8.9: 良好,基本正确但有小瑕疵\n\
- 5.0-6.9: 一般,有明显不足\n\
- 0.0-4.9: 较差,存在严重错误\n\n\
请严格按 JSON Schema 返回结构化数据。"
}
fn parse_critique_output(content: &str) -> Result<CritiqueOutput> {
if let Ok(output) = serde_json::from_str::<CritiqueOutput>(content) {
return Ok(output);
}
let json_str = crate::utils::json_parse::extract_json_from_markdown(content);
if let Ok(output) = serde_json::from_str::<CritiqueOutput>(&json_str) {
return Ok(output);
}
Self::try_auto_fix(&json_str)
}
fn try_auto_fix(json_str: &str) -> Result<CritiqueOutput> {
let fixed = crate::utils::json_parse::clean_json(json_str);
match serde_json::from_str::<CritiqueOutput>(&fixed) {
Ok(output) => {
info!("Auto-fix succeeded for LLM critique output");
Ok(output)
}
Err(e) => {
warn!(error = %e, "Failed to parse critique output");
Ok(CritiqueOutput {
score: 5.0,
passed: false,
feedback: json_str.trim().to_string(),
suggestions: vec![],
})
}
}
}
}
impl Critic for LlmCritic {
fn critique<'a>(
&'a self,
task: &'a str,
answer: &'a str,
context: &'a str,
) -> BoxFuture<'a, Result<Critique>> {
Box::pin(async move {
info!(model = %self.model, "LlmCritic: evaluating answer");
let user_content = if context.is_empty() {
format!("原始任务:\n{}\n\n待评估的回答:\n{}", task, answer)
} else {
format!(
"原始任务:\n{}\n\n待评估的回答:\n{}\n\n附加上下文:\n{}",
task, answer, context
)
};
let messages = vec![
Message::system(self.system_prompt.clone()),
Message::user(user_content),
];
let response_format = Some(ResponseFormat::json_schema(
"critique_output",
critique_output_schema(),
));
let response = llm::chat(
self.client.clone(),
&self.model,
&messages,
Some(0.3),
Some(2048u32),
Some(false),
None,
None,
response_format,
)
.await
.map_err(|e| ReactError::Other(format!("LLM critique call failed: {}", e)))?;
let content = response
.choices
.first()
.and_then(|c| c.message.content.as_text())
.unwrap_or_default();
debug!(response = %content, "LlmCritic raw response");
let output = Self::parse_critique_output(&content)?;
let mut critique: Critique = output.into();
critique.passed = critique.score >= self.pass_threshold;
info!(
score = critique.score,
passed = critique.passed,
"LlmCritic: evaluation complete"
);
Ok(critique)
})
}
fn name(&self) -> &str {
"llm_critic"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_critique_output_json() {
let json = r#"{"score": 8.5, "passed": true, "feedback": "回答准确", "suggestions": ["可以更详细"]}"#;
let output = LlmCritic::parse_critique_output(json).unwrap();
assert_eq!(output.score, 8.5);
assert!(output.passed);
}
#[test]
fn test_parse_critique_output_markdown() {
let response = r#"```json
{"score": 6.0, "passed": false, "feedback": "不够完整", "suggestions": ["增加示例"]}
```"#;
let output = LlmCritic::parse_critique_output(response).unwrap();
assert_eq!(output.score, 6.0);
assert!(!output.passed);
}
#[test]
fn test_parse_critique_auto_fix() {
let json = r#"{"score": 7.0, "passed": true, "feedback": "良好",}"#;
let output = LlmCritic::parse_critique_output(json).unwrap();
assert_eq!(output.score, 7.0);
}
#[test]
fn test_parse_critique_fallback() {
let text = "无法解析的文本";
let output = LlmCritic::parse_critique_output(text).unwrap();
assert!(!output.passed); assert_eq!(output.score, 5.0);
}
}