use {
super::{
CollectionConfig,
CollectionFromDef,
Error,
READER,
SyncConfig,
WRITER,
When,
primitives::{Key, StoreId, Version},
},
crate::{
Group,
GroupId,
Network,
PeerId,
UniqueId,
collections::sync::{
Snapshot,
SnapshotStateMachine,
SnapshotSync,
protocol::SnapshotRequest,
},
groups::{
ApplyContext,
CommandError,
Cursor,
LeadershipPreference,
StateMachine,
},
primitives::{EncodeError, Encoded},
},
core::{any::type_name, borrow::Borrow, hash::Hash, ops::Range},
futures::{FutureExt, TryFutureExt},
serde::{Deserialize, Serialize},
std::hash::BuildHasherDefault,
tokio::sync::watch,
};
type HashSet<T> = im::HashSet<T, BuildHasherDefault<std::hash::DefaultHasher>>;
pub type SetWriter<T> = Set<T, WRITER>;
pub type SetReader<T> = Set<T, READER>;
pub struct Set<T: Key, const IS_WRITER: bool = WRITER> {
when: When,
group: Group<SetStateMachine<T>>,
data: watch::Receiver<HashSet<T>>,
}
impl<T: Key, const IS_WRITER: bool> Set<T, IS_WRITER> {
pub fn len(&self) -> usize {
self.data.borrow().len()
}
pub fn is_empty(&self) -> bool {
self.data.borrow().is_empty()
}
pub fn contains<Q>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.data.borrow().clone().contains(value)
}
pub fn is_subset<const W: bool>(&self, other: &Set<T, W>) -> bool {
self
.data
.borrow()
.clone()
.is_subset(other.data.borrow().clone())
}
pub fn iter(&self) -> impl Iterator<Item = T> {
let iter_clone = self.data.borrow().clone();
iter_clone.into_iter()
}
pub const fn when(&self) -> &When {
&self.when
}
pub fn version(&self) -> Version {
Version(self.group.committed())
}
pub fn group_id(&self) -> &GroupId {
self.group.id()
}
}
impl<T: Key> SetWriter<T> {
pub fn writer(network: &Network, store_id: impl Into<StoreId>) -> Self {
Self::writer_with_config(network, store_id, CollectionConfig::default())
}
pub fn writer_with_config(
network: &Network,
store_id: impl Into<StoreId>,
config: impl Into<CollectionConfig>,
) -> Self {
Self::create::<WRITER>(network, store_id, config.into())
}
pub fn new(network: &Network, store_id: impl Into<StoreId>) -> Self {
Self::writer(network, store_id)
}
pub fn new_with_config(
network: &Network,
store_id: impl Into<StoreId>,
config: impl Into<CollectionConfig>,
) -> Self {
Self::writer_with_config(network, store_id, config)
}
pub fn clear(
&self,
) -> impl Future<Output = Result<Version, Error<()>>> + Send + Sync + 'static
{
self.execute(
SetCommand::Clear,
|_| Error::Offline(()),
|_, _| unreachable!(),
)
}
pub fn insert(
&self,
value: T,
) -> impl Future<Output = Result<Version, Error<T>>> + Send + Sync + 'static
{
let value = Encoded(value);
self.execute(
SetCommand::Insert { value },
|cmd| match cmd {
SetCommand::Insert { value } => Error::Offline(value.0),
_ => unreachable!(),
},
|cmd, e| match cmd {
SetCommand::Insert { value } => Error::Encoding(value.0, e),
_ => unreachable!(),
},
)
}
pub fn extend(
&self,
values: impl IntoIterator<Item = T>,
) -> impl Future<Output = Result<Version, Error<Vec<T>>>> + Send + Sync + 'static
{
let entries: Vec<Encoded<T>> = values.into_iter().map(Encoded).collect();
let is_empty = entries.is_empty();
let current_version = self.group.committed();
let fut = self.execute(
SetCommand::Extend { entries },
|cmd| match cmd {
SetCommand::Extend { entries } => {
Error::Offline(entries.into_iter().map(|e| e.0).collect())
}
_ => unreachable!(),
},
|cmd, e| match cmd {
SetCommand::Extend { entries } => {
Error::Encoding(entries.into_iter().map(|e| e.0).collect(), e)
}
_ => unreachable!(),
},
);
async move {
if is_empty {
Ok(Version(current_version))
} else {
fut.await
}
}
}
pub fn remove<Q: Borrow<T>>(
&self,
value: &Q,
) -> impl Future<Output = Result<Version, Error<T>>> + Send + Sync + 'static
{
let value = Encoded(value.borrow().clone());
self.execute(
SetCommand::RemoveMany {
values: vec![value],
},
|cmd| match cmd {
SetCommand::RemoveMany { mut values } => {
Error::Offline(values.remove(0).0)
}
_ => unreachable!(),
},
|cmd, e| match cmd {
SetCommand::RemoveMany { mut values } => {
Error::Encoding(values.remove(0).0, e)
}
_ => unreachable!(),
},
)
}
pub fn remove_many(
&self,
values: impl IntoIterator<Item = T>,
) -> impl Future<Output = Result<Version, Error<Vec<T>>>> + Send + Sync + 'static
{
let values: Vec<Encoded<T>> = values.into_iter().map(Encoded).collect();
let is_empty = values.is_empty();
let current_version = self.group.committed();
let fut = self.execute(
SetCommand::RemoveMany { values },
|cmd| match cmd {
SetCommand::RemoveMany { values } => {
Error::Offline(values.into_iter().map(|v| v.0).collect())
}
_ => unreachable!(),
},
|cmd, e| match cmd {
SetCommand::RemoveMany { values } => {
Error::Encoding(values.into_iter().map(|v| v.0).collect(), e)
}
_ => unreachable!(),
},
);
async move {
if is_empty {
Ok(Version(current_version))
} else {
fut.await
}
}
}
}
impl<T: Key, const IS_WRITER: bool> Set<T, IS_WRITER> {
pub fn reader(
network: &Network,
store_id: impl Into<StoreId>,
) -> SetReader<T> {
Self::reader_with_config(network, store_id, CollectionConfig::default())
}
pub fn reader_with_config(
network: &Network,
store_id: impl Into<StoreId>,
config: impl Into<CollectionConfig>,
) -> SetReader<T> {
Self::create::<READER>(network, store_id, config.into())
}
fn create<const W: bool>(
network: &Network,
store_id: impl Into<StoreId>,
config: CollectionConfig,
) -> Set<T, W> {
let store_id = store_id.into();
let machine = SetStateMachine::new(
store_id, W,
config.sync,
network.local().id(),
);
let data = machine.data();
let mut builder = network
.groups()
.with_key(store_id)
.with_state_machine(machine);
for validator in config.auth {
builder = builder.require_ticket(validator);
}
let group = builder.join();
Set::<T, W> {
when: When::new(group.when().clone()),
group,
data,
}
}
}
impl<T: Key, const WRITER: bool> CollectionFromDef for Set<T, WRITER> {
type Reader = SetReader<T>;
type Writer = SetWriter<T>;
fn reader_with_config(
network: &Network,
store_id: StoreId,
config: CollectionConfig,
) -> Self::Reader {
Self::Reader::reader_with_config(network, store_id, config)
}
fn writer_with_config(
network: &Network,
store_id: StoreId,
config: CollectionConfig,
) -> Self::Writer {
Self::Writer::writer_with_config(network, store_id, config)
}
}
impl<T: Key> SetWriter<T> {
fn execute<TErr>(
&self,
command: SetCommand<T>,
offline_err: impl FnOnce(SetCommand<T>) -> Error<TErr> + Send + Sync + 'static,
encoding_err: impl FnOnce(SetCommand<T>, EncodeError) -> Error<TErr>
+ Send
+ Sync
+ 'static,
) -> impl Future<Output = Result<Version, Error<TErr>>> + Send + Sync + 'static
{
self
.group
.execute(command)
.map_err(|e| match e {
CommandError::Offline(mut items) => {
let command = items.remove(0);
offline_err(command)
}
CommandError::Encoding(mut items, err) => {
let command = items.remove(0);
encoding_err(command, err)
}
CommandError::GroupTerminated => Error::NetworkDown,
CommandError::NoCommands => unreachable!(),
})
.map(|position| position.map(Version))
}
}
struct SetStateMachine<T: Key> {
data: HashSet<T>,
latest: watch::Sender<HashSet<T>>,
store_id: StoreId,
local_id: PeerId,
state_sync: SnapshotSync<Self>,
is_writer: bool,
}
impl<T: Key> SetStateMachine<T> {
pub fn new(
store_id: StoreId,
is_writer: bool,
sync_config: SyncConfig,
local_id: PeerId,
) -> Self {
let data = HashSet::default();
let state_sync = SnapshotSync::new(sync_config, |request| {
SetCommand::TakeSnapshot(request)
});
let latest = watch::Sender::new(data.clone());
Self {
data,
latest,
store_id,
local_id,
state_sync,
is_writer,
}
}
pub fn data(&self) -> watch::Receiver<HashSet<T>> {
self.latest.subscribe()
}
}
impl<T: Key> StateMachine for SetStateMachine<T> {
type Command = SetCommand<T>;
type Query = ();
type QueryResult = ();
type StateSync = SnapshotSync<Self>;
fn apply(&mut self, command: Self::Command, ctx: &dyn ApplyContext) {
self.apply_batch([command], ctx);
}
fn apply_batch(
&mut self,
commands: impl IntoIterator<Item = Self::Command>,
ctx: &dyn ApplyContext,
) {
let mut commands_len = 0usize;
let mut sync_requests = vec![];
for command in commands {
match command {
SetCommand::Clear => {
self.data.clear();
}
SetCommand::Insert { value } => {
self.data.insert(value.0);
}
SetCommand::RemoveMany { values } => {
for value in values {
self.data.remove(&value.0);
}
}
SetCommand::Extend { entries } => {
for value in entries {
self.data.insert(value.0);
}
}
SetCommand::TakeSnapshot(request) => {
if request.requested_by != self.local_id
&& !self.state_sync.is_expired(&request)
{
sync_requests.push(request);
}
}
}
commands_len += 1;
}
self.latest.send_replace(self.data.clone());
if !sync_requests.is_empty() {
let snapshot = self.create_snapshot();
let position = Cursor::new(
ctx.current_term(),
ctx.committed().index() + commands_len as u64,
);
for request in sync_requests {
self
.state_sync
.serve_snapshot(request, position, snapshot.clone());
}
}
}
fn signature(&self) -> crate::UniqueId {
UniqueId::from("mosaik_collections_set")
.derive(self.store_id)
.derive(type_name::<T>())
}
fn query(&self, (): Self::Query) {}
fn state_sync(&self) -> Self::StateSync {
self.state_sync.clone()
}
fn leadership_preference(&self) -> LeadershipPreference {
if self.is_writer {
LeadershipPreference::Normal
} else {
LeadershipPreference::Observer
}
}
}
impl<T: Key> SnapshotStateMachine for SetStateMachine<T> {
type Snapshot = SetSnapshot<T>;
fn create_snapshot(&self) -> Self::Snapshot {
SetSnapshot {
data: self.data.clone(),
}
}
fn install_snapshot(&mut self, snapshot: Self::Snapshot) {
self.data = snapshot.data;
self.latest.send_replace(self.data.clone());
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound = "T: Key")]
enum SetCommand<T> {
Clear,
Insert { value: Encoded<T> },
RemoveMany { values: Vec<Encoded<T>> },
Extend { entries: Vec<Encoded<T>> },
TakeSnapshot(SnapshotRequest),
}
#[derive(Debug, Clone)]
pub struct SetSnapshot<T: Key> {
data: HashSet<T>,
}
impl<T: Key> Default for SetSnapshot<T> {
fn default() -> Self {
Self {
data: HashSet::default(),
}
}
}
impl<T: Key> Snapshot for SetSnapshot<T> {
type Item = Encoded<T>;
fn len(&self) -> u64 {
self.data.len() as u64
}
fn iter_range(
&self,
range: Range<u64>,
) -> Option<impl Iterator<Item = Self::Item>> {
if range.end > self.data.len() as u64 {
return None;
}
Some(
self
.data
.clone()
.into_iter()
.skip(range.start as usize)
.take((range.end - range.start) as usize)
.map(Encoded),
)
}
fn append(&mut self, items: impl IntoIterator<Item = Self::Item>) {
self.data.extend(items.into_iter().map(|e| e.0));
}
}