disk_mpmc/
lib.rs

1use std::{marker::PhantomData, sync::Arc, time::Duration};
2
3use mmapcell::MmapCell;
4
5mod datapage;
6pub mod manager;
7
8use datapage::DataPage;
9use manager::DataPagesManager;
10
11#[derive(Clone)]
12pub struct Grouped;
13
14#[derive(Clone)]
15pub struct Anonymous;
16
17#[derive(Clone)]
18pub struct Receiver<T> {
19    group: usize,
20    anon_count: u32,
21    manager: DataPagesManager,
22    datapage_count: usize,
23    datapage: Arc<MmapCell<DataPage>>,
24    _type: std::marker::PhantomData<T>,
25}
26
27pub trait GenReceiver {
28    fn pop(&mut self) -> Result<&[u8], std::io::Error>;
29}
30
31impl Receiver<Grouped> {
32    pub fn new(group: usize, manager: DataPagesManager) -> Result<Self, std::io::Error> {
33        let (datapage_count, datapage) = manager.get_or_create_datapage(0)?;
34        //let datapage_count = RefCell::new(datapage_count);
35        //let datapage = RefCell::new(datapage);
36
37        Ok(Receiver {
38            group,
39            anon_count: 0,
40            manager,
41            datapage_count,
42            datapage,
43            _type: PhantomData,
44        })
45    }
46
47    pub fn pop_with_timeout(&mut self, timeout: Duration) -> Result<Option<&[u8]>, std::io::Error> {
48        loop {
49            let count = self.datapage.get().increment_group_count(self.group, 1);
50
51            match self.datapage.get().get_with_timeout(count, timeout) {
52                Ok(data) => return Ok(data),
53                // WARN: if you add more errors in the future make sure to match on them!!!
54                Err(_e) => {}
55            };
56
57            let (dp_count, datapage) = self
58                .manager
59                .get_or_create_datapage(self.datapage_count.wrapping_add(1))?;
60
61            self.datapage_count = dp_count;
62            self.datapage = datapage;
63        }
64    }
65}
66impl GenReceiver for Receiver<Grouped> {
67    fn pop(&mut self) -> Result<&[u8], std::io::Error> {
68        loop {
69            let count = self.datapage.get().increment_group_count(self.group, 1);
70
71            match self.datapage.get().get(count) {
72                Ok(data) => return Ok(data),
73                // WARN: if you add more errors in the future make sure to match on them!!!
74                Err(_e) => {}
75            };
76
77            let (dp_count, datapage) = self
78                .manager
79                .get_or_create_datapage(self.datapage_count.wrapping_add(1))?;
80
81            self.datapage_count = dp_count;
82            self.datapage = datapage;
83        }
84    }
85}
86
87impl Receiver<Anonymous> {
88    pub fn new_anon(manager: DataPagesManager) -> Result<Self, std::io::Error> {
89        Ok(Receiver::new(0, manager)?.into())
90    }
91}
92
93impl GenReceiver for Receiver<Anonymous> {
94    fn pop(&mut self) -> Result<&[u8], std::io::Error> {
95        loop {
96            let count = self.anon_count;
97            self.anon_count += 1;
98
99            match self.datapage.get().get(count) {
100                Ok(data) => return Ok(data),
101                // WARN: if you add more errors in the future make sure to match on them!!!
102                Err(_e) => {}
103            };
104
105            self.anon_count = 0;
106
107            let (dp_count, datapage) = self
108                .manager
109                .get_or_create_datapage(self.datapage_count.wrapping_add(1))?;
110
111            self.datapage_count = dp_count;
112            self.datapage = datapage;
113        }
114    }
115}
116
117impl From<Receiver<Grouped>> for Receiver<Anonymous> {
118    fn from(value: Receiver<Grouped>) -> Self {
119        Receiver {
120            group: 0,
121            anon_count: 0,
122            manager: value.manager,
123            datapage_count: value.datapage_count,
124            datapage: value.datapage,
125            _type: PhantomData,
126        }
127    }
128}
129
130#[derive(Clone)]
131pub struct Sender {
132    manager: DataPagesManager,
133    datapage_count: usize,
134    datapage: Arc<MmapCell<DataPage>>,
135}
136
137impl Sender {
138    pub fn new(manager: DataPagesManager) -> Result<Self, std::io::Error> {
139        let (datapage_count, datapage) = manager.get_or_create_datapage(0)?;
140        //let datapage_count = RefCell::new(datapage_count);
141        //let datapage = RefCell::new(datapage);
142
143        Ok(Sender {
144            manager,
145            datapage_count,
146            datapage,
147        })
148    }
149
150    pub fn push<T: AsRef<[u8]>>(&mut self, data: T) -> Result<(), std::io::Error> {
151        loop {
152            match self.datapage.get_mut().push(&data) {
153                Ok(()) => return Ok(()),
154                Err(_e) => {}
155            }
156
157            let (dp_count, datapage) = self
158                .manager
159                .get_or_create_datapage(self.datapage_count.wrapping_add(1))?;
160
161            self.datapage_count = dp_count;
162            self.datapage = datapage;
163        }
164    }
165}
166
167// TODO: Move these out to a test dir
168// they take up wayyyy too much space
169#[cfg(test)]
170mod test {
171    use std::{
172        path::{Path, PathBuf},
173        sync::{
174            atomic::{AtomicUsize, Ordering},
175            mpsc, Barrier,
176        },
177        thread,
178        time::Instant,
179    };
180
181    use rand::random;
182    use tracing::info;
183
184    use super::*;
185
186    fn mkdir_random() -> PathBuf {
187        const TEST_DIR: &str = "/tmp/";
188        let num: u64 = random();
189        let rand_file_name = format!("disk-mpmc-test-{:X}", num);
190
191        let dir = Path::new(TEST_DIR).join(rand_file_name);
192
193        std::fs::create_dir_all(&dir).unwrap();
194
195        dir
196    }
197
198    #[test]
199    fn sequential_test() {
200        const TEST_MESSAGE: &str = const_str::repeat!("a", 100);
201
202        tracing_subscriber::fmt::init();
203
204        let path = mkdir_random();
205        let manager = DataPagesManager::new(&path).unwrap();
206
207        let mut tx = Sender::new(manager.clone()).unwrap();
208        let now = Instant::now();
209
210        for _ in 0..50_000_000 {
211            tx.push(TEST_MESSAGE).unwrap();
212        }
213        let elapsed = now.elapsed();
214
215        let test_msg_bytes = TEST_MESSAGE.as_bytes().len() * 50_000_000;
216        let test_msg_mb = test_msg_bytes as f64 * 0.000001;
217        info!(
218            "pushed 50,000,000 messages ({:.2} MB) in {} ms [{:.2}MB/s]",
219            test_msg_mb,
220            elapsed.as_millis(),
221            test_msg_bytes as f64 / elapsed.as_micros() as f64
222        );
223
224        let mut rx = Receiver::new(0, manager).unwrap();
225        let now = Instant::now();
226        for _ in 0..50_000_000 {
227            rx.pop().unwrap();
228        }
229        let elapsed = now.elapsed();
230
231        let test_msg_bytes = TEST_MESSAGE.as_bytes().len() * 50_000_000;
232        let test_msg_mb = test_msg_bytes as f64 * 0.000001;
233        info!(
234            "popped 50,000,000 messages ({:.2} MB) in {} ms [{:.2}MB/s]",
235            test_msg_mb,
236            elapsed.as_millis(),
237            test_msg_bytes as f64 / elapsed.as_micros() as f64
238        );
239
240        std::fs::remove_dir_all(path).unwrap();
241    }
242
243    #[test]
244    fn spsc_test() {
245        const TOTAL_MESSAGES: usize = 50_000_000;
246        const NUM_THREADS: usize = 1;
247        const TEST_MESSAGE: &str = const_str::repeat!("a", 100);
248
249        tracing_subscriber::fmt::init();
250
251        let path = mkdir_random();
252        let manager = DataPagesManager::new(&path).unwrap();
253        let rx = Receiver::new(0, manager.clone()).unwrap();
254        let (tx_end, rx_end) = mpsc::sync_channel(1);
255
256        let mut handles = Vec::new();
257        let msg_count = Arc::new(AtomicUsize::new(0));
258        let barrier = Arc::new(Barrier::new(NUM_THREADS * 2 + 1));
259
260        for _ in 0..NUM_THREADS {
261            let tx_end_clone = tx_end.clone();
262            let mut rx_clone = rx.clone();
263            let msgs_count_clone = msg_count.clone();
264            let barrier_clone = barrier.clone();
265
266            handles.push(thread::spawn(move || {
267                barrier_clone.wait();
268
269                loop {
270                    let m = msgs_count_clone.load(Ordering::Relaxed);
271
272                    if m == TOTAL_MESSAGES {
273                        break;
274                    }
275
276                    let msg = rx_clone.pop().unwrap(); // blocking
277                    assert!(String::from_utf8_lossy(msg).eq(TEST_MESSAGE));
278                    msgs_count_clone.fetch_add(1, Ordering::Relaxed);
279                }
280
281                let _ = tx_end_clone.send(());
282            }));
283        }
284
285        let tx = Sender::new(manager).unwrap();
286
287        for _ in 0..NUM_THREADS {
288            let mut tx_clone = tx.clone();
289            let barrier_clone = barrier.clone();
290
291            handles.push(thread::spawn(move || {
292                barrier_clone.wait();
293
294                for _ in 0..TOTAL_MESSAGES / NUM_THREADS {
295                    tx_clone.push(TEST_MESSAGE).unwrap();
296                }
297            }));
298        }
299
300        barrier.wait();
301        let now = Instant::now();
302        let _ = rx_end.recv();
303
304        let elapsed = now.elapsed();
305        let test_msg_bytes = TEST_MESSAGE.as_bytes().len() * 50_000_000;
306        let test_msg_mb = test_msg_bytes as f64 * 0.000001;
307        info!(
308            "pushed & popped 50,000,000 messages ({:.2} MB) in {} ms [{:.2}MB/s]",
309            test_msg_mb,
310            elapsed.as_millis(),
311            test_msg_bytes as f64 / elapsed.as_micros() as f64
312        );
313
314        std::fs::remove_dir_all(path).unwrap();
315    }
316
317    #[test]
318    fn mpmc_test() {
319        const TOTAL_MESSAGES: usize = 50_000_000;
320        const NUM_THREADS: usize = 8;
321        const TEST_MESSAGE: &str = const_str::repeat!("a", 100);
322
323        tracing_subscriber::fmt::init();
324
325        let path = mkdir_random();
326        let manager = DataPagesManager::new(&path).unwrap();
327        let rx = Receiver::new(0, manager.clone()).unwrap();
328        let (tx_end, rx_end) = mpsc::sync_channel(1);
329
330        let mut handles = Vec::new();
331        let msg_count = Arc::new(AtomicUsize::new(0));
332        let barrier = Arc::new(Barrier::new(NUM_THREADS * 2 + 1));
333
334        for _ in 0..NUM_THREADS {
335            let tx_end_clone = tx_end.clone();
336            let mut rx_clone = rx.clone();
337            let msgs_count_clone = msg_count.clone();
338            let barrier_clone = barrier.clone();
339
340            handles.push(thread::spawn(move || {
341                barrier_clone.wait();
342
343                loop {
344                    let m = msgs_count_clone.load(Ordering::Relaxed);
345
346                    if m == TOTAL_MESSAGES {
347                        break;
348                    }
349
350                    let msg = rx_clone.pop().unwrap(); // blocking
351                    assert!(String::from_utf8_lossy(msg).eq(TEST_MESSAGE));
352                    msgs_count_clone.fetch_add(1, Ordering::Relaxed);
353                }
354
355                let _ = tx_end_clone.send(());
356            }));
357        }
358
359        let tx = Sender::new(manager).unwrap();
360
361        for _ in 0..NUM_THREADS {
362            let mut tx_clone = tx.clone();
363            let barrier_clone = barrier.clone();
364
365            handles.push(thread::spawn(move || {
366                barrier_clone.wait();
367
368                for _ in 0..TOTAL_MESSAGES / NUM_THREADS {
369                    tx_clone.push(TEST_MESSAGE).unwrap();
370                }
371            }));
372        }
373
374        barrier.wait();
375        let now = Instant::now();
376        let _ = rx_end.recv();
377
378        let elapsed = now.elapsed();
379        let test_msg_bytes = TEST_MESSAGE.as_bytes().len() * 50_000_000;
380        let test_msg_mb = test_msg_bytes as f64 * 0.000001;
381        info!(
382            "pushed & popped 50,000,000 messages ({:.2} MB) in {} ms [{:.2}MB/s]",
383            test_msg_mb,
384            elapsed.as_millis(),
385            test_msg_bytes as f64 / elapsed.as_micros() as f64
386        );
387
388        std::fs::remove_dir_all(path).unwrap();
389    }
390
391    #[test]
392    fn two_topics() {
393        const TOTAL_MESSAGES: usize = 50_000_000;
394        const NUM_THREADS: usize = 1;
395        const TEST_MESSAGE: &str = const_str::repeat!("a", 100);
396
397        tracing_subscriber::fmt::init();
398
399        let path = mkdir_random();
400        let path2 = mkdir_random();
401        let manager1 = DataPagesManager::new(&path).unwrap();
402        let manager2 = DataPagesManager::new(&path2).unwrap();
403
404        let mut handles = Vec::new();
405        let barrier = Arc::new(Barrier::new(NUM_THREADS * 2 + 1));
406
407        let tx = Sender::new(manager1).unwrap();
408        let tx2 = Sender::new(manager2).unwrap();
409
410        for _ in 0..NUM_THREADS {
411            let mut tx_clone = tx.clone();
412            let barrier_clone = barrier.clone();
413
414            handles.push(thread::spawn(move || {
415                barrier_clone.wait();
416
417                for _ in 0..TOTAL_MESSAGES / NUM_THREADS {
418                    tx_clone.push(TEST_MESSAGE).unwrap();
419                }
420            }));
421        }
422
423        for _ in 0..NUM_THREADS {
424            let mut tx_clone = tx2.clone();
425            let barrier_clone = barrier.clone();
426
427            handles.push(thread::spawn(move || {
428                barrier_clone.wait();
429
430                for _ in 0..TOTAL_MESSAGES / NUM_THREADS {
431                    tx_clone.push(TEST_MESSAGE).unwrap();
432                }
433            }));
434        }
435
436        barrier.wait();
437
438        let now = Instant::now();
439        for h in handles {
440            h.join().unwrap();
441        }
442
443        let elapsed = now.elapsed();
444        let test_msg_bytes = TEST_MESSAGE.as_bytes().len() * 100_000_000;
445        let test_msg_mb = test_msg_bytes as f64 * 0.000001;
446        info!(
447            "pushed 100,000,000 messages ({:.2} MB) in {} ms [{:.2}MB/s]",
448            test_msg_mb,
449            elapsed.as_millis(),
450            test_msg_bytes as f64 / elapsed.as_micros() as f64
451        );
452
453        std::fs::remove_dir_all(path).unwrap();
454    }
455}