use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use hashgraph_like_consensus::{
events::BroadcastEventBus, signing::EthereumConsensusSigner, storage::InMemoryConsensusStorage,
};
use openmls_rust_crypto::MemoryStorage;
use crate::core::{
ConsensusPlugin, ConversationPluginsFactory, DeterministicStewardList, PeerScoreStorage,
PeerScoringService, ScoringConfig, StewardListConfig, default_score_deltas,
};
use crate::mls_crypto::{DeMlsStorage, KeyPackageBytes, MlsCredentials, MlsError, OpenMlsService};
#[derive(Default)]
pub struct MemoryDeMlsStorage {
key_package_refs: RwLock<HashSet<Vec<u8>>>,
mls: MemoryStorage,
}
impl MemoryDeMlsStorage {
pub fn new() -> Self {
Self::default()
}
}
impl DeMlsStorage for MemoryDeMlsStorage {
type MlsStorage = MemoryStorage;
type StorageError = openmls_rust_crypto::MemoryStorageError;
fn store_key_package_ref(&self, hash_ref: &[u8]) -> Result<(), MlsError> {
self.key_package_refs.write()?.insert(hash_ref.to_vec());
Ok(())
}
fn is_our_key_package(&self, hash_ref: &[u8]) -> Result<bool, MlsError> {
Ok(self.key_package_refs.read()?.contains(hash_ref))
}
fn remove_key_package_ref(&self, hash_ref: &[u8]) -> Result<(), MlsError> {
self.key_package_refs.write()?.remove(hash_ref);
Ok(())
}
fn mls_storage(&self) -> &Self::MlsStorage {
&self.mls
}
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryPeerScoreStorage {
scores: HashMap<Vec<u8>, i64>,
}
impl InMemoryPeerScoreStorage {
pub fn new() -> Self {
Self::default()
}
}
impl PeerScoreStorage for InMemoryPeerScoreStorage {
fn get(&self, member_id: &[u8]) -> Option<i64> {
self.scores.get(member_id).copied()
}
fn set(&mut self, member_id: &[u8], score: i64) {
self.scores.insert(member_id.to_vec(), score);
}
fn remove(&mut self, member_id: &[u8]) {
self.scores.remove(member_id);
}
fn all_scores(&self) -> Vec<(Vec<u8>, i64)> {
self.scores.iter().map(|(k, v)| (k.clone(), *v)).collect()
}
}
pub struct DefaultConsensusPlugin;
impl ConsensusPlugin for DefaultConsensusPlugin {
type Scope = String;
type ConsensusStorage = InMemoryConsensusStorage<String>;
type EventBus = BroadcastEventBus<String>;
type Signer = EthereumConsensusSigner;
fn new_storage() -> Self::ConsensusStorage {
InMemoryConsensusStorage::new()
}
fn new_event_bus() -> Self::EventBus {
BroadcastEventBus::default()
}
}
pub type DefaultMlsService = OpenMlsService<Arc<MemoryDeMlsStorage>>;
pub type DefaultPeerScoring = PeerScoringService<InMemoryPeerScoreStorage>;
pub type DefaultStewardList = DeterministicStewardList;
pub struct DefaultConversationPluginsFactory {
pub(crate) storage: Arc<MemoryDeMlsStorage>,
pub(crate) credentials: Arc<MlsCredentials>,
}
impl DefaultConversationPluginsFactory {
pub fn new(storage: Arc<MemoryDeMlsStorage>, credentials: Arc<MlsCredentials>) -> Self {
Self {
storage,
credentials,
}
}
}
impl ConversationPluginsFactory for DefaultConversationPluginsFactory {
type Mls = DefaultMlsService;
type Scoring = DefaultPeerScoring;
type StewardList = DefaultStewardList;
fn create_mls(&self, conversation_id: String) -> Result<Self::Mls, MlsError> {
OpenMlsService::new_as_creator(
conversation_id,
Arc::clone(&self.storage),
Arc::clone(&self.credentials),
)
}
fn welcome_mls(&self, welcome_bytes: &[u8]) -> Result<Option<Self::Mls>, MlsError> {
OpenMlsService::new_from_welcome(
welcome_bytes,
Arc::clone(&self.storage),
Arc::clone(&self.credentials),
)
}
fn make_scoring(&self, config: &ScoringConfig) -> Self::Scoring {
PeerScoringService::new(
InMemoryPeerScoreStorage::new(),
default_score_deltas(),
config.clone(),
)
}
fn make_steward_list(
&self,
conversation_id: &[u8],
config: StewardListConfig,
) -> Self::StewardList {
DeterministicStewardList::empty(conversation_id.to_vec(), config)
}
fn generate_key_package(&self) -> Result<KeyPackageBytes, MlsError> {
OpenMlsService::<Arc<MemoryDeMlsStorage>>::generate_key_package(
&self.storage,
&self.credentials,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_storage_round_trip() {
let mut storage = InMemoryPeerScoreStorage::new();
assert_eq!(storage.get(b"alice"), None);
storage.set(b"alice", 42);
assert_eq!(storage.get(b"alice"), Some(42));
storage.set(b"bob", -3);
let all = storage.all_scores();
assert_eq!(all.len(), 2);
storage.remove(b"alice");
assert_eq!(storage.get(b"alice"), None);
assert_eq!(storage.all_scores().len(), 1);
}
}