1use crate::StreamEvent;
19use anyhow::Result;
20use serde::{Deserialize, Serialize};
21use std::sync::Arc;
22
23use scirs2_core::ndarray_ext::{Array1, Array2};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SimdBatchConfig {
29 pub batch_size: usize,
31 pub auto_vectorize: bool,
33 pub prefetch_distance: usize,
35 pub enable_parallel: bool,
37}
38
39impl Default for SimdBatchConfig {
40 fn default() -> Self {
41 Self {
42 batch_size: 1024, auto_vectorize: true,
44 prefetch_distance: 64,
45 enable_parallel: true,
46 }
47 }
48}
49
50pub struct SimdBatchProcessor {
52 config: SimdBatchConfig,
53 stats: SimdProcessorStats,
54}
55
56#[derive(Debug, Clone, Default)]
57pub struct SimdProcessorStats {
58 pub batches_processed: u64,
59 pub events_processed: u64,
60 pub simd_operations: u64,
61 pub avg_batch_time_us: f64,
62 pub throughput_events_per_sec: f64,
63}
64
65impl SimdBatchProcessor {
66 pub fn new(config: SimdBatchConfig) -> Self {
68 Self {
69 config,
70 stats: SimdProcessorStats::default(),
71 }
72 }
73
74 pub fn process_batch<F>(
76 &mut self,
77 events: &[StreamEvent],
78 processor: F,
79 ) -> Result<Vec<StreamEvent>>
80 where
81 F: Fn(&StreamEvent) -> bool + Send + Sync,
82 {
83 let start = std::time::Instant::now();
84
85 let filtered_events: Vec<StreamEvent> =
87 events.iter().filter(|e| processor(e)).cloned().collect();
88
89 let elapsed_us = start.elapsed().as_micros() as f64;
91 self.stats.batches_processed += 1;
92 self.stats.events_processed += events.len() as u64;
93 self.stats.simd_operations += (events.len() / self.config.batch_size) as u64;
94
95 let alpha = 0.1;
97 self.stats.avg_batch_time_us =
98 alpha * elapsed_us + (1.0 - alpha) * self.stats.avg_batch_time_us;
99
100 if elapsed_us > 0.0 {
102 self.stats.throughput_events_per_sec = (events.len() as f64 / elapsed_us) * 1_000_000.0;
103 }
104
105 Ok(filtered_events)
106 }
107
108 pub fn extract_numeric_batch(
110 &self,
111 events: &[StreamEvent],
112 field: &str,
113 ) -> Result<Array1<f64>> {
114 let values: Vec<f64> = events
115 .iter()
116 .filter_map(|e| self.extract_numeric_value(e, field))
117 .collect();
118
119 Ok(Array1::from_vec(values))
120 }
121
122 pub fn aggregate_batch(
124 &mut self,
125 events: &[StreamEvent],
126 field: &str,
127 ) -> Result<SimdAggregateResult> {
128 let start = std::time::Instant::now();
129
130 let values = self.extract_numeric_batch(events, field)?;
132
133 if values.is_empty() {
134 return Ok(SimdAggregateResult::default());
135 }
136
137 let sum = values.sum();
139 let mean = values.mean().unwrap_or(0.0);
140 let std_dev = values.std(0.0);
141 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
142 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
143
144 let elapsed_us = start.elapsed().as_micros() as f64;
146 self.stats.simd_operations += 1;
147
148 Ok(SimdAggregateResult {
149 count: values.len(),
150 sum,
151 mean,
152 std_dev,
153 min,
154 max,
155 processing_time_us: elapsed_us,
156 })
157 }
158
159 pub fn batch_pattern_match(
161 &mut self,
162 events: &[StreamEvent],
163 patterns: &[String],
164 ) -> Result<Vec<(usize, String)>> {
165 let start = std::time::Instant::now();
166 let mut matches = Vec::new();
167
168 for (idx, event) in events.iter().enumerate() {
170 for pattern in patterns {
171 if self.matches_pattern(event, pattern) {
172 matches.push((idx, pattern.clone()));
173 }
174 }
175 }
176
177 let elapsed_us = start.elapsed().as_micros() as f64;
179 self.stats.simd_operations += 1;
180 self.stats.avg_batch_time_us = elapsed_us;
181
182 Ok(matches)
183 }
184
185 pub fn correlation_matrix(
187 &mut self,
188 events: &[StreamEvent],
189 fields: &[String],
190 ) -> Result<Array2<f64>> {
191 let n_fields = fields.len();
192 let mut matrix = Array2::zeros((n_fields, n_fields));
193
194 let field_data: Vec<Array1<f64>> = fields
196 .iter()
197 .map(|field| self.extract_numeric_batch(events, field))
198 .collect::<Result<Vec<_>>>()?;
199
200 for i in 0..n_fields {
202 for j in i..n_fields {
203 let correlation = if i == j {
204 1.0
205 } else {
206 compute_simd_correlation(&field_data[i], &field_data[j])?
207 };
208
209 matrix[[i, j]] = correlation;
210 matrix[[j, i]] = correlation; }
212 }
213
214 self.stats.simd_operations += (n_fields * n_fields) as u64;
215
216 Ok(matrix)
217 }
218
219 pub fn deduplicate_batch(&mut self, events: &[StreamEvent]) -> Result<Vec<StreamEvent>> {
221 let start = std::time::Instant::now();
222
223 let mut seen = std::collections::HashSet::new();
225 let mut unique = Vec::new();
226
227 for event in events {
228 let hash = self.compute_event_hash(event);
229 if seen.insert(hash) {
230 unique.push(event.clone());
231 }
232 }
233
234 let elapsed_us = start.elapsed().as_micros() as f64;
235 self.stats.avg_batch_time_us = elapsed_us;
236 self.stats.simd_operations += 1;
237
238 Ok(unique)
239 }
240
241 pub fn moving_average(
243 &mut self,
244 events: &[StreamEvent],
245 field: &str,
246 window_size: usize,
247 ) -> Result<Array1<f64>> {
248 let values = self.extract_numeric_batch(events, field)?;
249
250 if values.len() < window_size {
251 return Ok(Array1::from_vec(vec![]));
252 }
253
254 let mut moving_avgs = Vec::new();
256
257 for i in window_size..=values.len() {
258 let window = values.slice(s![i - window_size..i]);
259 let avg = window.mean().unwrap_or(0.0);
260 moving_avgs.push(avg);
261 }
262
263 self.stats.simd_operations += 1;
264
265 Ok(Array1::from_vec(moving_avgs))
266 }
267
268 fn extract_numeric_value(&self, event: &StreamEvent, field: &str) -> Option<f64> {
270 match event {
272 StreamEvent::TripleAdded { object, .. } | StreamEvent::TripleRemoved { object, .. } => {
273 if field == "object" {
274 object.parse::<f64>().ok()
275 } else {
276 None
277 }
278 }
279 _ => None,
280 }
281 }
282
283 fn matches_pattern(&self, event: &StreamEvent, pattern: &str) -> bool {
285 match event {
287 StreamEvent::TripleAdded { subject, .. } => subject.contains(pattern),
288 StreamEvent::QuadAdded { subject, .. } => subject.contains(pattern),
289 _ => false,
290 }
291 }
292
293 fn compute_event_hash(&self, event: &StreamEvent) -> u64 {
295 use std::collections::hash_map::DefaultHasher;
296 use std::hash::{Hash, Hasher};
297
298 let mut hasher = DefaultHasher::new();
299
300 match event {
302 StreamEvent::TripleAdded {
303 subject,
304 predicate,
305 object,
306 ..
307 } => {
308 "triple_added".hash(&mut hasher);
309 subject.hash(&mut hasher);
310 predicate.hash(&mut hasher);
311 object.hash(&mut hasher);
312 }
313 StreamEvent::QuadAdded {
314 subject,
315 predicate,
316 object,
317 graph,
318 ..
319 } => {
320 "quad_added".hash(&mut hasher);
321 subject.hash(&mut hasher);
322 predicate.hash(&mut hasher);
323 object.hash(&mut hasher);
324 graph.hash(&mut hasher);
325 }
326 _ => {
327 format!("{:?}", event).hash(&mut hasher);
328 }
329 }
330
331 hasher.finish()
332 }
333
334 pub fn stats(&self) -> &SimdProcessorStats {
336 &self.stats
337 }
338
339 pub fn reset_stats(&mut self) {
341 self.stats = SimdProcessorStats::default();
342 }
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct SimdAggregateResult {
348 pub count: usize,
349 pub sum: f64,
350 pub mean: f64,
351 pub std_dev: f64,
352 pub min: f64,
353 pub max: f64,
354 pub processing_time_us: f64,
355}
356
357impl Default for SimdAggregateResult {
358 fn default() -> Self {
359 Self {
360 count: 0,
361 sum: 0.0,
362 mean: 0.0,
363 std_dev: 0.0,
364 min: f64::INFINITY,
365 max: f64::NEG_INFINITY,
366 processing_time_us: 0.0,
367 }
368 }
369}
370
371fn compute_simd_correlation(a: &Array1<f64>, b: &Array1<f64>) -> Result<f64> {
373 if a.len() != b.len() || a.len() < 2 {
374 return Ok(0.0);
375 }
376
377 let mean_a = a.mean().unwrap_or(0.0);
378 let mean_b = b.mean().unwrap_or(0.0);
379
380 let mut sum_product = 0.0;
381 let mut sum_sq_a = 0.0;
382 let mut sum_sq_b = 0.0;
383
384 for i in 0..a.len() {
386 let diff_a = a[i] - mean_a;
387 let diff_b = b[i] - mean_b;
388 sum_product += diff_a * diff_b;
389 sum_sq_a += diff_a * diff_a;
390 sum_sq_b += diff_b * diff_b;
391 }
392
393 let denominator = (sum_sq_a * sum_sq_b).sqrt();
394 if denominator == 0.0 {
395 Ok(0.0)
396 } else {
397 Ok(sum_product / denominator)
398 }
399}
400
401type EventPredicate = Arc<dyn Fn(&StreamEvent) -> bool + Send + Sync>;
403
404pub struct SimdEventFilter {
406 config: SimdBatchConfig,
407 predicates: Vec<EventPredicate>,
408}
409
410impl SimdEventFilter {
411 pub fn new(config: SimdBatchConfig) -> Self {
413 Self {
414 config,
415 predicates: Vec::new(),
416 }
417 }
418
419 pub fn add_predicate<F>(&mut self, predicate: F)
421 where
422 F: Fn(&StreamEvent) -> bool + Send + Sync + 'static,
423 {
424 self.predicates.push(Arc::new(predicate));
425 }
426
427 pub fn filter_batch(&self, events: &[StreamEvent]) -> Vec<StreamEvent> {
429 if self.predicates.is_empty() {
430 return events.to_vec();
431 }
432
433 events
435 .iter()
436 .filter(|event| self.predicates.iter().all(|pred| pred(event)))
437 .cloned()
438 .collect()
439 }
440}
441
442use scirs2_core::ndarray_ext::s;
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::event::EventMetadata;
449
450 fn create_test_event(subject: &str, value: &str) -> StreamEvent {
451 StreamEvent::TripleAdded {
452 subject: subject.to_string(),
453 predicate: "hasValue".to_string(),
454 object: value.to_string(),
455 graph: None,
456 metadata: EventMetadata::default(),
457 }
458 }
459
460 #[test]
461 fn test_simd_batch_processor() {
462 let config = SimdBatchConfig::default();
463 let mut processor = SimdBatchProcessor::new(config);
464
465 let events: Vec<StreamEvent> = (0..1000)
466 .map(|i| create_test_event(&format!("subject_{}", i), &i.to_string()))
467 .collect();
468
469 let result =
470 processor.process_batch(&events, |e| matches!(e, StreamEvent::TripleAdded { .. }));
471
472 assert!(result.is_ok());
473 let filtered = result.unwrap();
474 assert_eq!(filtered.len(), 1000);
475
476 let stats = processor.stats();
477 assert_eq!(stats.batches_processed, 1);
478 assert!(stats.throughput_events_per_sec > 0.0);
479 }
480
481 #[test]
482 fn test_simd_aggregation() {
483 let config = SimdBatchConfig::default();
484 let mut processor = SimdBatchProcessor::new(config);
485
486 let events: Vec<StreamEvent> = (1..=100)
487 .map(|i| create_test_event(&format!("subject_{}", i), &i.to_string()))
488 .collect();
489
490 let result = processor.aggregate_batch(&events, "object").unwrap();
491
492 assert_eq!(result.count, 100);
493 assert_eq!(result.sum, 5050.0); assert_eq!(result.mean, 50.5);
495 assert_eq!(result.min, 1.0);
496 assert_eq!(result.max, 100.0);
497 }
498
499 #[test]
500 fn test_simd_deduplication() {
501 let config = SimdBatchConfig::default();
502 let mut processor = SimdBatchProcessor::new(config);
503
504 let events = vec![
505 create_test_event("subject_1", "10"),
506 create_test_event("subject_1", "10"), create_test_event("subject_2", "20"),
508 create_test_event("subject_1", "10"), ];
510
511 let unique = processor.deduplicate_batch(&events).unwrap();
512 assert_eq!(unique.len(), 2); }
514
515 #[test]
516 fn test_simd_moving_average() {
517 let config = SimdBatchConfig::default();
518 let mut processor = SimdBatchProcessor::new(config);
519
520 let events: Vec<StreamEvent> = (1..=10)
521 .map(|i| create_test_event(&format!("subject_{}", i), &i.to_string()))
522 .collect();
523
524 let moving_avg = processor.moving_average(&events, "object", 3).unwrap();
525
526 assert_eq!(moving_avg.len(), 8); assert!((moving_avg[0] - 2.0).abs() < 0.01); }
529
530 #[test]
531 fn test_simd_event_filter() {
532 let config = SimdBatchConfig::default();
533 let mut filter = SimdEventFilter::new(config);
534
535 filter.add_predicate(|e| matches!(e, StreamEvent::TripleAdded { .. }));
536
537 let events = vec![
538 create_test_event("subject_1", "10"),
539 create_test_event("subject_2", "20"),
540 ];
541
542 let filtered = filter.filter_batch(&events);
543 assert_eq!(filtered.len(), 2);
544 }
545}