crush_core/plugin/
selector.rs1use crate::error::{PluginError, Result, ValidationError};
8use crate::plugin::{list_plugins, PluginMetadata};
9
10#[derive(Debug, Clone, Copy)]
15pub struct ScoringWeights {
16 pub throughput: f64,
18
19 pub compression_ratio: f64,
21}
22
23impl ScoringWeights {
24 pub fn new(throughput: f64, compression_ratio: f64) -> Result<Self> {
44 if throughput < 0.0 || compression_ratio < 0.0 {
45 return Err(
46 ValidationError::InvalidWeights("Weights cannot be negative".to_string()).into(),
47 );
48 }
49
50 let sum = throughput + compression_ratio;
51 if (sum - 1.0).abs() > 1e-6 {
52 return Err(ValidationError::InvalidWeights(format!(
53 "Weights must sum to 1.0, got {sum}"
54 ))
55 .into());
56 }
57
58 Ok(Self {
59 throughput,
60 compression_ratio,
61 })
62 }
63}
64
65impl Default for ScoringWeights {
66 fn default() -> Self {
68 Self {
69 throughput: 0.7,
70 compression_ratio: 0.3,
71 }
72 }
73}
74
75pub fn calculate_plugin_score(
113 plugin: &PluginMetadata,
114 all_plugins: &[PluginMetadata],
115 weights: &ScoringWeights,
116) -> f64 {
117 if all_plugins.is_empty() {
118 return 0.0;
119 }
120
121 if all_plugins.len() == 1 {
123 return 1.0;
124 }
125
126 let log_throughputs: Vec<f64> = all_plugins.iter().map(|p| p.throughput.ln()).collect();
128
129 let plugin_log_throughput = plugin.throughput.ln();
130
131 let min_log_throughput = log_throughputs
133 .iter()
134 .copied()
135 .fold(f64::INFINITY, f64::min);
136 let max_log_throughput = log_throughputs
137 .iter()
138 .copied()
139 .fold(f64::NEG_INFINITY, f64::max);
140
141 let compression_ratios: Vec<f64> = all_plugins.iter().map(|p| p.compression_ratio).collect();
142
143 let min_ratio = compression_ratios
144 .iter()
145 .copied()
146 .fold(f64::INFINITY, f64::min);
147 let max_ratio = compression_ratios
148 .iter()
149 .copied()
150 .fold(f64::NEG_INFINITY, f64::max);
151
152 let norm_throughput = if (max_log_throughput - min_log_throughput).abs() < 1e-9 {
154 1.0 } else {
156 (plugin_log_throughput - min_log_throughput) / (max_log_throughput - min_log_throughput)
157 };
158
159 let norm_ratio = if (max_ratio - min_ratio).abs() < 1e-9 {
161 1.0 } else {
163 (max_ratio - plugin.compression_ratio) / (max_ratio - min_ratio)
164 };
165
166 weights.throughput * norm_throughput + weights.compression_ratio * norm_ratio
168}
169
170pub struct PluginSelector {
172 weights: ScoringWeights,
173}
174
175impl PluginSelector {
176 #[must_use]
178 pub fn new(weights: ScoringWeights) -> Self {
179 Self { weights }
180 }
181
182 pub fn select(&self) -> Result<PluginMetadata> {
191 let plugins = list_plugins();
192
193 if plugins.is_empty() {
194 return Err(PluginError::NotFound(
195 "No plugins available. Call init_plugins() first.".to_string(),
196 )
197 .into());
198 }
199
200 let mut scored_plugins: Vec<(f64, &PluginMetadata)> = plugins
202 .iter()
203 .map(|plugin| {
204 let score = calculate_plugin_score(plugin, &plugins, &self.weights);
205 (score, plugin)
206 })
207 .collect();
208
209 scored_plugins.sort_by(|a, b| {
211 b.0.partial_cmp(&a.0)
212 .unwrap_or(std::cmp::Ordering::Equal)
213 .then_with(|| a.1.name.cmp(b.1.name))
214 });
215
216 Ok(*scored_plugins[0].1)
218 }
219
220 pub fn select_by_name(&self, name: &str) -> Result<PluginMetadata> {
226 let plugins = list_plugins();
227
228 plugins.into_iter().find(|p| p.name == name).ok_or_else(|| {
229 PluginError::NotFound(format!(
230 "Plugin '{}' not found. Available plugins: {}",
231 name,
232 list_plugins()
233 .iter()
234 .map(|p| p.name)
235 .collect::<Vec<_>>()
236 .join(", ")
237 ))
238 .into()
239 })
240 }
241}
242
243impl Default for PluginSelector {
244 fn default() -> Self {
245 Self::new(ScoringWeights::default())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_scoring_weights_validation() {
255 assert!(ScoringWeights::new(0.7, 0.3).is_ok());
257 assert!(ScoringWeights::new(0.5, 0.5).is_ok());
258 assert!(ScoringWeights::new(1.0, 0.0).is_ok());
259
260 assert!(ScoringWeights::new(0.6, 0.6).is_err());
262 assert!(ScoringWeights::new(0.3, 0.3).is_err());
263
264 assert!(ScoringWeights::new(-0.1, 1.1).is_err());
266 }
267
268 #[test]
269 fn test_default_weights() {
270 let weights = ScoringWeights::default();
271 assert!((weights.throughput - 0.7).abs() < 1e-6);
272 assert!((weights.compression_ratio - 0.3).abs() < 1e-6);
273 }
274
275 #[test]
276 fn test_calculate_score_single_plugin() {
277 let plugin = PluginMetadata {
278 name: "test",
279 version: "1.0.0",
280 magic_number: [0x43, 0x52, 0x01, 0x00],
281 throughput: 500.0,
282 compression_ratio: 0.35,
283 description: "Test",
284 };
285
286 let plugins = vec![plugin];
287 let weights = ScoringWeights::default();
288 let score = calculate_plugin_score(&plugin, &plugins, &weights);
289
290 assert!((score - 1.0).abs() < 1e-6);
292 }
293
294 #[test]
295 #[allow(clippy::unwrap_used)]
296 fn test_calculate_score_multiple_plugins() {
297 let fast = PluginMetadata {
298 name: "fast",
299 version: "1.0.0",
300 magic_number: [0x43, 0x52, 0x01, 0x10],
301 throughput: 1000.0,
302 compression_ratio: 0.8,
303 description: "Fast but poor compression",
304 };
305
306 let slow = PluginMetadata {
307 name: "slow",
308 version: "1.0.0",
309 magic_number: [0x43, 0x52, 0x01, 0x11],
310 throughput: 100.0,
311 compression_ratio: 0.3,
312 description: "Slow but good compression",
313 };
314
315 let plugins = vec![fast, slow];
316 let weights = ScoringWeights::default();
317
318 let fast_score = calculate_plugin_score(&fast, &plugins, &weights);
319 let slow_score = calculate_plugin_score(&slow, &plugins, &weights);
320
321 assert!(fast_score > slow_score);
323
324 let balanced = ScoringWeights::new(0.5, 0.5).unwrap();
326 let fast_balanced = calculate_plugin_score(&fast, &plugins, &balanced);
327 let slow_balanced = calculate_plugin_score(&slow, &plugins, &balanced);
328
329 assert!((fast_balanced - fast_score).abs() > 1e-6);
331 assert!((slow_balanced - slow_score).abs() > 1e-6);
332 }
333}