1use crate::rerank::RankedResult;
7use distx_core::PointId;
8use serde::Serialize;
9use serde_json::Value;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize)]
14pub struct ExplainedResult {
15 pub id: PointIdSer,
17 pub score: f32,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub payload: Option<Value>,
22 pub explain: HashMap<String, f32>,
24}
25
26#[derive(Debug, Clone, Serialize)]
28#[serde(untagged)]
29pub enum PointIdSer {
30 String(String),
31 Integer(u64),
32}
33
34impl From<&PointId> for PointIdSer {
35 fn from(id: &PointId) -> Self {
36 match id {
37 PointId::String(s) => PointIdSer::String(s.clone()),
38 PointId::Uuid(u) => PointIdSer::String(u.to_string()),
39 PointId::Integer(i) => PointIdSer::Integer(*i),
40 }
41 }
42}
43
44impl ExplainedResult {
45 pub fn from_ranked(ranked: RankedResult, include_payload: bool) -> Self {
47 Self {
48 id: PointIdSer::from(&ranked.point.id),
49 score: ranked.score,
50 payload: if include_payload { ranked.point.payload.clone() } else { None },
51 explain: ranked.field_scores,
52 }
53 }
54
55 pub fn from_ranked_list(ranked_list: Vec<RankedResult>, include_payload: bool) -> Vec<Self> {
57 ranked_list
58 .into_iter()
59 .map(|r| Self::from_ranked(r, include_payload))
60 .collect()
61 }
62}
63
64#[derive(Debug, Clone, Serialize)]
66pub struct SimilarResponse {
67 pub result: Vec<ExplainedResult>,
69}
70
71impl SimilarResponse {
72 pub fn new(results: Vec<ExplainedResult>) -> Self {
74 Self { result: results }
75 }
76
77 pub fn from_ranked(ranked_list: Vec<RankedResult>, include_payload: bool) -> Self {
79 Self {
80 result: ExplainedResult::from_ranked_list(ranked_list, include_payload),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize)]
87pub struct SimilarityStats {
88 pub candidates_count: usize,
90 pub results_count: usize,
92 pub avg_score: f32,
94 pub best_score: f32,
96 pub top_contributing_field: Option<String>,
98}
99
100impl SimilarityStats {
101 pub fn compute(results: &[RankedResult], candidates_count: usize) -> Self {
103 if results.is_empty() {
104 return Self {
105 candidates_count,
106 results_count: 0,
107 avg_score: 0.0,
108 best_score: 0.0,
109 top_contributing_field: None,
110 };
111 }
112
113 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
114 let avg_score = scores.iter().sum::<f32>() / scores.len() as f32;
115 let best_score = scores[0]; let top_contributing_field = results[0]
119 .field_scores
120 .iter()
121 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
122 .map(|(name, _)| name.clone());
123
124 Self {
125 candidates_count,
126 results_count: results.len(),
127 avg_score,
128 best_score,
129 top_contributing_field,
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use distx_core::{Point, Vector};
138 use serde_json::json;
139
140 fn create_test_ranked_result(id: &str, score: f32) -> RankedResult {
141 let mut field_scores = HashMap::new();
142 field_scores.insert("name".to_string(), 0.4);
143 field_scores.insert("price".to_string(), 0.25);
144 field_scores.insert("category".to_string(), 0.15);
145
146 RankedResult {
147 point: Point::new(
148 PointId::String(id.to_string()),
149 Vector::new(vec![0.0; 10]),
150 Some(json!({"name": "Test", "price": 1.99})),
151 ),
152 score,
153 field_scores,
154 }
155 }
156
157 #[test]
158 fn test_explained_result_creation() {
159 let ranked = create_test_ranked_result("1", 0.85);
160 let explained = ExplainedResult::from_ranked(ranked, true);
161
162 assert!(matches!(explained.id, PointIdSer::String(ref s) if s == "1"));
163 assert_eq!(explained.score, 0.85);
164 assert!(explained.payload.is_some());
165 assert_eq!(explained.explain.len(), 3);
166 }
167
168 #[test]
169 fn test_explained_result_without_payload() {
170 let ranked = create_test_ranked_result("1", 0.85);
171 let explained = ExplainedResult::from_ranked(ranked, false);
172
173 assert!(explained.payload.is_none());
174 }
175
176 #[test]
177 fn test_similar_response_serialization() {
178 let ranked = vec![
179 create_test_ranked_result("1", 0.95),
180 create_test_ranked_result("2", 0.85),
181 ];
182
183 let response = SimilarResponse::from_ranked(ranked, true);
184 let json = serde_json::to_string(&response).unwrap();
185
186 assert!(json.contains("\"result\""));
187 assert!(json.contains("\"score\""));
188 assert!(json.contains("\"explain\""));
189 }
190
191 #[test]
192 fn test_similarity_stats() {
193 let results = vec![
194 create_test_ranked_result("1", 0.95),
195 create_test_ranked_result("2", 0.85),
196 create_test_ranked_result("3", 0.75),
197 ];
198
199 let stats = SimilarityStats::compute(&results, 10);
200
201 assert_eq!(stats.candidates_count, 10);
202 assert_eq!(stats.results_count, 3);
203 assert_eq!(stats.best_score, 0.95);
204 assert!((stats.avg_score - 0.85).abs() < 0.01);
205 assert_eq!(stats.top_contributing_field, Some("name".to_string()));
206 }
207
208 #[test]
209 fn test_empty_stats() {
210 let stats = SimilarityStats::compute(&[], 5);
211
212 assert_eq!(stats.candidates_count, 5);
213 assert_eq!(stats.results_count, 0);
214 assert_eq!(stats.best_score, 0.0);
215 }
216
217 #[test]
218 fn test_point_id_ser_variants() {
219 let string_id = PointIdSer::from(&PointId::String("test".to_string()));
220 let int_id = PointIdSer::from(&PointId::Integer(42));
221
222 let string_json = serde_json::to_string(&string_id).unwrap();
223 let int_json = serde_json::to_string(&int_id).unwrap();
224
225 assert_eq!(string_json, "\"test\"");
226 assert_eq!(int_json, "42");
227 }
228}