Skip to main content

sklears_compose/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(unused_imports)]
6#![allow(unused_variables)]
7#![allow(unused_mut)]
8#![allow(unused_assignments)]
9#![allow(unused_doc_comments)]
10#![allow(unused_parens)]
11#![allow(unused_comparisons)]
12#![allow(clippy::all)]
13#![allow(clippy::pedantic)]
14#![allow(clippy::nursery)]
15//! Composite estimators and transformers
16//!
17//! This module provides meta-estimators for composing other estimators.
18//! It includes tools for applying different transformers to different
19//! subsets of features and for transforming target variables.
20//!
21//! ## Core Modules
22//!
23//! The following modules form the stable core of `sklears-compose`:
24//!
25//! - [`pipeline`] - Linear pipelines for chaining estimators and transformers
26//! - [`column_transformer`] - Apply different transformers to different feature subsets
27//! - [`ensemble`] - Voting classifiers/regressors, model fusion, and dynamic selection
28//! - [`boosting`] - AdaBoost and Gradient Boosting meta-estimators
29//! - [`dag_pipeline`] - Directed acyclic graph pipelines with branching and merging
30//! - [`advanced_pipeline`] - Conditional and branching pipeline variants
31//! - [`streaming`] - Streaming/online pipelines for incremental learning
32//! - [`feature_engineering`] - Automatic feature interaction detection
33//! - [`validation`] - Comprehensive pipeline validation (data, structure, performance)
34//! - [`optimization`] - Pipeline hyperparameter optimization and robust execution
35//!
36//! ## Experimental Subsystems
37//!
38//! The following modules are experimental in v0.1.0. Their APIs may change
39//! significantly in future releases. They are included to provide early access
40//! to advanced capabilities but should not be relied upon for production use.
41//!
42//! ### Infrastructure and Resilience
43//!
44//! - [`circuit_breaker`] - Circuit breaker pattern for fault-tolerant pipelines.
45//!   Experimental in v0.1.0. API may change.
46//! - [`fault_core`] - Core fault tolerance primitives.
47//!   Experimental in v0.1.0. API may change.
48//! - [`middleware`] - Authentication, caching, and monitoring middleware chains.
49//!   Experimental in v0.1.0. API may change.
50//! - [`resource_management`] - Resource monitoring and optimization.
51//!   Experimental in v0.1.0. API may change.
52//! - [`scheduling`] - Task scheduling and workflow management.
53//!   Experimental in v0.1.0. API may change.
54//! - [`state_management`] - Pipeline state checkpointing and version control.
55//!   Experimental in v0.1.0. API may change.
56//! - [`external_integration`] - REST API and database integration adapters.
57//!   Experimental in v0.1.0. API may change.
58//!
59//! ### Distributed and Parallel Computing
60//!
61//! - [`distributed`] - Distributed map-reduce pipelines and cluster management.
62//!   Experimental in v0.1.0. API may change.
63//! - [`distributed_tracing`] - Distributed tracing and span-based diagnostics.
64//!   Experimental in v0.1.0. API may change.
65//! - [`parallel_execution`] - Parallel pipeline execution with load balancing.
66//!   Experimental in v0.1.0. API may change.
67//!
68//! ### Advanced ML Paradigms
69//!
70//! - [`automl`] - AutoML optimization and neural architecture search.
71//!   Experimental in v0.1.0. API may change.
72//! - [`continual_learning`] - Continual/lifelong learning with memory buffers.
73//!   Experimental in v0.1.0. API may change.
74//! - [`differentiable`] - Differentiable pipelines with automatic differentiation.
75//!   Experimental in v0.1.0. API may change.
76//! - [`few_shot`] - Few-shot learning (MAML, Prototypical Networks).
77//!   Experimental in v0.1.0. API may change.
78//! - [`meta_learning`] - Meta-learning pipelines with experience replay.
79//!   Experimental in v0.1.0. API may change.
80//! - [`transfer_learning`] - Transfer learning and domain adaptation pipelines.
81//!   Experimental in v0.1.0. API may change.
82//! - [`quantum`] - Quantum computing pipeline primitives.
83//!   Experimental in v0.1.0. API may change.
84//!
85//! ### Domain-Specific Pipelines
86//!
87//! - [`nlp_pipelines`] - NLP pipelines (tokenization, sentiment, NER, summarization).
88//!   Experimental in v0.1.0. API may change.
89//! - [`cv_pipelines`] - Computer vision pipelines (detection, feature extraction).
90//!   Experimental in v0.1.0. API may change.
91//! - [`time_series_pipelines`] - Time series and IoT data pipelines.
92//!   Experimental in v0.1.0. API may change.
93//!
94//! ### WebAssembly and Compilation Targets
95//!
96//! - [`wasm_integration`] - WebAssembly compilation and deployment.
97//!   Experimental in v0.1.0. API may change.
98//! - [`enhanced_wasm_integration`] - Advanced WASM features (JS bindings, worker threads).
99//!   Experimental in v0.1.0. API may change.
100//!
101//! ### Extensibility and Tooling
102//!
103//! - [`modular_framework`] - Pluggable component registry and dependency graphs.
104//!   Experimental in v0.1.0. API may change.
105//! - [`plugin_architecture`] - Plugin loading and component schemas.
106//!   Experimental in v0.1.0. API may change.
107//! - [`workflow_language`] - Pipeline DSL and visual builder (requires `workflow` feature).
108//!   Experimental in v0.1.0. API may change.
109//! - [`zero_cost`] - Zero-cost abstraction primitives (arenas, lock-free queues).
110//!   Experimental in v0.1.0. API may change.
111//!
112//! ### Performance and Diagnostics
113//!
114//! - [`profile_guided_optimization`] - Profile-guided optimization with ML performance prediction.
115//!   Experimental in v0.1.0. API may change.
116//! - [`simd_optimizations`] - SIMD-accelerated data layout and feature operations.
117//!   Experimental in v0.1.0. API may change.
118//! - [`advanced_debugging`] - Interactive debugger with breakpoints and profiling.
119//!   Experimental in v0.1.0. API may change.
120//! - [`performance_profiler`] - Stage-level performance profiling and bottleneck detection.
121//!   Experimental in v0.1.0. API may change.
122//! - [`stress_testing`] - Stress testing and edge-case generation.
123//!   Experimental in v0.1.0. API may change.
124//!
125//! ## Known Limitations
126//!
127//! The following modules are disabled due to ndarray HRTB (Higher-Ranked Trait Bound)
128//! lifetime constraints introduced in ndarray 0.17. Planned for re-enabling in v0.2.0:
129//! - `cross_validation` - Cross-validation for composed models and pipelines
130
131// #![warn(missing_docs)]
132
133// Re-export core modules
134pub mod advanced_debugging;
135pub mod advanced_pipeline;
136pub mod api_consistency;
137pub mod automated_alerting;
138pub mod automl;
139pub mod benchmarking;
140pub mod boosting;
141pub mod circuit_breaker;
142pub mod column_transformer;
143pub mod error;
144// pub mod composable_execution;  // Migrated to execution module
145pub mod config_management;
146pub mod configuration_validation;
147pub mod continual_learning;
148// KNOWN ISSUE (v0.1.0): Module disabled due to ndarray HRTB lifetime constraints. Planned for v0.2.0.
149// pub mod cross_validation;
150pub mod cv_pipelines;
151pub mod dag_pipeline;
152pub mod debugging;
153pub mod debugging_utilities;
154pub mod developer_experience;
155pub mod differentiable;
156pub mod distributed;
157pub mod distributed_tracing;
158pub mod enhanced_compile_time_validation;
159pub mod enhanced_error_messages;
160pub mod enhanced_errors;
161pub mod ensemble;
162pub mod execution;
163pub mod execution_config;
164pub mod execution_core;
165pub mod execution_hooks;
166pub mod execution_strategies;
167pub mod execution_types;
168pub mod external_integration;
169pub mod fault_core;
170pub mod feature_engineering;
171pub mod few_shot;
172pub mod fluent_api;
173pub mod memory_optimization;
174pub mod meta_learning;
175pub mod middleware;
176pub mod mock;
177pub mod modular_framework;
178pub mod monitoring;
179pub mod nlp_pipelines;
180pub mod optimization;
181pub mod parallel_execution;
182// pub mod pattern_optimization; // Temporarily disabled - needs rand_distr dependency and type fixes
183pub mod performance_optimization;
184pub mod performance_profiler;
185pub mod performance_testing;
186pub mod pipeline;
187pub mod pipeline_visualization;
188pub mod plugin_architecture;
189pub mod profile_guided_optimization;
190pub mod property_testing;
191pub mod quality_assurance;
192pub mod quantum;
193pub mod resource_management;
194// pub mod retry; // Temporarily disabled - needs missing type imports and context fixes
195pub mod enhanced_wasm_integration;
196pub mod scheduling;
197pub mod simd_optimizations;
198pub mod state_management;
199pub mod streaming;
200pub mod stress_testing;
201pub mod task_definitions;
202pub mod task_scheduling;
203pub mod time_series_pipelines;
204pub mod transfer_learning;
205pub mod type_safety;
206pub mod validation;
207pub mod wasm_integration;
208#[cfg(feature = "workflow")]
209pub mod workflow_language;
210pub mod zero_cost;
211
212// Re-export main types and traits
213pub use advanced_debugging::{
214    AdvancedPipelineDebugger, AdvancedProfiler, Breakpoint, BreakpointCondition, CallStackFrame,
215    CpuSample, DebugConfig, DebugEvent, DebugSession, DebugSessionHandle, DebugSessionState,
216    DebugStatistics, ExecutionStep, MemorySample, StepResult, VariableInspector, VariableValue,
217    WatchExpression, WatchResult,
218};
219pub use advanced_pipeline::{
220    BranchConfig, BranchingPipeline, BranchingPipelineBuilder, BranchingPipelineTrained,
221    ConditionalPipeline, ConditionalPipelineBuilder, ConditionalPipelineTrained, DataCondition,
222    FeatureCountCondition,
223};
224pub use api_consistency::{
225    ApiConsistencyChecker, ApiRecommendation, ConfigSummary, ConfigValue, ConsistencyIssue,
226    ConsistencyReport, ExecutionMetadata, FittedModelSummary, FittedTransformerSummary,
227    IssueCategory, IssueSeverity, MetadataProvider, ModelSummary, PipelineConsistencyReport,
228    RecommendationCategory, RecommendationPriority, StandardBuilder, StandardConfig,
229    StandardEstimator, StandardFittedEstimator, StandardFittedTransformer, StandardResult,
230    StandardTransformer,
231};
232pub use automated_alerting::{
233    ActiveAlert, AlertChannel, AlertCondition, AlertConfig, AlertEvent, AlertSeverity, AlertStatus,
234    AutomatedAlerter, ConsoleAlertChannel, EmailAlertChannel, EscalationLevel, LogicalOperator,
235    PatternField, SilencePeriod, SlackAlertChannel, ThresholdOperator, WebhookAlertChannel,
236};
237pub use automl::{
238    ActivationFunction, AlgorithmChoice, AlgorithmType, AutoMLConfig, AutoMLOptimizer, LayerType,
239    NASStrategy, NeuralArchitecture, NeuralArchitectureSearch, NeuralSearchSpace,
240    OptimizationHistory, OptimizationMetric, OptimizationReport, ParameterRange,
241    ParameterValue as AutoMLParameterValue, SearchSpace, SearchStrategy as AutoMLSearchStrategy,
242    TrialResult, TrialStatus,
243};
244pub use benchmarking::{
245    BenchmarkConfig, BenchmarkReport, BenchmarkResult, BenchmarkSuite, ComplexityClass,
246    MemoryUsage as BenchmarkMemoryUsage, ScalabilityMetrics, UseCase,
247};
248pub use boosting::{
249    AdaBoostAlgorithm, AdaBoostClassifier, AdaBoostTrained, GradientBoostingRegressor,
250    GradientBoostingTrained, LossFunction,
251};
252pub use circuit_breaker::{
253    AdvancedCircuitBreaker, AnalyticsInsight, AnalyticsProcessor, AnalyticsRecommendation,
254    AnalyticsResult, CircuitBreaker, CircuitBreakerAnalytics, CircuitBreakerBuilder,
255    CircuitBreakerError, CircuitBreakerEvent, CircuitBreakerEventRecorder, CircuitBreakerEventType,
256    CircuitBreakerFailureDetector, CircuitBreakerRecoveryManager, CircuitBreakerStatsAggregator,
257    CircuitBreakerStatsTracker, ConsoleEventPublisher, ErrorTracker, EventPublisher,
258    FileEventPublisher, HealthCheckResult, HealthMetrics, RecoveryContext, RecoveryResult,
259    RecoveryStrategy as CircuitBreakerRecoveryStrategy, RequestCounters, ResponseTimeTracker,
260    SlidingWindow, ValidationResult as CircuitBreakerValidationResult,
261};
262pub use column_transformer::{
263    ColumnTransformer, ColumnTransformerBuilder, ColumnTransformerOutput, ColumnTransformerTrained,
264};
265pub use config_management::{
266    ConfigManager, ConfigValue as ConfigManagementConfigValue, EnvironmentConfig, EstimatorConfig,
267    ExecutionConfig, PipelineConfig, ResourceConfig, StepConfig,
268    ValidationRule as ConfigValidationRule,
269};
270pub use configuration_validation::{
271    CompileTimeValidator, ConfigurationValidator, Constraint, CustomValidationRule,
272    DependencyConstraint, FieldConstraints, FieldType, PipelineConfigValidator, RuleType,
273    RuntimeValidator, ValidationBuilder, ValidationReport, ValidationResult, ValidationSchema,
274    ValidationSeverity, ValidationStatus, ValidationSummary,
275};
276pub use continual_learning::{
277    ContinualLearningPipeline, ContinualLearningPipelineTrained, ContinualLearningStrategy,
278    MemoryBuffer, MemorySample as ContinualMemorySample, SamplingStrategy, Task, TaskStatistics,
279};
280// KNOWN ISSUE (v0.1.0): Module disabled due to ndarray HRTB lifetime constraints. Planned for v0.2.0.
281// pub use cross_validation::{
282//     CVStrategy, CVSummary, ComposedModelCrossValidator, CrossValidationConfig,
283//     CrossValidationResults, FoldResult, NestedCVResults, OuterFoldResult, ScoringConfig,
284//     ScoringMetric as CVScoringMetric, TimeSeriesConfig,
285// };
286pub use cv_pipelines::{
287    AdaptationAlgorithm, AdaptationMetric, AdaptiveQualityConfig, BoundingBox,
288    BufferManagementConfig, CVConfig, CVMetrics, CVModel, CVPipeline, CVPipelineState,
289    CVPrediction, CacheEvictionPolicy, CachingStrategy, CameraInfo, CameraIntrinsics,
290    CameraSettings, ColorSpace, CompressionAlgorithm, CompressionConfig, ComputeDevice,
291    ConfidenceScores, ContrastiveLearningConfig, CrossModalLearningConfig, CrossModalStrategy,
292    DenoisingAlgorithm, Detection, DetectionMetadata, DistillationConfig, EncodingConfig,
293    ErrorResilienceConfig, ExifData, ExtractorConfig, ExtractorType, FeatureExtractor,
294    FeatureMetadata, FeatureQuality, FeatureStatistics, FeatureVector,
295    FusionStrategy as CVFusionStrategy, GPSInfo, ImageData, ImageDataType, ImageFormat,
296    ImageMetadata, ImageSpecification, ImageTransform, InterpolationMethod, LensInfo,
297    LoadBalancingAlgorithm, MemoryOptimizationLevel, Modality, ModelConfig, ModelMetadata,
298    ModelPerformance, ModelType, MultiModalConfig, NetworkOptimizationConfig, NoiseReductionConfig,
299    NormalizationSpec, ObjectDetectionResult, ParallelProcessingConfig,
300    ParallelStrategy as CVParallelStrategy, PerformanceConfig, PerformanceMetrics,
301    PostProcessor as CVPostProcessor, PredictionMetadata, PredictionResult, PredictionType,
302    ProcessedResult, ProcessingComplexity, ProcessingMode, ProcessingStatistics, ProcessorConfig,
303    ProcessorType, QualityEnhancementConfig, QualityLevel, QualityMetrics as CVQualityMetrics,
304    QualitySettings, RateControlMethod, RealTimeProcessingConfig,
305    RecoveryStrategy as CVRecoveryStrategy, ResourceUtilization, SharpeningConfig, StreamingConfig,
306    StreamingProtocol, SyncMethod, TemporalAlignmentConfig, TransformParameter, VideoCodec,
307};
308pub use dag_pipeline::{
309    BranchCondition, ComparisonOp, DAGNode, DAGPipeline, DAGPipelineTrained,
310    ExecutionRecord as DAGExecutionRecord, ExecutionStats, MergeStrategy, NodeComponent,
311    NodeConfig, NodeOutput,
312};
313pub use debugging::{
314    Bottleneck as DebuggingBottleneck, BottleneckDetector,
315    BottleneckSeverity as DebuggingBottleneckSeverity, BottleneckType,
316    Breakpoint as DebuggingBreakpoint, BreakpointCondition as DebuggingBreakpointCondition,
317    BreakpointFrequency, ComparisonOperator, DataSnapshot, DataSummary,
318    DebugConfig as DebuggingConfig, DebugLogLevel, DebugOutputFormat,
319    DebugSession as DebuggingSession, ErrorAnalysis, ErrorPattern, ErrorResolutionStatus,
320    ErrorStatistics, ErrorTracker as DebuggingErrorTracker, ExecutionState,
321    ExecutionStep as EnhancedExecutionStep, InteractiveDebugger, IoStatistics,
322    MemoryUsage as DebuggingMemoryUsage, PerformanceAnalysis as DebuggingPerformanceAnalysis,
323    PerformanceMeasurement, PerformanceMetric, PerformanceProfiler, PipelineDebugger, SessionState,
324    StatisticalSummary, StepError, TrackedError,
325};
326// Temporarily commented out to avoid import conflicts
327// pub use debugging_utilities::{
328//     BottleneckAnalysis, BreakpointCondition as UtilsBreakpointCondition, CacheStatistics,
329//     ContextValue, DataFlowAnalysis, DataInspector, DataLineageNode, DataMetadata,
330//     DataSnapshot as UtilsDataSnapshot, DataStatistics, DebugCommand, DebugReport,
331//     DebugSession as UtilsDebugSession, DebugState, DebuggingConfig, ErrorContext,
332//     ErrorContextManager, ErrorInfo, ErrorSuggestion, ExecutionContext as UtilsExecutionContext,
333//     ExecutionState as UtilsExecutionState, ExecutionStatistics,
334//     ExecutionStep as UtilsExecutionStep, ExecutionTracer, ExportFormat, InspectionConfig,
335//     InteractiveConfig, InteractiveDebugger as UtilsInteractiveDebugger, MeasurementSession,
336//     PerformanceMetrics as UtilsPerformanceMetrics, PerformanceProfiler as UtilsPerformanceProfiler,
337//     PipelineDebugger as UtilsPipelineDebugger, ProfilingConfig, QualityMetric, StepStatus,
338//     StepType as UtilsStepType, TracingConfig, TransformationGraph, TransformationNode,
339//     TransformationSummary, WatchExpression as UtilsWatchExpression,
340// };
341pub use developer_experience::{
342    Breakpoint as DeveloperBreakpoint, CodeExample, DebugState, DebugSummary,
343    DeveloperFriendlyError, ErrorMessageEnhancer, ExecutionContext, FixSuggestion,
344    PipelineDebugger as DeveloperPipelineDebugger, StepType, SuggestionPriority, TraceEntry,
345    WatchExpression as DeveloperWatchExpression,
346};
347pub use differentiable::{
348    ActivationFunction as DiffActivationFunction, AutoDiffConfig, AutoDiffEngine, ComputationGraph,
349    ComputationNode, DifferentiableOperation, DifferentiablePipeline, DifferentiableStage,
350    DifferentiationMode, DualNumber, GradientAccumulation, GradientContext, GradientRecord,
351    LearningRateSchedule as DiffLearningRateSchedule, NeuralPipelineController,
352    OptimizationConfig as DiffOptimizationConfig, OptimizerState,
353    OptimizerType as DiffOptimizerType, Parameter as DiffParameter,
354    ParameterConfig as DiffParameterConfig, PipelineComponent, TrainingMetrics, TrainingState,
355};
356pub use distributed::TaskPriority as DistributedTaskPriority;
357pub use distributed::{
358    ClusterManager, ClusterNode, DataShard, DistributedTask, FaultDetector, LoadBalancer,
359    MapReducePipeline, NodeStatus, ResourceRequirements, TaskResult as DistributedTaskResult,
360    TaskStatus as DistributedTaskStatus,
361};
362pub use distributed_tracing::{
363    Bottleneck, BottleneckSeverity as TracingBottleneckSeverity, ConsoleTraceExporter,
364    DistributedTracer, JsonFileTraceExporter, LogEntry, LogLevel as TracingLogLevel,
365    ServiceAnalysis, ServiceInfo, SpanStatus, Trace, TraceAnalysis, TraceExporter, TraceHandle,
366    TraceSpan, TraceStatistics, TracingConfig,
367};
368pub use enhanced_compile_time_validation::{
369    BuilderConfig, CompileTimeValidator as EnhancedCompileTimeValidator, ConfigType,
370    ConfigValue as EnhancedConfigValue, ConfigurationLocation, ConstraintValidator,
371    CrossReferenceRule, CrossReferenceValidator, CustomValidator, DependencyValidator,
372    FieldDefinition, ParameterConstraintValidator, PipelineConfigurationSchema,
373    PipelineSchemaValidator, ReferenceType, SchemaConstraint, SchemaConstraintType,
374    SchemaValidator, SuggestionAction, SuggestionPriority as ValidationSuggestionPriority,
375    TypeConstraint, TypeSafeConfigBuilder, Unbuilt, Unvalidated, Validated,
376    ValidatedPipelineConfig, ValidationConfig, ValidationError, ValidationErrorCategory,
377    ValidationMetrics, ValidationProof, ValidationResult as EnhancedValidationResult,
378    ValidationRule, ValidationSeverity as EnhancedValidationSeverity,
379    ValidationStatus as EnhancedValidationStatus, ValidationSuggestion, ValidationWarning,
380    ValueConstraint, WarningCategory,
381};
382pub use enhanced_error_messages::{
383    ActionableSuggestion, AutoRecoveryHandler, CodeExample as EnhancedCodeExample,
384    ConfigurationContext, ContextProvider, ContextType, DataContext, DataQualityMetrics,
385    DataStatistics, DifficultyLevel, DocumentationLink, EnhancedErrorContext, EnhancedErrorMessage,
386    EnvironmentContext, ErrorCategory, ErrorClassification, ErrorContextCollector,
387    ErrorEnhancementConfig, ErrorEnhancementStatistics, ErrorFormatter, ErrorFrequency,
388    ErrorMessageEnhancer as EnhancedErrorMessageEnhancer, ErrorPattern as EnhancedErrorPattern,
389    ErrorPatternAnalyzer, ExpertiseLevel, ImplementationStep, IssueRelationship, MissingValueInfo,
390    OutputFormat, PerformanceBottleneck, PerformanceContext as EnhancedPerformanceContext,
391    PipelineContext as EnhancedPipelineContext, QualityIssue, RecoveryAdvisor,
392    RecoveryStrategy as EnhancedRecoveryStrategy, RelatedIssue, ResolutionStep, ResolutionStrategy,
393    SeverityLevel, SimilarIssue, StackFrame, SuggestionEngine, SuggestionGenerator,
394};
395pub use enhanced_errors::{
396    DataShape, EnhancedErrorBuilder, ErrorContext, ImpactLevel,
397    PerformanceMetrics as EnhancedPerformanceMetrics, PerformanceWarningType, PipelineError,
398    ResourceType, StructureErrorType, TypeViolationType,
399};
400pub use ensemble::{
401    CompetenceEstimation, DynamicEnsembleSelector, DynamicEnsembleSelectorBuilder,
402    DynamicEnsembleSelectorTrained, FusionStrategy, HierarchicalComposition,
403    HierarchicalCompositionBuilder, HierarchicalCompositionTrained, HierarchicalNode,
404    HierarchicalStrategy, ModelFusion, ModelFusionBuilder, ModelFusionTrained, SelectionStrategy,
405    VotingClassifier, VotingClassifierBuilder, VotingClassifierTrained, VotingRegressor,
406    VotingRegressorBuilder, VotingRegressorTrained,
407};
408pub use execution::{
409    ExecutionEngineConfig, ExecutionStrategy, ExecutionTask,
410    ParameterValue as ComposableParameterValue, PerformanceGoals, ResourceConstraints,
411    StrategyConfig, StrategyMetrics, TaskHandle, TaskPriority, TaskResult, TaskScheduler,
412    TaskStatus, TaskType,
413};
414pub use execution_hooks::{
415    CustomHook, CustomHookBuilder, ExecutionContext as HookExecutionContext, ExecutionHook,
416    HookData, HookManager, HookPhase, HookResult, LogLevel as HookLogLevel, LoggingHook,
417    MemoryUsage as HookMemoryUsage, PerformanceHook, PerformanceMetrics as HookPerformanceMetrics,
418    ValidationHook,
419};
420pub use external_integration::{
421    AuthConfig, AuthCredentials, AuthType, BackoffStrategy, CircuitBreakerConfig,
422    CircuitBreakerState, CircuitState, ConnectionConfig, DatabaseIntegration, ExternalIntegration,
423    ExternalIntegrationManager, HealthCheckConfig, HealthStatus, IntegrationConfig,
424    IntegrationData, IntegrationRequest, IntegrationResponse, IntegrationType, Operation,
425    OperationResult, RateLimitConfig, RefreshConfig, RestApiIntegration, RetryCondition,
426    RetryPolicy, TimeoutConfig, TlsConfig,
427};
428pub use feature_engineering::{
429    AutoFeatureEngineer, ColumnType, ColumnTypeDetector, DetectionMethod, FeatureInteraction,
430    FeatureInteractionDetector, InteractionType,
431};
432pub use few_shot::{
433    DistanceMetric, FewShotLearnerType, FewShotPipeline, FewShotPipelineTrained, MAMLLearner,
434    MAMLLearnerTrained, MetaLearnerWrapper, PrototypicalNetwork, PrototypicalNetworkTrained,
435    SupportSet,
436};
437pub use fluent_api::{
438    CacheStrategy, CachingConfiguration, DebugConfiguration, FeatureEngineeringChain,
439    FeatureUnionBuilder, FluentPipelineBuilder, ImputationStrategy, LogLevel, MemoryConfiguration,
440    PipelineConfiguration, PipelinePresets, PreprocessingChain, ValidationConfiguration,
441    ValidationLevel,
442};
443pub use memory_optimization::{
444    MemoryEfficientOps, MemoryMonitor, MemoryMonitorConfig, MemoryPool, MemoryStatistics,
445    MemoryUsage as MemoryOptimizationUsage, PoolStatistics, StreamingBuffer,
446};
447pub use meta_learning::{
448    AdaptationStrategy, Experience, ExperienceStorage, MetaLearningPipeline,
449    MetaLearningPipelineTrained,
450};
451pub use middleware::{
452    AlertManager, AlertRule, AlertSeverity as MiddlewareAlertSeverity, AuthenticationCredentials,
453    AuthenticationMethod as MiddlewareAuthMethod, AuthenticationMiddleware, AuthenticationProvider,
454    AuthorizationConfig as MiddlewareAuthConfig, AuthorizationMiddleware, CacheConfig, CacheEntry,
455    CachingMiddleware, ErrorAction, MiddlewareChain, MiddlewareChainConfig, MiddlewareContext,
456    MiddlewareStats, MonitoringMiddleware, PipelineMiddleware, TransformationMiddleware,
457    UserInfo as MiddlewareUserInfo, ValidationMiddleware,
458};
459pub use mock::{MockPredictor, MockTransformer};
460pub use modular_framework::{
461    CapabilityMismatch, CompatibilityReport, ComponentCapability, ComponentConfig,
462    ComponentDependency, ComponentFactory, ComponentInfo, ComponentMetadata, ComponentNode,
463    ComponentRegistry, ComponentStatus, CompositionContext, ConfigValue as ModularConfigValue,
464    DependencyGraph, EnvironmentSettings, ErrorHandlingStrategy, ExecutionCondition,
465    ExecutionMetadata as ModularExecutionMetadata, ExecutionStrategy as ModularExecutionStrategy,
466    LogLevel as ModularLogLevel, MissingDependency, ModularPipeline, ModularPipelineBuilder,
467    PipelineConfig as ModularPipelineConfig, PipelineStep as ModularPipelineStep,
468    PluggableComponent, ResourceLimits, ResourceManager, VersionConflict,
469};
470pub use monitoring::{
471    Anomaly, AnomalySeverity, AnomalyType, ExecutionContext as MonitoringExecutionContext,
472    ExecutionHandle, ExecutionStatus, Metric, MetricsSnapshot, MonitorConfig, PerformanceAnalysis,
473    PerformanceBaseline, PerformanceTrends, PipelineMonitor, StagePerformance, Trend,
474};
475pub use nlp_pipelines::{
476    AnalysisResult, BagOfWordsExtractor, ContextManager, ConversationResponse, ConversationTurn,
477    ConversationalAI, DocumentParser, DocumentProcessor, Entity, EvaluationConfig,
478    EvaluationMetric, FeatureExtractionConfig, FeatureExtractor as NLPFeatureExtractor,
479    LanguageDetector, LanguageModel, ModelConfig as NLPModelConfig, ModelPrediction,
480    ModelType as NLPModelType, MultiLanguageSupport, NERAnalyzer, NLPPipeline, NLPPipelineConfig,
481    OutputFormatter, PreprocessingConfig, ProcessingResult, ProcessingStats,
482    QuestionAnsweringModel, SentimentAnalyzer, SimpleLanguageModel, SummarizationStrategy,
483    TextAnalyzer, TextClassifier, TextNormalizer, TextPreprocessor, TextSummarizationModel,
484    TfIdfExtractor, TopicModelingAnalyzer, TrainingConfig as NLPTrainingConfig, TranslationModel,
485    WordEmbeddingExtractor,
486};
487pub use optimization::{
488    ErrorHandlingStrategy as OptimizationErrorHandlingStrategy, FallbackStrategy,
489    MultiObjectiveResult, OptimizationResults, ParameterSpace, ParameterType, ParetoFront,
490    PipelineOptimizer, PipelineValidator, RobustPipelineExecutor, ScoringMetric, SearchStrategy,
491};
492// pub use pattern_optimization::{
493//     MultiObjectiveOptimizer, MOOAlgorithm, ParetoFrontManager,
494//     ScalarizationMethod, ScalarizationType, PreferenceModel, PreferenceModelType,
495//     PreferenceStructure, PreferenceRelation, RelationType, IterationResult,
496//     SolutionSelector, DiversityMaintainer, MOOConvergenceDetector,
497//     MOOPerformanceIndicators, ArchiveManager, DominationAnalyzer,
498//     FrontExtractor, FrontQualityAssessor, FrontVisualizer,
499// };
500pub use parallel_execution::{
501    AsyncTask, LoadBalancingStrategy, ParallelConfig, ParallelExecutionStrategy, ParallelExecutor,
502    ParallelPipeline, ParallelTask, TaskResult as ParallelTaskResult, WorkerStatistics,
503};
504pub use performance_profiler::{
505    BottleneckMetrics, BottleneckSeverity as ProfilerBottleneckSeverity,
506    BottleneckType as ProfilerBottleneckType, ComparativeAnalysis, CpuSample as ProfilerCpuSample,
507    GpuSample, ImplementationDifficulty, MemorySample as ProfilerMemorySample,
508    OptimizationCategory, OptimizationHint, OptimizationPriority, OverallMetrics,
509    PerformanceProfiler as ProfilerPerformanceProfiler,
510    PerformanceReport as ProfilerPerformanceReport, ProfileSession, ProfilerConfig, StageProfile,
511    SummaryMetrics, TrendDirection as ProfilerTrendDirection,
512};
513pub use performance_testing::{
514    BenchmarkContext, BenchmarkResult as PerformanceBenchmarkResult, BenchmarkStorage,
515    CpuStatistics, EnvironmentConfig as PerformanceEnvironmentConfig, EnvironmentMetadata,
516    MemoryStatistics as PerformanceMemoryStatistics, OutlierDetection,
517    PerformanceMetrics as PerformanceTestingMetrics, PerformanceRegressionTester,
518    PerformanceReport, ProfilingConfig, RegressionAnalysis, RegressionSeverity,
519    RegressionThresholds, StatisticalAnalysisConfig, StatisticalTest, SystemInfo,
520    ThroughputMetrics, TimeStatistics, TrendAnalysis,
521};
522pub use pipeline::{Pipeline, PipelineBuilder, PipelinePredictor, PipelineStep, PipelineTrained};
523pub use pipeline_visualization::{
524    DataSpecification, DataType as VisualizationDataType, EdgeProperties, EdgeStyle, ExportFormat,
525    FontProperties, FontWeight, GraphEdge, GraphNode, IoSpecification, LayoutAlgorithm, NodeShape,
526    NodeSize, ParameterValue as VisualizationParameterValue, PipelineGraph, PipelineVisualizer,
527    ShapeSpecification, VisualProperties, VisualizationConfig,
528};
529pub use plugin_architecture::{
530    ComponentConfig as PluginComponentConfig, ComponentContext,
531    ComponentFactory as PluginComponentFactory, ComponentSchema, ConfigValue as PluginConfigValue,
532    ParameterConstraint, ParameterSchema, ParameterType as PluginParameterType, Plugin,
533    PluginCapability, PluginComponent, PluginConfig, PluginContext, PluginEstimator, PluginLoader,
534    PluginMetadata, PluginRegistry, PluginTransformer,
535};
536pub use profile_guided_optimization::{
537    AccessPattern, CacheOptimizationHints, DataCharacteristics, ExecutionMetrics, HardwareContext,
538    MLPerformancePredictor, MemoryLayout, OptimizationLevel, OptimizationStats,
539    OptimizationStrategy, OptimizerConfig, ParallelStrategy, PerformancePredictor,
540    PerformanceProfile, ProfileGuidedOptimizer, SimdFeature,
541};
542pub use property_testing::{
543    PipelinePropertyTester, PropertyTestCase, PropertyTestGenerator, PropertyTestResult,
544    StatisticalValidator, TestSuiteResult, TestSuiteRunner,
545    ValidationResult as PropertyValidationResult, ValidationStatistics,
546};
547pub use quality_assurance::{
548    AutomatedQualityAssurance, ComplianceStatus, ExecutiveSummary,
549    IssueCategory as QAIssueCategory, IssueSeverity as QAIssueSeverity, QAConfig,
550    QualityAssessment, QualityGates, QualityIssue as QAQualityIssue, QualityMetrics, QualityReport,
551    QualityStandards, RecommendationCategory as QARecommendationCategory,
552    RecommendationPriority as QARecommendationPriority, TrendDirection as QATrendDirection,
553};
554pub use quantum::{
555    QuantumBackend, QuantumEnsemble, QuantumGate, QuantumPipeline, QuantumPipelineBuilder,
556    QuantumPipelineStep, QuantumTransformer,
557};
558// pub use retry::{
559//     AdaptiveLearningSystem, BackoffAlgorithm, ExponentialBackoffAlgorithm, RetryManager,
560//     RetryStrategy, RetryContext, RetryConfig as RetryConfiguration, RetryError, RetryMetrics,
561//     AdaptiveStrategy, CircuitBreakerStrategy, LinearBackoffStrategy, FeatureEngineering,
562//     PolicyEvaluator, RetryPolicyEngine, ConfigurationManager, GlobalRetryConfig,
563//     AlertingSystem, RetryMetricsCollector, ModelPerformanceMetrics, SystemStatistics,
564// };
565pub use scheduling::{
566    ResourcePool, RetryConfig, ScheduledTask, SchedulerStatistics, SchedulingStrategy,
567    TaskScheduler as SchedulingTaskScheduler, TaskState, Workflow, WorkflowManager,
568};
569pub use simd_optimizations::{SimdConfig, SimdDataLayout, SimdFeatureOps, SimdOps};
570pub use state_management::{
571    CheckpointConfig, ExecutionStatistics, PersistenceStrategy, PipelineVersionControl, StateData,
572    StateManager, StateSnapshot, StateSynchronizer,
573};
574pub use streaming::{
575    StateManagement, StreamConfig, StreamDataPoint, StreamStats, StreamWindow, StreamingPipeline,
576    StreamingPipelineTrained, UpdateStrategy, WindowingStrategy,
577};
578pub use stress_testing::{
579    ComputationType, EdgeCase, MemoryPattern, ResourceMonitor, ResourceStats, ResourceUsageStats,
580    StressTestConfig, StressTestIssue, StressTestReport, StressTestResult, StressTestScenario,
581    StressTester,
582};
583// TODO: Enable once time_series_pipelines modules are implemented
584// pub use time_series_pipelines::{
585//     AnomalyDetector, AnomalyType as TimeSeriesAnomalyType, DataStream, EdgeNode, FeatureEngineer,
586//     IoTConfig, IoTDataPipeline, IoTDevice, MessageBroker, PostProcessor as TSPostProcessor,
587//     QualityControlConfig, RealTimeConfig, SamplingFrequency, SeasonalPattern,
588//     TimeSeriesCharacteristics, TimeSeriesConfig as TSConfig, TimeSeriesMetrics, TimeSeriesModel,
589//     TimeSeriesPipeline, TimeSeriesTransform, TrendType,
590// };
591pub use enhanced_wasm_integration::{
592    BrowserFeature, BrowserFeatureDetection, BrowserInfo,
593    BrowserIntegration as EnhancedBrowserIntegration, CompilationTarget, CompiledWasmModule,
594    CpuIntensity, ExecutionContext as WasmExecutionContext, ExecutionProfile,
595    FeatureDetectionStrategy, FunctionHandle, FunctionSignature, GeneratedBinding, IoProfile,
596    JsBindingsGenerator, JsType, LoadedWasmModule, MemoryConstraints,
597    MemoryLayout as WasmMemoryLayout, MemoryPermissions, MemoryProfile, MemoryRegion, ModuleLoader,
598    ModuleSource, OptimizationPass, OptimizationResult,
599    OptimizationStrategy as WasmOptimizationStrategy, PerformanceHints,
600    PerformanceProfile as WasmPerformanceProfile, PerformanceRequirements, ProfilingSession,
601    ScalingProfile, TaskData, TaskPriority as WasmTaskPriority, TaskType as WasmTaskType,
602    TypedArrayType, WasmArchitecture, WasmCompiler as EnhancedWasmCompiler, WasmExport,
603    WasmExportValue, WasmImport, WasmInstance, WasmIntegrationConfig, WasmIntegrationManager,
604    WasmMemoryView, WasmModuleManager, WasmModuleMetadata, WasmPerformanceOptimizer, WasmProfiler,
605    WasmType, WasmValue as EnhancedWasmValue, WebApiIntegration,
606    WorkerStatistics as WasmWorkerStatistics, WorkerStatus, WorkerTask, WorkerThread,
607    WorkerThreadManager,
608};
609pub use transfer_learning::{
610    domain_adaptation::{
611        DomainAdaptationPipeline, DomainAdaptationPipelineTrained, DomainAdaptationStrategy,
612    },
613    AdaptationConfig, LearningRateSchedule, PretrainedModel, TransferLearningPipeline,
614    TransferLearningPipelineTrained, TransferStrategy,
615};
616pub use type_safety::{
617    CategoricalInput, ClassificationOutput, DataFlowValidation, DataFlowValidator, DenseOutput,
618    Input, MixedInput, NumericInput, Output, PipelineValidation, PipelineValidationError,
619    RegressionOutput, SparseOutput, StructureValidation, TypeCompatible, TypedEstimator,
620    TypedFeatureUnion, TypedPipelineBuilder, TypedPipelineStage, TypedTransformer,
621};
622pub use validation::{
623    ComprehensivePipelineValidator, CrossValidationResult, CrossValidator, DataValidationResult,
624    DataValidator, MessageLevel, PerformanceValidationResult, PerformanceValidator,
625    RobustnessTestResult, RobustnessTester, StatisticalValidationResult,
626    StatisticalValidator as ComprehensiveStatisticalValidator, StructureValidationResult,
627    StructureValidator, ValidationMessage, ValidationReport as ComprehensiveValidationReport,
628};
629pub use wasm_integration::{
630    BrowserIntegration, DataSchema, OptimizationLevel as WasmOptimizationLevel, PipelineMetadata,
631    WasmCompiler, WasmConfig, WasmDataType, WasmModule, WasmOptimization, WasmPipeline, WasmStep,
632    WasmStepType, WasmValue,
633};
634#[cfg(feature = "workflow")]
635pub use workflow_language::{
636    CodeLanguage, ComponentRegistry as WorkflowComponentRegistry, DataType, DslError, DslLexer,
637    DslParser, ExecutionConfig as WorkflowExecutionConfig, ExecutionMode, FileFormat,
638    ParameterValue, PipelineDSL, StepDefinition, StepType as WorkflowStepType, Token,
639    VisualPipelineBuilder, WorkflowDefinition, WorkflowMetadata,
640};
641pub use zero_cost::{
642    AllocationInfo, Arena, AtomicRcData, ConcurrencyStats, CowData, LockFreeQueue,
643    MemoryLeakConfig, MemoryLeakDetector, MemoryPool as ZeroCostMemoryPool, MemoryStats,
644    PooledBuffer, QueueStats, SafeConcurrentData, SharedData, TrackedAllocation, WeakRcData,
645    WorkStealingDeque, WorkStealingStats, ZeroCopySlice, ZeroCopyView, ZeroCostBuffer,
646    ZeroCostBuilder, ZeroCostCompose, ZeroCostComposition, ZeroCostConditional, ZeroCostEstimator,
647    ZeroCostFeatureSelector, ZeroCostFeatureUnion, ZeroCostLayout, ZeroCostParallel,
648    ZeroCostPipeline, ZeroCostStep,
649};
650
651use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
652use sklears_core::types::Float;
653use sklears_core::{
654    error::Result as SklResult,
655    prelude::{Fit as CoreFit, Predict, SklearsError, Transform},
656    traits::{Estimator, Fit, Untrained},
657};
658use std::collections::HashMap;
659
660/// Transformed Target Regressor
661///
662/// Meta-estimator that transforms the target y before fitting a regression model.
663/// The predictions are mapped back to the original space via an inverse transform.
664///
665/// # Parameters
666///
667/// * `regressor` - The regressor to use for prediction
668/// * `transformer` - The transformer to apply to the target
669/// * `func` - Function to transform the target
670/// * `inverse_func` - Function to inverse transform the predictions
671/// * `check_inverse` - Whether to check the inverse transform
672///
673// Type aliases to reduce complexity
674type TransformBox = Box<dyn for<'a> Transform<ArrayView1<'a, Float>, Array1<f64>> + Send + Sync>;
675type TransformFunc = fn(&ArrayView1<'_, Float>) -> Array1<f64>;
676
677/// # Examples
678///
679/// ```
680/// use sklears_compose::TransformedTargetRegressor;
681/// use scirs2_core::ndarray::array;
682///
683/// let X = array![[1.0], [2.0], [3.0]];
684/// let y = array![1.0, 4.0, 9.0];
685/// ```
686pub struct TransformedTargetRegressor<S = Untrained> {
687    state: S,
688    regressor: Option<Box<dyn PipelinePredictor>>,
689    transformer: Option<TransformBox>,
690    func: Option<TransformFunc>,
691    inverse_func: Option<TransformFunc>,
692    check_inverse: bool,
693}
694
695/// Trained state for `TransformedTargetRegressor`
696pub struct TransformedTargetRegressorTrained {
697    fitted_regressor: Box<dyn PipelinePredictor>,
698    fitted_transformer: Option<TransformBox>,
699    func: Option<TransformFunc>,
700    inverse_func: Option<TransformFunc>,
701    n_features_in: usize,
702    feature_names_in: Option<Vec<String>>,
703}
704
705impl TransformedTargetRegressor<Untrained> {
706    /// Create a new `TransformedTargetRegressor`
707    #[must_use]
708    pub fn new(regressor: Box<dyn PipelinePredictor>) -> Self {
709        Self {
710            state: Untrained,
711            regressor: Some(regressor),
712            transformer: None,
713            func: None,
714            inverse_func: None,
715            check_inverse: true,
716        }
717    }
718
719    /// Set the transformer
720    #[must_use]
721    pub fn transformer(
722        mut self,
723        transformer: Box<dyn for<'a> Transform<ArrayView1<'a, Float>, Array1<f64>> + Send + Sync>,
724    ) -> Self {
725        self.transformer = Some(transformer);
726        self
727    }
728
729    /// Set transformation function
730    pub fn func(mut self, func: fn(&ArrayView1<'_, Float>) -> Array1<f64>) -> Self {
731        self.func = Some(func);
732        self
733    }
734
735    /// Set inverse transformation function
736    pub fn inverse_func(mut self, inverse_func: fn(&ArrayView1<'_, Float>) -> Array1<f64>) -> Self {
737        self.inverse_func = Some(inverse_func);
738        self
739    }
740
741    /// Set whether to check inverse transform
742    #[must_use]
743    pub fn check_inverse(mut self, check: bool) -> Self {
744        self.check_inverse = check;
745        self
746    }
747}
748
749impl Estimator for TransformedTargetRegressor<Untrained> {
750    type Config = ();
751    type Error = SklearsError;
752    type Float = Float;
753
754    fn config(&self) -> &Self::Config {
755        &()
756    }
757}
758
759impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
760    for TransformedTargetRegressor<Untrained>
761{
762    type Fitted = TransformedTargetRegressor<TransformedTargetRegressorTrained>;
763
764    fn fit(
765        self,
766        x: &ArrayView2<'_, Float>,
767        y: &Option<&ArrayView1<'_, Float>>,
768    ) -> SklResult<Self::Fitted> {
769        if let Some(y_values) = y.as_ref() {
770            let mut regressor = self.regressor.ok_or_else(|| SklearsError::InvalidData {
771                reason: "No regressor provided".to_string(),
772            })?;
773
774            // Transform the target if transformer or function is provided
775            let transformed_y = if let Some(ref transformer) = self.transformer {
776                transformer.transform(y_values)?
777            } else if let Some(func) = self.func {
778                func(y_values)
779            } else {
780                y_values.mapv(|v| v)
781            };
782
783            // Fit the regressor on transformed target
784            regressor.fit(x, &transformed_y.view())?;
785
786            Ok(TransformedTargetRegressor {
787                state: TransformedTargetRegressorTrained {
788                    fitted_regressor: regressor,
789                    fitted_transformer: self.transformer,
790                    func: self.func,
791                    inverse_func: self.inverse_func,
792                    n_features_in: x.ncols(),
793                    feature_names_in: None,
794                },
795                regressor: None,
796                transformer: None,
797                func: None,
798                inverse_func: None,
799                check_inverse: self.check_inverse,
800            })
801        } else {
802            Err(SklearsError::InvalidInput(
803                "Target values required for fitting".to_string(),
804            ))
805        }
806    }
807}
808
809impl TransformedTargetRegressor<TransformedTargetRegressorTrained> {
810    /// Predict using the fitted regressor and inverse transform
811    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
812        // Get predictions from the fitted regressor
813        let transformed_predictions = self.state.fitted_regressor.predict(x)?;
814
815        // Apply inverse transformation
816        let predictions = if let Some(inverse_func) = self.state.inverse_func {
817            inverse_func(&transformed_predictions.view())
818        } else if let Some(ref transformer) = self.state.fitted_transformer {
819            // Note: This assumes the transformer has an inverse_transform method
820            // In a real implementation, you'd need a proper inverse transformer trait
821            transformed_predictions
822        } else {
823            transformed_predictions
824        };
825
826        Ok(predictions)
827    }
828
829    /// Get the fitted regressor
830    #[must_use]
831    pub fn regressor(&self) -> &dyn PipelinePredictor {
832        &*self.state.fitted_regressor
833    }
834}
835
836/// Feature Union for parallel feature extraction
837///
838/// Applies multiple transformers in parallel and concatenates their results.
839///
840/// # Examples
841///
842/// ```ignore
843/// use sklears_compose::FeatureUnion;
844/// use scirs2_core::ndarray::array;
845///
846/// let data = array![[1.0, 2.0], [3.0, 4.0]];
847/// let union = FeatureUnion::new()
848///     .transformer("trans1", Box::new(MockTransformer::new()))
849///     .transformer("trans2", Box::new(MockTransformer::new()));
850/// ```
851#[derive(Debug)]
852pub struct FeatureUnion<S = Untrained> {
853    state: S,
854    transformers: Vec<(String, Box<dyn PipelineStep>)>,
855    n_jobs: Option<i32>,
856    transformer_weights: Option<HashMap<String, f64>>,
857    preserve_dataframe: bool,
858}
859
860/// Trained state for `FeatureUnion`
861#[derive(Debug)]
862pub struct FeatureUnionTrained {
863    fitted_transformers: Vec<(String, Box<dyn PipelineStep>)>,
864    n_features_in: usize,
865    feature_names_in: Option<Vec<String>>,
866}
867
868impl FeatureUnion<Untrained> {
869    /// Create a new `FeatureUnion`
870    #[must_use]
871    pub fn new() -> Self {
872        Self {
873            state: Untrained,
874            transformers: Vec::new(),
875            n_jobs: None,
876            transformer_weights: None,
877            preserve_dataframe: false,
878        }
879    }
880
881    /// Add a transformer
882    #[must_use]
883    pub fn transformer(mut self, name: &str, transformer: Box<dyn PipelineStep>) -> Self {
884        self.transformers.push((name.to_string(), transformer));
885        self
886    }
887
888    /// Set number of jobs
889    #[must_use]
890    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
891        self.n_jobs = n_jobs;
892        self
893    }
894
895    /// Set transformer weights
896    #[must_use]
897    pub fn transformer_weights(mut self, weights: HashMap<String, f64>) -> Self {
898        self.transformer_weights = Some(weights);
899        self
900    }
901
902    /// Set preserve dataframe option
903    #[must_use]
904    pub fn preserve_dataframe(mut self, preserve: bool) -> Self {
905        self.preserve_dataframe = preserve;
906        self
907    }
908}
909
910impl Default for FeatureUnion<Untrained> {
911    fn default() -> Self {
912        Self::new()
913    }
914}
915
916impl Estimator for FeatureUnion<Untrained> {
917    type Config = ();
918    type Error = SklearsError;
919    type Float = Float;
920
921    fn config(&self) -> &Self::Config {
922        &()
923    }
924}
925
926impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for FeatureUnion<Untrained> {
927    type Fitted = FeatureUnion<FeatureUnionTrained>;
928
929    fn fit(
930        self,
931        x: &ArrayView2<'_, Float>,
932        y: &Option<&ArrayView1<'_, Float>>,
933    ) -> SklResult<Self::Fitted> {
934        let mut fitted_transformers = Vec::new();
935
936        for (name, mut transformer) in self.transformers {
937            transformer.fit(x, y.as_ref().copied())?;
938            fitted_transformers.push((name, transformer));
939        }
940
941        Ok(FeatureUnion {
942            state: FeatureUnionTrained {
943                fitted_transformers,
944                n_features_in: x.ncols(),
945                feature_names_in: None,
946            },
947            transformers: Vec::new(),
948            n_jobs: self.n_jobs,
949            transformer_weights: self.transformer_weights,
950            preserve_dataframe: self.preserve_dataframe,
951        })
952    }
953}
954
955impl FeatureUnion<FeatureUnionTrained> {
956    /// Transform data using all fitted transformers and concatenate results
957    pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
958        if self.state.fitted_transformers.is_empty() {
959            return Ok(x.mapv(|v| v));
960        }
961
962        let mut results = Vec::new();
963
964        for (name, transformer) in &self.state.fitted_transformers {
965            let mut transformed = transformer.transform(x)?;
966
967            // Apply weights if specified
968            if let Some(ref weights) = self.transformer_weights {
969                if let Some(&weight) = weights.get(name) {
970                    transformed.mapv_inplace(|v| v * weight);
971                }
972            }
973
974            results.push(transformed);
975        }
976
977        // Concatenate all results along the feature axis
978        if results.len() == 1 {
979            Ok(results.into_iter().next().unwrap_or_default())
980        } else {
981            let total_features: usize = results
982                .iter()
983                .map(scirs2_core::ndarray::ArrayBase::ncols)
984                .sum();
985            let n_samples = results[0].nrows();
986
987            let mut concatenated = Array2::zeros((n_samples, total_features));
988            let mut col_idx = 0;
989
990            for result in results {
991                let end_idx = col_idx + result.ncols();
992                concatenated
993                    .slice_mut(s![.., col_idx..end_idx])
994                    .assign(&result);
995                col_idx = end_idx;
996            }
997
998            Ok(concatenated)
999        }
1000    }
1001
1002    /// Get the fitted transformers
1003    #[must_use]
1004    pub fn transformers(&self) -> &[(String, Box<dyn PipelineStep>)] {
1005        &self.state.fitted_transformers
1006    }
1007}
1008
1009// Import ndarray slice macro
1010use scirs2_core::ndarray::s;
1011
1012#[allow(non_snake_case)]
1013#[cfg(test)]
1014mod tests {
1015    use super::*;
1016    use scirs2_core::ndarray::array;
1017
1018    #[test]
1019    fn test_mock_transformer() {
1020        let transformer = MockTransformer::new();
1021        let x = array![[1.0, 2.0], [3.0, 4.0]];
1022        let result = crate::PipelineStep::transform(&transformer, &x.view()).unwrap_or_default();
1023        assert_eq!(result, x.mapv(|v| v as f64));
1024    }
1025
1026    #[test]
1027    fn test_mock_predictor() {
1028        let mut predictor = MockPredictor::new();
1029        let x = array![[1.0, 2.0], [3.0, 4.0]];
1030        let y = array![1.0, 2.0];
1031
1032        predictor.fit(&x.view(), &y.view()).unwrap_or_default();
1033        assert!(predictor.is_fitted());
1034
1035        let predictions = predictor.predict(&x.view()).unwrap_or_default();
1036        assert_eq!(predictions.len(), x.nrows());
1037    }
1038
1039    #[test]
1040    fn test_feature_union() {
1041        let x = array![[1.0, 2.0], [3.0, 4.0]];
1042
1043        let union = FeatureUnion::new()
1044            .transformer("trans1", Box::new(MockTransformer::new()))
1045            .transformer("trans2", Box::new(MockTransformer::with_scale(2.0)));
1046
1047        let fitted_union = union
1048            .fit(&x.view(), &None)
1049            .expect("operation should succeed");
1050        let result = fitted_union.transform(&x.view()).unwrap_or_default();
1051
1052        // Should concatenate results from both transformers
1053        assert_eq!(result.ncols(), 4); // 2 features * 2 transformers
1054        assert_eq!(result.nrows(), 2); // Same number of samples
1055    }
1056
1057    #[test]
1058    fn test_pipeline_basic() {
1059        let x = array![[1.0, 2.0], [3.0, 4.0]];
1060        let y = array![1.0, 2.0];
1061
1062        let pipeline = Pipeline::builder()
1063            .step("scaler", Box::new(MockTransformer::new()))
1064            .estimator(Box::new(MockPredictor::new()))
1065            .build();
1066
1067        let fitted_pipeline = pipeline
1068            .fit(&x.view(), &Some(&y.view()))
1069            .expect("operation should succeed");
1070        let predictions = fitted_pipeline.predict(&x.view()).unwrap_or_default();
1071
1072        assert_eq!(predictions.len(), x.nrows());
1073    }
1074}