1use azoth_core::{
8 error::{AzothError, Result},
9 ReadPoolConfig,
10};
11use lmdb::{Database, Environment, RoTransaction, Transaction};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{Semaphore, SemaphorePermit};
15
16pub struct PooledLmdbReadTxn<'a> {
21 txn: RoTransaction<'a>,
22 state_db: Database,
23 _permit: SemaphorePermit<'a>,
24}
25
26impl<'a> PooledLmdbReadTxn<'a> {
27 pub fn get_state(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
29 match self.txn.get(self.state_db, &key) {
30 Ok(bytes) => Ok(Some(bytes.to_vec())),
31 Err(lmdb::Error::NotFound) => Ok(None),
32 Err(e) => Err(AzothError::Transaction(e.to_string())),
33 }
34 }
35
36 pub fn exists(&self, key: &[u8]) -> Result<bool> {
38 Ok(self.get_state(key)?.is_some())
39 }
40}
41
42pub struct LmdbReadPool {
58 env: Arc<Environment>,
59 state_db: Database,
60 semaphore: Arc<Semaphore>,
61 acquire_timeout: Duration,
62 enabled: bool,
63}
64
65impl LmdbReadPool {
66 pub fn new(env: Arc<Environment>, state_db: Database, config: ReadPoolConfig) -> Self {
68 let pool_size = if config.enabled { config.pool_size } else { 1 };
69 Self {
70 env,
71 state_db,
72 semaphore: Arc::new(Semaphore::new(pool_size)),
73 acquire_timeout: Duration::from_millis(config.acquire_timeout_ms),
74 enabled: config.enabled,
75 }
76 }
77
78 pub async fn acquire(&self) -> Result<PooledLmdbReadTxn<'_>> {
83 let permit = tokio::time::timeout(self.acquire_timeout, self.semaphore.acquire())
84 .await
85 .map_err(|_| {
86 AzothError::Timeout(format!(
87 "Read pool acquire timeout after {:?}",
88 self.acquire_timeout
89 ))
90 })?
91 .map_err(|e| AzothError::Internal(format!("Semaphore closed: {}", e)))?;
92
93 let txn = self
94 .env
95 .begin_ro_txn()
96 .map_err(|e| AzothError::Transaction(e.to_string()))?;
97
98 Ok(PooledLmdbReadTxn {
99 txn,
100 state_db: self.state_db,
101 _permit: permit,
102 })
103 }
104
105 pub fn try_acquire(&self) -> Result<Option<PooledLmdbReadTxn<'_>>> {
109 match self.semaphore.try_acquire() {
110 Ok(permit) => {
111 let txn = self
112 .env
113 .begin_ro_txn()
114 .map_err(|e| AzothError::Transaction(e.to_string()))?;
115
116 Ok(Some(PooledLmdbReadTxn {
117 txn,
118 state_db: self.state_db,
119 _permit: permit,
120 }))
121 }
122 Err(_) => Ok(None),
123 }
124 }
125
126 pub fn acquire_blocking(&self) -> Result<PooledLmdbReadTxn<'_>> {
134 let deadline = std::time::Instant::now() + self.acquire_timeout;
135 let mut backoff_ms = 1u64;
136 const MAX_BACKOFF_MS: u64 = 32;
137
138 let permit = loop {
139 match self.semaphore.try_acquire() {
140 Ok(permit) => break permit,
141 Err(_) => {
142 if std::time::Instant::now() >= deadline {
143 return Err(AzothError::Timeout(format!(
144 "LMDB read pool acquire timeout after {:?}",
145 self.acquire_timeout
146 )));
147 }
148 std::thread::sleep(Duration::from_millis(backoff_ms));
149 backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
150 }
151 }
152 };
153
154 let txn = self
155 .env
156 .begin_ro_txn()
157 .map_err(|e| AzothError::Transaction(e.to_string()))?;
158
159 Ok(PooledLmdbReadTxn {
160 txn,
161 state_db: self.state_db,
162 _permit: permit,
163 })
164 }
165
166 pub fn available_permits(&self) -> usize {
168 self.semaphore.available_permits()
169 }
170
171 pub fn is_enabled(&self) -> bool {
173 self.enabled
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use lmdb::{DatabaseFlags, Environment, EnvironmentFlags, WriteFlags};
181 use tempfile::TempDir;
182
183 fn create_test_env() -> (TempDir, Arc<Environment>, Database) {
184 let temp_dir = TempDir::new().unwrap();
185 let mut builder = Environment::new();
186 builder.set_max_dbs(1);
187 builder.set_max_readers(10);
188 builder.set_flags(EnvironmentFlags::empty());
189 let env = builder.open(temp_dir.path()).unwrap();
190 let db = env.create_db(Some("test"), DatabaseFlags::empty()).unwrap();
191 (temp_dir, Arc::new(env), db)
192 }
193
194 #[tokio::test]
195 async fn test_pool_acquire_release() {
196 let (_temp_dir, env, db) = create_test_env();
197 let config = ReadPoolConfig::enabled(2);
198 let pool = LmdbReadPool::new(env, db, config);
199
200 assert_eq!(pool.available_permits(), 2);
201
202 {
204 let txn1 = pool.acquire().await.unwrap();
205 assert_eq!(pool.available_permits(), 1);
206 drop(txn1);
207 }
208 assert_eq!(pool.available_permits(), 2);
209
210 {
212 let txn2 = pool.acquire().await.unwrap();
213 assert_eq!(pool.available_permits(), 1);
214 drop(txn2);
215 }
216 assert_eq!(pool.available_permits(), 2);
217 }
218
219 #[test]
220 fn test_try_acquire() {
221 let (_temp_dir, env, db) = create_test_env();
222 let config = ReadPoolConfig::enabled(1);
223 let pool = LmdbReadPool::new(env, db, config);
224
225 let txn = pool.try_acquire().unwrap();
227 assert!(txn.is_some());
228
229 assert!(pool.try_acquire().unwrap().is_none());
231
232 drop(txn);
234 assert!(pool.try_acquire().unwrap().is_some());
235 }
236
237 #[test]
238 fn test_pool_get_state() {
239 let (_temp_dir, env, db) = create_test_env();
240
241 {
243 let mut txn = env.begin_rw_txn().unwrap();
244 txn.put(db, b"key1", b"value1", WriteFlags::empty())
245 .unwrap();
246 txn.commit().unwrap();
247 }
248
249 let config = ReadPoolConfig::enabled(2);
250 let pool = LmdbReadPool::new(env, db, config);
251
252 let txn = pool.try_acquire().unwrap().unwrap();
253 let value = txn.get_state(b"key1").unwrap();
254 assert_eq!(value, Some(b"value1".to_vec()));
255
256 let missing = txn.get_state(b"nonexistent").unwrap();
257 assert!(missing.is_none());
258 }
259}