Skip to main content

adk_eval/
conversation_scorer.rs

1//! Multi-turn conversation metrics.
2//!
3//! The [`ConversationScorer`] evaluates multi-turn conversations using a
4//! [`StructuredJudge`] for semantic metrics (context retention, goal completion,
5//! coherence) and optionally an [`EmbeddingScorer`](crate::EmbeddingScorer) for
6//! topic drift measurement.
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use adk_eval::{ConversationScorer, ConversationScorerConfig};
12//! use adk_eval::structured_judge::StructuredJudge;
13//! use adk_core::Content;
14//! use std::sync::Arc;
15//!
16//! let judge = StructuredJudge::new(model);
17//! let scorer = ConversationScorer::new(judge);
18//!
19//! let conversation = vec![
20//!     Content::new("user").with_text("Hello, help me plan a trip to Paris"),
21//!     Content::new("model").with_text("I'd be happy to help with your Paris trip!"),
22//!     Content::new("user").with_text("What about hotels near the Eiffel Tower?"),
23//!     Content::new("model").with_text("Here are some great hotels near the Eiffel Tower..."),
24//! ];
25//!
26//! let metrics = scorer.score(&conversation, "Plan a trip to Paris").await?;
27//! assert!((0.0..=1.0).contains(&metrics.context_retention));
28//! ```
29
30use crate::error::{EvalError, Result};
31use crate::structured_judge::StructuredJudge;
32use adk_core::Content;
33use serde::{Deserialize, Serialize};
34
35#[cfg(feature = "embedding")]
36use std::sync::Arc;
37
38#[cfg(feature = "embedding")]
39use crate::embedding_scorer::EmbeddingScorer;
40
41/// Multi-turn conversation quality metrics.
42///
43/// All scores are in the range \[0.0, 1.0\].
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ConversationMetrics {
46    /// Score measuring whether the agent correctly references information from
47    /// prior turns (0.0–1.0).
48    pub context_retention: f64,
49    /// Score measuring whether the agent achieves the stated objective across
50    /// the conversation (0.0–1.0).
51    pub goal_completion: f64,
52    /// Score measuring logical consistency between consecutive agent responses
53    /// (0.0–1.0).
54    pub coherence: f64,
55    /// Score measuring deviation from original topic (0.0–1.0, where 1.0
56    /// indicates no drift).
57    pub topic_drift: f64,
58}
59
60/// Configuration for conversation scoring thresholds.
61///
62/// Each threshold defines the minimum acceptable score for a metric.
63/// Scores below the threshold indicate a failure for that metric.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ConversationScorerConfig {
66    /// Minimum acceptable context retention score.
67    pub context_retention_threshold: f64,
68    /// Minimum acceptable goal completion score.
69    pub goal_completion_threshold: f64,
70    /// Minimum acceptable coherence score.
71    pub coherence_threshold: f64,
72    /// Minimum acceptable topic drift score (1.0 = no drift).
73    pub topic_drift_threshold: f64,
74}
75
76impl Default for ConversationScorerConfig {
77    fn default() -> Self {
78        Self {
79            context_retention_threshold: 0.7,
80            goal_completion_threshold: 0.7,
81            coherence_threshold: 0.7,
82            topic_drift_threshold: 0.7,
83        }
84    }
85}
86
87/// Scores multi-turn conversations on quality metrics.
88///
89/// Uses a [`StructuredJudge`] for context retention, goal completion, and
90/// coherence metrics. For topic drift, uses an [`EmbeddingScorer`] if available,
91/// otherwise falls back to the structured judge.
92pub struct ConversationScorer {
93    judge: StructuredJudge,
94    #[cfg(feature = "embedding")]
95    embedding_scorer: Option<Arc<EmbeddingScorer>>,
96    config: ConversationScorerConfig,
97}
98
99impl ConversationScorer {
100    /// Create a new conversation scorer with default configuration.
101    ///
102    /// Uses the structured judge for all metrics including topic drift.
103    pub fn new(judge: StructuredJudge) -> Self {
104        Self {
105            judge,
106            #[cfg(feature = "embedding")]
107            embedding_scorer: None,
108            config: ConversationScorerConfig::default(),
109        }
110    }
111
112    /// Create a conversation scorer with an embedding scorer for topic drift.
113    ///
114    /// Topic drift will be measured using cosine similarity between the first
115    /// and last turn embeddings instead of the structured judge.
116    #[cfg(feature = "embedding")]
117    pub fn with_embedding(judge: StructuredJudge, embedding: Arc<EmbeddingScorer>) -> Self {
118        Self {
119            judge,
120            embedding_scorer: Some(embedding),
121            config: ConversationScorerConfig::default(),
122        }
123    }
124
125    /// Create a conversation scorer with full configuration.
126    #[cfg(feature = "embedding")]
127    pub fn with_config(
128        judge: StructuredJudge,
129        embedding: Option<Arc<EmbeddingScorer>>,
130        config: ConversationScorerConfig,
131    ) -> Self {
132        Self { judge, embedding_scorer: embedding, config }
133    }
134
135    /// Create a conversation scorer with custom configuration (no embedding).
136    #[cfg(not(feature = "embedding"))]
137    pub fn with_config(judge: StructuredJudge, config: ConversationScorerConfig) -> Self {
138        Self { judge, config }
139    }
140
141    /// Returns the current configuration.
142    pub fn config(&self) -> &ConversationScorerConfig {
143        &self.config
144    }
145
146    /// Score a multi-turn conversation.
147    ///
148    /// Evaluates the conversation on four metrics:
149    /// - **Context retention**: Does the agent reference prior turn information?
150    /// - **Goal completion**: Does the agent achieve the stated objective?
151    /// - **Coherence**: Are consecutive responses logically consistent?
152    /// - **Topic drift**: Does the conversation stay on topic?
153    ///
154    /// All scores are clamped to \[0.0, 1.0\].
155    ///
156    /// # Errors
157    ///
158    /// Returns an error if the judge LLM calls fail or the conversation is empty.
159    pub async fn score(&self, conversation: &[Content], goal: &str) -> Result<ConversationMetrics> {
160        if conversation.is_empty() {
161            return Err(EvalError::ScoringError("cannot score an empty conversation".to_string()));
162        }
163
164        let context_retention = self.score_context_retention(conversation).await?;
165        let goal_completion = self.score_goal_completion(conversation, goal).await?;
166        let coherence = self.score_coherence(conversation).await?;
167        let topic_drift = self.score_topic_drift(conversation).await?;
168
169        Ok(ConversationMetrics {
170            context_retention: context_retention.clamp(0.0, 1.0),
171            goal_completion: goal_completion.clamp(0.0, 1.0),
172            coherence: coherence.clamp(0.0, 1.0),
173            topic_drift: topic_drift.clamp(0.0, 1.0),
174        })
175    }
176
177    /// Score context retention using the structured judge.
178    ///
179    /// Evaluates whether the agent correctly references information from
180    /// prior turns in its responses.
181    async fn score_context_retention(&self, conversation: &[Content]) -> Result<f64> {
182        let conversation_text = format_conversation(conversation);
183
184        let criterion = "Context Retention: Evaluate whether the agent correctly \
185            references and uses information from earlier turns in the conversation. \
186            A score of 1.0 means the agent perfectly retains and uses all prior context. \
187            A score of 0.0 means the agent completely ignores previous context.";
188
189        let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
190
191        Ok(verdict.score)
192    }
193
194    /// Score goal completion using the structured judge.
195    ///
196    /// Evaluates whether the agent achieves the stated objective across the
197    /// conversation.
198    async fn score_goal_completion(&self, conversation: &[Content], goal: &str) -> Result<f64> {
199        let conversation_text = format_conversation(conversation);
200
201        let criterion = format!(
202            "Goal Completion: Evaluate whether the agent successfully achieves \
203            the following goal across the conversation: \"{goal}\". \
204            A score of 1.0 means the goal is fully achieved. \
205            A score of 0.0 means no progress toward the goal was made."
206        );
207
208        let verdict = self.judge.judge(goal, &conversation_text, &criterion).await?;
209
210        Ok(verdict.score)
211    }
212
213    /// Score coherence using the structured judge.
214    ///
215    /// Evaluates logical consistency between consecutive agent responses.
216    async fn score_coherence(&self, conversation: &[Content]) -> Result<f64> {
217        let conversation_text = format_conversation(conversation);
218
219        let criterion = "Coherence: Evaluate the logical consistency between consecutive \
220            responses in this conversation. A score of 1.0 means all responses are \
221            perfectly logically consistent with each other. A score of 0.0 means \
222            responses contradict each other or are completely incoherent.";
223
224        let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
225
226        Ok(verdict.score)
227    }
228
229    /// Score topic drift.
230    ///
231    /// If an embedding scorer is available, uses cosine similarity between the
232    /// first and last turn text. Otherwise falls back to the structured judge.
233    async fn score_topic_drift(&self, conversation: &[Content]) -> Result<f64> {
234        #[cfg(feature = "embedding")]
235        if let Some(embedding) = &self.embedding_scorer {
236            return self.score_topic_drift_embedding(conversation, embedding).await;
237        }
238
239        // Fallback: use structured judge
240        self.score_topic_drift_judge(conversation).await
241    }
242
243    /// Score topic drift using embedding similarity between first and last turns.
244    #[cfg(feature = "embedding")]
245    async fn score_topic_drift_embedding(
246        &self,
247        conversation: &[Content],
248        embedding: &EmbeddingScorer,
249    ) -> Result<f64> {
250        let first_text = extract_text_from_content(&conversation[0]);
251        let last_text = extract_text_from_content(conversation.last().unwrap());
252
253        if first_text.is_empty() || last_text.is_empty() {
254            // Fall back to judge if we can't extract text
255            return self.score_topic_drift_judge(conversation).await;
256        }
257
258        // Similarity of 1.0 means no topic drift
259        embedding.score(&first_text, &last_text).await
260    }
261
262    /// Score topic drift using the structured judge as fallback.
263    async fn score_topic_drift_judge(&self, conversation: &[Content]) -> Result<f64> {
264        let conversation_text = format_conversation(conversation);
265
266        let criterion = "Topic Drift: Evaluate how well the conversation stays on its \
267            original topic. A score of 1.0 means the conversation remains perfectly \
268            on-topic throughout. A score of 0.0 means the conversation has completely \
269            diverged from its original topic.";
270
271        let verdict = self.judge.judge(&conversation_text, &conversation_text, criterion).await?;
272
273        Ok(verdict.score)
274    }
275}
276
277/// Extract all text parts from a Content item and concatenate them.
278fn extract_text_from_content(content: &Content) -> String {
279    content.parts.iter().filter_map(|part| part.text()).collect::<Vec<_>>().join(" ")
280}
281
282/// Format a conversation into a readable string for the judge.
283fn format_conversation(conversation: &[Content]) -> String {
284    let mut output = String::new();
285    for (i, content) in conversation.iter().enumerate() {
286        let text = extract_text_from_content(content);
287        if !text.is_empty() {
288            output.push_str(&format!("Turn {} [{}]: {}\n", i + 1, content.role, text));
289        }
290    }
291    output
292}
293
294#[cfg(test)]
295mod tests {
296    #[allow(unused_imports)]
297    use super::*;
298    use std::sync::Arc;
299
300    #[test]
301    fn test_conversation_scorer_config_default() {
302        let config = ConversationScorerConfig::default();
303        assert_eq!(config.context_retention_threshold, 0.7);
304        assert_eq!(config.goal_completion_threshold, 0.7);
305        assert_eq!(config.coherence_threshold, 0.7);
306        assert_eq!(config.topic_drift_threshold, 0.7);
307    }
308
309    #[test]
310    fn test_conversation_scorer_config_serialization() {
311        let config = ConversationScorerConfig {
312            context_retention_threshold: 0.8,
313            goal_completion_threshold: 0.6,
314            coherence_threshold: 0.75,
315            topic_drift_threshold: 0.9,
316        };
317        let json = serde_json::to_string(&config).unwrap();
318        let deserialized: ConversationScorerConfig = serde_json::from_str(&json).unwrap();
319        assert_eq!(deserialized.context_retention_threshold, 0.8);
320        assert_eq!(deserialized.goal_completion_threshold, 0.6);
321        assert_eq!(deserialized.coherence_threshold, 0.75);
322        assert_eq!(deserialized.topic_drift_threshold, 0.9);
323    }
324
325    #[test]
326    fn test_conversation_metrics_serialization() {
327        let metrics = ConversationMetrics {
328            context_retention: 0.85,
329            goal_completion: 0.9,
330            coherence: 0.75,
331            topic_drift: 0.8,
332        };
333        let json = serde_json::to_string(&metrics).unwrap();
334        let deserialized: ConversationMetrics = serde_json::from_str(&json).unwrap();
335        assert_eq!(deserialized.context_retention, 0.85);
336        assert_eq!(deserialized.goal_completion, 0.9);
337        assert_eq!(deserialized.coherence, 0.75);
338        assert_eq!(deserialized.topic_drift, 0.8);
339    }
340
341    #[test]
342    fn test_extract_text_from_content() {
343        let content = Content::new("user").with_text("Hello world");
344        let text = extract_text_from_content(&content);
345        assert_eq!(text, "Hello world");
346    }
347
348    #[test]
349    fn test_extract_text_from_content_multiple_parts() {
350        let content = Content::new("model").with_text("Part one").with_text("Part two");
351        let text = extract_text_from_content(&content);
352        assert_eq!(text, "Part one Part two");
353    }
354
355    #[test]
356    fn test_extract_text_from_empty_content() {
357        let content = Content::new("user");
358        let text = extract_text_from_content(&content);
359        assert_eq!(text, "");
360    }
361
362    #[test]
363    fn test_format_conversation() {
364        let conversation = vec![
365            Content::new("user").with_text("Hi there"),
366            Content::new("model").with_text("Hello! How can I help?"),
367            Content::new("user").with_text("Tell me about Rust"),
368        ];
369        let formatted = format_conversation(&conversation);
370        assert!(formatted.contains("Turn 1 [user]: Hi there"));
371        assert!(formatted.contains("Turn 2 [model]: Hello! How can I help?"));
372        assert!(formatted.contains("Turn 3 [user]: Tell me about Rust"));
373    }
374
375    #[test]
376    fn test_format_conversation_skips_empty_text() {
377        let conversation = vec![
378            Content::new("user").with_text("Hello"),
379            Content::new("model"), // empty content
380            Content::new("user").with_text("World"),
381        ];
382        let formatted = format_conversation(&conversation);
383        assert!(formatted.contains("Turn 1 [user]: Hello"));
384        assert!(!formatted.contains("Turn 2"));
385        assert!(formatted.contains("Turn 3 [user]: World"));
386    }
387
388    #[test]
389    fn test_conversation_scorer_new() {
390        let model = Arc::new(adk_model::MockLlm::new("test-model"));
391        let judge = StructuredJudge::new(model);
392        let scorer = ConversationScorer::new(judge);
393        assert_eq!(scorer.config().context_retention_threshold, 0.7);
394    }
395
396    #[tokio::test]
397    async fn test_conversation_scorer_empty_conversation_error() {
398        let model = Arc::new(adk_model::MockLlm::new("test-model"));
399        let judge = StructuredJudge::new(model);
400        let scorer = ConversationScorer::new(judge);
401
402        let result = scorer.score(&[], "some goal").await;
403        assert!(result.is_err());
404        let err = result.unwrap_err();
405        assert!(err.to_string().contains("empty conversation"));
406    }
407}