use amaters_core::{
CipherBlob, Key,
error::{AmateRSError, ErrorContext, Result as CoreResult},
storage::MemoryStorage,
traits::StorageEngine,
};
use amaters_net::server::AqlServerBuilder;
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::RwLock;
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tracing::warn;
#[derive(Debug, Clone)]
pub struct MockStorage {
inner: Arc<MemoryStorage>,
errors: Arc<DashMap<Key, AmateRSError>>,
}
impl MockStorage {
pub fn new() -> Self {
Self {
inner: Arc::new(MemoryStorage::new()),
errors: Arc::new(DashMap::new()),
}
}
pub async fn insert(&self, key: impl Into<Key>, value: CipherBlob) -> CoreResult<()> {
self.inner.put(&key.into(), &value).await
}
pub fn inject_error(&self, key: impl Into<Key>, err: AmateRSError) {
self.errors.insert(key.into(), err);
}
pub fn clear_error(&self, key: impl Into<Key>) {
self.errors.remove(&key.into());
}
pub async fn get_all(&self) -> CoreResult<Vec<(Key, CipherBlob)>> {
let keys = self.inner.keys().await?;
let mut out = Vec::with_capacity(keys.len());
for k in keys {
if let Some(v) = self.inner.get(&k).await? {
out.push((k, v));
}
}
Ok(out)
}
fn check_error(&self, key: &Key) -> Option<AmateRSError> {
self.errors.get(key).map(|e| e.clone())
}
}
impl Default for MockStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StorageEngine for MockStorage {
async fn put(&self, key: &Key, value: &CipherBlob) -> CoreResult<()> {
if let Some(err) = self.check_error(key) {
return Err(err);
}
self.inner.put(key, value).await
}
async fn get(&self, key: &Key) -> CoreResult<Option<CipherBlob>> {
if let Some(err) = self.check_error(key) {
return Err(err);
}
self.inner.get(key).await
}
async fn atomic_update<F>(&self, key: &Key, f: F) -> CoreResult<()>
where
F: Fn(&CipherBlob) -> CoreResult<CipherBlob> + Send + Sync,
{
self.inner.atomic_update(key, f).await
}
async fn delete(&self, key: &Key) -> CoreResult<()> {
if let Some(err) = self.check_error(key) {
return Err(err);
}
self.inner.delete(key).await
}
async fn range(&self, start: &Key, end: &Key) -> CoreResult<Vec<(Key, CipherBlob)>> {
self.inner.range(start, end).await
}
async fn keys(&self) -> CoreResult<Vec<Key>> {
self.inner.keys().await
}
async fn flush(&self) -> CoreResult<()> {
self.inner.flush().await
}
async fn close(&self) -> CoreResult<()> {
self.inner.close().await
}
}
pub struct MockServerBuilder {
initial_values: HashMap<Key, CipherBlob>,
initial_errors: HashMap<Key, AmateRSError>,
}
impl MockServerBuilder {
pub fn new() -> Self {
Self {
initial_values: HashMap::new(),
initial_errors: HashMap::new(),
}
}
#[must_use]
pub fn with_value(mut self, key: impl Into<Key>, value: CipherBlob) -> Self {
self.initial_values.insert(key.into(), value);
self
}
#[must_use]
pub fn with_error(mut self, key: impl Into<Key>, err: AmateRSError) -> Self {
self.initial_errors.insert(key.into(), err);
self
}
pub async fn start(self) -> anyhow::Result<MockServerHandle> {
let storage = Arc::new(MockStorage::new());
for (key, value) in self.initial_values {
storage.inner.put(&key, &value).await?;
}
for (key, err) in self.initial_errors {
storage.errors.insert(key, err);
}
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let grpc_service = AqlServerBuilder::new(Arc::clone(&storage)).build_grpc_service();
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let result = tonic::transport::Server::builder()
.add_service(grpc_service)
.serve_with_incoming_shutdown(incoming, async {
let _ = shutdown_rx.await;
})
.await;
if let Err(e) = result {
warn!("[mock_server] tonic serve error: {e}");
}
});
Ok(MockServerHandle {
addr,
storage,
shutdown_tx: RwLock::new(Some(shutdown_tx)),
})
}
}
impl Default for MockServerBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct MockServerHandle {
addr: SocketAddr,
storage: Arc<MockStorage>,
shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
}
impl MockServerHandle {
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn endpoint(&self) -> String {
format!("http://{}", self.addr)
}
pub async fn insert(&self, key: impl Into<Key>, value: CipherBlob) -> CoreResult<()> {
self.storage.insert(key, value).await
}
pub async fn get_all(&self) -> CoreResult<Vec<(Key, CipherBlob)>> {
self.storage.get_all().await
}
pub fn inject_error(&self, key: impl Into<Key>, err: AmateRSError) {
self.storage.inject_error(key, err);
}
pub fn clear_error(&self, key: impl Into<Key>) {
self.storage.clear_error(key);
}
pub async fn shutdown(self) {
let maybe_tx = self.shutdown_tx.write().take();
if let Some(tx) = maybe_tx {
let _ = tx.send(());
}
tokio::task::yield_now().await;
}
}
impl Drop for MockServerHandle {
fn drop(&mut self) {
let maybe_tx = self.shutdown_tx.write().take();
if let Some(tx) = maybe_tx {
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_storage_basic_operations() -> CoreResult<()> {
let storage = MockStorage::new();
let key = Key::from_str("hello");
let value = CipherBlob::new(vec![1, 2, 3]);
storage.put(&key, &value).await?;
let got = storage.get(&key).await?;
assert_eq!(got, Some(value.clone()));
storage.delete(&key).await?;
let got2 = storage.get(&key).await?;
assert!(got2.is_none());
Ok(())
}
#[tokio::test]
async fn test_mock_storage_error_injection_get() {
let storage = MockStorage::new();
let key = Key::from_str("bad_key");
storage.inject_error(
"bad_key",
AmateRSError::IoError(ErrorContext::new("simulated I/O failure")),
);
let result = storage.get(&key).await;
assert!(result.is_err());
let msg = result.expect_err("should be Err").to_string();
assert!(msg.contains("simulated I/O failure"), "got: {msg}");
}
#[tokio::test]
async fn test_mock_storage_error_injection_put() {
let storage = MockStorage::new();
storage.inject_error(
"readonly_key",
AmateRSError::ValidationError(ErrorContext::new("write denied")),
);
let result = storage
.put(&Key::from_str("readonly_key"), &CipherBlob::new(vec![9]))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_storage_error_injection_delete() {
let storage = MockStorage::new();
storage.inject_error(
"nodelete_key",
AmateRSError::ValidationError(ErrorContext::new("delete denied")),
);
let result = storage.delete(&Key::from_str("nodelete_key")).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_storage_clear_error_restores_normal() -> CoreResult<()> {
let storage = MockStorage::new();
let key = Key::from_str("transient");
let value = CipherBlob::new(vec![7]);
storage.inject_error(
"transient",
AmateRSError::IoError(ErrorContext::new("transient failure")),
);
assert!(storage.get(&key).await.is_err());
storage.clear_error("transient");
let result = storage.get(&key).await?;
assert!(result.is_none());
storage.put(&key, &value).await?;
let result2 = storage.get(&key).await?;
assert_eq!(result2, Some(value));
Ok(())
}
#[tokio::test]
async fn test_mock_storage_unaffected_key_works() -> CoreResult<()> {
let storage = MockStorage::new();
storage.inject_error("bad", AmateRSError::IoError(ErrorContext::new("fail")));
let good_key = Key::from_str("good");
let value = CipherBlob::new(vec![1]);
storage.put(&good_key, &value).await?;
let got = storage.get(&good_key).await?;
assert_eq!(got, Some(value));
Ok(())
}
#[tokio::test]
async fn test_mock_server_builder_start_and_endpoint() -> anyhow::Result<()> {
let handle = MockServerBuilder::new().start().await?;
let ep = handle.endpoint();
assert!(ep.starts_with("http://127.0.0.1:"), "endpoint: {ep}");
handle.shutdown().await;
Ok(())
}
#[tokio::test]
async fn test_mock_server_with_value_preload() -> anyhow::Result<()> {
let key = Key::from_str("preloaded");
let value = CipherBlob::new(vec![10, 20, 30]);
let handle = MockServerBuilder::new()
.with_value(key.clone(), value.clone())
.start()
.await?;
let all = handle.get_all().await?;
assert_eq!(all.len(), 1);
assert_eq!(all[0].0, key);
assert_eq!(all[0].1, value);
handle.shutdown().await;
Ok(())
}
#[tokio::test]
async fn test_mock_server_runtime_insert() -> anyhow::Result<()> {
let handle = MockServerBuilder::new().start().await?;
handle
.insert(Key::from_str("k1"), CipherBlob::new(vec![1]))
.await?;
handle
.insert(Key::from_str("k2"), CipherBlob::new(vec![2]))
.await?;
let all = handle.get_all().await?;
assert_eq!(all.len(), 2);
handle.shutdown().await;
Ok(())
}
#[tokio::test]
async fn test_mock_server_double_shutdown_noop() -> anyhow::Result<()> {
let handle = MockServerBuilder::new().start().await?;
let addr = handle.addr();
handle.shutdown().await;
let result = tokio::time::timeout(std::time::Duration::from_millis(200), async {
tokio::net::TcpStream::connect(addr).await
})
.await;
let connected = result.map(|r| r.is_ok()).unwrap_or(false);
assert!(!connected, "server should be shut down");
Ok(())
}
}