Skip to main content

crush_core/plugin/
selector.rs

1//! Plugin selection and scoring logic
2//!
3//! Implements intelligent plugin selection based on performance metadata.
4//! Uses configurable scoring weights (default 70% throughput, 30% compression ratio)
5//! with logarithmic throughput scaling and min-max normalization.
6
7use crate::error::{PluginError, Result, ValidationError};
8use crate::plugin::{list_plugins, PluginMetadata};
9
10/// Scoring weights for plugin selection
11///
12/// Weights determine the relative importance of throughput vs compression ratio
13/// when selecting a plugin. Must sum to 1.0.
14#[derive(Debug, Clone, Copy)]
15pub struct ScoringWeights {
16    /// Weight for throughput (MB/s) - default 0.7 (70%)
17    pub throughput: f64,
18
19    /// Weight for compression ratio - default 0.3 (30%)
20    pub compression_ratio: f64,
21}
22
23impl ScoringWeights {
24    /// Create new scoring weights
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if:
29    /// - Either weight is negative
30    /// - Weights don't sum to 1.0 (within epsilon)
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// use crush_core::ScoringWeights;
36    ///
37    /// // Default 70/30 weighting
38    /// let weights = ScoringWeights::new(0.7, 0.3).expect("Valid weights");
39    ///
40    /// // Balanced weighting
41    /// let balanced = ScoringWeights::new(0.5, 0.5).expect("Valid weights");
42    /// ```
43    pub fn new(throughput: f64, compression_ratio: f64) -> Result<Self> {
44        if throughput < 0.0 || compression_ratio < 0.0 {
45            return Err(
46                ValidationError::InvalidWeights("Weights cannot be negative".to_string()).into(),
47            );
48        }
49
50        let sum = throughput + compression_ratio;
51        if (sum - 1.0).abs() > 1e-6 {
52            return Err(ValidationError::InvalidWeights(format!(
53                "Weights must sum to 1.0, got {sum}"
54            ))
55            .into());
56        }
57
58        Ok(Self {
59            throughput,
60            compression_ratio,
61        })
62    }
63}
64
65impl Default for ScoringWeights {
66    /// Default weights: 70% throughput, 30% compression ratio
67    fn default() -> Self {
68        Self {
69            throughput: 0.7,
70            compression_ratio: 0.3,
71        }
72    }
73}
74
75/// Calculate plugin score using weighted metadata
76///
77/// Implements the scoring algorithm from research:
78/// 1. Logarithmic throughput scaling (prevents throughput dominance)
79/// 2. Min-max normalization (scales values to `[0,1]`)
80/// 3. Weighted sum based on user preferences
81///
82/// # Arguments
83///
84/// * `plugin` - The plugin to score
85/// * `all_plugins` - All available plugins (for min-max normalization)
86/// * `weights` - Scoring weights (throughput vs compression ratio)
87///
88/// # Returns
89///
90/// Score in range [0.0, 1.0] where higher is better
91///
92/// # Examples
93///
94/// ```
95/// use crush_core::{calculate_plugin_score, PluginMetadata, ScoringWeights};
96///
97/// let plugin = PluginMetadata {
98///     name: "test",
99///     version: "1.0.0",
100///     magic_number: [0x43, 0x52, 0x01, 0x00],
101///     throughput: 500.0,
102///     compression_ratio: 0.35,
103///     description: "Test plugin",
104/// };
105///
106/// let plugins = vec![plugin];
107/// let weights = ScoringWeights::default();
108/// let score = calculate_plugin_score(&plugin, &plugins, &weights);
109///
110/// assert!(score >= 0.0 && score <= 1.0);
111/// ```
112pub fn calculate_plugin_score(
113    plugin: &PluginMetadata,
114    all_plugins: &[PluginMetadata],
115    weights: &ScoringWeights,
116) -> f64 {
117    if all_plugins.is_empty() {
118        return 0.0;
119    }
120
121    // Special case: single plugin always scores 1.0
122    if all_plugins.len() == 1 {
123        return 1.0;
124    }
125
126    // Apply logarithmic scaling to throughput
127    let log_throughputs: Vec<f64> = all_plugins.iter().map(|p| p.throughput.ln()).collect();
128
129    let plugin_log_throughput = plugin.throughput.ln();
130
131    // Find min/max for normalization
132    let min_log_throughput = log_throughputs
133        .iter()
134        .copied()
135        .fold(f64::INFINITY, f64::min);
136    let max_log_throughput = log_throughputs
137        .iter()
138        .copied()
139        .fold(f64::NEG_INFINITY, f64::max);
140
141    let compression_ratios: Vec<f64> = all_plugins.iter().map(|p| p.compression_ratio).collect();
142
143    let min_ratio = compression_ratios
144        .iter()
145        .copied()
146        .fold(f64::INFINITY, f64::min);
147    let max_ratio = compression_ratios
148        .iter()
149        .copied()
150        .fold(f64::NEG_INFINITY, f64::max);
151
152    // Normalize throughput (higher is better)
153    let norm_throughput = if (max_log_throughput - min_log_throughput).abs() < 1e-9 {
154        1.0 // All same throughput
155    } else {
156        (plugin_log_throughput - min_log_throughput) / (max_log_throughput - min_log_throughput)
157    };
158
159    // Normalize compression ratio (lower is better, so invert)
160    let norm_ratio = if (max_ratio - min_ratio).abs() < 1e-9 {
161        1.0 // All same ratio
162    } else {
163        (max_ratio - plugin.compression_ratio) / (max_ratio - min_ratio)
164    };
165
166    // Weighted score
167    weights.throughput * norm_throughput + weights.compression_ratio * norm_ratio
168}
169
170/// Plugin selector with scoring logic
171pub struct PluginSelector {
172    weights: ScoringWeights,
173}
174
175impl PluginSelector {
176    /// Create a new plugin selector with custom weights
177    #[must_use]
178    pub fn new(weights: ScoringWeights) -> Self {
179        Self { weights }
180    }
181
182    /// Select the best plugin based on scoring
183    ///
184    /// Returns the plugin with the highest score. In case of ties,
185    /// selects alphabetically by name.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if no plugins are available.
190    pub fn select(&self) -> Result<PluginMetadata> {
191        let plugins = list_plugins();
192
193        if plugins.is_empty() {
194            return Err(PluginError::NotFound(
195                "No plugins available. Call init_plugins() first.".to_string(),
196            )
197            .into());
198        }
199
200        // Calculate scores for all plugins
201        let mut scored_plugins: Vec<(f64, &PluginMetadata)> = plugins
202            .iter()
203            .map(|plugin| {
204                let score = calculate_plugin_score(plugin, &plugins, &self.weights);
205                (score, plugin)
206            })
207            .collect();
208
209        // Sort by score (descending), then by name (ascending) for ties
210        scored_plugins.sort_by(|a, b| {
211            b.0.partial_cmp(&a.0)
212                .unwrap_or(std::cmp::Ordering::Equal)
213                .then_with(|| a.1.name.cmp(b.1.name))
214        });
215
216        // Return the highest-scoring plugin
217        Ok(*scored_plugins[0].1)
218    }
219
220    /// Select a plugin by name (manual override)
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the specified plugin is not found.
225    pub fn select_by_name(&self, name: &str) -> Result<PluginMetadata> {
226        let plugins = list_plugins();
227
228        plugins.into_iter().find(|p| p.name == name).ok_or_else(|| {
229            PluginError::NotFound(format!(
230                "Plugin '{}' not found. Available plugins: {}",
231                name,
232                list_plugins()
233                    .iter()
234                    .map(|p| p.name)
235                    .collect::<Vec<_>>()
236                    .join(", ")
237            ))
238            .into()
239        })
240    }
241}
242
243impl Default for PluginSelector {
244    fn default() -> Self {
245        Self::new(ScoringWeights::default())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_scoring_weights_validation() {
255        // Valid weights
256        assert!(ScoringWeights::new(0.7, 0.3).is_ok());
257        assert!(ScoringWeights::new(0.5, 0.5).is_ok());
258        assert!(ScoringWeights::new(1.0, 0.0).is_ok());
259
260        // Invalid: doesn't sum to 1.0
261        assert!(ScoringWeights::new(0.6, 0.6).is_err());
262        assert!(ScoringWeights::new(0.3, 0.3).is_err());
263
264        // Invalid: negative
265        assert!(ScoringWeights::new(-0.1, 1.1).is_err());
266    }
267
268    #[test]
269    fn test_default_weights() {
270        let weights = ScoringWeights::default();
271        assert!((weights.throughput - 0.7).abs() < 1e-6);
272        assert!((weights.compression_ratio - 0.3).abs() < 1e-6);
273    }
274
275    #[test]
276    fn test_calculate_score_single_plugin() {
277        let plugin = PluginMetadata {
278            name: "test",
279            version: "1.0.0",
280            magic_number: [0x43, 0x52, 0x01, 0x00],
281            throughput: 500.0,
282            compression_ratio: 0.35,
283            description: "Test",
284        };
285
286        let plugins = vec![plugin];
287        let weights = ScoringWeights::default();
288        let score = calculate_plugin_score(&plugin, &plugins, &weights);
289
290        // Single plugin always scores 1.0
291        assert!((score - 1.0).abs() < 1e-6);
292    }
293
294    #[test]
295    #[allow(clippy::unwrap_used)]
296    fn test_calculate_score_multiple_plugins() {
297        let fast = PluginMetadata {
298            name: "fast",
299            version: "1.0.0",
300            magic_number: [0x43, 0x52, 0x01, 0x10],
301            throughput: 1000.0,
302            compression_ratio: 0.8,
303            description: "Fast but poor compression",
304        };
305
306        let slow = PluginMetadata {
307            name: "slow",
308            version: "1.0.0",
309            magic_number: [0x43, 0x52, 0x01, 0x11],
310            throughput: 100.0,
311            compression_ratio: 0.3,
312            description: "Slow but good compression",
313        };
314
315        let plugins = vec![fast, slow];
316        let weights = ScoringWeights::default();
317
318        let fast_score = calculate_plugin_score(&fast, &plugins, &weights);
319        let slow_score = calculate_plugin_score(&slow, &plugins, &weights);
320
321        // With 70% throughput weight, fast should win
322        assert!(fast_score > slow_score);
323
324        // With balanced weights, might be different
325        let balanced = ScoringWeights::new(0.5, 0.5).unwrap();
326        let fast_balanced = calculate_plugin_score(&fast, &plugins, &balanced);
327        let slow_balanced = calculate_plugin_score(&slow, &plugins, &balanced);
328
329        // Scores should change with different weights
330        assert!((fast_balanced - fast_score).abs() > 1e-6);
331        assert!((slow_balanced - slow_score).abs() > 1e-6);
332    }
333}