distx_schema/
explain.rs

1//! Explainability for structured similarity results
2//!
3//! Provides output structures that explain how similarity scores were computed,
4//! showing per-field contributions for transparency.
5
6use crate::rerank::RankedResult;
7use distx_core::PointId;
8use serde::Serialize;
9use serde_json::Value;
10use std::collections::HashMap;
11
12/// An explained similarity result with per-field score breakdown
13#[derive(Debug, Clone, Serialize)]
14pub struct ExplainedResult {
15    /// Point ID
16    pub id: PointIdSer,
17    /// Overall weighted similarity score
18    pub score: f32,
19    /// Point payload
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub payload: Option<Value>,
22    /// Per-field score contributions (already weighted)
23    pub explain: HashMap<String, f32>,
24}
25
26/// Serializable wrapper for PointId
27#[derive(Debug, Clone, Serialize)]
28#[serde(untagged)]
29pub enum PointIdSer {
30    String(String),
31    Integer(u64),
32}
33
34impl From<&PointId> for PointIdSer {
35    fn from(id: &PointId) -> Self {
36        match id {
37            PointId::String(s) => PointIdSer::String(s.clone()),
38            PointId::Uuid(u) => PointIdSer::String(u.to_string()),
39            PointId::Integer(i) => PointIdSer::Integer(*i),
40        }
41    }
42}
43
44impl ExplainedResult {
45    /// Create an explained result from a ranked result
46    pub fn from_ranked(ranked: RankedResult, include_payload: bool) -> Self {
47        Self {
48            id: PointIdSer::from(&ranked.point.id),
49            score: ranked.score,
50            payload: if include_payload { ranked.point.payload.clone() } else { None },
51            explain: ranked.field_scores,
52        }
53    }
54
55    /// Create a list of explained results from ranked results
56    pub fn from_ranked_list(ranked_list: Vec<RankedResult>, include_payload: bool) -> Vec<Self> {
57        ranked_list
58            .into_iter()
59            .map(|r| Self::from_ranked(r, include_payload))
60            .collect()
61    }
62}
63
64/// Response structure for the /similar endpoint
65#[derive(Debug, Clone, Serialize)]
66pub struct SimilarResponse {
67    /// List of similar items with explanations
68    pub result: Vec<ExplainedResult>,
69}
70
71impl SimilarResponse {
72    /// Create a new similar response
73    pub fn new(results: Vec<ExplainedResult>) -> Self {
74        Self { result: results }
75    }
76
77    /// Create from ranked results
78    pub fn from_ranked(ranked_list: Vec<RankedResult>, include_payload: bool) -> Self {
79        Self {
80            result: ExplainedResult::from_ranked_list(ranked_list, include_payload),
81        }
82    }
83}
84
85/// Summary statistics for a similarity query
86#[derive(Debug, Clone, Serialize)]
87pub struct SimilarityStats {
88    /// Number of candidates considered
89    pub candidates_count: usize,
90    /// Number of results returned
91    pub results_count: usize,
92    /// Average score of results
93    pub avg_score: f32,
94    /// Score of best result
95    pub best_score: f32,
96    /// Field that contributed most to best result
97    pub top_contributing_field: Option<String>,
98}
99
100impl SimilarityStats {
101    /// Compute stats from ranked results
102    pub fn compute(results: &[RankedResult], candidates_count: usize) -> Self {
103        if results.is_empty() {
104            return Self {
105                candidates_count,
106                results_count: 0,
107                avg_score: 0.0,
108                best_score: 0.0,
109                top_contributing_field: None,
110            };
111        }
112
113        let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
114        let avg_score = scores.iter().sum::<f32>() / scores.len() as f32;
115        let best_score = scores[0]; // Results are sorted
116
117        // Find top contributing field in best result
118        let top_contributing_field = results[0]
119            .field_scores
120            .iter()
121            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
122            .map(|(name, _)| name.clone());
123
124        Self {
125            candidates_count,
126            results_count: results.len(),
127            avg_score,
128            best_score,
129            top_contributing_field,
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use distx_core::{Point, Vector};
138    use serde_json::json;
139
140    fn create_test_ranked_result(id: &str, score: f32) -> RankedResult {
141        let mut field_scores = HashMap::new();
142        field_scores.insert("name".to_string(), 0.4);
143        field_scores.insert("price".to_string(), 0.25);
144        field_scores.insert("category".to_string(), 0.15);
145
146        RankedResult {
147            point: Point::new(
148                PointId::String(id.to_string()),
149                Vector::new(vec![0.0; 10]),
150                Some(json!({"name": "Test", "price": 1.99})),
151            ),
152            score,
153            field_scores,
154        }
155    }
156
157    #[test]
158    fn test_explained_result_creation() {
159        let ranked = create_test_ranked_result("1", 0.85);
160        let explained = ExplainedResult::from_ranked(ranked, true);
161
162        assert!(matches!(explained.id, PointIdSer::String(ref s) if s == "1"));
163        assert_eq!(explained.score, 0.85);
164        assert!(explained.payload.is_some());
165        assert_eq!(explained.explain.len(), 3);
166    }
167
168    #[test]
169    fn test_explained_result_without_payload() {
170        let ranked = create_test_ranked_result("1", 0.85);
171        let explained = ExplainedResult::from_ranked(ranked, false);
172
173        assert!(explained.payload.is_none());
174    }
175
176    #[test]
177    fn test_similar_response_serialization() {
178        let ranked = vec![
179            create_test_ranked_result("1", 0.95),
180            create_test_ranked_result("2", 0.85),
181        ];
182
183        let response = SimilarResponse::from_ranked(ranked, true);
184        let json = serde_json::to_string(&response).unwrap();
185
186        assert!(json.contains("\"result\""));
187        assert!(json.contains("\"score\""));
188        assert!(json.contains("\"explain\""));
189    }
190
191    #[test]
192    fn test_similarity_stats() {
193        let results = vec![
194            create_test_ranked_result("1", 0.95),
195            create_test_ranked_result("2", 0.85),
196            create_test_ranked_result("3", 0.75),
197        ];
198
199        let stats = SimilarityStats::compute(&results, 10);
200
201        assert_eq!(stats.candidates_count, 10);
202        assert_eq!(stats.results_count, 3);
203        assert_eq!(stats.best_score, 0.95);
204        assert!((stats.avg_score - 0.85).abs() < 0.01);
205        assert_eq!(stats.top_contributing_field, Some("name".to_string()));
206    }
207
208    #[test]
209    fn test_empty_stats() {
210        let stats = SimilarityStats::compute(&[], 5);
211
212        assert_eq!(stats.candidates_count, 5);
213        assert_eq!(stats.results_count, 0);
214        assert_eq!(stats.best_score, 0.0);
215    }
216
217    #[test]
218    fn test_point_id_ser_variants() {
219        let string_id = PointIdSer::from(&PointId::String("test".to_string()));
220        let int_id = PointIdSer::from(&PointId::Integer(42));
221
222        let string_json = serde_json::to_string(&string_id).unwrap();
223        let int_json = serde_json::to_string(&int_id).unwrap();
224
225        assert_eq!(string_json, "\"test\"");
226        assert_eq!(int_json, "42");
227    }
228}