sklears_compose/
middleware.rs

1//! Pipeline Execution Middleware Framework
2//!
3//! This module provides a comprehensive middleware system for pipeline execution,
4//! allowing for flexible interception, modification, and extension of pipeline
5//! behavior including authentication, validation, transformation, caching,
6//! monitoring, and custom processing logic.
7
8use scirs2_core::ndarray::Array2;
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::Estimator,
12    types::Float,
13};
14use std::collections::HashMap;
15use std::hash::Hash;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant, SystemTime};
18
19/// Middleware execution context containing request/response data and metadata
20#[derive(Debug)]
21pub struct MiddlewareContext {
22    /// Unique request identifier
23    pub request_id: String,
24    /// Request timestamp
25    pub timestamp: SystemTime,
26    /// Request metadata
27    pub metadata: HashMap<String, String>,
28    /// User/session information
29    pub user_info: Option<UserInfo>,
30    /// Processing state
31    pub state: ContextState,
32    /// Execution metrics
33    pub metrics: ExecutionMetrics,
34    /// Custom data storage
35    pub custom_data: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
36}
37
38/// User information for authentication and authorization
39#[derive(Debug, Clone)]
40pub struct UserInfo {
41    /// User identifier
42    pub user_id: String,
43    /// User roles
44    pub roles: Vec<String>,
45    /// User permissions
46    pub permissions: Vec<String>,
47    /// Session token
48    pub session_token: Option<String>,
49    /// Authentication method
50    pub auth_method: AuthenticationMethod,
51}
52
53/// Authentication methods
54#[derive(Debug, Clone)]
55pub enum AuthenticationMethod {
56    None,
57    /// ApiKey
58    ApiKey {
59        key: String,
60    },
61    /// BearerToken
62    BearerToken {
63        token: String,
64    },
65    /// BasicAuth
66    BasicAuth {
67        username: String,
68        password: String,
69    },
70    /// OAuth
71    OAuth {
72        provider: String,
73        token: String,
74    },
75    /// Certificate
76    Certificate {
77        cert_fingerprint: String,
78    },
79    /// Custom
80    Custom {
81        method: String,
82    },
83}
84
85/// Context processing state
86#[derive(Debug, Clone)]
87pub enum ContextState {
88    /// Initializing
89    Initializing,
90    /// Processing
91    Processing,
92    /// Completed
93    Completed,
94    /// Error
95    Error { message: String },
96    /// Cancelled
97    Cancelled,
98}
99
100/// Execution metrics for monitoring and profiling
101#[derive(Debug, Clone)]
102pub struct ExecutionMetrics {
103    /// Start time
104    pub start_time: Instant,
105    /// End time
106    pub end_time: Option<Instant>,
107    /// Processing duration
108    pub duration: Option<Duration>,
109    /// Memory usage (bytes)
110    pub memory_usage: u64,
111    /// CPU usage percentage
112    pub cpu_usage: f64,
113    /// Throughput (operations/second)
114    pub throughput: f64,
115    /// Error count
116    pub error_count: usize,
117    /// Custom metrics
118    pub custom_metrics: HashMap<String, f64>,
119}
120
121/// Pipeline middleware trait
122pub trait PipelineMiddleware: Send + Sync {
123    /// Middleware name
124    fn name(&self) -> &str;
125
126    /// Execute before pipeline processing
127    fn before_process(
128        &self,
129        context: &mut MiddlewareContext,
130        input: &Array2<Float>,
131    ) -> SklResult<()>;
132
133    /// Execute after pipeline processing
134    fn after_process(
135        &self,
136        context: &mut MiddlewareContext,
137        output: &Array2<Float>,
138    ) -> SklResult<()>;
139
140    /// Handle errors during pipeline execution
141    fn on_error(
142        &self,
143        context: &mut MiddlewareContext,
144        error: &SklearsError,
145    ) -> SklResult<ErrorAction>;
146
147    /// Middleware priority (lower numbers execute first)
148    fn priority(&self) -> i32 {
149        100
150    }
151
152    /// Whether middleware should be executed
153    fn should_execute(&self, context: &MiddlewareContext) -> bool {
154        true
155    }
156}
157
158/// Action to take when an error occurs
159#[derive(Debug, Clone)]
160pub enum ErrorAction {
161    /// Continue processing
162    Continue,
163    /// Retry processing
164    Retry {
165        max_attempts: usize,
166        delay: Duration,
167    },
168    /// Abort processing
169    Abort,
170    /// Fallback to alternative processing
171    Fallback { fallback_data: Array2<Float> },
172}
173
174/// Middleware chain for executing multiple middleware components
175pub struct MiddlewareChain {
176    /// Registered middleware components
177    middlewares: Vec<Box<dyn PipelineMiddleware>>,
178    /// Chain configuration
179    config: MiddlewareChainConfig,
180    /// Execution statistics
181    stats: MiddlewareStats,
182}
183
184/// Middleware chain configuration
185#[derive(Debug, Clone)]
186pub struct MiddlewareChainConfig {
187    /// Enable parallel execution where possible
188    pub parallel_execution: bool,
189    /// Maximum execution time per middleware
190    pub timeout_per_middleware: Duration,
191    /// Global timeout for entire chain
192    pub global_timeout: Duration,
193    /// Continue on middleware errors
194    pub continue_on_error: bool,
195    /// Enable detailed logging
196    pub detailed_logging: bool,
197}
198
199/// Middleware execution statistics
200#[derive(Debug, Clone)]
201pub struct MiddlewareStats {
202    /// Total requests processed
203    pub total_requests: u64,
204    /// Successful requests
205    pub successful_requests: u64,
206    /// Failed requests
207    pub failed_requests: u64,
208    /// Average execution time
209    pub average_execution_time: Duration,
210    /// Per-middleware statistics
211    pub middleware_stats: HashMap<String, MiddlewareMetrics>,
212}
213
214/// Individual middleware metrics
215#[derive(Debug, Clone)]
216pub struct MiddlewareMetrics {
217    /// Execution count
218    pub execution_count: u64,
219    /// Total execution time
220    pub total_execution_time: Duration,
221    /// Average execution time
222    pub average_execution_time: Duration,
223    /// Error count
224    pub error_count: u64,
225    /// Success rate
226    pub success_rate: f64,
227}
228
229/// Authentication middleware
230pub struct AuthenticationMiddleware {
231    /// Authentication providers
232    providers: HashMap<String, Box<dyn AuthenticationProvider>>,
233    /// Authentication configuration
234    config: AuthenticationConfig,
235}
236
237/// Authentication provider trait
238pub trait AuthenticationProvider: Send + Sync {
239    /// Provider name
240    fn name(&self) -> &str;
241
242    /// Authenticate user credentials
243    fn authenticate(&self, credentials: &AuthenticationCredentials) -> SklResult<UserInfo>;
244
245    /// Validate existing session
246    fn validate_session(&self, session_token: &str) -> SklResult<bool>;
247
248    /// Refresh authentication token
249    fn refresh_token(&self, refresh_token: &str) -> SklResult<String>;
250}
251
252/// Authentication credentials
253#[derive(Debug, Clone)]
254pub enum AuthenticationCredentials {
255    /// ApiKey
256    ApiKey { key: String },
257    /// BearerToken
258    BearerToken { token: String },
259    /// BasicAuth
260    BasicAuth { username: String, password: String },
261    /// OAuth
262    OAuth { provider: String, token: String },
263    /// Certificate
264    Certificate { certificate: Vec<u8> },
265}
266
267/// Authentication configuration
268#[derive(Debug, Clone)]
269pub struct AuthenticationConfig {
270    /// Required authentication methods
271    pub required_methods: Vec<String>,
272    /// Allow anonymous access
273    pub allow_anonymous: bool,
274    /// Session timeout
275    pub session_timeout: Duration,
276    /// Token refresh threshold
277    pub token_refresh_threshold: Duration,
278    /// Maximum failed attempts
279    pub max_failed_attempts: usize,
280    /// Lockout duration
281    pub lockout_duration: Duration,
282}
283
284/// Authorization middleware
285pub struct AuthorizationMiddleware {
286    /// Access control policies
287    policies: Vec<AccessPolicy>,
288    /// Role-based access control
289    rbac: RoleBasedAccessControl,
290    /// Authorization configuration
291    config: AuthorizationConfig,
292}
293
294/// Access control policy
295#[derive(Debug, Clone)]
296pub struct AccessPolicy {
297    /// Policy name
298    pub name: String,
299    /// Resource pattern
300    pub resource_pattern: String,
301    /// Required permissions
302    pub required_permissions: Vec<String>,
303    /// Allowed roles
304    pub allowed_roles: Vec<String>,
305    /// Conditions
306    pub conditions: Vec<AccessCondition>,
307    /// Policy effect
308    pub effect: PolicyEffect,
309}
310
311/// Access condition
312#[derive(Debug, Clone)]
313pub enum AccessCondition {
314    /// TimeWindow
315    TimeWindow { start: String, end: String },
316    /// IpRange
317    IpRange { cidr: String },
318    /// UserAttribute
319    UserAttribute { attribute: String, value: String },
320    /// ResourceAttribute
321    ResourceAttribute { attribute: String, value: String },
322    /// Custom
323    Custom { condition: String },
324}
325
326/// Policy effect
327#[derive(Debug, Clone)]
328pub enum PolicyEffect {
329    /// Allow
330    Allow,
331    /// Deny
332    Deny,
333    /// Conditional
334    Conditional,
335}
336
337/// Role-based access control
338#[derive(Debug, Clone)]
339pub struct RoleBasedAccessControl {
340    /// Role definitions
341    pub roles: HashMap<String, Role>,
342    /// Permission definitions
343    pub permissions: HashMap<String, Permission>,
344    /// Role hierarchy
345    pub role_hierarchy: HashMap<String, Vec<String>>,
346}
347
348/// Role definition
349#[derive(Debug, Clone)]
350pub struct Role {
351    /// Role name
352    pub name: String,
353    /// Role description
354    pub description: String,
355    /// Assigned permissions
356    pub permissions: Vec<String>,
357    /// Role metadata
358    pub metadata: HashMap<String, String>,
359}
360
361/// Permission definition
362#[derive(Debug, Clone)]
363pub struct Permission {
364    /// Permission name
365    pub name: String,
366    /// Permission description
367    pub description: String,
368    /// Resource type
369    pub resource_type: String,
370    /// Allowed actions
371    pub actions: Vec<String>,
372}
373
374/// Authorization configuration
375#[derive(Debug, Clone)]
376pub struct AuthorizationConfig {
377    /// Default policy effect
378    pub default_effect: PolicyEffect,
379    /// Enable role inheritance
380    pub enable_role_inheritance: bool,
381    /// Cache authorization decisions
382    pub cache_decisions: bool,
383    /// Cache TTL
384    pub cache_ttl: Duration,
385}
386
387/// Validation middleware
388pub struct ValidationMiddleware {
389    /// Input validators
390    input_validators: Vec<Box<dyn InputValidator>>,
391    /// Output validators
392    output_validators: Vec<Box<dyn OutputValidator>>,
393    /// Validation configuration
394    config: ValidationConfig,
395}
396
397/// Input validation trait
398pub trait InputValidator: Send + Sync {
399    /// Validator name
400    fn name(&self) -> &str;
401
402    /// Validate input data
403    fn validate(
404        &self,
405        input: &Array2<Float>,
406        context: &MiddlewareContext,
407    ) -> SklResult<ValidationResult>;
408
409    /// Validation severity
410    fn severity(&self) -> ValidationSeverity;
411}
412
413/// Output validation trait
414pub trait OutputValidator: Send + Sync {
415    /// Validator name
416    fn name(&self) -> &str;
417
418    /// Validate output data
419    fn validate(
420        &self,
421        output: &Array2<Float>,
422        context: &MiddlewareContext,
423    ) -> SklResult<ValidationResult>;
424
425    /// Validation severity
426    fn severity(&self) -> ValidationSeverity;
427}
428
429/// Validation result
430#[derive(Debug, Clone)]
431pub struct ValidationResult {
432    /// Validation passed
433    pub valid: bool,
434    /// Validation messages
435    pub messages: Vec<ValidationMessage>,
436    /// Suggested corrections
437    pub corrections: Vec<ValidationCorrection>,
438}
439
440/// Validation message
441#[derive(Debug, Clone)]
442pub struct ValidationMessage {
443    /// Message text
444    pub message: String,
445    /// Severity level
446    pub severity: ValidationSeverity,
447    /// Field or location
448    pub field: Option<String>,
449    /// Error code
450    pub code: Option<String>,
451}
452
453/// Validation severity levels
454#[derive(Debug, Clone)]
455pub enum ValidationSeverity {
456    /// Info
457    Info,
458    /// Warning
459    Warning,
460    /// Error
461    Error,
462    /// Critical
463    Critical,
464}
465
466/// Validation correction suggestion
467#[derive(Debug, Clone)]
468pub struct ValidationCorrection {
469    /// Description of the correction
470    pub description: String,
471    /// Corrected value
472    pub corrected_value: Option<Array2<Float>>,
473    /// Confidence in correction
474    pub confidence: f64,
475}
476
477/// Validation configuration
478#[derive(Debug, Clone)]
479pub struct ValidationConfig {
480    /// Fail on validation errors
481    pub fail_on_error: bool,
482    /// Apply corrections automatically
483    pub auto_correct: bool,
484    /// Validation timeout
485    pub timeout: Duration,
486    /// Maximum corrections per request
487    pub max_corrections: usize,
488}
489
490/// Transformation middleware
491pub struct TransformationMiddleware {
492    /// Pre-processing transformations
493    pre_transformations: Vec<Box<dyn DataTransformer>>,
494    /// Post-processing transformations
495    post_transformations: Vec<Box<dyn DataTransformer>>,
496    /// Transformation configuration
497    config: TransformationConfig,
498}
499
500/// Data transformation trait
501pub trait DataTransformer: Send + Sync {
502    /// Transformer name
503    fn name(&self) -> &str;
504
505    /// Transform data
506    fn transform(
507        &self,
508        data: &Array2<Float>,
509        context: &MiddlewareContext,
510    ) -> SklResult<Array2<Float>>;
511
512    /// Check if transformation should be applied
513    fn should_transform(&self, data: &Array2<Float>, context: &MiddlewareContext) -> bool;
514
515    /// Get transformation metadata
516    fn get_metadata(&self) -> TransformationMetadata;
517}
518
519/// Transformation metadata
520#[derive(Debug, Clone)]
521pub struct TransformationMetadata {
522    /// Transformation type
523    pub transformation_type: String,
524    /// Input requirements
525    pub input_requirements: Vec<String>,
526    /// Output characteristics
527    pub output_characteristics: Vec<String>,
528    /// Performance impact
529    pub performance_impact: PerformanceImpact,
530}
531
532/// Performance impact assessment
533#[derive(Debug, Clone)]
534pub enum PerformanceImpact {
535    /// Minimal
536    Minimal,
537    /// Low
538    Low,
539    /// Medium
540    Medium,
541    /// High
542    High,
543    /// Extreme
544    Extreme,
545}
546
547/// Transformation configuration
548#[derive(Debug, Clone)]
549pub struct TransformationConfig {
550    /// Enable parallel transformations
551    pub parallel_transformations: bool,
552    /// Transformation timeout
553    pub timeout: Duration,
554    /// Cache transformed data
555    pub cache_results: bool,
556    /// Cache TTL
557    pub cache_ttl: Duration,
558}
559
560/// Caching middleware
561pub struct CachingMiddleware {
562    /// Cache storage
563    cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
564    /// Cache configuration
565    config: CacheConfig,
566    /// Cache statistics
567    stats: CacheStats,
568}
569
570/// Cache entry
571#[derive(Debug, Clone)]
572pub struct CacheEntry {
573    /// Cached data
574    pub data: Array2<Float>,
575    /// Creation timestamp
576    pub created_at: SystemTime,
577    /// Last accessed timestamp
578    pub last_accessed: SystemTime,
579    /// Access count
580    pub access_count: u64,
581    /// Entry metadata
582    pub metadata: HashMap<String, String>,
583}
584
585/// Cache configuration
586#[derive(Debug, Clone)]
587pub struct CacheConfig {
588    /// Maximum cache size (entries)
589    pub max_size: usize,
590    /// Entry TTL
591    pub ttl: Duration,
592    /// Cache eviction policy
593    pub eviction_policy: EvictionPolicy,
594    /// Enable cache statistics
595    pub enable_stats: bool,
596    /// Cache key strategy
597    pub key_strategy: CacheKeyStrategy,
598}
599
600/// Cache eviction policies
601#[derive(Debug, Clone)]
602pub enum EvictionPolicy {
603    /// LRU
604    LRU, // Least Recently Used
605    /// LFU
606    LFU, // Least Frequently Used
607    /// FIFO
608    FIFO, // First In, First Out
609    /// TTL
610    TTL, // Time To Live
611    /// Random
612    Random,
613}
614
615/// Cache key strategy
616#[derive(Debug, Clone)]
617pub enum CacheKeyStrategy {
618    /// HashInput
619    HashInput,
620    /// HashInputAndContext
621    HashInputAndContext,
622    /// Custom
623    Custom { generator: String },
624}
625
626/// Cache statistics
627#[derive(Debug, Clone)]
628pub struct CacheStats {
629    /// Cache hits
630    pub hits: u64,
631    /// Cache misses
632    pub misses: u64,
633    /// Hit ratio
634    pub hit_ratio: f64,
635    /// Total size (bytes)
636    pub total_size: u64,
637    /// Number of entries
638    pub entry_count: usize,
639    /// Evictions
640    pub evictions: u64,
641}
642
643/// Monitoring middleware
644pub struct MonitoringMiddleware {
645    /// Metrics collectors
646    collectors: Vec<Box<dyn MetricsCollector>>,
647    /// Monitoring configuration
648    config: MonitoringConfig,
649    /// Alert manager
650    alert_manager: AlertManager,
651}
652
653/// Metrics collector trait
654pub trait MetricsCollector: Send + Sync {
655    /// Collector name
656    fn name(&self) -> &str;
657
658    /// Collect metrics
659    fn collect(&self, context: &MiddlewareContext, data: &Array2<Float>) -> SklResult<Vec<Metric>>;
660
661    /// Get supported metric types
662    fn supported_metrics(&self) -> Vec<String>;
663}
664
665/// Metric definition
666#[derive(Debug, Clone)]
667pub struct Metric {
668    /// Metric name
669    pub name: String,
670    /// Metric value
671    pub value: f64,
672    /// Metric type
673    pub metric_type: MetricType,
674    /// Timestamp
675    pub timestamp: SystemTime,
676    /// Labels
677    pub labels: HashMap<String, String>,
678}
679
680/// Metric types
681#[derive(Debug, Clone)]
682pub enum MetricType {
683    /// Counter
684    Counter,
685    /// Gauge
686    Gauge,
687    /// Histogram
688    Histogram,
689    /// Summary
690    Summary,
691    /// Timer
692    Timer,
693}
694
695/// Alert manager
696#[derive(Debug, Clone)]
697pub struct AlertManager {
698    /// Alert rules
699    pub rules: Vec<AlertRule>,
700    /// Active alerts
701    pub active_alerts: Vec<Alert>,
702    /// Alert channels
703    pub channels: Vec<AlertChannel>,
704}
705
706/// Alert rule
707#[derive(Debug, Clone)]
708pub struct AlertRule {
709    /// Rule name
710    pub name: String,
711    /// Metric to monitor
712    pub metric: String,
713    /// Threshold condition
714    pub condition: AlertCondition,
715    /// Alert severity
716    pub severity: AlertSeverity,
717    /// Evaluation interval
718    pub evaluation_interval: Duration,
719}
720
721/// Alert condition
722#[derive(Debug, Clone)]
723pub enum AlertCondition {
724    /// Threshold
725    Threshold { operator: String, value: f64 },
726    /// Range
727    Range { min: f64, max: f64 },
728    /// Rate
729    Rate {
730        change_percent: f64,
731        time_window: Duration,
732    },
733    /// Anomaly
734    Anomaly { sensitivity: f64 },
735}
736
737/// Alert severity levels
738#[derive(Debug, Clone)]
739pub enum AlertSeverity {
740    /// Info
741    Info,
742    /// Warning
743    Warning,
744    /// Critical
745    Critical,
746    /// Emergency
747    Emergency,
748}
749
750/// Active alert
751#[derive(Debug, Clone)]
752pub struct Alert {
753    /// Alert ID
754    pub id: String,
755    /// Rule that triggered the alert
756    pub rule_name: String,
757    /// Current value
758    pub current_value: f64,
759    /// Alert message
760    pub message: String,
761    /// Triggered at
762    pub triggered_at: SystemTime,
763    /// Status
764    pub status: AlertStatus,
765}
766
767/// Alert status
768#[derive(Debug, Clone)]
769pub enum AlertStatus {
770    /// Triggered
771    Triggered,
772    /// Acknowledged
773    Acknowledged,
774    /// Resolved
775    Resolved,
776    /// Suppressed
777    Suppressed,
778}
779
780/// Alert channel
781#[derive(Debug, Clone)]
782pub enum AlertChannel {
783    /// Email
784    Email { addresses: Vec<String> },
785    /// Webhook
786    Webhook { url: String },
787    /// Slack
788    Slack {
789        webhook_url: String,
790        channel: String,
791    },
792    /// Console
793    Console,
794    /// Log
795    Log { file_path: String },
796}
797
798/// Monitoring configuration
799#[derive(Debug, Clone)]
800pub struct MonitoringConfig {
801    /// Enable real-time monitoring
802    pub real_time: bool,
803    /// Metrics collection interval
804    pub collection_interval: Duration,
805    /// Metrics retention period
806    pub retention_period: Duration,
807    /// Enable alerting
808    pub enable_alerting: bool,
809    /// Alert evaluation interval
810    pub alert_evaluation_interval: Duration,
811}
812
813impl MiddlewareChain {
814    /// Create a new middleware chain
815    #[must_use]
816    pub fn new(config: MiddlewareChainConfig) -> Self {
817        Self {
818            middlewares: Vec::new(),
819            config,
820            stats: MiddlewareStats {
821                total_requests: 0,
822                successful_requests: 0,
823                failed_requests: 0,
824                average_execution_time: Duration::from_millis(0),
825                middleware_stats: HashMap::new(),
826            },
827        }
828    }
829
830    /// Add middleware to the chain
831    pub fn add_middleware(&mut self, middleware: Box<dyn PipelineMiddleware>) {
832        self.middlewares.push(middleware);
833        self.middlewares.sort_by_key(|m| m.priority());
834    }
835
836    /// Execute the middleware chain
837    pub fn execute(
838        &mut self,
839        context: &mut MiddlewareContext,
840        input: &Array2<Float>,
841        processor: &dyn Fn(&Array2<Float>) -> SklResult<Array2<Float>>,
842    ) -> SklResult<Array2<Float>> {
843        let start_time = Instant::now();
844        self.stats.total_requests += 1;
845
846        // Execute before_process hooks
847        for middleware in &self.middlewares {
848            if middleware.should_execute(context) {
849                if let Err(e) = middleware.before_process(context, input) {
850                    let action = middleware.on_error(context, &e)?;
851                    match action {
852                        ErrorAction::Continue => {}
853                        ErrorAction::Abort => {
854                            self.stats.failed_requests += 1;
855                            return Err(e);
856                        }
857                        ErrorAction::Retry {
858                            max_attempts,
859                            delay,
860                        } => {
861                            // Implement retry logic
862                            std::thread::sleep(delay);
863                            return self.execute(context, input, processor);
864                        }
865                        ErrorAction::Fallback { fallback_data } => {
866                            return Ok(fallback_data);
867                        }
868                    }
869                }
870            }
871        }
872
873        // Execute main processor
874        let result = processor(input)?;
875
876        // Execute after_process hooks
877        for middleware in &self.middlewares {
878            if middleware.should_execute(context) {
879                if let Err(e) = middleware.after_process(context, &result) {
880                    let action = middleware.on_error(context, &e)?;
881                    match action {
882                        ErrorAction::Continue => {}
883                        ErrorAction::Abort => {
884                            self.stats.failed_requests += 1;
885                            return Err(e);
886                        }
887                        _ => {}
888                    }
889                }
890            }
891        }
892
893        // Update statistics
894        let execution_time = start_time.elapsed();
895        self.stats.successful_requests += 1;
896        self.update_execution_stats(execution_time);
897
898        context.state = ContextState::Completed;
899        context.metrics.end_time = Some(Instant::now());
900        context.metrics.duration = Some(execution_time);
901
902        Ok(result)
903    }
904
905    /// Get execution statistics
906    #[must_use]
907    pub fn get_stats(&self) -> &MiddlewareStats {
908        &self.stats
909    }
910
911    /// Update execution statistics
912    fn update_execution_stats(&mut self, execution_time: Duration) {
913        let total_time = self.stats.average_execution_time.as_nanos() as f64
914            * (self.stats.total_requests - 1) as f64;
915        let new_avg_nanos =
916            (total_time + execution_time.as_nanos() as f64) / self.stats.total_requests as f64;
917        self.stats.average_execution_time = Duration::from_nanos(new_avg_nanos as u64);
918    }
919}
920
921impl AuthenticationMiddleware {
922    /// Create new authentication middleware
923    #[must_use]
924    pub fn new(config: AuthenticationConfig) -> Self {
925        Self {
926            providers: HashMap::new(),
927            config,
928        }
929    }
930
931    /// Add authentication provider
932    pub fn add_provider(&mut self, provider: Box<dyn AuthenticationProvider>) {
933        self.providers.insert(provider.name().to_string(), provider);
934    }
935
936    /// Authenticate request
937    pub fn authenticate(&self, credentials: &AuthenticationCredentials) -> SklResult<UserInfo> {
938        for provider in self.providers.values() {
939            if let Ok(user_info) = provider.authenticate(credentials) {
940                return Ok(user_info);
941            }
942        }
943        Err(SklearsError::InvalidInput(
944            "Authentication failed".to_string(),
945        ))
946    }
947}
948
949impl PipelineMiddleware for AuthenticationMiddleware {
950    fn name(&self) -> &'static str {
951        "authentication"
952    }
953
954    fn before_process(
955        &self,
956        context: &mut MiddlewareContext,
957        _input: &Array2<Float>,
958    ) -> SklResult<()> {
959        if !self.config.allow_anonymous && context.user_info.is_none() {
960            return Err(SklearsError::InvalidInput(
961                "Authentication required".to_string(),
962            ));
963        }
964        Ok(())
965    }
966
967    fn after_process(
968        &self,
969        _context: &mut MiddlewareContext,
970        _output: &Array2<Float>,
971    ) -> SklResult<()> {
972        Ok(())
973    }
974
975    fn on_error(
976        &self,
977        _context: &mut MiddlewareContext,
978        _error: &SklearsError,
979    ) -> SklResult<ErrorAction> {
980        Ok(ErrorAction::Abort)
981    }
982
983    fn priority(&self) -> i32 {
984        10 // High priority - authenticate early
985    }
986}
987
988impl AuthorizationMiddleware {
989    /// Create new authorization middleware
990    #[must_use]
991    pub fn new(config: AuthorizationConfig) -> Self {
992        Self {
993            policies: Vec::new(),
994            rbac: RoleBasedAccessControl {
995                roles: HashMap::new(),
996                permissions: HashMap::new(),
997                role_hierarchy: HashMap::new(),
998            },
999            config,
1000        }
1001    }
1002
1003    /// Add access policy
1004    pub fn add_policy(&mut self, policy: AccessPolicy) {
1005        self.policies.push(policy);
1006    }
1007
1008    /// Check authorization
1009    pub fn authorize(&self, user_info: &UserInfo, resource: &str, action: &str) -> SklResult<bool> {
1010        for policy in &self.policies {
1011            if self.policy_matches(policy, resource)
1012                && self.check_permissions(policy, user_info, action)
1013            {
1014                return Ok(policy.effect == PolicyEffect::Allow);
1015            }
1016        }
1017
1018        // Default to configured default effect
1019        Ok(matches!(self.config.default_effect, PolicyEffect::Allow))
1020    }
1021
1022    /// Check if policy matches resource
1023    fn policy_matches(&self, policy: &AccessPolicy, resource: &str) -> bool {
1024        // Simplified pattern matching
1025        policy.resource_pattern == "*" || policy.resource_pattern == resource
1026    }
1027
1028    /// Check user permissions against policy
1029    fn check_permissions(&self, policy: &AccessPolicy, user_info: &UserInfo, action: &str) -> bool {
1030        // Check role-based access
1031        for role in &user_info.roles {
1032            if policy.allowed_roles.contains(role) {
1033                return true;
1034            }
1035        }
1036
1037        // Check permission-based access
1038        for permission in &user_info.permissions {
1039            if policy.required_permissions.contains(permission) {
1040                return true;
1041            }
1042        }
1043
1044        false
1045    }
1046}
1047
1048impl PipelineMiddleware for AuthorizationMiddleware {
1049    fn name(&self) -> &'static str {
1050        "authorization"
1051    }
1052
1053    fn before_process(
1054        &self,
1055        context: &mut MiddlewareContext,
1056        _input: &Array2<Float>,
1057    ) -> SklResult<()> {
1058        if let Some(user_info) = &context.user_info {
1059            if !self.authorize(user_info, "pipeline", "execute")? {
1060                return Err(SklearsError::InvalidInput("Access denied".to_string()));
1061            }
1062        }
1063        Ok(())
1064    }
1065
1066    fn after_process(
1067        &self,
1068        _context: &mut MiddlewareContext,
1069        _output: &Array2<Float>,
1070    ) -> SklResult<()> {
1071        Ok(())
1072    }
1073
1074    fn on_error(
1075        &self,
1076        _context: &mut MiddlewareContext,
1077        _error: &SklearsError,
1078    ) -> SklResult<ErrorAction> {
1079        Ok(ErrorAction::Abort)
1080    }
1081
1082    fn priority(&self) -> i32 {
1083        20 // Execute after authentication
1084    }
1085}
1086
1087impl CachingMiddleware {
1088    /// Create new caching middleware
1089    #[must_use]
1090    pub fn new(config: CacheConfig) -> Self {
1091        Self {
1092            cache: Arc::new(Mutex::new(HashMap::new())),
1093            config,
1094            stats: CacheStats {
1095                hits: 0,
1096                misses: 0,
1097                hit_ratio: 0.0,
1098                total_size: 0,
1099                entry_count: 0,
1100                evictions: 0,
1101            },
1102        }
1103    }
1104
1105    /// Generate cache key
1106    fn generate_cache_key(&self, input: &Array2<Float>, context: &MiddlewareContext) -> String {
1107        use std::collections::hash_map::DefaultHasher;
1108        use std::hash::Hasher;
1109
1110        match &self.config.key_strategy {
1111            CacheKeyStrategy::HashInput => {
1112                let mut hasher = DefaultHasher::new();
1113                if let Some(slice) = input.as_slice() {
1114                    for &x in slice {
1115                        (x.to_bits()).hash(&mut hasher);
1116                    }
1117                }
1118                format!("{:x}", hasher.finish())
1119            }
1120            CacheKeyStrategy::HashInputAndContext => {
1121                let mut hasher = DefaultHasher::new();
1122                if let Some(slice) = input.as_slice() {
1123                    for &x in slice {
1124                        (x.to_bits()).hash(&mut hasher);
1125                    }
1126                }
1127                context.request_id.hash(&mut hasher);
1128                format!("{:x}", hasher.finish())
1129            }
1130            CacheKeyStrategy::Custom { .. } => {
1131                let mut hasher = DefaultHasher::new();
1132                if let Some(slice) = input.as_slice() {
1133                    for &x in slice {
1134                        (x.to_bits()).hash(&mut hasher);
1135                    }
1136                }
1137                format!("{:x}", hasher.finish())
1138            }
1139        }
1140    }
1141
1142    /// Get cached data
1143    pub fn get(&mut self, key: &str) -> Option<Array2<Float>> {
1144        let result = {
1145            let mut cache = self.cache.lock().unwrap();
1146
1147            if let Some(entry) = cache.get_mut(key) {
1148                // Check if entry is still valid
1149                if entry.created_at.elapsed().unwrap_or(Duration::MAX) <= self.config.ttl {
1150                    entry.last_accessed = SystemTime::now();
1151                    entry.access_count += 1;
1152                    Some((entry.data.clone(), true)) // (data, is_hit)
1153                } else {
1154                    // Entry expired, remove it
1155                    cache.remove(key);
1156                    Some((Array2::zeros((0, 0)), false)) // Dummy data, will indicate eviction
1157                }
1158            } else {
1159                None
1160            }
1161        };
1162
1163        match result {
1164            Some((data, true)) => {
1165                self.stats.hits += 1;
1166                self.update_hit_ratio();
1167                Some(data)
1168            }
1169            Some((_, false)) => {
1170                self.stats.evictions += 1;
1171                self.stats.misses += 1;
1172                self.update_hit_ratio();
1173                None
1174            }
1175            None => {
1176                self.stats.misses += 1;
1177                self.update_hit_ratio();
1178                None
1179            }
1180        }
1181    }
1182
1183    /// Put data in cache
1184    pub fn put(&mut self, key: String, data: Array2<Float>) {
1185        let max_size = self.config.max_size;
1186        let eviction_policy = self.config.eviction_policy.clone();
1187
1188        let (evicted, final_size) = {
1189            let mut cache = self.cache.lock().unwrap();
1190
1191            // Check if cache is full
1192            let mut evicted = false;
1193            if cache.len() >= max_size {
1194                // Perform eviction within the lock
1195                match eviction_policy {
1196                    EvictionPolicy::LRU => {
1197                        if let Some(lru_key) = cache
1198                            .iter()
1199                            .min_by_key(|(_, entry)| entry.last_accessed)
1200                            .map(|(key, _)| key.clone())
1201                        {
1202                            cache.remove(&lru_key);
1203                            evicted = true;
1204                        }
1205                    }
1206                    EvictionPolicy::LFU => {
1207                        if let Some(lfu_key) = cache
1208                            .iter()
1209                            .min_by_key(|(_, entry)| entry.access_count)
1210                            .map(|(key, _)| key.clone())
1211                        {
1212                            cache.remove(&lfu_key);
1213                            evicted = true;
1214                        }
1215                    }
1216                    _ => {
1217                        // Simple FIFO for other policies
1218                        if let Some(first_key) = cache.keys().next().cloned() {
1219                            cache.remove(&first_key);
1220                            evicted = true;
1221                        }
1222                    }
1223                }
1224            }
1225
1226            let entry = CacheEntry {
1227                data,
1228                created_at: SystemTime::now(),
1229                last_accessed: SystemTime::now(),
1230                access_count: 1,
1231                metadata: HashMap::new(),
1232            };
1233
1234            cache.insert(key, entry);
1235            (evicted, cache.len())
1236        };
1237
1238        if evicted {
1239            self.stats.evictions += 1;
1240        }
1241        self.stats.entry_count = final_size;
1242    }
1243
1244    /// Evict entries based on policy (internal method that doesn't borrow self)
1245    fn evict_entries_internal(&mut self, cache: &mut HashMap<String, CacheEntry>) {
1246        let eviction_policy = self.config.eviction_policy.clone();
1247        match eviction_policy {
1248            EvictionPolicy::LRU => {
1249                if let Some(lru_key) = cache
1250                    .iter()
1251                    .min_by_key(|(_, entry)| entry.last_accessed)
1252                    .map(|(key, _)| key.clone())
1253                {
1254                    cache.remove(&lru_key);
1255                    self.stats.evictions += 1;
1256                }
1257            }
1258            EvictionPolicy::LFU => {
1259                if let Some(lfu_key) = cache
1260                    .iter()
1261                    .min_by_key(|(_, entry)| entry.access_count)
1262                    .map(|(key, _)| key.clone())
1263                {
1264                    cache.remove(&lfu_key);
1265                    self.stats.evictions += 1;
1266                }
1267            }
1268            _ => {
1269                // Simple FIFO for other policies
1270                if let Some(first_key) = cache.keys().next().cloned() {
1271                    cache.remove(&first_key);
1272                    self.stats.evictions += 1;
1273                }
1274            }
1275        }
1276    }
1277
1278    /// Update hit ratio
1279    fn update_hit_ratio(&mut self) {
1280        let total = self.stats.hits + self.stats.misses;
1281        if total > 0 {
1282            self.stats.hit_ratio = self.stats.hits as f64 / total as f64;
1283        }
1284    }
1285}
1286
1287impl PipelineMiddleware for CachingMiddleware {
1288    fn name(&self) -> &'static str {
1289        "caching"
1290    }
1291
1292    fn before_process(
1293        &self,
1294        context: &mut MiddlewareContext,
1295        input: &Array2<Float>,
1296    ) -> SklResult<()> {
1297        let cache_key = self.generate_cache_key(input, context);
1298        context.metadata.insert("cache_key".to_string(), cache_key);
1299        Ok(())
1300    }
1301
1302    fn after_process(
1303        &self,
1304        context: &mut MiddlewareContext,
1305        output: &Array2<Float>,
1306    ) -> SklResult<()> {
1307        if let Some(cache_key) = context.metadata.get("cache_key") {
1308            self.cache.lock().unwrap().insert(
1309                cache_key.clone(),
1310                /// CacheEntry
1311                CacheEntry {
1312                    data: output.clone(),
1313                    created_at: SystemTime::now(),
1314                    last_accessed: SystemTime::now(),
1315                    access_count: 1,
1316                    metadata: HashMap::new(),
1317                },
1318            );
1319        }
1320        Ok(())
1321    }
1322
1323    fn on_error(
1324        &self,
1325        _context: &mut MiddlewareContext,
1326        _error: &SklearsError,
1327    ) -> SklResult<ErrorAction> {
1328        Ok(ErrorAction::Continue)
1329    }
1330
1331    fn priority(&self) -> i32 {
1332        50
1333    }
1334}
1335
1336impl Default for MiddlewareChainConfig {
1337    fn default() -> Self {
1338        Self {
1339            parallel_execution: false,
1340            timeout_per_middleware: Duration::from_secs(30),
1341            global_timeout: Duration::from_secs(300),
1342            continue_on_error: false,
1343            detailed_logging: false,
1344        }
1345    }
1346}
1347
1348impl Default for AuthenticationConfig {
1349    fn default() -> Self {
1350        Self {
1351            required_methods: Vec::new(),
1352            allow_anonymous: true,
1353            session_timeout: Duration::from_secs(3600),
1354            token_refresh_threshold: Duration::from_secs(300),
1355            max_failed_attempts: 3,
1356            lockout_duration: Duration::from_secs(300),
1357        }
1358    }
1359}
1360
1361impl Default for AuthorizationConfig {
1362    fn default() -> Self {
1363        Self {
1364            default_effect: PolicyEffect::Deny,
1365            enable_role_inheritance: true,
1366            cache_decisions: true,
1367            cache_ttl: Duration::from_secs(300),
1368        }
1369    }
1370}
1371
1372impl Default for CacheConfig {
1373    fn default() -> Self {
1374        Self {
1375            max_size: 1000,
1376            ttl: Duration::from_secs(3600),
1377            eviction_policy: EvictionPolicy::LRU,
1378            enable_stats: true,
1379            key_strategy: CacheKeyStrategy::HashInput,
1380        }
1381    }
1382}
1383
1384impl Default for ValidationConfig {
1385    fn default() -> Self {
1386        Self {
1387            fail_on_error: true,
1388            auto_correct: false,
1389            timeout: Duration::from_secs(30),
1390            max_corrections: 10,
1391        }
1392    }
1393}
1394
1395impl Default for TransformationConfig {
1396    fn default() -> Self {
1397        Self {
1398            parallel_transformations: false,
1399            timeout: Duration::from_secs(60),
1400            cache_results: false,
1401            cache_ttl: Duration::from_secs(300),
1402        }
1403    }
1404}
1405
1406impl Default for MonitoringConfig {
1407    fn default() -> Self {
1408        Self {
1409            real_time: true,
1410            collection_interval: Duration::from_secs(60),
1411            retention_period: Duration::from_secs(86400),
1412            enable_alerting: true,
1413            alert_evaluation_interval: Duration::from_secs(60),
1414        }
1415    }
1416}
1417
1418impl PartialEq for PolicyEffect {
1419    fn eq(&self, other: &Self) -> bool {
1420        matches!(
1421            (self, other),
1422            (PolicyEffect::Allow, PolicyEffect::Allow)
1423                | (PolicyEffect::Deny, PolicyEffect::Deny)
1424                | (PolicyEffect::Conditional, PolicyEffect::Conditional)
1425        )
1426    }
1427}
1428
1429#[allow(non_snake_case)]
1430#[cfg(test)]
1431mod tests {
1432    use super::*;
1433
1434    #[test]
1435    fn test_middleware_context_creation() {
1436        let context = MiddlewareContext {
1437            request_id: "test-123".to_string(),
1438            timestamp: SystemTime::now(),
1439            metadata: HashMap::new(),
1440            user_info: None,
1441            state: ContextState::Initializing,
1442            metrics: ExecutionMetrics {
1443                start_time: Instant::now(),
1444                end_time: None,
1445                duration: None,
1446                memory_usage: 0,
1447                cpu_usage: 0.0,
1448                throughput: 0.0,
1449                error_count: 0,
1450                custom_metrics: HashMap::new(),
1451            },
1452            custom_data: HashMap::new(),
1453        };
1454
1455        assert_eq!(context.request_id, "test-123");
1456        assert!(matches!(context.state, ContextState::Initializing));
1457    }
1458
1459    #[test]
1460    fn test_middleware_chain_creation() {
1461        let config = MiddlewareChainConfig::default();
1462        let chain = MiddlewareChain::new(config);
1463
1464        assert_eq!(chain.middlewares.len(), 0);
1465        assert_eq!(chain.stats.total_requests, 0);
1466    }
1467
1468    #[test]
1469    fn test_authentication_middleware() {
1470        let config = AuthenticationConfig::default();
1471        let auth_middleware = AuthenticationMiddleware::new(config);
1472
1473        assert_eq!(auth_middleware.name(), "authentication");
1474        assert_eq!(auth_middleware.priority(), 10);
1475    }
1476
1477    #[test]
1478    fn test_caching_middleware() {
1479        let config = CacheConfig::default();
1480        let cache_middleware = CachingMiddleware::new(config);
1481
1482        assert_eq!(cache_middleware.name(), "caching");
1483        assert_eq!(cache_middleware.stats.hit_ratio, 0.0);
1484    }
1485
1486    #[test]
1487    fn test_cache_key_generation() {
1488        let config = CacheConfig::default();
1489        let cache_middleware = CachingMiddleware::new(config);
1490
1491        let input = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1492        let context = MiddlewareContext {
1493            request_id: "test".to_string(),
1494            timestamp: SystemTime::now(),
1495            metadata: HashMap::new(),
1496            user_info: None,
1497            state: ContextState::Processing,
1498            metrics: ExecutionMetrics {
1499                start_time: Instant::now(),
1500                end_time: None,
1501                duration: None,
1502                memory_usage: 0,
1503                cpu_usage: 0.0,
1504                throughput: 0.0,
1505                error_count: 0,
1506                custom_metrics: HashMap::new(),
1507            },
1508            custom_data: HashMap::new(),
1509        };
1510
1511        let key = cache_middleware.generate_cache_key(&input, &context);
1512        assert!(!key.is_empty());
1513    }
1514
1515    #[test]
1516    fn test_access_policy() {
1517        let policy = AccessPolicy {
1518            name: "test_policy".to_string(),
1519            resource_pattern: "/api/*".to_string(),
1520            required_permissions: vec!["read".to_string()],
1521            allowed_roles: vec!["user".to_string()],
1522            conditions: Vec::new(),
1523            effect: PolicyEffect::Allow,
1524        };
1525
1526        assert_eq!(policy.name, "test_policy");
1527        assert_eq!(policy.effect, PolicyEffect::Allow);
1528    }
1529
1530    #[test]
1531    fn test_validation_result() {
1532        let result = ValidationResult {
1533            valid: true,
1534            messages: Vec::new(),
1535            corrections: Vec::new(),
1536        };
1537
1538        assert!(result.valid);
1539        assert_eq!(result.messages.len(), 0);
1540    }
1541
1542    #[test]
1543    fn test_cache_stats() {
1544        let mut stats = CacheStats {
1545            hits: 10,
1546            misses: 5,
1547            hit_ratio: 0.0,
1548            total_size: 1024,
1549            entry_count: 15,
1550            evictions: 2,
1551        };
1552
1553        // Calculate hit ratio
1554        let total = stats.hits + stats.misses;
1555        stats.hit_ratio = stats.hits as f64 / total as f64;
1556
1557        assert_eq!(stats.hit_ratio, 10.0 / 15.0);
1558    }
1559
1560    #[test]
1561    fn test_user_info() {
1562        let user_info = UserInfo {
1563            user_id: "user123".to_string(),
1564            roles: vec!["admin".to_string(), "user".to_string()],
1565            permissions: vec!["read".to_string(), "write".to_string()],
1566            session_token: Some("token123".to_string()),
1567            auth_method: AuthenticationMethod::ApiKey {
1568                key: "api_key_123".to_string(),
1569            },
1570        };
1571
1572        assert_eq!(user_info.user_id, "user123");
1573        assert_eq!(user_info.roles.len(), 2);
1574        assert_eq!(user_info.permissions.len(), 2);
1575    }
1576
1577    #[test]
1578    fn test_metric_creation() {
1579        let metric = Metric {
1580            name: "response_time".to_string(),
1581            value: 150.5,
1582            metric_type: MetricType::Timer,
1583            timestamp: SystemTime::now(),
1584            labels: HashMap::from([
1585                ("service".to_string(), "api".to_string()),
1586                ("version".to_string(), "1.0".to_string()),
1587            ]),
1588        };
1589
1590        assert_eq!(metric.name, "response_time");
1591        assert_eq!(metric.value, 150.5);
1592        assert!(matches!(metric.metric_type, MetricType::Timer));
1593    }
1594}