1use crate::traits::SimdError;
7
8#[cfg(feature = "no-std")]
9extern crate alloc;
10
11#[cfg(feature = "no-std")]
12use alloc::{
13 boxed::Box,
14 collections::BTreeMap as HashMap,
15 format,
16 string::{String, ToString},
17 vec::Vec,
18};
19#[cfg(feature = "no-std")]
20use core::any::Any;
21#[cfg(not(feature = "no-std"))]
22use std::{any::Any, collections::HashMap, string::ToString};
23
24#[cfg(feature = "no-std")]
27use alloc::sync::Arc;
28#[cfg(not(feature = "no-std"))]
29use std::sync::Arc;
30
31pub type MiddlewareResult<T> = Result<T, SimdError>;
33
34pub struct PipelineContext {
36 pub data: Vec<f32>,
38 pub metadata: HashMap<String, String>,
40 pub context: HashMap<String, Box<dyn Any + Send + Sync>>,
42}
43
44impl PipelineContext {
45 pub fn new(data: Vec<f32>) -> Self {
47 Self {
48 data,
49 metadata: HashMap::new(),
50 context: HashMap::new(),
51 }
52 }
53
54 pub fn set_metadata(&mut self, key: String, value: String) {
56 self.metadata.insert(key, value);
57 }
58
59 pub fn get_metadata(&self, key: &str) -> Option<&String> {
61 self.metadata.get(key)
62 }
63
64 pub fn set_context<T: Any + Send + Sync>(&mut self, key: String, value: T) {
66 self.context.insert(key, Box::new(value));
67 }
68
69 pub fn get_context<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
71 self.context.get(key).and_then(|v| v.downcast_ref::<T>())
72 }
73}
74
75pub trait Middleware: Send + Sync {
77 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()>;
79
80 fn name(&self) -> &str;
82
83 fn should_execute(&self, context: &PipelineContext) -> bool {
85 let _ = context; true
87 }
88}
89
90pub struct Pipeline {
92 middleware: Vec<Arc<dyn Middleware>>,
94 name: String,
96 fail_fast: bool,
98}
99
100impl Pipeline {
101 pub fn new(name: String) -> Self {
103 Self {
104 middleware: Vec::new(),
105 name,
106 fail_fast: true,
107 }
108 }
109
110 pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
112 self.middleware.push(Arc::new(middleware));
113 self
114 }
115
116 pub fn fail_fast(mut self, fail_fast: bool) -> Self {
118 self.fail_fast = fail_fast;
119 self
120 }
121
122 pub fn execute(&self, mut context: PipelineContext) -> MiddlewareResult<PipelineContext> {
124 context.set_metadata("pipeline_name".to_string(), self.name.clone());
125
126 for middleware in &self.middleware {
127 if middleware.should_execute(&context) {
128 if let Err(e) = middleware.process(&mut context) {
129 if self.fail_fast {
130 return Err(e);
131 }
132 context.set_metadata(
134 format!("error_{}", middleware.name()),
135 format!("Error: {}", e),
136 );
137 }
138 }
139 }
140
141 Ok(context)
142 }
143
144 pub fn name(&self) -> &str {
146 &self.name
147 }
148
149 pub fn middleware_count(&self) -> usize {
151 self.middleware.len()
152 }
153}
154
155pub struct PipelineBuilder {
157 name: String,
158 middleware: Vec<Arc<dyn Middleware>>,
159 fail_fast: bool,
160}
161
162impl PipelineBuilder {
163 pub fn new(name: String) -> Self {
165 Self {
166 name,
167 middleware: Vec::new(),
168 fail_fast: true,
169 }
170 }
171
172 pub fn with_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
174 self.middleware.push(Arc::new(middleware));
175 self
176 }
177
178 pub fn fail_fast(mut self, fail_fast: bool) -> Self {
180 self.fail_fast = fail_fast;
181 self
182 }
183
184 pub fn build(self) -> Pipeline {
186 Pipeline {
187 middleware: self.middleware,
188 name: self.name,
189 fail_fast: self.fail_fast,
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
197pub struct NormalizationMiddleware {
198 norm_type: NormType,
200}
201
202#[derive(Debug, Clone)]
203pub enum NormType {
204 L1,
205 L2,
206 MinMax,
207}
208
209impl NormalizationMiddleware {
210 pub fn new(norm_type: NormType) -> Self {
211 Self { norm_type }
212 }
213}
214
215impl Middleware for NormalizationMiddleware {
216 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
217 let data = &mut context.data;
218
219 match self.norm_type {
220 NormType::L1 => {
221 let sum: f32 = data.iter().map(|x| x.abs()).sum();
222 if sum != 0.0 {
223 data.iter_mut().for_each(|x| *x /= sum);
224 }
225 }
226 NormType::L2 => {
227 let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
228 if norm != 0.0 {
229 data.iter_mut().for_each(|x| *x /= norm);
230 }
231 }
232 NormType::MinMax => {
233 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
234 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
235 let range = max_val - min_val;
236 if range != 0.0 {
237 data.iter_mut().for_each(|x| *x = (*x - min_val) / range);
238 }
239 }
240 }
241
242 context.set_metadata("normalized".to_string(), format!("{:?}", self.norm_type));
243 Ok(())
244 }
245
246 fn name(&self) -> &str {
247 "normalization"
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct FilteringMiddleware {
254 min_threshold: f32,
256 max_threshold: f32,
258}
259
260impl FilteringMiddleware {
261 pub fn new(min_threshold: f32, max_threshold: f32) -> Self {
262 Self {
263 min_threshold,
264 max_threshold,
265 }
266 }
267}
268
269impl Middleware for FilteringMiddleware {
270 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
271 let original_len = context.data.len();
272 context
273 .data
274 .retain(|&x| x >= self.min_threshold && x <= self.max_threshold);
275
276 let filtered_count = original_len - context.data.len();
277 context.set_metadata("filtered_count".to_string(), filtered_count.to_string());
278
279 Ok(())
280 }
281
282 fn name(&self) -> &str {
283 "filtering"
284 }
285}
286
287#[derive(Debug, Clone)]
289pub struct TransformationMiddleware {
290 transform_type: TransformType,
292}
293
294#[derive(Debug, Clone)]
295pub enum TransformType {
296 Log,
297 Exp,
298 Sqrt,
299 Square,
300 Abs,
301}
302
303impl TransformationMiddleware {
304 pub fn new(transform_type: TransformType) -> Self {
305 Self { transform_type }
306 }
307}
308
309impl Middleware for TransformationMiddleware {
310 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
311 let data = &mut context.data;
312
313 match self.transform_type {
314 TransformType::Log => {
315 data.iter_mut().for_each(|x| *x = x.max(f32::EPSILON).ln());
316 }
317 TransformType::Exp => {
318 data.iter_mut().for_each(|x| *x = x.exp());
319 }
320 TransformType::Sqrt => {
321 data.iter_mut().for_each(|x| *x = x.max(0.0).sqrt());
322 }
323 TransformType::Square => {
324 data.iter_mut().for_each(|x| *x = *x * *x);
325 }
326 TransformType::Abs => {
327 data.iter_mut().for_each(|x| *x = x.abs());
328 }
329 }
330
331 context.set_metadata(
332 "transformed".to_string(),
333 format!("{:?}", self.transform_type),
334 );
335 Ok(())
336 }
337
338 fn name(&self) -> &str {
339 "transformation"
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct AggregationMiddleware {
346 agg_type: AggregationType,
348}
349
350#[derive(Debug, Clone)]
351pub enum AggregationType {
352 Sum,
353 Mean,
354 Max,
355 Min,
356 StdDev,
357}
358
359impl AggregationMiddleware {
360 pub fn new(agg_type: AggregationType) -> Self {
361 Self { agg_type }
362 }
363}
364
365impl Middleware for AggregationMiddleware {
366 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
367 let data = &context.data;
368
369 if data.is_empty() {
370 return Err(SimdError::InvalidInput(
371 "Empty data for aggregation".to_string(),
372 ));
373 }
374
375 let result = match self.agg_type {
376 AggregationType::Sum => data.iter().sum::<f32>(),
377 AggregationType::Mean => data.iter().sum::<f32>() / data.len() as f32,
378 AggregationType::Max => data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
379 AggregationType::Min => data.iter().fold(f32::INFINITY, |a, &b| a.min(b)),
380 AggregationType::StdDev => {
381 let mean = data.iter().sum::<f32>() / data.len() as f32;
382 let variance =
383 data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
384 variance.sqrt()
385 }
386 };
387
388 context.set_metadata("aggregation_result".to_string(), result.to_string());
389 context.set_metadata(
390 "aggregation_type".to_string(),
391 format!("{:?}", self.agg_type),
392 );
393
394 Ok(())
395 }
396
397 fn name(&self) -> &str {
398 "aggregation"
399 }
400}
401
402pub struct ConditionalMiddleware {
404 condition: Box<dyn Fn(&PipelineContext) -> bool + Send + Sync>,
406 middleware: Arc<dyn Middleware>,
408}
409
410impl ConditionalMiddleware {
411 pub fn new<F, M>(condition: F, middleware: M) -> Self
412 where
413 F: Fn(&PipelineContext) -> bool + Send + Sync + 'static,
414 M: Middleware + 'static,
415 {
416 Self {
417 condition: Box::new(condition),
418 middleware: Arc::new(middleware),
419 }
420 }
421}
422
423impl Middleware for ConditionalMiddleware {
424 fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
425 if (self.condition)(context) {
426 self.middleware.process(context)
427 } else {
428 Ok(())
429 }
430 }
431
432 fn name(&self) -> &str {
433 "conditional"
434 }
435
436 fn should_execute(&self, context: &PipelineContext) -> bool {
437 (self.condition)(context)
438 }
439}
440
441#[allow(non_snake_case)]
442#[cfg(all(test, not(feature = "no-std")))]
443mod tests {
444 use super::*;
445
446 #[cfg(feature = "no-std")]
447 use alloc::{
448 string::{String, ToString},
449 vec,
450 vec::Vec,
451 };
452
453 #[test]
454 fn test_pipeline_context_creation() {
455 let data = vec![1.0, 2.0, 3.0, 4.0];
456 let context = PipelineContext::new(data.clone());
457
458 assert_eq!(context.data, data);
459 assert!(context.metadata.is_empty());
460 assert!(context.context.is_empty());
461 }
462
463 #[test]
464 fn test_pipeline_context_metadata() {
465 let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
466
467 context.set_metadata("test_key".to_string(), "test_value".to_string());
468 assert_eq!(
469 context.get_metadata("test_key"),
470 Some(&"test_value".to_string())
471 );
472 assert_eq!(context.get_metadata("nonexistent"), None);
473 }
474
475 #[test]
476 fn test_pipeline_context_context_data() {
477 let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
478
479 context.set_context("test_int".to_string(), 42i32);
480 assert_eq!(context.get_context::<i32>("test_int"), Some(&42i32));
481 assert_eq!(context.get_context::<f32>("test_int"), None);
482 }
483
484 #[test]
485 fn test_normalization_middleware_l2() {
486 let mut context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
487 let middleware = NormalizationMiddleware::new(NormType::L2);
488
489 middleware
490 .process(&mut context)
491 .expect("operation should succeed");
492
493 assert!((context.data[0] - 0.6).abs() < 1e-6);
495 assert!((context.data[1] - 0.8).abs() < 1e-6);
496 assert!((context.data[2] - 0.0).abs() < 1e-6);
497 }
498
499 #[test]
500 fn test_filtering_middleware() {
501 let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
502 let middleware = FilteringMiddleware::new(2.0, 4.0);
503
504 middleware
505 .process(&mut context)
506 .expect("operation should succeed");
507
508 assert_eq!(context.data, vec![2.0, 3.0, 4.0]);
509 assert_eq!(
510 context.get_metadata("filtered_count"),
511 Some(&"2".to_string())
512 );
513 }
514
515 #[test]
516 fn test_transformation_middleware_sqrt() {
517 let mut context = PipelineContext::new(vec![1.0, 4.0, 9.0, 16.0]);
518 let middleware = TransformationMiddleware::new(TransformType::Sqrt);
519
520 middleware
521 .process(&mut context)
522 .expect("operation should succeed");
523
524 assert_eq!(context.data, vec![1.0, 2.0, 3.0, 4.0]);
525 }
526
527 #[test]
528 fn test_aggregation_middleware_mean() {
529 let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0, 4.0]);
530 let middleware = AggregationMiddleware::new(AggregationType::Mean);
531
532 middleware
533 .process(&mut context)
534 .expect("operation should succeed");
535
536 assert_eq!(
537 context.get_metadata("aggregation_result"),
538 Some(&"2.5".to_string())
539 );
540 assert_eq!(
541 context.get_metadata("aggregation_type"),
542 Some(&"Mean".to_string())
543 );
544 }
545
546 #[test]
547 fn test_pipeline_builder() {
548 let pipeline = PipelineBuilder::new("test_pipeline".to_string())
549 .with_middleware(NormalizationMiddleware::new(NormType::L2))
550 .with_middleware(FilteringMiddleware::new(0.1, 0.9))
551 .fail_fast(false)
552 .build();
553
554 assert_eq!(pipeline.name(), "test_pipeline");
555 assert_eq!(pipeline.middleware_count(), 2);
556 }
557
558 #[test]
559 fn test_pipeline_execution() {
560 let pipeline = Pipeline::new("test_pipeline".to_string())
561 .add_middleware(NormalizationMiddleware::new(NormType::L2))
562 .add_middleware(TransformationMiddleware::new(TransformType::Square));
563
564 let context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
565 let result = pipeline.execute(context).expect("operation should succeed");
566
567 assert!((result.data[0] - 0.36).abs() < 1e-6);
570 assert!((result.data[1] - 0.64).abs() < 1e-6);
571 assert!((result.data[2] - 0.0).abs() < 1e-6);
572 }
573
574 #[test]
575 fn test_conditional_middleware() {
576 let condition = |context: &PipelineContext| context.data.len() > 2;
577 let middleware =
578 ConditionalMiddleware::new(condition, NormalizationMiddleware::new(NormType::L2));
579
580 let mut context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
582 middleware
583 .process(&mut context)
584 .expect("operation should succeed");
585 assert!((context.data[0] - 0.6).abs() < 1e-6);
586
587 let mut context = PipelineContext::new(vec![3.0, 4.0]);
589 let original_data = context.data.clone();
590 middleware
591 .process(&mut context)
592 .expect("operation should succeed");
593 assert_eq!(context.data, original_data); }
595
596 #[test]
597 fn test_empty_data_handling() {
598 let mut context = PipelineContext::new(vec![]);
599 let middleware = AggregationMiddleware::new(AggregationType::Mean);
600
601 let result = middleware.process(&mut context);
602 assert!(result.is_err());
603 }
604
605 #[test]
606 fn test_pipeline_metadata() {
607 let pipeline = Pipeline::new("test_pipeline".to_string())
608 .add_middleware(NormalizationMiddleware::new(NormType::L2));
609
610 let context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
611 let result = pipeline.execute(context).expect("operation should succeed");
612
613 assert_eq!(
614 result.get_metadata("pipeline_name"),
615 Some(&"test_pipeline".to_string())
616 );
617 assert_eq!(result.get_metadata("normalized"), Some(&"L2".to_string()));
618 }
619}