1use crate::adaptive::{quantize_adaptive, AccuracyTarget};
13use crate::encode::quantize_embedding;
14use crate::error::QuantError;
15use crate::scheme::QuantScheme;
16use crate::vector::QuantizedVector;
17use lnmp_embedding::Vector;
18use std::time::{Duration, Instant};
19
20#[derive(Debug, Clone, Copy, Default)]
22pub struct BatchStats {
23 pub total: usize,
25 pub succeeded: usize,
27 pub failed: usize,
29 pub total_time: Duration,
31}
32
33#[derive(Debug)]
35pub struct BatchResult {
36 pub results: Vec<Result<QuantizedVector, QuantError>>,
39 pub stats: BatchStats,
41}
42
43pub fn quantize_batch(embeddings: &[Vector], scheme: QuantScheme) -> BatchResult {
63 let start_time = Instant::now();
64 let mut results = Vec::with_capacity(embeddings.len());
65 let mut succeeded = 0;
66 let mut failed = 0;
67
68 for emb in embeddings {
69 let result = quantize_embedding(emb, scheme);
70 if result.is_ok() {
71 succeeded += 1;
72 } else {
73 failed += 1;
74 }
75 results.push(result);
76 }
77
78 BatchResult {
79 results,
80 stats: BatchStats {
81 total: embeddings.len(),
82 succeeded,
83 failed,
84 total_time: start_time.elapsed(),
85 },
86 }
87}
88
89pub fn quantize_batch_adaptive(embeddings: &[Vector], target: AccuracyTarget) -> BatchResult {
98 let start_time = Instant::now();
99 let mut results = Vec::with_capacity(embeddings.len());
100 let mut succeeded = 0;
101 let mut failed = 0;
102
103 for emb in embeddings {
104 let result = quantize_adaptive(emb, target);
105 if result.is_ok() {
106 succeeded += 1;
107 } else {
108 failed += 1;
109 }
110 results.push(result);
111 }
112
113 BatchResult {
114 results,
115 stats: BatchStats {
116 total: embeddings.len(),
117 succeeded,
118 failed,
119 total_time: start_time.elapsed(),
120 },
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use lnmp_embedding::Vector;
128
129 #[test]
130 fn test_batch_quantization() {
131 let vecs = vec![
132 Vector::from_f32(vec![0.1, 0.2]),
133 Vector::from_f32(vec![0.3, 0.4]),
134 Vector::from_f32(vec![0.5, 0.6]),
135 ];
136
137 let result = quantize_batch(&vecs, QuantScheme::QInt8);
138
139 assert_eq!(result.stats.total, 3);
140 assert_eq!(result.stats.succeeded, 3);
141 assert_eq!(result.stats.failed, 0);
142 assert_eq!(result.results.len(), 3);
143
144 for res in result.results {
145 assert!(res.is_ok());
146 assert_eq!(res.unwrap().scheme, QuantScheme::QInt8);
147 }
148 }
149
150 #[test]
151 fn test_batch_adaptive() {
152 let vecs = vec![
153 Vector::from_f32(vec![0.1, 0.2]),
154 Vector::from_f32(vec![0.3, 0.4]),
155 ];
156
157 let result = quantize_batch_adaptive(&vecs, AccuracyTarget::Compact);
158
159 assert_eq!(result.stats.total, 2);
160 assert_eq!(result.stats.succeeded, 2);
161 assert_eq!(result.results.len(), 2);
162
163 for res in result.results {
164 assert!(res.is_ok());
165 assert_eq!(res.unwrap().scheme, QuantScheme::Binary);
166 }
167 }
168
169 #[test]
170 fn test_batch_with_errors() {
171 let vecs = vec![
174 Vector::from_f32(vec![0.1]),
175 Vector::from_f32(vec![]), Vector::from_f32(vec![0.2]),
177 ];
178
179 let result = quantize_batch(&vecs, QuantScheme::QInt8);
180
181 assert_eq!(result.stats.total, 3);
182 assert_eq!(result.stats.succeeded, 2);
183 assert_eq!(result.stats.failed, 1);
184
185 assert!(result.results[0].is_ok());
186 assert!(result.results[1].is_err());
187 assert!(result.results[2].is_ok());
188 }
189}