mod application;
mod commit;
mod create;
pub use create::EventTag;
pub(crate) mod crypto;
mod decryption;
mod error_handling;
mod process;
mod proposal;
mod validation;
use mdk_storage_traits::groups::types as group_types;
use mdk_storage_traits::groups::{MessageSortOrder, Pagination};
use mdk_storage_traits::messages::types as message_types;
use mdk_storage_traits::{GroupId, MdkStorageProvider};
use nostr::{EventId, Timestamp};
use sha2::{Digest, Sha256};
use crate::MDK;
use crate::error::Error;
use crate::groups::UpdateGroupResult;
pub(crate) type Result<T> = std::result::Result<T, Error>;
pub(crate) fn content_hash(content: &str) -> [u8; 32] {
Sha256::digest(content.as_bytes()).into()
}
pub(crate) fn create_processed_message_record(
wrapper_event_id: EventId,
message_event_id: Option<EventId>,
epoch: Option<u64>,
mls_group_id: Option<GroupId>,
state: message_types::ProcessedMessageState,
failure_reason: Option<String>,
) -> message_types::ProcessedMessage {
message_types::ProcessedMessage {
wrapper_event_id,
message_event_id,
processed_at: Timestamp::now(),
epoch,
mls_group_id,
state,
failure_reason,
}
}
pub(crate) const DEFAULT_EPOCH_LOOKBACK: u64 = 5;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct MessageProcessingContext {
pub sender_leaf_index: Option<u32>,
}
#[derive(Debug)]
pub struct MessageProcessingOutcome {
pub result: MessageProcessingResult,
pub context: MessageProcessingContext,
}
impl MessageProcessingOutcome {
pub(crate) fn new(result: MessageProcessingResult, sender_leaf_index: Option<u32>) -> Self {
Self {
result,
context: MessageProcessingContext { sender_leaf_index },
}
}
pub(crate) fn without_context(result: MessageProcessingResult) -> Self {
Self::new(result, None)
}
}
pub enum MessageProcessingResult {
ApplicationMessage(message_types::Message),
Proposal(UpdateGroupResult),
PendingProposal {
mls_group_id: GroupId,
},
IgnoredProposal {
mls_group_id: GroupId,
reason: String,
},
ExternalJoinProposal {
mls_group_id: GroupId,
},
Commit {
mls_group_id: GroupId,
},
Unprocessable {
mls_group_id: GroupId,
},
PreviouslyFailed,
}
impl std::fmt::Debug for MessageProcessingResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ApplicationMessage(msg) => f
.debug_struct("ApplicationMessage")
.field("id", &msg.id)
.field("pubkey", &msg.pubkey)
.field("kind", &msg.kind)
.field("mls_group_id", &"[REDACTED]")
.field("created_at", &msg.created_at)
.field("state", &msg.state)
.finish(),
Self::Proposal(result) => f
.debug_struct("Proposal")
.field("evolution_event_id", &result.evolution_event.id)
.field("mls_group_id", &"[REDACTED]")
.finish(),
Self::PendingProposal { .. } => f
.debug_struct("PendingProposal")
.field("mls_group_id", &"[REDACTED]")
.finish(),
Self::IgnoredProposal { reason, .. } => f
.debug_struct("IgnoredProposal")
.field("mls_group_id", &"[REDACTED]")
.field("reason", reason)
.finish(),
Self::ExternalJoinProposal { .. } => f
.debug_struct("ExternalJoinProposal")
.field("mls_group_id", &"[REDACTED]")
.finish(),
Self::Commit { .. } => f
.debug_struct("Commit")
.field("mls_group_id", &"[REDACTED]")
.finish(),
Self::Unprocessable { .. } => f
.debug_struct("Unprocessable")
.field("mls_group_id", &"[REDACTED]")
.finish(),
Self::PreviouslyFailed => f.debug_struct("PreviouslyFailed").finish(),
}
}
}
impl<Storage> MDK<Storage>
where
Storage: MdkStorageProvider,
{
pub fn get_message(
&self,
mls_group_id: &GroupId,
event_id: &EventId,
) -> Result<Option<message_types::Message>> {
self.storage()
.find_message_by_event_id(mls_group_id, event_id)
.map_err(|_e| Error::Message("Storage error while finding message".to_string()))
}
pub fn get_messages(
&self,
mls_group_id: &GroupId,
pagination: Option<Pagination>,
) -> Result<Vec<message_types::Message>> {
self.storage()
.messages(mls_group_id, pagination)
.map_err(|_e| Error::Message("Storage error while getting messages".to_string()))
}
pub fn get_last_message(
&self,
mls_group_id: &GroupId,
sort_order: MessageSortOrder,
) -> Result<Option<message_types::Message>> {
self.storage()
.last_message(mls_group_id, sort_order)
.map_err(|_e| Error::Message("Storage error while getting last message".to_string()))
}
pub(crate) fn save_message_record(&self, message: message_types::Message) -> Result<()> {
self.storage()
.save_message(message)
.map_err(|_e| Error::Message("Storage error while saving message".to_string()))
}
pub(crate) fn save_processed_message_record(
&self,
processed_message: message_types::ProcessedMessage,
) -> Result<()> {
self.storage()
.save_processed_message(processed_message)
.map_err(|_e| {
Error::Message("Storage error while saving processed message".to_string())
})
}
pub(crate) fn save_group_record(&self, group: group_types::Group) -> Result<()> {
self.storage()
.save_group(group)
.map_err(|_e| Error::Group("Storage error while saving group".to_string()))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use mdk_storage_traits::groups::Pagination;
use nostr::EventId;
use crate::test_util::*;
use crate::tests::create_test_mdk;
#[test]
fn test_get_message_not_found() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let group_id = create_test_group(&mdk, &creator, &members, &admins);
let non_existent_event_id = EventId::all_zeros();
let result = mdk.get_message(&group_id, &non_existent_event_id);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_get_messages_empty_group() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let group_id = create_test_group(&mdk, &creator, &members, &admins);
let messages = mdk
.get_messages(&group_id, None)
.expect("Failed to get messages");
assert!(messages.is_empty());
}
#[test]
fn test_get_messages_with_pagination() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let group_id = create_test_group(&mdk, &creator, &members, &admins);
for i in 0..15 {
let rumor = create_test_rumor(&creator, &format!("Message {}", i));
mdk.create_message(&group_id, rumor, None)
.expect("Failed to create message");
}
let page1 = mdk
.get_messages(&group_id, Some(Pagination::new(Some(10), Some(0))))
.expect("Failed to get first page");
assert_eq!(page1.len(), 10, "First page should have 10 messages");
let page2 = mdk
.get_messages(&group_id, Some(Pagination::new(Some(10), Some(10))))
.expect("Failed to get second page");
assert_eq!(page2.len(), 5, "Second page should have 5 messages");
let page1_ids: HashSet<_> = page1.iter().map(|m| m.id).collect();
let page2_ids: HashSet<_> = page2.iter().map(|m| m.id).collect();
assert!(
page1_ids.is_disjoint(&page2_ids),
"Pages should not have duplicate messages"
);
let all_messages = mdk
.get_messages(&group_id, None)
.expect("Failed to get all messages");
assert_eq!(
all_messages.len(),
15,
"Should get all 15 messages with default pagination"
);
let page3 = mdk
.get_messages(&group_id, Some(Pagination::new(Some(10), Some(20))))
.expect("Failed to get third page");
assert!(
page3.is_empty(),
"Should return empty when offset exceeds message count"
);
let small_page = mdk
.get_messages(&group_id, Some(Pagination::new(Some(3), Some(0))))
.expect("Failed to get small page");
assert_eq!(small_page.len(), 3, "Should respect small page size");
}
#[test]
fn test_get_messages_for_group() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let group_id = create_test_group(&mdk, &creator, &members, &admins);
let rumor1 = create_test_rumor(&creator, "First message");
let rumor2 = create_test_rumor(&creator, "Second message");
let _event1 = mdk
.create_message(&group_id, rumor1, None)
.expect("Failed to create first message");
let _event2 = mdk
.create_message(&group_id, rumor2, None)
.expect("Failed to create second message");
let messages = mdk
.get_messages(&group_id, None)
.expect("Failed to get messages");
assert_eq!(messages.len(), 2);
let contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
assert!(contents.contains(&"First message"));
assert!(contents.contains(&"Second message"));
for message in &messages {
assert_eq!(message.mls_group_id, group_id.clone());
}
}
#[test]
fn test_get_messages_nonexistent_group() {
let mdk = create_test_mdk();
let non_existent_group_id = crate::GroupId::from_slice(&[9, 9, 9, 9]);
let result = mdk.get_messages(&non_existent_group_id, None);
assert!(
result.is_err(),
"Should return error for non-existent group"
);
}
#[test]
fn test_get_nonexistent_message() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let group_id = create_test_group(&mdk, &creator, &members, &admins);
let non_existent_id = nostr::EventId::all_zeros();
let result = mdk.get_message(&group_id, &non_existent_id);
assert!(result.is_ok(), "Should succeed");
assert!(
result.unwrap().is_none(),
"Should return None for non-existent message"
);
}
}