kizzasi_io/
multiplex.rs

1//! Stream multiplexing and demultiplexing utilities
2//!
3//! This module provides tools for combining multiple streams into one
4//! and splitting one stream into multiple outputs.
5
6use crate::error::{IoError, IoResult};
7use crate::sync::{Timestamp, TimestampedSample};
8use std::collections::{HashMap, VecDeque};
9use tokio::sync::mpsc;
10
11/// Strategy for multiplexing multiple streams
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum MultiplexStrategy {
14    /// Interleave samples round-robin
15    RoundRobin,
16    /// Merge by timestamp order
17    TimeOrdered,
18    /// Concatenate streams sequentially
19    Sequential,
20    /// Weighted round-robin
21    Weighted,
22}
23
24/// Multiplexer configuration
25#[derive(Debug, Clone)]
26pub struct MultiplexConfig {
27    /// Multiplexing strategy
28    pub strategy: MultiplexStrategy,
29    /// Buffer size per input stream
30    pub buffer_size: usize,
31    /// Weights for weighted round-robin (stream_id -> weight)
32    pub weights: HashMap<String, u32>,
33}
34
35impl Default for MultiplexConfig {
36    fn default() -> Self {
37        Self {
38            strategy: MultiplexStrategy::RoundRobin,
39            buffer_size: 1024,
40            weights: HashMap::new(),
41        }
42    }
43}
44
45/// Stream multiplexer - combines multiple streams into one
46pub struct StreamMultiplexer {
47    config: MultiplexConfig,
48    buffers: HashMap<String, VecDeque<TimestampedSample>>,
49    round_robin_index: usize,
50    stream_ids: Vec<String>,
51}
52
53impl StreamMultiplexer {
54    /// Create new multiplexer
55    pub fn new(config: MultiplexConfig) -> Self {
56        Self {
57            config,
58            buffers: HashMap::new(),
59            round_robin_index: 0,
60            stream_ids: Vec::new(),
61        }
62    }
63
64    /// Add input stream
65    pub fn add_stream(&mut self, stream_id: String) {
66        if !self.buffers.contains_key(&stream_id) {
67            self.buffers.insert(
68                stream_id.clone(),
69                VecDeque::with_capacity(self.config.buffer_size),
70            );
71            self.stream_ids.push(stream_id);
72        }
73    }
74
75    /// Push sample from a stream
76    pub fn push(&mut self, sample: TimestampedSample) -> IoResult<()> {
77        let buffer = self.buffers.get_mut(&sample.stream_id).ok_or_else(|| {
78            IoError::InvalidConfig(format!("Unknown stream: {}", sample.stream_id))
79        })?;
80
81        if buffer.len() >= self.config.buffer_size {
82            buffer.pop_front(); // Remove oldest
83        }
84
85        // For TimeOrdered strategy, insert in sorted order by timestamp
86        if self.config.strategy == MultiplexStrategy::TimeOrdered {
87            // Find insertion position to maintain ascending timestamp order
88            let pos = buffer
89                .iter()
90                .position(|s| s.timestamp > sample.timestamp)
91                .unwrap_or(buffer.len());
92            buffer.insert(pos, sample);
93        } else {
94            buffer.push_back(sample);
95        }
96        Ok(())
97    }
98
99    /// Get next multiplexed sample
100    pub fn next_sample(&mut self) -> Option<TimestampedSample> {
101        match self.config.strategy {
102            MultiplexStrategy::RoundRobin => self.next_round_robin(),
103            MultiplexStrategy::TimeOrdered => self.next_time_ordered(),
104            MultiplexStrategy::Sequential => self.next_sequential(),
105            MultiplexStrategy::Weighted => self.next_weighted(),
106        }
107    }
108
109    fn next_round_robin(&mut self) -> Option<TimestampedSample> {
110        if self.stream_ids.is_empty() {
111            return None;
112        }
113
114        let start_index = self.round_robin_index;
115        loop {
116            let stream_id = &self.stream_ids[self.round_robin_index];
117            self.round_robin_index = (self.round_robin_index + 1) % self.stream_ids.len();
118
119            if let Some(buffer) = self.buffers.get_mut(stream_id) {
120                if let Some(sample) = buffer.pop_front() {
121                    return Some(sample);
122                }
123            }
124
125            // Check if we've looped back without finding anything
126            if self.round_robin_index == start_index {
127                break;
128            }
129        }
130
131        None
132    }
133
134    fn next_time_ordered(&mut self) -> Option<TimestampedSample> {
135        let mut earliest: Option<(String, Timestamp)> = None;
136
137        // Find stream with earliest timestamp
138        for (stream_id, buffer) in &self.buffers {
139            if let Some(sample) = buffer.front() {
140                match earliest {
141                    None => earliest = Some((stream_id.clone(), sample.timestamp)),
142                    Some((_, min_ts)) if sample.timestamp < min_ts => {
143                        earliest = Some((stream_id.clone(), sample.timestamp));
144                    }
145                    _ => {}
146                }
147            }
148        }
149
150        // Pop from that stream
151        earliest.and_then(|(stream_id, _)| self.buffers.get_mut(&stream_id)?.pop_front())
152    }
153
154    fn next_sequential(&mut self) -> Option<TimestampedSample> {
155        // Drain first stream completely before moving to next
156        for stream_id in &self.stream_ids {
157            if let Some(buffer) = self.buffers.get_mut(stream_id) {
158                if let Some(sample) = buffer.pop_front() {
159                    return Some(sample);
160                }
161            }
162        }
163        None
164    }
165
166    fn next_weighted(&mut self) -> Option<TimestampedSample> {
167        // Weighted round-robin based on configured weights
168        if self.stream_ids.is_empty() {
169            return None;
170        }
171
172        let total_weight: u32 = self
173            .stream_ids
174            .iter()
175            .map(|id| self.config.weights.get(id).copied().unwrap_or(1))
176            .sum();
177
178        if total_weight == 0 {
179            return self.next_round_robin();
180        }
181
182        // Try streams proportional to their weights
183        for _ in 0..total_weight {
184            let stream_id = &self.stream_ids[self.round_robin_index % self.stream_ids.len()];
185            let weight = self.config.weights.get(stream_id).copied().unwrap_or(1);
186
187            self.round_robin_index += 1;
188
189            if let Some(buffer) = self.buffers.get_mut(stream_id) {
190                if !buffer.is_empty() && weight > 0 {
191                    return buffer.pop_front();
192                }
193            }
194        }
195
196        None
197    }
198
199    /// Get number of buffered samples for a stream
200    pub fn buffered(&self, stream_id: &str) -> usize {
201        self.buffers.get(stream_id).map(|b| b.len()).unwrap_or(0)
202    }
203
204    /// Get total buffered samples across all streams
205    pub fn total_buffered(&self) -> usize {
206        self.buffers.values().map(|b| b.len()).sum()
207    }
208
209    /// Clear all buffers
210    pub fn clear(&mut self) {
211        for buffer in self.buffers.values_mut() {
212            buffer.clear();
213        }
214    }
215}
216
217impl Default for StreamMultiplexer {
218    fn default() -> Self {
219        Self::new(MultiplexConfig::default())
220    }
221}
222
223/// Stream demultiplexer - splits one stream into multiple outputs
224pub struct StreamDemultiplexer<F>
225where
226    F: Fn(&TimestampedSample) -> String,
227{
228    router: F,
229    buffers: HashMap<String, VecDeque<TimestampedSample>>,
230    buffer_size: usize,
231}
232
233impl<F> StreamDemultiplexer<F>
234where
235    F: Fn(&TimestampedSample) -> String,
236{
237    /// Create new demultiplexer with routing function
238    pub fn new(router: F, buffer_size: usize) -> Self {
239        Self {
240            router,
241            buffers: HashMap::new(),
242            buffer_size,
243        }
244    }
245
246    /// Push sample (will be routed to appropriate output)
247    pub fn push(&mut self, sample: TimestampedSample) {
248        let output_id = (self.router)(&sample);
249
250        let buffer = self
251            .buffers
252            .entry(output_id)
253            .or_insert_with(|| VecDeque::with_capacity(self.buffer_size));
254
255        if buffer.len() >= self.buffer_size {
256            buffer.pop_front();
257        }
258
259        buffer.push_back(sample);
260    }
261
262    /// Get samples from specific output
263    pub fn pop(&mut self, output_id: &str) -> Option<TimestampedSample> {
264        self.buffers.get_mut(output_id)?.pop_front()
265    }
266
267    /// Get all samples from specific output
268    pub fn drain(&mut self, output_id: &str) -> Vec<TimestampedSample> {
269        self.buffers
270            .get_mut(output_id)
271            .map(|b| b.drain(..).collect())
272            .unwrap_or_default()
273    }
274
275    /// Get number of buffered samples for output
276    pub fn buffered(&self, output_id: &str) -> usize {
277        self.buffers.get(output_id).map(|b| b.len()).unwrap_or(0)
278    }
279
280    /// Get all output IDs
281    pub fn output_ids(&self) -> Vec<String> {
282        self.buffers.keys().cloned().collect()
283    }
284
285    /// Clear all buffers
286    pub fn clear(&mut self) {
287        self.buffers.clear();
288    }
289}
290
291/// Async multiplexer using channels
292pub struct AsyncMultiplexer {
293    receivers: Vec<mpsc::Receiver<TimestampedSample>>,
294    strategy: MultiplexStrategy,
295}
296
297impl AsyncMultiplexer {
298    /// Create new async multiplexer
299    pub fn new(strategy: MultiplexStrategy) -> Self {
300        Self {
301            receivers: Vec::new(),
302            strategy,
303        }
304    }
305
306    /// Add input channel
307    pub fn add_receiver(&mut self, receiver: mpsc::Receiver<TimestampedSample>) {
308        self.receivers.push(receiver);
309    }
310
311    /// Get next multiplexed sample (async)
312    pub async fn next(&mut self) -> Option<TimestampedSample> {
313        match self.strategy {
314            MultiplexStrategy::RoundRobin => self.next_round_robin().await,
315            MultiplexStrategy::TimeOrdered => {
316                // Time-ordered requires buffering, not efficient for async
317                self.next_round_robin().await
318            }
319            _ => self.next_round_robin().await,
320        }
321    }
322
323    async fn next_round_robin(&mut self) -> Option<TimestampedSample> {
324        for receiver in &mut self.receivers {
325            if let Ok(sample) = receiver.try_recv() {
326                return Some(sample);
327            }
328        }
329        None
330    }
331
332    /// Get number of input streams
333    pub fn num_streams(&self) -> usize {
334        self.receivers.len()
335    }
336}
337
338/// Channel-based stream splitter
339pub struct ChannelSplitter {
340    senders: HashMap<String, mpsc::Sender<TimestampedSample>>,
341}
342
343impl ChannelSplitter {
344    /// Create new channel splitter
345    pub fn new() -> Self {
346        Self {
347            senders: HashMap::new(),
348        }
349    }
350
351    /// Add output channel
352    pub fn add_output(&mut self, output_id: String, sender: mpsc::Sender<TimestampedSample>) {
353        self.senders.insert(output_id, sender);
354    }
355
356    /// Send sample to specific output
357    pub async fn send(&self, output_id: &str, sample: TimestampedSample) -> IoResult<()> {
358        let sender = self
359            .senders
360            .get(output_id)
361            .ok_or_else(|| IoError::InvalidConfig(format!("Unknown output: {}", output_id)))?;
362
363        sender
364            .send(sample)
365            .await
366            .map_err(|_| IoError::SendFailed("Channel send failed".to_string()))
367    }
368
369    /// Broadcast sample to all outputs
370    pub async fn broadcast(&self, sample: TimestampedSample) -> IoResult<()> {
371        for sender in self.senders.values() {
372            sender
373                .send(sample.clone())
374                .await
375                .map_err(|_| IoError::SendFailed("Broadcast failed".to_string()))?;
376        }
377        Ok(())
378    }
379
380    /// Get number of outputs
381    pub fn num_outputs(&self) -> usize {
382        self.senders.len()
383    }
384}
385
386impl Default for ChannelSplitter {
387    fn default() -> Self {
388        Self::new()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use scirs2_core::ndarray::Array1;
396
397    #[test]
398    fn test_multiplexer_round_robin() {
399        let mut mux = StreamMultiplexer::default();
400
401        mux.add_stream("stream1".to_string());
402        mux.add_stream("stream2".to_string());
403
404        // Add samples
405        for i in 0..5 {
406            let sample1 = TimestampedSample::new(
407                i * 1000,
408                Array1::from_vec(vec![i as f32]),
409                "stream1".to_string(),
410            );
411            let sample2 = TimestampedSample::new(
412                i * 1000 + 500,
413                Array1::from_vec(vec![(i + 10) as f32]),
414                "stream2".to_string(),
415            );
416
417            mux.push(sample1).unwrap();
418            mux.push(sample2).unwrap();
419        }
420
421        // Should alternate between streams
422        let s1 = mux.next_sample().unwrap();
423        assert_eq!(s1.stream_id, "stream1");
424
425        let s2 = mux.next_sample().unwrap();
426        assert_eq!(s2.stream_id, "stream2");
427    }
428
429    #[test]
430    fn test_demultiplexer() {
431        let router = |sample: &TimestampedSample| {
432            if sample.data[0] > 5.0 {
433                "high".to_string()
434            } else {
435                "low".to_string()
436            }
437        };
438
439        let mut demux = StreamDemultiplexer::new(router, 100);
440
441        // Push samples
442        demux.push(TimestampedSample::new(
443            0,
444            Array1::from_vec(vec![3.0]),
445            "test".to_string(),
446        ));
447        demux.push(TimestampedSample::new(
448            1000,
449            Array1::from_vec(vec![7.0]),
450            "test".to_string(),
451        ));
452        demux.push(TimestampedSample::new(
453            2000,
454            Array1::from_vec(vec![2.0]),
455            "test".to_string(),
456        ));
457
458        assert_eq!(demux.buffered("low"), 2);
459        assert_eq!(demux.buffered("high"), 1);
460
461        let low_sample = demux.pop("low").unwrap();
462        assert_eq!(low_sample.data[0], 3.0);
463    }
464
465    #[test]
466    fn test_multiplexer_time_ordered() {
467        let config = MultiplexConfig {
468            strategy: MultiplexStrategy::TimeOrdered,
469            ..Default::default()
470        };
471
472        let mut mux = StreamMultiplexer::new(config);
473
474        mux.add_stream("s1".to_string());
475        mux.add_stream("s2".to_string());
476
477        // Add samples with different timestamps
478        mux.push(TimestampedSample::new(
479            3000,
480            Array1::from_vec(vec![3.0]),
481            "s1".to_string(),
482        ))
483        .unwrap();
484        mux.push(TimestampedSample::new(
485            1000,
486            Array1::from_vec(vec![1.0]),
487            "s2".to_string(),
488        ))
489        .unwrap();
490        mux.push(TimestampedSample::new(
491            2000,
492            Array1::from_vec(vec![2.0]),
493            "s1".to_string(),
494        ))
495        .unwrap();
496
497        // Should come out in time order
498        let first = mux.next_sample().unwrap();
499        assert_eq!(first.timestamp, 1000);
500        assert_eq!(first.data[0], 1.0);
501
502        let second = mux.next_sample().unwrap();
503        assert_eq!(second.timestamp, 2000);
504        assert_eq!(second.data[0], 2.0);
505
506        let third = mux.next_sample().unwrap();
507        assert_eq!(third.timestamp, 3000);
508        assert_eq!(third.data[0], 3.0);
509    }
510
511    #[tokio::test]
512    async fn test_channel_splitter() {
513        let mut splitter = ChannelSplitter::new();
514
515        let (tx1, mut rx1) = mpsc::channel(10);
516        let (tx2, mut rx2) = mpsc::channel(10);
517
518        splitter.add_output("out1".to_string(), tx1);
519        splitter.add_output("out2".to_string(), tx2);
520
521        let sample = TimestampedSample::new(0, Array1::from_vec(vec![1.0]), "test".to_string());
522
523        // Send to specific output
524        splitter.send("out1", sample.clone()).await.unwrap();
525
526        let received = rx1.recv().await.unwrap();
527        assert_eq!(received.data[0], 1.0);
528
529        // Broadcast
530        let sample2 = TimestampedSample::new(1000, Array1::from_vec(vec![2.0]), "test".to_string());
531        splitter.broadcast(sample2).await.unwrap();
532
533        assert!(rx1.recv().await.is_some());
534        assert!(rx2.recv().await.is_some());
535    }
536}