1use crate::error::{EvalError, Result};
31use crate::structured_judge::StructuredJudge;
32use adk_core::Content;
33use serde::{Deserialize, Serialize};
34
35#[cfg(feature = "embedding")]
36use std::sync::Arc;
37
38#[cfg(feature = "embedding")]
39use crate::embedding_scorer::EmbeddingScorer;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ConversationMetrics {
46 pub context_retention: f64,
49 pub goal_completion: f64,
52 pub coherence: f64,
55 pub topic_drift: f64,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ConversationScorerConfig {
66 pub context_retention_threshold: f64,
68 pub goal_completion_threshold: f64,
70 pub coherence_threshold: f64,
72 pub topic_drift_threshold: f64,
74}
75
76impl Default for ConversationScorerConfig {
77 fn default() -> Self {
78 Self {
79 context_retention_threshold: 0.7,
80 goal_completion_threshold: 0.7,
81 coherence_threshold: 0.7,
82 topic_drift_threshold: 0.7,
83 }
84 }
85}
86
87pub struct ConversationScorer {
93 judge: StructuredJudge,
94 #[cfg(feature = "embedding")]
95 embedding_scorer: Option<Arc<EmbeddingScorer>>,
96 config: ConversationScorerConfig,
97}
98
99impl ConversationScorer {
100 pub fn new(judge: StructuredJudge) -> Self {
104 Self {
105 judge,
106 #[cfg(feature = "embedding")]
107 embedding_scorer: None,
108 config: ConversationScorerConfig::default(),
109 }
110 }
111
112 #[cfg(feature = "embedding")]
117 pub fn with_embedding(judge: StructuredJudge, embedding: Arc<EmbeddingScorer>) -> Self {
118 Self {
119 judge,
120 embedding_scorer: Some(embedding),
121 config: ConversationScorerConfig::default(),
122 }
123 }
124
125 #[cfg(feature = "embedding")]
127 pub fn with_config(
128 judge: StructuredJudge,
129 embedding: Option<Arc<EmbeddingScorer>>,
130 config: ConversationScorerConfig,
131 ) -> Self {
132 Self { judge, embedding_scorer: embedding, config }
133 }
134
135 #[cfg(not(feature = "embedding"))]
137 pub fn with_config(judge: StructuredJudge, config: ConversationScorerConfig) -> Self {
138 Self { judge, config }
139 }
140
141 pub fn config(&self) -> &ConversationScorerConfig {
143 &self.config
144 }
145
146 pub async fn score(&self, conversation: &[Content], goal: &str) -> Result<ConversationMetrics> {
160 if conversation.is_empty() {
161 return Err(EvalError::ScoringError("cannot score an empty conversation".to_string()));
162 }
163
164 let context_retention = self.score_context_retention(conversation).await?;
165 let goal_completion = self.score_goal_completion(conversation, goal).await?;
166 let coherence = self.score_coherence(conversation).await?;
167 let topic_drift = self.score_topic_drift(conversation).await?;
168
169 Ok(ConversationMetrics {
170 context_retention: context_retention.clamp(0.0, 1.0),
171 goal_completion: goal_completion.clamp(0.0, 1.0),
172 coherence: coherence.clamp(0.0, 1.0),
173 topic_drift: topic_drift.clamp(0.0, 1.0),
174 })
175 }
176
177 async fn score_context_retention(&self, conversation: &[Content]) -> Result<f64> {
182 let conversation_text = format_conversation(conversation);
183
184 let criterion = "Context Retention: Evaluate whether the agent correctly \
185 references and uses information from earlier turns in the conversation. \
186 A score of 1.0 means the agent perfectly retains and uses all prior context. \
187 A score of 0.0 means the agent completely ignores previous context.";
188
189 let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
190
191 Ok(verdict.score)
192 }
193
194 async fn score_goal_completion(&self, conversation: &[Content], goal: &str) -> Result<f64> {
199 let conversation_text = format_conversation(conversation);
200
201 let criterion = format!(
202 "Goal Completion: Evaluate whether the agent successfully achieves \
203 the following goal across the conversation: \"{goal}\". \
204 A score of 1.0 means the goal is fully achieved. \
205 A score of 0.0 means no progress toward the goal was made."
206 );
207
208 let verdict = self.judge.judge(goal, &conversation_text, &criterion).await?;
209
210 Ok(verdict.score)
211 }
212
213 async fn score_coherence(&self, conversation: &[Content]) -> Result<f64> {
217 let conversation_text = format_conversation(conversation);
218
219 let criterion = "Coherence: Evaluate the logical consistency between consecutive \
220 responses in this conversation. A score of 1.0 means all responses are \
221 perfectly logically consistent with each other. A score of 0.0 means \
222 responses contradict each other or are completely incoherent.";
223
224 let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
225
226 Ok(verdict.score)
227 }
228
229 async fn score_topic_drift(&self, conversation: &[Content]) -> Result<f64> {
234 #[cfg(feature = "embedding")]
235 if let Some(embedding) = &self.embedding_scorer {
236 return self.score_topic_drift_embedding(conversation, embedding).await;
237 }
238
239 self.score_topic_drift_judge(conversation).await
241 }
242
243 #[cfg(feature = "embedding")]
245 async fn score_topic_drift_embedding(
246 &self,
247 conversation: &[Content],
248 embedding: &EmbeddingScorer,
249 ) -> Result<f64> {
250 let first_text = extract_text_from_content(&conversation[0]);
251 let last_text = extract_text_from_content(conversation.last().unwrap());
252
253 if first_text.is_empty() || last_text.is_empty() {
254 return self.score_topic_drift_judge(conversation).await;
256 }
257
258 embedding.score(&first_text, &last_text).await
260 }
261
262 async fn score_topic_drift_judge(&self, conversation: &[Content]) -> Result<f64> {
264 let conversation_text = format_conversation(conversation);
265
266 let criterion = "Topic Drift: Evaluate how well the conversation stays on its \
267 original topic. A score of 1.0 means the conversation remains perfectly \
268 on-topic throughout. A score of 0.0 means the conversation has completely \
269 diverged from its original topic.";
270
271 let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
272
273 Ok(verdict.score)
274 }
275}
276
277fn extract_text_from_content(content: &Content) -> String {
279 content.parts.iter().filter_map(|part| part.text()).collect::<Vec<_>>().join(" ")
280}
281
282fn format_conversation(conversation: &[Content]) -> String {
284 let mut output = String::new();
285 for (i, content) in conversation.iter().enumerate() {
286 let text = extract_text_from_content(content);
287 if !text.is_empty() {
288 output.push_str(&format!("Turn {} [{}]: {}\n", i + 1, content.role, text));
289 }
290 }
291 output
292}
293
294#[cfg(test)]
295mod tests {
296 #[allow(unused_imports)]
297 use super::*;
298 use std::sync::Arc;
299
300 #[test]
301 fn test_conversation_scorer_config_default() {
302 let config = ConversationScorerConfig::default();
303 assert_eq!(config.context_retention_threshold, 0.7);
304 assert_eq!(config.goal_completion_threshold, 0.7);
305 assert_eq!(config.coherence_threshold, 0.7);
306 assert_eq!(config.topic_drift_threshold, 0.7);
307 }
308
309 #[test]
310 fn test_conversation_scorer_config_serialization() {
311 let config = ConversationScorerConfig {
312 context_retention_threshold: 0.8,
313 goal_completion_threshold: 0.6,
314 coherence_threshold: 0.75,
315 topic_drift_threshold: 0.9,
316 };
317 let json = serde_json::to_string(&config).unwrap();
318 let deserialized: ConversationScorerConfig = serde_json::from_str(&json).unwrap();
319 assert_eq!(deserialized.context_retention_threshold, 0.8);
320 assert_eq!(deserialized.goal_completion_threshold, 0.6);
321 assert_eq!(deserialized.coherence_threshold, 0.75);
322 assert_eq!(deserialized.topic_drift_threshold, 0.9);
323 }
324
325 #[test]
326 fn test_conversation_metrics_serialization() {
327 let metrics = ConversationMetrics {
328 context_retention: 0.85,
329 goal_completion: 0.9,
330 coherence: 0.75,
331 topic_drift: 0.8,
332 };
333 let json = serde_json::to_string(&metrics).unwrap();
334 let deserialized: ConversationMetrics = serde_json::from_str(&json).unwrap();
335 assert_eq!(deserialized.context_retention, 0.85);
336 assert_eq!(deserialized.goal_completion, 0.9);
337 assert_eq!(deserialized.coherence, 0.75);
338 assert_eq!(deserialized.topic_drift, 0.8);
339 }
340
341 #[test]
342 fn test_extract_text_from_content() {
343 let content = Content::new("user").with_text("Hello world");
344 let text = extract_text_from_content(&content);
345 assert_eq!(text, "Hello world");
346 }
347
348 #[test]
349 fn test_extract_text_from_content_multiple_parts() {
350 let content = Content::new("model").with_text("Part one").with_text("Part two");
351 let text = extract_text_from_content(&content);
352 assert_eq!(text, "Part one Part two");
353 }
354
355 #[test]
356 fn test_extract_text_from_empty_content() {
357 let content = Content::new("user");
358 let text = extract_text_from_content(&content);
359 assert_eq!(text, "");
360 }
361
362 #[test]
363 fn test_format_conversation() {
364 let conversation = vec![
365 Content::new("user").with_text("Hi there"),
366 Content::new("model").with_text("Hello! How can I help?"),
367 Content::new("user").with_text("Tell me about Rust"),
368 ];
369 let formatted = format_conversation(&conversation);
370 assert!(formatted.contains("Turn 1 [user]: Hi there"));
371 assert!(formatted.contains("Turn 2 [model]: Hello! How can I help?"));
372 assert!(formatted.contains("Turn 3 [user]: Tell me about Rust"));
373 }
374
375 #[test]
376 fn test_format_conversation_skips_empty_text() {
377 let conversation = vec![
378 Content::new("user").with_text("Hello"),
379 Content::new("model"), Content::new("user").with_text("World"),
381 ];
382 let formatted = format_conversation(&conversation);
383 assert!(formatted.contains("Turn 1 [user]: Hello"));
384 assert!(!formatted.contains("Turn 2"));
385 assert!(formatted.contains("Turn 3 [user]: World"));
386 }
387
388 #[test]
389 fn test_conversation_scorer_new() {
390 let model = Arc::new(adk_model::MockLlm::new("test-model"));
391 let judge = StructuredJudge::new(model);
392 let scorer = ConversationScorer::new(judge);
393 assert_eq!(scorer.config().context_retention_threshold, 0.7);
394 }
395
396 #[tokio::test]
397 async fn test_conversation_scorer_empty_conversation_error() {
398 let model = Arc::new(adk_model::MockLlm::new("test-model"));
399 let judge = StructuredJudge::new(model);
400 let scorer = ConversationScorer::new(judge);
401
402 let result = scorer.score(&[], "some goal").await;
403 assert!(result.is_err());
404 let err = result.unwrap_err();
405 assert!(err.to_string().contains("empty conversation"));
406 }
407}