1use amaters_core::{
33 CipherBlob, Key,
34 error::{AmateRSError, ErrorContext, Result as CoreResult},
35 storage::MemoryStorage,
36 traits::StorageEngine,
37};
38use amaters_net::server::AqlServerBuilder;
39use async_trait::async_trait;
40use dashmap::DashMap;
41use parking_lot::RwLock;
42use std::{collections::HashMap, net::SocketAddr, sync::Arc};
43use tokio::net::TcpListener;
44use tokio::sync::oneshot;
45use tracing::warn;
46
47#[derive(Debug, Clone)]
61pub struct MockStorage {
62 inner: Arc<MemoryStorage>,
63 errors: Arc<DashMap<Key, AmateRSError>>,
65}
66
67impl MockStorage {
68 pub fn new() -> Self {
70 Self {
71 inner: Arc::new(MemoryStorage::new()),
72 errors: Arc::new(DashMap::new()),
73 }
74 }
75
76 pub async fn insert(&self, key: impl Into<Key>, value: CipherBlob) -> CoreResult<()> {
85 self.inner.put(&key.into(), &value).await
86 }
87
88 pub fn inject_error(&self, key: impl Into<Key>, err: AmateRSError) {
93 self.errors.insert(key.into(), err);
94 }
95
96 pub fn clear_error(&self, key: impl Into<Key>) {
98 self.errors.remove(&key.into());
99 }
100
101 pub async fn get_all(&self) -> CoreResult<Vec<(Key, CipherBlob)>> {
103 let keys = self.inner.keys().await?;
104 let mut out = Vec::with_capacity(keys.len());
105 for k in keys {
106 if let Some(v) = self.inner.get(&k).await? {
107 out.push((k, v));
108 }
109 }
110 Ok(out)
111 }
112
113 fn check_error(&self, key: &Key) -> Option<AmateRSError> {
115 self.errors.get(key).map(|e| e.clone())
116 }
117}
118
119impl Default for MockStorage {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[async_trait]
126impl StorageEngine for MockStorage {
127 async fn put(&self, key: &Key, value: &CipherBlob) -> CoreResult<()> {
128 if let Some(err) = self.check_error(key) {
129 return Err(err);
130 }
131 self.inner.put(key, value).await
132 }
133
134 async fn get(&self, key: &Key) -> CoreResult<Option<CipherBlob>> {
135 if let Some(err) = self.check_error(key) {
136 return Err(err);
137 }
138 self.inner.get(key).await
139 }
140
141 async fn atomic_update<F>(&self, key: &Key, f: F) -> CoreResult<()>
142 where
143 F: Fn(&CipherBlob) -> CoreResult<CipherBlob> + Send + Sync,
144 {
145 self.inner.atomic_update(key, f).await
147 }
148
149 async fn delete(&self, key: &Key) -> CoreResult<()> {
150 if let Some(err) = self.check_error(key) {
151 return Err(err);
152 }
153 self.inner.delete(key).await
154 }
155
156 async fn range(&self, start: &Key, end: &Key) -> CoreResult<Vec<(Key, CipherBlob)>> {
157 self.inner.range(start, end).await
158 }
159
160 async fn keys(&self) -> CoreResult<Vec<Key>> {
161 self.inner.keys().await
162 }
163
164 async fn flush(&self) -> CoreResult<()> {
165 self.inner.flush().await
166 }
167
168 async fn close(&self) -> CoreResult<()> {
169 self.inner.close().await
170 }
171}
172
173pub struct MockServerBuilder {
183 initial_values: HashMap<Key, CipherBlob>,
185 initial_errors: HashMap<Key, AmateRSError>,
187}
188
189impl MockServerBuilder {
190 pub fn new() -> Self {
192 Self {
193 initial_values: HashMap::new(),
194 initial_errors: HashMap::new(),
195 }
196 }
197
198 #[must_use]
203 pub fn with_value(mut self, key: impl Into<Key>, value: CipherBlob) -> Self {
204 self.initial_values.insert(key.into(), value);
205 self
206 }
207
208 #[must_use]
213 pub fn with_error(mut self, key: impl Into<Key>, err: AmateRSError) -> Self {
214 self.initial_errors.insert(key.into(), err);
215 self
216 }
217
218 pub async fn start(self) -> anyhow::Result<MockServerHandle> {
228 let storage = Arc::new(MockStorage::new());
229
230 for (key, value) in self.initial_values {
232 storage.inner.put(&key, &value).await?;
233 }
234 for (key, err) in self.initial_errors {
236 storage.errors.insert(key, err);
237 }
238
239 let listener = TcpListener::bind("127.0.0.1:0").await?;
241 let addr = listener.local_addr()?;
242
243 let grpc_service = AqlServerBuilder::new(Arc::clone(&storage)).build_grpc_service();
244 let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
245
246 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
247
248 tokio::spawn(async move {
249 let result = tonic::transport::Server::builder()
250 .add_service(grpc_service)
251 .serve_with_incoming_shutdown(incoming, async {
252 let _ = shutdown_rx.await;
253 })
254 .await;
255
256 if let Err(e) = result {
257 warn!("[mock_server] tonic serve error: {e}");
258 }
259 });
260
261 Ok(MockServerHandle {
262 addr,
263 storage,
264 shutdown_tx: RwLock::new(Some(shutdown_tx)),
265 })
266 }
267}
268
269impl Default for MockServerBuilder {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275pub struct MockServerHandle {
284 addr: SocketAddr,
285 storage: Arc<MockStorage>,
287 shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
289}
290
291impl MockServerHandle {
292 pub fn addr(&self) -> SocketAddr {
294 self.addr
295 }
296
297 pub fn endpoint(&self) -> String {
299 format!("http://{}", self.addr)
300 }
301
302 pub async fn insert(&self, key: impl Into<Key>, value: CipherBlob) -> CoreResult<()> {
311 self.storage.insert(key, value).await
312 }
313
314 pub async fn get_all(&self) -> CoreResult<Vec<(Key, CipherBlob)>> {
320 self.storage.get_all().await
321 }
322
323 pub fn inject_error(&self, key: impl Into<Key>, err: AmateRSError) {
328 self.storage.inject_error(key, err);
329 }
330
331 pub fn clear_error(&self, key: impl Into<Key>) {
333 self.storage.clear_error(key);
334 }
335
336 pub async fn shutdown(self) {
341 let maybe_tx = self.shutdown_tx.write().take();
342 if let Some(tx) = maybe_tx {
343 let _ = tx.send(());
344 }
345 tokio::task::yield_now().await;
347 }
348}
349
350impl Drop for MockServerHandle {
351 fn drop(&mut self) {
352 let maybe_tx = self.shutdown_tx.write().take();
354 if let Some(tx) = maybe_tx {
355 let _ = tx.send(());
356 }
357 }
358}
359
360#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[tokio::test]
369 async fn test_mock_storage_basic_operations() -> CoreResult<()> {
370 let storage = MockStorage::new();
371 let key = Key::from_str("hello");
372 let value = CipherBlob::new(vec![1, 2, 3]);
373
374 storage.put(&key, &value).await?;
375 let got = storage.get(&key).await?;
376 assert_eq!(got, Some(value.clone()));
377
378 storage.delete(&key).await?;
379 let got2 = storage.get(&key).await?;
380 assert!(got2.is_none());
381
382 Ok(())
383 }
384
385 #[tokio::test]
386 async fn test_mock_storage_error_injection_get() {
387 let storage = MockStorage::new();
388 let key = Key::from_str("bad_key");
389
390 storage.inject_error(
391 "bad_key",
392 AmateRSError::IoError(ErrorContext::new("simulated I/O failure")),
393 );
394
395 let result = storage.get(&key).await;
396 assert!(result.is_err());
397 let msg = result.expect_err("should be Err").to_string();
398 assert!(msg.contains("simulated I/O failure"), "got: {msg}");
399 }
400
401 #[tokio::test]
402 async fn test_mock_storage_error_injection_put() {
403 let storage = MockStorage::new();
404
405 storage.inject_error(
406 "readonly_key",
407 AmateRSError::ValidationError(ErrorContext::new("write denied")),
408 );
409
410 let result = storage
411 .put(&Key::from_str("readonly_key"), &CipherBlob::new(vec![9]))
412 .await;
413 assert!(result.is_err());
414 }
415
416 #[tokio::test]
417 async fn test_mock_storage_error_injection_delete() {
418 let storage = MockStorage::new();
419
420 storage.inject_error(
421 "nodelete_key",
422 AmateRSError::ValidationError(ErrorContext::new("delete denied")),
423 );
424
425 let result = storage.delete(&Key::from_str("nodelete_key")).await;
426 assert!(result.is_err());
427 }
428
429 #[tokio::test]
430 async fn test_mock_storage_clear_error_restores_normal() -> CoreResult<()> {
431 let storage = MockStorage::new();
432 let key = Key::from_str("transient");
433 let value = CipherBlob::new(vec![7]);
434
435 storage.inject_error(
436 "transient",
437 AmateRSError::IoError(ErrorContext::new("transient failure")),
438 );
439 assert!(storage.get(&key).await.is_err());
440
441 storage.clear_error("transient");
442 let result = storage.get(&key).await?;
444 assert!(result.is_none());
445
446 storage.put(&key, &value).await?;
447 let result2 = storage.get(&key).await?;
448 assert_eq!(result2, Some(value));
449
450 Ok(())
451 }
452
453 #[tokio::test]
454 async fn test_mock_storage_unaffected_key_works() -> CoreResult<()> {
455 let storage = MockStorage::new();
456
457 storage.inject_error("bad", AmateRSError::IoError(ErrorContext::new("fail")));
458
459 let good_key = Key::from_str("good");
460 let value = CipherBlob::new(vec![1]);
461 storage.put(&good_key, &value).await?;
462 let got = storage.get(&good_key).await?;
463 assert_eq!(got, Some(value));
464
465 Ok(())
466 }
467
468 #[tokio::test]
469 async fn test_mock_server_builder_start_and_endpoint() -> anyhow::Result<()> {
470 let handle = MockServerBuilder::new().start().await?;
471 let ep = handle.endpoint();
472 assert!(ep.starts_with("http://127.0.0.1:"), "endpoint: {ep}");
473 handle.shutdown().await;
474 Ok(())
475 }
476
477 #[tokio::test]
478 async fn test_mock_server_with_value_preload() -> anyhow::Result<()> {
479 let key = Key::from_str("preloaded");
480 let value = CipherBlob::new(vec![10, 20, 30]);
481
482 let handle = MockServerBuilder::new()
483 .with_value(key.clone(), value.clone())
484 .start()
485 .await?;
486
487 let all = handle.get_all().await?;
488 assert_eq!(all.len(), 1);
489 assert_eq!(all[0].0, key);
490 assert_eq!(all[0].1, value);
491
492 handle.shutdown().await;
493 Ok(())
494 }
495
496 #[tokio::test]
497 async fn test_mock_server_runtime_insert() -> anyhow::Result<()> {
498 let handle = MockServerBuilder::new().start().await?;
499
500 handle
501 .insert(Key::from_str("k1"), CipherBlob::new(vec![1]))
502 .await?;
503 handle
504 .insert(Key::from_str("k2"), CipherBlob::new(vec![2]))
505 .await?;
506
507 let all = handle.get_all().await?;
508 assert_eq!(all.len(), 2);
509
510 handle.shutdown().await;
511 Ok(())
512 }
513
514 #[tokio::test]
515 async fn test_mock_server_double_shutdown_noop() -> anyhow::Result<()> {
516 let handle = MockServerBuilder::new().start().await?;
517 let addr = handle.addr();
519 handle.shutdown().await;
520 let result = tokio::time::timeout(std::time::Duration::from_millis(200), async {
522 tokio::net::TcpStream::connect(addr).await
523 })
524 .await;
525 let connected = result.map(|r| r.is_ok()).unwrap_or(false);
527 assert!(!connected, "server should be shut down");
528 Ok(())
529 }
530}