1use std::error::Error;
2use std::path::Path;
3use libc::{c_uint, size_t};
4
5use heed3::{Database, EnvFlags, EnvOpenOptions, RoTxn, RwTxn, WithTls};
6
7use super::topic::{Consumer, Producer};
8
9#[cfg(test)]
10use super::topic::Topic;
11
12pub struct Env {
13 pub lmdb_env: heed3::Env,
14 pub root: String,
15}
16
17impl Env {
18 pub fn new<P: AsRef<Path>>(root: P, max_topics: Option<c_uint>, map_size: Option<size_t>) -> Result<Env, Box<dyn Error>> {
19 let env = unsafe {
20 EnvOpenOptions::new()
21 .map_size(map_size.unwrap_or(256 * 1024 * 1024))
22 .max_dbs(max_topics.unwrap_or(256) * 2)
23 .flags(EnvFlags::NO_SYNC | EnvFlags::NO_SUB_DIR)
24 .open(root.as_ref())
25 .map(|lmdb_env| Env { lmdb_env, root: root.as_ref().to_str().unwrap().to_string() })
26 };
27
28 Ok(env?)
29 }
30
31 pub fn db<K, V>(&self, wtxn: &mut RwTxn, name: &str) -> Result<Database<K, V>, Box<dyn Error>>
32 where K: 'static, V: 'static
33 {
34 Ok(self.lmdb_env.create_database::<K, V>(wtxn, Some(name))?)
35 }
36
37 pub fn producer(&self, name: &str, chunk_size: Option<u64>) -> Result<Producer, Box<dyn Error>> {
38 Producer::new(&self, name, chunk_size)
39 }
40
41 pub fn consumer(&self, name: &str, chunks_to_keep: Option<u64>) -> Result<Consumer, Box<dyn Error>> {
42 Consumer::new(&self, name, chunks_to_keep)
43 }
44
45 pub fn write_txn(&self) -> Result<RwTxn, Box<dyn Error>> {
46 Ok(self.lmdb_env.write_txn()?)
47 }
48
49 pub fn read_txn(&self) -> Result<RoTxn<WithTls>, Box<dyn Error>> {
50 Ok(self.lmdb_env.read_txn()?)
51 }
52}
53
54#[test]
55fn test_single() -> Result<(), Box<dyn Error>> {
56 let env = Env::new("/tmp/foo_env", None, None)?;
57 let mut producer = env.producer("test", Some(16 *1024 * 1024))?;
58 for i in 0..1024*1024 {
59 producer.push_back(&format!("{}", i).as_bytes())?;
60 }
61
62 let mut consumer = env.consumer("test", None)?;
63 let lag = consumer.lag()?;
64 println!("Current lag is: {}", lag);
65
66 let mut message_count = 0;
67 loop {
68 let item = consumer.pop_front()?;
69 if let Some(item) = item {
70 message_count += 1;
71 if message_count % (1024 * 100) == 0 {
72 println!("Got message: {}", String::from_utf8(item.data)?);
73 let cur_lag = consumer.lag()?;
74 assert!(lag == cur_lag + message_count as u64);
75 }
76 } else {
77 println!("Read {} messages.", message_count);
78 break;
79 }
80 }
81
82 Ok(())
83}
84
85#[test]
86fn test_batch() -> Result<(), Box<dyn Error>> {
87 let env = Env::new("/tmp/foo_env", None, None)?;
88 let mut producer = env.producer("test", Some(16 * 1024 * 1024))?;
89 for i in 0..1024*100 {
90 let vec: Vec<String> = (0..10).map(|v| format!("{}_{}", i, v)).collect();
91 let batch: Vec<&[u8]> = vec.iter().map(|v| v.as_bytes()).collect();
92
93 producer.push_back_batch(&batch)?;
94 }
95
96 let mut consumer = env.consumer("test", None)?;
97 let lag = consumer.lag()?;
98 println!("Current lag is: {}", lag);
99 let mut message_count = 0;
100 loop {
101 let items = consumer.pop_front_n(10)?;
102 if items.len() > 0 {
103 message_count += items.len();
104 if message_count % (1024 * 100) == 0 {
105 println!("Got message: {}", String::from_utf8(items[0].data.clone())?);
106 let cur_lag = consumer.lag()?;
107 assert!(lag == cur_lag + message_count as u64);
108 }
109 } else {
110 println!("Read {} messages.", message_count);
111 break;
112 }
113 }
114
115 Ok(())
116}