Skip to main content

sklears_simd/
plugin_architecture.rs

1//! Plugin architecture for custom SIMD operations
2//!
3//! This module provides a plugin system that allows users to register and use
4//! custom SIMD operations at runtime.
5
6#[cfg(not(feature = "no-std"))]
7use std::collections::HashMap;
8#[cfg(not(feature = "no-std"))]
9use std::fmt;
10#[cfg(not(feature = "no-std"))]
11use std::string::ToString;
12#[cfg(not(feature = "no-std"))]
13use std::sync::{Arc, Mutex, RwLock};
14
15#[cfg(feature = "no-std")]
16use alloc::collections::BTreeMap as HashMap;
17#[cfg(feature = "no-std")]
18use alloc::format;
19#[cfg(feature = "no-std")]
20use alloc::string::{String, ToString};
21#[cfg(feature = "no-std")]
22use alloc::sync::Arc;
23#[cfg(feature = "no-std")]
24use alloc::vec::Vec;
25#[cfg(feature = "no-std")]
26use core::fmt;
27#[cfg(feature = "no-std")]
28use spin::{Mutex, RwLock};
29
30/// Trait for custom SIMD operations
31pub trait SimdOperation: Send + Sync {
32    /// The name of the operation
33    fn name(&self) -> &str;
34
35    /// The version of the operation
36    fn version(&self) -> &str;
37
38    /// Description of what the operation does
39    fn description(&self) -> &str;
40
41    /// Execute the operation on f32 data
42    fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError>;
43
44    /// Execute the operation on f64 data
45    fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError>;
46
47    /// Get the required input size for a given output size
48    fn required_input_size(&self, output_size: usize) -> usize {
49        output_size // Default 1:1 mapping
50    }
51
52    /// Check if the operation supports in-place execution
53    fn supports_inplace(&self) -> bool {
54        false
55    }
56
57    /// Get the SIMD width requirements
58    fn simd_requirements(&self) -> SimdRequirements {
59        SimdRequirements::default()
60    }
61}
62
63/// SIMD requirements for an operation
64#[derive(Debug, Clone)]
65pub struct SimdRequirements {
66    pub min_width: usize,
67    pub preferred_width: usize,
68    pub requires_aligned_memory: bool,
69    pub requires_specific_features: Vec<String>,
70}
71
72impl Default for SimdRequirements {
73    fn default() -> Self {
74        Self {
75            min_width: 1,
76            preferred_width: 4,
77            requires_aligned_memory: false,
78            requires_specific_features: Vec::new(),
79        }
80    }
81}
82
83/// Plugin error types
84#[derive(Debug, Clone)]
85pub enum PluginError {
86    InvalidInput(String),
87    InvalidOutput(String),
88    IncompatibleSizes(usize, usize),
89    UnsupportedOperation(String),
90    ExecutionFailed(String),
91    RegistrationFailed(String),
92    NotFound(String),
93}
94
95impl fmt::Display for PluginError {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            PluginError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
99            PluginError::InvalidOutput(msg) => write!(f, "Invalid output: {}", msg),
100            PluginError::IncompatibleSizes(input, output) => {
101                write!(
102                    f,
103                    "Incompatible sizes: input {} vs output {}",
104                    input, output
105                )
106            }
107            PluginError::UnsupportedOperation(op) => {
108                write!(f, "Unsupported operation: {}", op)
109            }
110            PluginError::ExecutionFailed(msg) => write!(f, "Execution failed: {}", msg),
111            PluginError::RegistrationFailed(msg) => write!(f, "Registration failed: {}", msg),
112            PluginError::NotFound(name) => write!(f, "Plugin not found: {}", name),
113        }
114    }
115}
116
117#[cfg(not(feature = "no-std"))]
118impl std::error::Error for PluginError {}
119
120#[cfg(feature = "no-std")]
121impl core::error::Error for PluginError {}
122
123/// Plugin metadata
124#[derive(Debug, Clone)]
125pub struct PluginMetadata {
126    pub name: String,
127    pub version: String,
128    pub description: String,
129    pub author: String,
130    pub license: String,
131    pub dependencies: Vec<String>,
132    pub simd_requirements: SimdRequirements,
133}
134
135impl Default for PluginMetadata {
136    fn default() -> Self {
137        Self {
138            name: "Unknown".to_string(),
139            version: "0.1.0".to_string(),
140            description: "Custom SIMD operation".to_string(),
141            author: "Unknown".to_string(),
142            license: "MIT".to_string(),
143            dependencies: Vec::new(),
144            simd_requirements: SimdRequirements::default(),
145        }
146    }
147}
148
149/// Plugin wrapper that includes metadata
150pub struct Plugin {
151    pub metadata: PluginMetadata,
152    pub operation: Arc<dyn SimdOperation>,
153}
154
155impl Plugin {
156    pub fn new(operation: Arc<dyn SimdOperation>) -> Self {
157        let metadata = PluginMetadata {
158            name: operation.name().to_string(),
159            version: operation.version().to_string(),
160            description: operation.description().to_string(),
161            ..Default::default()
162        };
163
164        Self {
165            metadata,
166            operation,
167        }
168    }
169
170    pub fn with_metadata(operation: Arc<dyn SimdOperation>, metadata: PluginMetadata) -> Self {
171        Self {
172            metadata,
173            operation,
174        }
175    }
176}
177
178/// Plugin registry for managing custom SIMD operations
179pub struct PluginRegistry {
180    plugins: RwLock<HashMap<String, Arc<Plugin>>>,
181    execution_stats: Mutex<HashMap<String, ExecutionStats>>,
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct ExecutionStats {
186    pub total_calls: u64,
187    pub total_elements_processed: u64,
188    #[cfg(not(feature = "no-std"))]
189    pub total_execution_time: std::time::Duration,
190    pub last_error: Option<String>,
191}
192
193impl Default for PluginRegistry {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl PluginRegistry {
200    /// Create a new plugin registry
201    pub fn new() -> Self {
202        Self {
203            plugins: RwLock::new(HashMap::new()),
204            execution_stats: Mutex::new(HashMap::new()),
205        }
206    }
207
208    /// Helper function to handle RwLock read locking in both std and no-std environments
209    #[cfg(not(feature = "no-std"))]
210    fn read_plugins(&self) -> std::sync::RwLockReadGuard<'_, HashMap<String, Arc<Plugin>>> {
211        self.plugins.read().expect("operation should succeed")
212    }
213
214    #[cfg(feature = "no-std")]
215    fn read_plugins(&self) -> spin::RwLockReadGuard<'_, HashMap<String, Arc<Plugin>>> {
216        self.plugins.read()
217    }
218
219    /// Helper function to handle RwLock write locking in both std and no-std environments
220    #[cfg(not(feature = "no-std"))]
221    fn write_plugins(&self) -> std::sync::RwLockWriteGuard<'_, HashMap<String, Arc<Plugin>>> {
222        self.plugins.write().expect("operation should succeed")
223    }
224
225    #[cfg(feature = "no-std")]
226    fn write_plugins(&self) -> spin::RwLockWriteGuard<'_, HashMap<String, Arc<Plugin>>> {
227        self.plugins.write()
228    }
229
230    /// Helper function to handle Mutex locking in both std and no-std environments
231    #[cfg(not(feature = "no-std"))]
232    fn lock_stats(&self) -> std::sync::MutexGuard<'_, HashMap<String, ExecutionStats>> {
233        self.execution_stats
234            .lock()
235            .expect("lock should not be poisoned")
236    }
237
238    #[cfg(feature = "no-std")]
239    fn lock_stats(&self) -> spin::MutexGuard<'_, HashMap<String, ExecutionStats>, spin::Spin> {
240        self.execution_stats.lock()
241    }
242
243    /// Register a new plugin
244    pub fn register(&self, plugin: Plugin) -> Result<(), PluginError> {
245        let name = plugin.metadata.name.clone();
246
247        // Validate plugin
248        self.validate_plugin(&plugin)?;
249
250        // Register the plugin
251        let mut plugins = self.write_plugins();
252        plugins.insert(name.clone(), Arc::new(plugin));
253
254        // Initialize stats
255        let mut stats = self.lock_stats();
256        stats.insert(name, ExecutionStats::default());
257
258        Ok(())
259    }
260
261    /// Unregister a plugin
262    pub fn unregister(&self, name: &str) -> Result<(), PluginError> {
263        let mut plugins = self.write_plugins();
264        plugins.remove(name);
265
266        let mut stats = self.lock_stats();
267        stats.remove(name);
268
269        Ok(())
270    }
271
272    /// Get a plugin by name
273    pub fn get(&self, name: &str) -> Result<Arc<Plugin>, PluginError> {
274        let plugins = self.read_plugins();
275        plugins
276            .get(name)
277            .cloned()
278            .ok_or_else(|| PluginError::NotFound(name.to_string()))
279    }
280
281    /// List all registered plugins
282    pub fn list(&self) -> Vec<String> {
283        let plugins = self.read_plugins();
284        plugins.keys().cloned().collect()
285    }
286
287    /// Execute a plugin operation on f32 data
288    pub fn execute_f32(
289        &self,
290        name: &str,
291        input: &[f32],
292        output: &mut [f32],
293    ) -> Result<(), PluginError> {
294        let plugin = self.get(name)?;
295
296        #[cfg(not(feature = "no-std"))]
297        let start_time = std::time::Instant::now();
298        let result = plugin.operation.execute_f32(input, output);
299        #[cfg(not(feature = "no-std"))]
300        let execution_time = start_time.elapsed();
301
302        // Update stats
303        #[cfg(not(feature = "no-std"))]
304        self.update_stats(name, input.len(), execution_time, result.as_ref().err());
305        #[cfg(feature = "no-std")]
306        self.update_stats(name, input.len(), result.as_ref().err());
307
308        result
309    }
310
311    /// Execute a plugin operation on f64 data
312    pub fn execute_f64(
313        &self,
314        name: &str,
315        input: &[f64],
316        output: &mut [f64],
317    ) -> Result<(), PluginError> {
318        let plugin = self.get(name)?;
319
320        #[cfg(not(feature = "no-std"))]
321        let start_time = std::time::Instant::now();
322        let result = plugin.operation.execute_f64(input, output);
323        #[cfg(not(feature = "no-std"))]
324        let execution_time = start_time.elapsed();
325
326        // Update stats
327        #[cfg(not(feature = "no-std"))]
328        self.update_stats(name, input.len(), execution_time, result.as_ref().err());
329        #[cfg(feature = "no-std")]
330        self.update_stats(name, input.len(), result.as_ref().err());
331
332        result
333    }
334
335    /// Get execution statistics for a plugin
336    pub fn get_stats(&self, name: &str) -> Option<ExecutionStats> {
337        let stats = self.lock_stats();
338        stats.get(name).cloned()
339    }
340
341    /// Clear execution statistics
342    pub fn clear_stats(&self) {
343        let mut stats = self.lock_stats();
344        for stat in stats.values_mut() {
345            *stat = ExecutionStats::default();
346        }
347    }
348
349    /// Find plugins by capability
350    pub fn find_by_capability(&self, requires_inplace: bool, min_width: usize) -> Vec<String> {
351        let plugins = self.read_plugins();
352        plugins
353            .iter()
354            .filter(|(_, plugin)| {
355                let op = &plugin.operation;
356                (!requires_inplace || op.supports_inplace())
357                    && op.simd_requirements().min_width <= min_width
358            })
359            .map(|(name, _)| name.clone())
360            .collect()
361    }
362
363    fn validate_plugin(&self, plugin: &Plugin) -> Result<(), PluginError> {
364        let name = &plugin.metadata.name;
365
366        // Check if already registered
367        let plugins = self.read_plugins();
368        if plugins.contains_key(name) {
369            return Err(PluginError::RegistrationFailed(format!(
370                "Plugin '{}' is already registered",
371                name
372            )));
373        }
374
375        // Basic validation
376        if name.is_empty() {
377            return Err(PluginError::RegistrationFailed(
378                "Plugin name cannot be empty".to_string(),
379            ));
380        }
381
382        Ok(())
383    }
384
385    #[cfg(not(feature = "no-std"))]
386    fn update_stats(
387        &self,
388        name: &str,
389        elements: usize,
390        time: std::time::Duration,
391        error: Option<&PluginError>,
392    ) {
393        let mut stats = self.lock_stats();
394        if let Some(stat) = stats.get_mut(name) {
395            stat.total_calls += 1;
396            stat.total_elements_processed += elements as u64;
397            stat.total_execution_time += time;
398            if let Some(err) = error {
399                stat.last_error = Some(err.to_string());
400            }
401        }
402    }
403
404    #[cfg(feature = "no-std")]
405    fn update_stats(&self, name: &str, elements: usize, error: Option<&PluginError>) {
406        let mut stats = self.lock_stats();
407        if let Some(stat) = stats.get_mut(name) {
408            stat.total_calls += 1;
409            stat.total_elements_processed += elements as u64;
410            if let Some(err) = error {
411                stat.last_error = Some(err.to_string());
412            }
413        }
414    }
415}
416
417/// Global plugin registry instance
418pub static GLOBAL_REGISTRY: once_cell::sync::Lazy<PluginRegistry> =
419    once_cell::sync::Lazy::new(PluginRegistry::new);
420
421/// Convenience functions for global registry
422pub mod global {
423    use super::*;
424
425    /// Register a plugin globally
426    pub fn register(plugin: Plugin) -> Result<(), PluginError> {
427        GLOBAL_REGISTRY.register(plugin)
428    }
429
430    /// Execute a plugin operation globally (f32)
431    pub fn execute_f32(name: &str, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
432        GLOBAL_REGISTRY.execute_f32(name, input, output)
433    }
434
435    /// Execute a plugin operation globally (f64)
436    pub fn execute_f64(name: &str, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
437        GLOBAL_REGISTRY.execute_f64(name, input, output)
438    }
439
440    /// List all globally registered plugins
441    pub fn list() -> Vec<String> {
442        GLOBAL_REGISTRY.list()
443    }
444
445    /// Get stats for a globally registered plugin
446    pub fn get_stats(name: &str) -> Option<ExecutionStats> {
447        GLOBAL_REGISTRY.get_stats(name)
448    }
449}
450
451/// Example plugin implementations
452pub mod examples {
453    use super::*;
454
455    /// Example: Custom square operation
456    pub struct SquareOperation;
457
458    impl SimdOperation for SquareOperation {
459        fn name(&self) -> &str {
460            "square"
461        }
462        fn version(&self) -> &str {
463            "1.0.0"
464        }
465        fn description(&self) -> &str {
466            "Square each element"
467        }
468
469        fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
470            if input.len() != output.len() {
471                return Err(PluginError::IncompatibleSizes(input.len(), output.len()));
472            }
473
474            for (i, &val) in input.iter().enumerate() {
475                output[i] = val * val;
476            }
477            Ok(())
478        }
479
480        fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
481            if input.len() != output.len() {
482                return Err(PluginError::IncompatibleSizes(input.len(), output.len()));
483            }
484
485            for (i, &val) in input.iter().enumerate() {
486                output[i] = val * val;
487            }
488            Ok(())
489        }
490
491        fn supports_inplace(&self) -> bool {
492            true
493        }
494    }
495
496    /// Example: Custom moving average operation
497    pub struct MovingAverageOperation {
498        window_size: usize,
499    }
500
501    impl MovingAverageOperation {
502        pub fn new(window_size: usize) -> Self {
503            Self { window_size }
504        }
505    }
506
507    impl SimdOperation for MovingAverageOperation {
508        fn name(&self) -> &str {
509            "moving_average"
510        }
511        fn version(&self) -> &str {
512            "1.0.0"
513        }
514        fn description(&self) -> &str {
515            "Compute moving average with configurable window"
516        }
517
518        fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
519            if input.len() < self.window_size {
520                return Err(PluginError::InvalidInput(
521                    "Input too small for window size".to_string(),
522                ));
523            }
524
525            let expected_output_size = input.len() - self.window_size + 1;
526            if output.len() != expected_output_size {
527                return Err(PluginError::IncompatibleSizes(
528                    expected_output_size,
529                    output.len(),
530                ));
531            }
532
533            for i in 0..output.len() {
534                let sum: f32 = input[i..i + self.window_size].iter().sum();
535                output[i] = sum / self.window_size as f32;
536            }
537            Ok(())
538        }
539
540        fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
541            if input.len() < self.window_size {
542                return Err(PluginError::InvalidInput(
543                    "Input too small for window size".to_string(),
544                ));
545            }
546
547            let expected_output_size = input.len() - self.window_size + 1;
548            if output.len() != expected_output_size {
549                return Err(PluginError::IncompatibleSizes(
550                    expected_output_size,
551                    output.len(),
552                ));
553            }
554
555            for i in 0..output.len() {
556                let sum: f64 = input[i..i + self.window_size].iter().sum();
557                output[i] = sum / self.window_size as f64;
558            }
559            Ok(())
560        }
561
562        fn required_input_size(&self, output_size: usize) -> usize {
563            output_size + self.window_size - 1
564        }
565    }
566}
567
568#[allow(non_snake_case)]
569#[cfg(all(test, not(feature = "no-std")))]
570mod tests {
571    use super::examples::*;
572    use super::*;
573
574    #[cfg(feature = "no-std")]
575    use alloc::vec;
576
577    #[test]
578    fn test_plugin_registration() {
579        let registry = PluginRegistry::new();
580        let operation = Arc::new(SquareOperation);
581        let plugin = Plugin::new(operation);
582
583        assert!(registry.register(plugin).is_ok());
584        assert!(registry.list().contains(&"square".to_string()));
585    }
586
587    #[test]
588    fn test_plugin_execution() {
589        let registry = PluginRegistry::new();
590        let operation = Arc::new(SquareOperation);
591        let plugin = Plugin::new(operation);
592
593        registry.register(plugin).expect("operation should succeed");
594
595        let input = vec![1.0, 2.0, 3.0, 4.0];
596        let mut output = vec![0.0; 4];
597
598        registry
599            .execute_f32("square", &input, &mut output)
600            .expect("operation should succeed");
601        assert_eq!(output, vec![1.0, 4.0, 9.0, 16.0]);
602    }
603
604    #[test]
605    fn test_moving_average_plugin() {
606        let registry = PluginRegistry::new();
607        let operation = Arc::new(MovingAverageOperation::new(3));
608        let plugin = Plugin::new(operation);
609
610        registry.register(plugin).expect("operation should succeed");
611
612        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
613        let mut output = vec![0.0; 3]; // 5 - 3 + 1 = 3
614
615        registry
616            .execute_f32("moving_average", &input, &mut output)
617            .expect("operation should succeed");
618
619        // Expected: [(1+2+3)/3, (2+3+4)/3, (3+4+5)/3] = [2.0, 3.0, 4.0]
620        assert_eq!(output, vec![2.0, 3.0, 4.0]);
621    }
622
623    #[test]
624    fn test_plugin_stats() {
625        let registry = PluginRegistry::new();
626        let operation = Arc::new(SquareOperation);
627        let plugin = Plugin::new(operation);
628
629        registry.register(plugin).expect("operation should succeed");
630
631        let input = vec![1.0, 2.0];
632        let mut output = vec![0.0; 2];
633
634        registry
635            .execute_f32("square", &input, &mut output)
636            .expect("operation should succeed");
637
638        let stats = registry
639            .get_stats("square")
640            .expect("operation should succeed");
641        assert_eq!(stats.total_calls, 1);
642        assert_eq!(stats.total_elements_processed, 2);
643    }
644
645    #[test]
646    fn test_global_registry() {
647        let operation = Arc::new(SquareOperation);
648        let plugin = Plugin::new(operation);
649
650        global::register(plugin).expect("operation should succeed");
651
652        let input = vec![2.0, 3.0];
653        let mut output = vec![0.0; 2];
654
655        global::execute_f32("square", &input, &mut output).expect("operation should succeed");
656        assert_eq!(output, vec![4.0, 9.0]);
657
658        let plugins = global::list();
659        assert!(plugins.contains(&"square".to_string()));
660    }
661
662    #[test]
663    fn test_error_handling() {
664        let registry = PluginRegistry::new();
665
666        // Test plugin not found
667        let input = vec![1.0];
668        let mut output = vec![0.0];
669        let result = registry.execute_f32("nonexistent", &input, &mut output);
670        assert!(matches!(result, Err(PluginError::NotFound(_))));
671
672        // Test incompatible sizes
673        let operation = Arc::new(SquareOperation);
674        let plugin = Plugin::new(operation);
675        registry.register(plugin).expect("operation should succeed");
676
677        let input = vec![1.0, 2.0];
678        let mut output = vec![0.0]; // Wrong size
679        let result = registry.execute_f32("square", &input, &mut output);
680        assert!(matches!(result, Err(PluginError::IncompatibleSizes(_, _))));
681    }
682}