use crate::{
crdt::{Mergeable, ReplicaId},
storage::{LocalStorage, Storage, StorageError},
sync::{SyncEngine, SyncState},
transport::{SyncTransport, TransportError},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::marker::PhantomData;
use tokio::sync::RwLock;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum CollectionError {
#[error("Storage error: {0}")]
Storage(#[from] StorageError),
#[error("Transport error: {0}")]
Transport(#[from] TransportError),
#[error("Sync error: {0}")]
Sync(#[from] crate::sync::SyncEngineError),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Item not found: {0}")]
NotFound(String),
#[error("Invalid operation: {0}")]
InvalidOperation(String),
}
pub struct LocalFirstCollection<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
Tr: SyncTransport + Clone + 'static,
{
storage: Storage,
sync_engine: Arc<RwLock<SyncEngine<Tr>>>,
auto_sync: bool,
_phantom: PhantomData<T>,
}
pub struct CollectionBuilder<Tr>
where
Tr: SyncTransport + Clone + 'static,
{
storage: Storage,
transport: Tr,
auto_sync: bool,
replica_id: Option<ReplicaId>,
}
impl<Tr> CollectionBuilder<Tr>
where
Tr: SyncTransport + Clone + 'static,
{
pub fn new(storage: Storage, transport: Tr) -> Self {
Self {
storage,
transport,
auto_sync: false,
replica_id: None,
}
}
pub fn with_auto_sync(mut self, enabled: bool) -> Self {
self.auto_sync = enabled;
self
}
pub fn with_replica_id(mut self, replica_id: ReplicaId) -> Self {
self.replica_id = Some(replica_id);
self
}
pub fn build<T>(self) -> LocalFirstCollection<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
{
let sync_engine = if let Some(replica_id) = self.replica_id {
SyncEngine::with_replica_id(self.storage.clone(), self.transport.clone(), replica_id)
} else {
SyncEngine::new(self.storage.clone(), self.transport.clone())
};
LocalFirstCollection::<T, Tr> {
storage: self.storage,
sync_engine: Arc::new(RwLock::new(sync_engine)),
auto_sync: self.auto_sync,
_phantom: PhantomData,
}
}
}
impl<T, Tr> LocalFirstCollection<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
Tr: SyncTransport + Clone + 'static,
{
pub fn new(storage: Storage, transport: Tr) -> Self {
let sync_engine = SyncEngine::new(storage.clone(), transport);
Self {
storage,
sync_engine: Arc::new(RwLock::new(sync_engine)),
auto_sync: false,
_phantom: PhantomData,
}
}
pub fn with_replica_id(storage: Storage, transport: Tr, replica_id: ReplicaId) -> Self {
let sync_engine = SyncEngine::with_replica_id(storage.clone(), transport, replica_id);
Self {
storage,
sync_engine: Arc::new(RwLock::new(sync_engine)),
auto_sync: false,
_phantom: PhantomData,
}
}
pub fn replica_id(&self) -> ReplicaId {
ReplicaId::default()
}
pub async fn insert(&self, key: &str, value: &T) -> Result<(), CollectionError> {
self.storage.set(key, value).await?;
if self.auto_sync {
let mut engine = self.sync_engine.write().await;
engine.sync(key, value).await?;
}
Ok(())
}
pub async fn get(&self, key: &str) -> Result<Option<T>, CollectionError> {
self.storage.get(key).await.map_err(Into::into)
}
pub async fn remove(&self, key: &str) -> Result<(), CollectionError> {
self.storage.remove(key).await.map_err(Into::into)
}
pub async fn keys(&self) -> Result<Vec<String>, CollectionError> {
self.storage.keys().await.map_err(Into::into)
}
pub async fn values(&self) -> Result<Vec<T>, CollectionError> {
let keys = self.storage.keys().await.map_err(|e| CollectionError::Storage(e))?;
let mut values = Vec::new();
for key in keys {
if let Some(value) = self.get(&key).await? {
values.push(value);
}
}
Ok(values)
}
pub async fn contains_key(&self, key: &str) -> Result<bool, CollectionError> {
self.storage.contains_key(key).await.map_err(Into::into)
}
pub async fn len(&self) -> Result<usize, CollectionError> {
self.storage.len().await.map_err(Into::into)
}
pub async fn is_empty(&self) -> Result<bool, CollectionError> {
self.storage.is_empty().await.map_err(Into::into)
}
pub async fn start_sync(&self) -> Result<(), CollectionError> {
let mut engine = self.sync_engine.write().await;
engine.start_sync().await.map_err(Into::into)
}
pub async fn stop_sync(&self) -> Result<(), CollectionError> {
let mut engine = self.sync_engine.write().await;
engine.stop_sync().await.map_err(Into::into)
}
pub async fn sync_state(&self) -> Result<SyncState, CollectionError> {
let engine = self.sync_engine.read().await;
Ok(engine.state().await)
}
pub async fn is_online(&self) -> Result<bool, CollectionError> {
let engine = self.sync_engine.read().await;
Ok(engine.is_online().await)
}
pub async fn peer_count(&self) -> Result<usize, CollectionError> {
let engine = self.sync_engine.read().await;
Ok(engine.peer_count().await)
}
pub fn set_auto_sync(&mut self, enabled: bool) {
self.auto_sync = enabled;
}
pub async fn force_sync(&self) -> Result<(), CollectionError> {
let mut engine = self.sync_engine.write().await;
engine.process_messages().await.map_err(|e| CollectionError::Sync(e))?;
Ok(())
}
pub async fn insert_batch(&self, items: impl IntoIterator<Item = (String, T)>) -> Result<(), CollectionError> {
let items: Vec<_> = items.into_iter().collect();
for (key, value) in &items {
self.storage.set(key, value).await?;
}
if self.auto_sync {
let mut engine = self.sync_engine.write().await;
for (key, value) in items {
engine.sync(&key, &value).await?;
}
}
Ok(())
}
pub async fn update_batch(&self, updates: impl IntoIterator<Item = (String, T)>) -> Result<(), CollectionError> {
let updates: Vec<_> = updates.into_iter().collect();
for (key, _) in &updates {
if !self.storage.contains_key(key).await? {
return Err(CollectionError::NotFound(key.clone()));
}
}
for (key, value) in &updates {
self.storage.set(key, value).await?;
}
if self.auto_sync {
let mut engine = self.sync_engine.write().await;
for (key, value) in updates {
engine.sync(&key, &value).await?;
}
}
Ok(())
}
pub async fn remove_batch(&self, keys: impl IntoIterator<Item = String>) -> Result<(), CollectionError> {
let keys: Vec<_> = keys.into_iter().collect();
for key in &keys {
self.storage.remove(key).await?;
}
if self.auto_sync {
let mut engine = self.sync_engine.write().await;
for key in keys {
engine.sync(&key, &T::default()).await?;
}
}
Ok(())
}
pub async fn get_batch(&self, keys: impl IntoIterator<Item = String>) -> Result<Vec<(String, Option<T>)>, CollectionError> {
let keys: Vec<_> = keys.into_iter().collect();
let mut results = Vec::new();
for key in keys {
let value = self.storage.get(&key).await?;
results.push((key, value));
}
Ok(results)
}
pub async fn contains_keys(&self, keys: impl IntoIterator<Item = String>) -> Result<Vec<(String, bool)>, CollectionError> {
let keys: Vec<_> = keys.into_iter().collect();
let mut results = Vec::new();
for key in keys {
let exists = self.storage.contains_key(&key).await?;
results.push((key, exists));
}
Ok(results)
}
pub async fn peers(&self) -> Result<impl Iterator<Item = (ReplicaId, crate::sync::PeerInfo)>, CollectionError> {
let engine = self.sync_engine.read().await;
Ok(engine.peers().await)
}
pub async fn sync_info(&self) -> Result<SyncInfo, CollectionError> {
let engine = self.sync_engine.read().await;
Ok(SyncInfo {
sync_state: engine.state().await,
peer_count: engine.peer_count().await,
is_online: engine.is_online().await,
})
}
}
#[derive(Debug, Clone)]
pub struct SyncInfo {
pub sync_state: SyncState,
pub peer_count: usize,
pub is_online: bool,
}
pub struct CollectionIterator<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
Tr: SyncTransport + Clone + 'static,
{
collection: Arc<LocalFirstCollection<T, Tr>>,
keys: Vec<String>,
current_index: usize,
_phantom: PhantomData<T>,
}
impl<T, Tr> CollectionIterator<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
Tr: SyncTransport + Clone + 'static,
{
pub fn new(collection: Arc<LocalFirstCollection<T, Tr>>) -> Self {
Self {
collection,
keys: Vec::new(),
current_index: 0,
_phantom: PhantomData,
}
}
pub async fn load_keys(&mut self) -> Result<(), CollectionError> {
self.keys = self.collection.storage.keys().await.map_err(|e| CollectionError::Storage(e))?;
Ok(())
}
}
impl<T, Tr> Iterator for CollectionIterator<T, Tr>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + Mergeable + Default,
Tr: SyncTransport + Clone + 'static,
{
type Item = (String, T);
fn next(&mut self) -> Option<Self::Item> {
if self.current_index >= self.keys.len() {
return None;
}
let key = self.keys[self.current_index].clone();
self.current_index += 1;
Some((key, T::default())) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Storage;
use crate::transport::InMemoryTransport;
use crate::crdt::{LwwRegister, ReplicaId};
#[tokio::test]
async fn test_collection_basic_operations() {
let storage = Storage::memory();
let transport = InMemoryTransport::new();
let collection = LocalFirstCollection::<LwwRegister<String>, _>::new(storage, transport);
let value1 = LwwRegister::new("value1".to_string(), ReplicaId::default());
assert!(collection.insert("key1", &value1).await.is_ok());
let value = collection.get("key1").await.unwrap();
assert_eq!(value, Some(value1));
assert!(collection.remove("key1").await.is_ok());
let value = collection.get("key1").await.unwrap();
assert_eq!(value, None);
}
#[tokio::test]
async fn test_collection_builder() {
let storage = Storage::memory();
let transport = InMemoryTransport::new();
let collection = CollectionBuilder::new(storage, transport)
.with_auto_sync(true)
.build::<LwwRegister<String>>();
assert!(collection.auto_sync);
}
#[tokio::test]
async fn test_collection_batch_operations() {
let storage = Storage::memory();
let transport = InMemoryTransport::new();
let collection = LocalFirstCollection::<LwwRegister<String>, _>::new(storage, transport);
let items = vec![
("key1".to_string(), LwwRegister::new("value1".to_string(), ReplicaId::default())),
("key2".to_string(), LwwRegister::new("value2".to_string(), ReplicaId::default())),
("key3".to_string(), LwwRegister::new("value3".to_string(), ReplicaId::default())),
];
assert!(collection.insert_batch(items.clone()).await.is_ok());
let keys = vec!["key1".to_string(), "key2".to_string(), "key3".to_string()];
let results = collection.get_batch(keys).await.unwrap();
assert_eq!(results.len(), 3);
let exists_results = collection.contains_keys(vec!["key1".to_string(), "key2".to_string(), "key4".to_string()]).await.unwrap();
assert_eq!(exists_results, vec![
("key1".to_string(), true),
("key2".to_string(), true),
("key4".to_string(), false),
]);
let updates = vec![
("key1".to_string(), LwwRegister::new("updated1".to_string(), ReplicaId::default())),
("key2".to_string(), LwwRegister::new("updated2".to_string(), ReplicaId::default())),
];
assert!(collection.update_batch(updates).await.is_ok());
let keys_to_remove = vec!["key1".to_string(), "key2".to_string()];
assert!(collection.remove_batch(keys_to_remove).await.is_ok());
let remaining = collection.get_batch(vec!["key3".to_string()]).await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].0, "key3");
}
#[tokio::test]
async fn test_collection_batch_performance() {
let storage = Storage::memory();
let transport = InMemoryTransport::new();
let collection = LocalFirstCollection::<LwwRegister<String>, _>::new(storage, transport);
let items: Vec<_> = (0..1000)
.map(|i| (
format!("key{}", i),
LwwRegister::new(format!("value{}", i), ReplicaId::default())
))
.collect();
let start = std::time::Instant::now();
assert!(collection.insert_batch(items).await.is_ok());
let batch_duration = start.elapsed();
let individual_items: Vec<_> = (1000..2000)
.map(|i| (
format!("key{}", i),
LwwRegister::new(format!("value{}", i), ReplicaId::default())
))
.collect();
let start = std::time::Instant::now();
for (key, value) in &individual_items {
assert!(collection.insert(key, value).await.is_ok());
}
let individual_duration = start.elapsed();
println!("Batch insert (1000 items): {:?}", batch_duration);
println!("Individual insert (1000 items): {:?}", individual_duration);
let total_count = collection.len().await.unwrap();
assert_eq!(total_count, 2000);
}
}