use crate::shard::ShardId;
use crate::storage::persistent::StorageBackend;
use anyhow::Result;
use async_trait::async_trait;
use oxirs_core::model::Triple;
use oxirs_core::RdfTerm;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Default, Clone)]
pub struct MockStorageBackend {
shards: Arc<RwLock<HashMap<ShardId, Vec<Triple>>>>,
}
impl MockStorageBackend {
pub fn new() -> Self {
Self::default()
}
pub async fn shard_count(&self) -> usize {
self.shards.read().await.len()
}
pub async fn clear(&self) {
self.shards.write().await.clear();
}
}
#[async_trait]
impl StorageBackend for MockStorageBackend {
async fn create_shard(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.insert(shard_id, Vec::new());
Ok(())
}
async fn delete_shard(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.remove(&shard_id);
Ok(())
}
async fn insert_triple_to_shard(&self, shard_id: ShardId, triple: Triple) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.push(triple);
} else {
shards.insert(shard_id, vec![triple]);
}
Ok(())
}
async fn delete_triple_from_shard(&self, shard_id: ShardId, triple: &Triple) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.retain(|t| t != triple);
}
Ok(())
}
async fn query_shard(
&self,
shard_id: ShardId,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
if let Some(shard) = shards.get(&shard_id) {
let results: Vec<Triple> = shard
.iter()
.filter(|triple| {
let subject_match = subject.map_or(true, |s| {
if let oxirs_core::model::Subject::NamedNode(named_node) = triple.subject()
{
named_node.as_str() == s
} else {
triple.subject().to_string() == s
}
});
let predicate_match =
predicate.map_or(true, |p| triple.predicate().as_str() == p);
let object_match = object.map_or(true, |o| {
if let oxirs_core::Object::NamedNode(named_node) = triple.object() {
named_node.as_str() == o
} else {
triple.object().to_string() == o
}
});
subject_match && predicate_match && object_match
})
.cloned()
.collect();
Ok(results)
} else {
Ok(Vec::new())
}
}
async fn get_shard_size(&self, shard_id: ShardId) -> Result<u64> {
let shards = self.shards.read().await;
if let Some(shard) = shards.get(&shard_id) {
Ok((shard.len() * 100) as u64)
} else {
Ok(0)
}
}
async fn get_shard_triple_count(&self, shard_id: ShardId) -> Result<usize> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).map_or(0, |s| s.len()))
}
async fn export_shard(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).cloned().unwrap_or_default())
}
async fn import_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()> {
self.shards.write().await.insert(shard_id, triples);
Ok(())
}
async fn get_shard_triples(&self, shard_id: ShardId) -> Result<Vec<Triple>> {
let shards = self.shards.read().await;
Ok(shards.get(&shard_id).cloned().unwrap_or_default())
}
async fn insert_triples_to_shard(&self, shard_id: ShardId, triples: Vec<Triple>) -> Result<()> {
let mut shards = self.shards.write().await;
if let Some(shard) = shards.get_mut(&shard_id) {
shard.extend(triples);
} else {
shards.insert(shard_id, triples);
}
Ok(())
}
async fn mark_shard_for_deletion(&self, shard_id: ShardId) -> Result<()> {
self.shards.write().await.remove(&shard_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxirs_core::model::{Literal, NamedNode, Subject};
#[tokio::test]
async fn test_mock_backend_basic_operations() {
let backend = MockStorageBackend::new();
let shard_id = 1;
backend.create_shard(shard_id).await.unwrap();
assert_eq!(backend.shard_count().await, 1);
let triple = Triple::new(
Subject::NamedNode(NamedNode::new("http://example.org/subject").unwrap()),
NamedNode::new("http://example.org/predicate").unwrap(),
oxirs_core::Object::Literal(Literal::new("object")),
);
backend
.insert_triple_to_shard(shard_id, triple.clone())
.await
.unwrap();
let results = backend
.query_shard(shard_id, None, None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
backend
.delete_triple_from_shard(shard_id, &triple)
.await
.unwrap();
let results = backend
.query_shard(shard_id, None, None, None)
.await
.unwrap();
assert_eq!(results.len(), 0);
backend.delete_shard(shard_id).await.unwrap();
assert_eq!(backend.shard_count().await, 0);
}
}