1use crate::error::{IoError, IoResult};
7use crate::sync::{Timestamp, TimestampedSample};
8use std::collections::{HashMap, VecDeque};
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum MultiplexStrategy {
14 RoundRobin,
16 TimeOrdered,
18 Sequential,
20 Weighted,
22}
23
24#[derive(Debug, Clone)]
26pub struct MultiplexConfig {
27 pub strategy: MultiplexStrategy,
29 pub buffer_size: usize,
31 pub weights: HashMap<String, u32>,
33}
34
35impl Default for MultiplexConfig {
36 fn default() -> Self {
37 Self {
38 strategy: MultiplexStrategy::RoundRobin,
39 buffer_size: 1024,
40 weights: HashMap::new(),
41 }
42 }
43}
44
45pub struct StreamMultiplexer {
47 config: MultiplexConfig,
48 buffers: HashMap<String, VecDeque<TimestampedSample>>,
49 round_robin_index: usize,
50 stream_ids: Vec<String>,
51}
52
53impl StreamMultiplexer {
54 pub fn new(config: MultiplexConfig) -> Self {
56 Self {
57 config,
58 buffers: HashMap::new(),
59 round_robin_index: 0,
60 stream_ids: Vec::new(),
61 }
62 }
63
64 pub fn add_stream(&mut self, stream_id: String) {
66 if !self.buffers.contains_key(&stream_id) {
67 self.buffers.insert(
68 stream_id.clone(),
69 VecDeque::with_capacity(self.config.buffer_size),
70 );
71 self.stream_ids.push(stream_id);
72 }
73 }
74
75 pub fn push(&mut self, sample: TimestampedSample) -> IoResult<()> {
77 let buffer = self.buffers.get_mut(&sample.stream_id).ok_or_else(|| {
78 IoError::InvalidConfig(format!("Unknown stream: {}", sample.stream_id))
79 })?;
80
81 if buffer.len() >= self.config.buffer_size {
82 buffer.pop_front(); }
84
85 if self.config.strategy == MultiplexStrategy::TimeOrdered {
87 let pos = buffer
89 .iter()
90 .position(|s| s.timestamp > sample.timestamp)
91 .unwrap_or(buffer.len());
92 buffer.insert(pos, sample);
93 } else {
94 buffer.push_back(sample);
95 }
96 Ok(())
97 }
98
99 pub fn next_sample(&mut self) -> Option<TimestampedSample> {
101 match self.config.strategy {
102 MultiplexStrategy::RoundRobin => self.next_round_robin(),
103 MultiplexStrategy::TimeOrdered => self.next_time_ordered(),
104 MultiplexStrategy::Sequential => self.next_sequential(),
105 MultiplexStrategy::Weighted => self.next_weighted(),
106 }
107 }
108
109 fn next_round_robin(&mut self) -> Option<TimestampedSample> {
110 if self.stream_ids.is_empty() {
111 return None;
112 }
113
114 let start_index = self.round_robin_index;
115 loop {
116 let stream_id = &self.stream_ids[self.round_robin_index];
117 self.round_robin_index = (self.round_robin_index + 1) % self.stream_ids.len();
118
119 if let Some(buffer) = self.buffers.get_mut(stream_id) {
120 if let Some(sample) = buffer.pop_front() {
121 return Some(sample);
122 }
123 }
124
125 if self.round_robin_index == start_index {
127 break;
128 }
129 }
130
131 None
132 }
133
134 fn next_time_ordered(&mut self) -> Option<TimestampedSample> {
135 let mut earliest: Option<(String, Timestamp)> = None;
136
137 for (stream_id, buffer) in &self.buffers {
139 if let Some(sample) = buffer.front() {
140 match earliest {
141 None => earliest = Some((stream_id.clone(), sample.timestamp)),
142 Some((_, min_ts)) if sample.timestamp < min_ts => {
143 earliest = Some((stream_id.clone(), sample.timestamp));
144 }
145 _ => {}
146 }
147 }
148 }
149
150 earliest.and_then(|(stream_id, _)| self.buffers.get_mut(&stream_id)?.pop_front())
152 }
153
154 fn next_sequential(&mut self) -> Option<TimestampedSample> {
155 for stream_id in &self.stream_ids {
157 if let Some(buffer) = self.buffers.get_mut(stream_id) {
158 if let Some(sample) = buffer.pop_front() {
159 return Some(sample);
160 }
161 }
162 }
163 None
164 }
165
166 fn next_weighted(&mut self) -> Option<TimestampedSample> {
167 if self.stream_ids.is_empty() {
169 return None;
170 }
171
172 let total_weight: u32 = self
173 .stream_ids
174 .iter()
175 .map(|id| self.config.weights.get(id).copied().unwrap_or(1))
176 .sum();
177
178 if total_weight == 0 {
179 return self.next_round_robin();
180 }
181
182 for _ in 0..total_weight {
184 let stream_id = &self.stream_ids[self.round_robin_index % self.stream_ids.len()];
185 let weight = self.config.weights.get(stream_id).copied().unwrap_or(1);
186
187 self.round_robin_index += 1;
188
189 if let Some(buffer) = self.buffers.get_mut(stream_id) {
190 if !buffer.is_empty() && weight > 0 {
191 return buffer.pop_front();
192 }
193 }
194 }
195
196 None
197 }
198
199 pub fn buffered(&self, stream_id: &str) -> usize {
201 self.buffers.get(stream_id).map(|b| b.len()).unwrap_or(0)
202 }
203
204 pub fn total_buffered(&self) -> usize {
206 self.buffers.values().map(|b| b.len()).sum()
207 }
208
209 pub fn clear(&mut self) {
211 for buffer in self.buffers.values_mut() {
212 buffer.clear();
213 }
214 }
215}
216
217impl Default for StreamMultiplexer {
218 fn default() -> Self {
219 Self::new(MultiplexConfig::default())
220 }
221}
222
223pub struct StreamDemultiplexer<F>
225where
226 F: Fn(&TimestampedSample) -> String,
227{
228 router: F,
229 buffers: HashMap<String, VecDeque<TimestampedSample>>,
230 buffer_size: usize,
231}
232
233impl<F> StreamDemultiplexer<F>
234where
235 F: Fn(&TimestampedSample) -> String,
236{
237 pub fn new(router: F, buffer_size: usize) -> Self {
239 Self {
240 router,
241 buffers: HashMap::new(),
242 buffer_size,
243 }
244 }
245
246 pub fn push(&mut self, sample: TimestampedSample) {
248 let output_id = (self.router)(&sample);
249
250 let buffer = self
251 .buffers
252 .entry(output_id)
253 .or_insert_with(|| VecDeque::with_capacity(self.buffer_size));
254
255 if buffer.len() >= self.buffer_size {
256 buffer.pop_front();
257 }
258
259 buffer.push_back(sample);
260 }
261
262 pub fn pop(&mut self, output_id: &str) -> Option<TimestampedSample> {
264 self.buffers.get_mut(output_id)?.pop_front()
265 }
266
267 pub fn drain(&mut self, output_id: &str) -> Vec<TimestampedSample> {
269 self.buffers
270 .get_mut(output_id)
271 .map(|b| b.drain(..).collect())
272 .unwrap_or_default()
273 }
274
275 pub fn buffered(&self, output_id: &str) -> usize {
277 self.buffers.get(output_id).map(|b| b.len()).unwrap_or(0)
278 }
279
280 pub fn output_ids(&self) -> Vec<String> {
282 self.buffers.keys().cloned().collect()
283 }
284
285 pub fn clear(&mut self) {
287 self.buffers.clear();
288 }
289}
290
291pub struct AsyncMultiplexer {
293 receivers: Vec<mpsc::Receiver<TimestampedSample>>,
294 strategy: MultiplexStrategy,
295}
296
297impl AsyncMultiplexer {
298 pub fn new(strategy: MultiplexStrategy) -> Self {
300 Self {
301 receivers: Vec::new(),
302 strategy,
303 }
304 }
305
306 pub fn add_receiver(&mut self, receiver: mpsc::Receiver<TimestampedSample>) {
308 self.receivers.push(receiver);
309 }
310
311 pub async fn next(&mut self) -> Option<TimestampedSample> {
313 match self.strategy {
314 MultiplexStrategy::RoundRobin => self.next_round_robin().await,
315 MultiplexStrategy::TimeOrdered => {
316 self.next_round_robin().await
318 }
319 _ => self.next_round_robin().await,
320 }
321 }
322
323 async fn next_round_robin(&mut self) -> Option<TimestampedSample> {
324 for receiver in &mut self.receivers {
325 if let Ok(sample) = receiver.try_recv() {
326 return Some(sample);
327 }
328 }
329 None
330 }
331
332 pub fn num_streams(&self) -> usize {
334 self.receivers.len()
335 }
336}
337
338pub struct ChannelSplitter {
340 senders: HashMap<String, mpsc::Sender<TimestampedSample>>,
341}
342
343impl ChannelSplitter {
344 pub fn new() -> Self {
346 Self {
347 senders: HashMap::new(),
348 }
349 }
350
351 pub fn add_output(&mut self, output_id: String, sender: mpsc::Sender<TimestampedSample>) {
353 self.senders.insert(output_id, sender);
354 }
355
356 pub async fn send(&self, output_id: &str, sample: TimestampedSample) -> IoResult<()> {
358 let sender = self
359 .senders
360 .get(output_id)
361 .ok_or_else(|| IoError::InvalidConfig(format!("Unknown output: {}", output_id)))?;
362
363 sender
364 .send(sample)
365 .await
366 .map_err(|_| IoError::SendFailed("Channel send failed".to_string()))
367 }
368
369 pub async fn broadcast(&self, sample: TimestampedSample) -> IoResult<()> {
371 for sender in self.senders.values() {
372 sender
373 .send(sample.clone())
374 .await
375 .map_err(|_| IoError::SendFailed("Broadcast failed".to_string()))?;
376 }
377 Ok(())
378 }
379
380 pub fn num_outputs(&self) -> usize {
382 self.senders.len()
383 }
384}
385
386impl Default for ChannelSplitter {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use scirs2_core::ndarray::Array1;
396
397 #[test]
398 fn test_multiplexer_round_robin() {
399 let mut mux = StreamMultiplexer::default();
400
401 mux.add_stream("stream1".to_string());
402 mux.add_stream("stream2".to_string());
403
404 for i in 0..5 {
406 let sample1 = TimestampedSample::new(
407 i * 1000,
408 Array1::from_vec(vec![i as f32]),
409 "stream1".to_string(),
410 );
411 let sample2 = TimestampedSample::new(
412 i * 1000 + 500,
413 Array1::from_vec(vec![(i + 10) as f32]),
414 "stream2".to_string(),
415 );
416
417 mux.push(sample1).unwrap();
418 mux.push(sample2).unwrap();
419 }
420
421 let s1 = mux.next_sample().unwrap();
423 assert_eq!(s1.stream_id, "stream1");
424
425 let s2 = mux.next_sample().unwrap();
426 assert_eq!(s2.stream_id, "stream2");
427 }
428
429 #[test]
430 fn test_demultiplexer() {
431 let router = |sample: &TimestampedSample| {
432 if sample.data[0] > 5.0 {
433 "high".to_string()
434 } else {
435 "low".to_string()
436 }
437 };
438
439 let mut demux = StreamDemultiplexer::new(router, 100);
440
441 demux.push(TimestampedSample::new(
443 0,
444 Array1::from_vec(vec![3.0]),
445 "test".to_string(),
446 ));
447 demux.push(TimestampedSample::new(
448 1000,
449 Array1::from_vec(vec![7.0]),
450 "test".to_string(),
451 ));
452 demux.push(TimestampedSample::new(
453 2000,
454 Array1::from_vec(vec![2.0]),
455 "test".to_string(),
456 ));
457
458 assert_eq!(demux.buffered("low"), 2);
459 assert_eq!(demux.buffered("high"), 1);
460
461 let low_sample = demux.pop("low").unwrap();
462 assert_eq!(low_sample.data[0], 3.0);
463 }
464
465 #[test]
466 fn test_multiplexer_time_ordered() {
467 let config = MultiplexConfig {
468 strategy: MultiplexStrategy::TimeOrdered,
469 ..Default::default()
470 };
471
472 let mut mux = StreamMultiplexer::new(config);
473
474 mux.add_stream("s1".to_string());
475 mux.add_stream("s2".to_string());
476
477 mux.push(TimestampedSample::new(
479 3000,
480 Array1::from_vec(vec![3.0]),
481 "s1".to_string(),
482 ))
483 .unwrap();
484 mux.push(TimestampedSample::new(
485 1000,
486 Array1::from_vec(vec![1.0]),
487 "s2".to_string(),
488 ))
489 .unwrap();
490 mux.push(TimestampedSample::new(
491 2000,
492 Array1::from_vec(vec![2.0]),
493 "s1".to_string(),
494 ))
495 .unwrap();
496
497 let first = mux.next_sample().unwrap();
499 assert_eq!(first.timestamp, 1000);
500 assert_eq!(first.data[0], 1.0);
501
502 let second = mux.next_sample().unwrap();
503 assert_eq!(second.timestamp, 2000);
504 assert_eq!(second.data[0], 2.0);
505
506 let third = mux.next_sample().unwrap();
507 assert_eq!(third.timestamp, 3000);
508 assert_eq!(third.data[0], 3.0);
509 }
510
511 #[tokio::test]
512 async fn test_channel_splitter() {
513 let mut splitter = ChannelSplitter::new();
514
515 let (tx1, mut rx1) = mpsc::channel(10);
516 let (tx2, mut rx2) = mpsc::channel(10);
517
518 splitter.add_output("out1".to_string(), tx1);
519 splitter.add_output("out2".to_string(), tx2);
520
521 let sample = TimestampedSample::new(0, Array1::from_vec(vec![1.0]), "test".to_string());
522
523 splitter.send("out1", sample.clone()).await.unwrap();
525
526 let received = rx1.recv().await.unwrap();
527 assert_eq!(received.data[0], 1.0);
528
529 let sample2 = TimestampedSample::new(1000, Array1::from_vec(vec![2.0]), "test".to_string());
531 splitter.broadcast(sample2).await.unwrap();
532
533 assert!(rx1.recv().await.is_some());
534 assert!(rx2.recv().await.is_some());
535 }
536}