use nostr::{EventId, Kind, Tag, TagKind, Timestamp, UnsignedEvent};
use openmls::prelude::*;
use tls_codec::Deserialize as TlsDeserialize;
use mdk_storage_traits::MdkStorageProvider;
use mdk_storage_traits::groups::types as group_types;
use mdk_storage_traits::welcomes::Pagination;
use mdk_storage_traits::welcomes::types as welcome_types;
use crate::MDK;
use crate::error::Error;
use crate::extension::NostrGroupDataExtension;
use crate::util::{ContentEncoding, decode_content};
#[derive(Debug)]
pub struct WelcomePreview {
pub staged_welcome: StagedWelcome,
pub nostr_group_data: NostrGroupDataExtension,
}
#[derive(Debug)]
pub struct JoinedGroupResult {
pub mls_group: MlsGroup,
pub nostr_group_data: NostrGroupDataExtension,
}
impl<Storage> MDK<Storage>
where
Storage: MdkStorageProvider,
{
pub fn get_welcome(&self, event_id: &EventId) -> Result<Option<welcome_types::Welcome>, Error> {
let welcome = self
.storage()
.find_welcome_by_event_id(event_id)
.map_err(|e| Error::Welcome(e.to_string()))?;
Ok(welcome)
}
pub fn get_pending_welcomes(
&self,
pagination: Option<Pagination>,
) -> Result<Vec<welcome_types::Welcome>, Error> {
let welcomes = self
.storage()
.pending_welcomes(pagination)
.map_err(|e| Error::Welcome(e.to_string()))?;
Ok(welcomes)
}
fn validate_welcome_event(event: &UnsignedEvent) -> Result<(), Error> {
if event.kind != Kind::MlsWelcome {
return Err(Error::InvalidWelcomeMessage);
}
let tags: Vec<&Tag> = event.tags.iter().collect();
if tags.len() < 3 {
return Err(Error::InvalidWelcomeMessage);
}
let mut has_relays = false;
let mut has_event_ref = false;
let mut has_encoding = false;
for tag in &tags {
match tag.kind() {
TagKind::Relays => {
let relay_slice = tag.as_slice();
if relay_slice.len() > 1 {
for relay_url in relay_slice.iter().skip(1) {
if nostr::RelayUrl::parse(relay_url).is_err() {
return Err(Error::InvalidWelcomeMessage);
}
}
has_relays = true;
}
}
kind if kind == TagKind::e() && tag.content().is_some_and(|c| !c.is_empty()) => {
has_event_ref = true;
}
TagKind::Client => {
match tag.content() {
Some(value) if !value.is_empty() => {}
_ => return Err(Error::InvalidWelcomeMessage),
}
}
TagKind::Custom(name) if name.as_ref() == "encoding" => {
if let Some(encoding_value) = tag.content() {
if encoding_value == "base64" {
has_encoding = true;
} else {
return Err(Error::InvalidWelcomeMessage);
}
} else {
return Err(Error::InvalidWelcomeMessage);
}
}
_ => {}
}
}
if !has_relays {
return Err(Error::InvalidWelcomeMessage);
}
if !has_event_ref {
return Err(Error::InvalidWelcomeMessage);
}
if !has_encoding {
return Err(Error::InvalidWelcomeMessage);
}
Ok(())
}
pub fn process_welcome(
&self,
wrapper_event_id: &EventId,
rumor_event: &UnsignedEvent,
) -> Result<welcome_types::Welcome, Error> {
Self::validate_welcome_event(rumor_event)?;
if let Some(processed_welcome) = self
.storage()
.find_processed_welcome_by_event_id(wrapper_event_id)
.map_err(|e| Error::Welcome(e.to_string()))?
{
if processed_welcome.state == welcome_types::ProcessedWelcomeState::Failed {
let reason = processed_welcome
.failure_reason
.unwrap_or_else(|| "unknown reason".to_string());
return Err(Error::WelcomePreviouslyFailed(reason));
}
return match processed_welcome.welcome_event_id {
Some(welcome_event_id) => self
.storage()
.find_welcome_by_event_id(&welcome_event_id)
.map_err(|e| Error::Welcome(e.to_string()))?
.ok_or_else(|| {
Error::Welcome("welcome record missing for processed welcome".to_string())
}),
None => Err(Error::Welcome(
"processed welcome missing welcome_event_id".to_string(),
)),
};
}
let welcome_preview = self.preview_welcome(wrapper_event_id, rumor_event)?;
let group = group_types::Group {
mls_group_id: welcome_preview
.staged_welcome
.group_context()
.group_id()
.clone()
.into(),
nostr_group_id: welcome_preview.nostr_group_data.nostr_group_id,
name: welcome_preview.nostr_group_data.name.clone(),
description: welcome_preview.nostr_group_data.description.clone(),
image_hash: welcome_preview.nostr_group_data.image_hash,
image_key: welcome_preview
.nostr_group_data
.image_key
.map(mdk_storage_traits::Secret::new),
image_nonce: welcome_preview
.nostr_group_data
.image_nonce
.map(mdk_storage_traits::Secret::new),
admin_pubkeys: welcome_preview.nostr_group_data.admins.clone(),
last_message_id: None,
last_message_at: None,
last_message_processed_at: None,
epoch: welcome_preview
.staged_welcome
.group_context()
.epoch()
.as_u64(),
state: group_types::GroupState::Pending,
self_update_state: group_types::SelfUpdateState::Required,
};
let mls_group_id = group.mls_group_id.clone();
self.storage()
.save_group(group)
.map_err(|e| Error::Group(e.to_string()))?;
self.storage()
.replace_group_relays(
&mls_group_id,
welcome_preview.nostr_group_data.relays.clone(),
)
.map_err(|e| Error::Group(e.to_string()))?;
let processed_welcome = welcome_types::ProcessedWelcome {
wrapper_event_id: *wrapper_event_id,
welcome_event_id: rumor_event.id,
processed_at: Timestamp::now(),
state: welcome_types::ProcessedWelcomeState::Processed,
failure_reason: None,
};
let rumor_event_id = rumor_event.id.ok_or(Error::MissingRumorEventId)?;
let welcome = welcome_types::Welcome {
id: rumor_event_id,
event: rumor_event.clone(),
mls_group_id: welcome_preview
.staged_welcome
.group_context()
.group_id()
.clone()
.into(),
nostr_group_id: welcome_preview.nostr_group_data.nostr_group_id,
group_name: welcome_preview.nostr_group_data.name,
group_description: welcome_preview.nostr_group_data.description,
group_image_hash: welcome_preview.nostr_group_data.image_hash,
group_image_key: welcome_preview
.nostr_group_data
.image_key
.map(mdk_storage_traits::Secret::new),
group_image_nonce: welcome_preview
.nostr_group_data
.image_nonce
.map(mdk_storage_traits::Secret::new),
group_admin_pubkeys: welcome_preview.nostr_group_data.admins,
group_relays: welcome_preview.nostr_group_data.relays,
welcomer: rumor_event.pubkey,
member_count: welcome_preview.staged_welcome.members().count() as u32,
state: welcome_types::WelcomeState::Pending,
wrapper_event_id: *wrapper_event_id,
};
self.storage()
.save_processed_welcome(processed_welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
self.storage()
.save_welcome(welcome.clone())
.map_err(|e| Error::Welcome(e.to_string()))?;
Ok(welcome)
}
pub fn accept_welcome(&self, welcome: &welcome_types::Welcome) -> Result<(), Error> {
let welcome_preview = self.preview_welcome(&welcome.wrapper_event_id, &welcome.event)?;
let mls_group = welcome_preview.staged_welcome.into_group(&self.provider)?;
let mut welcome = welcome.clone();
welcome.state = welcome_types::WelcomeState::Accepted;
self.storage()
.save_welcome(welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
if let Some(mut group) = self.get_group(&mls_group.group_id().into())? {
let mls_group_id = group.mls_group_id.clone();
group.state = group_types::GroupState::Active;
group.self_update_state = group_types::SelfUpdateState::Required;
self.storage().save_group(group).map_err(
|e: mdk_storage_traits::groups::error::GroupError| Error::Group(e.to_string()),
)?;
self.storage()
.replace_group_relays(&mls_group_id, welcome_preview.nostr_group_data.relays)
.map_err(|e| Error::Group(e.to_string()))?;
}
Ok(())
}
pub fn decline_welcome(&self, welcome: &welcome_types::Welcome) -> Result<(), Error> {
let welcome_preview = self.preview_welcome(&welcome.wrapper_event_id, &welcome.event)?;
let mls_group_id = welcome_preview.staged_welcome.group_context().group_id();
let mut welcome = welcome.clone();
welcome.state = welcome_types::WelcomeState::Declined;
self.storage()
.save_welcome(welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
if let Some(mut group) = self.get_group(&mls_group_id.into())? {
group.state = group_types::GroupState::Inactive;
self.storage()
.save_group(group)
.map_err(|e| Error::Group(e.to_string()))?;
}
Ok(())
}
fn parse_serialized_welcome(
&self,
mut welcome_message: &[u8],
) -> Result<(StagedWelcome, NostrGroupDataExtension), Error> {
let welcome_message_in = MlsMessageIn::tls_deserialize(&mut welcome_message)?;
let welcome: Welcome = match welcome_message_in.extract() {
MlsMessageBodyIn::Welcome(welcome) => welcome,
_ => return Err(Error::InvalidWelcomeMessage),
};
let sender_ratchet_config = SenderRatchetConfiguration::new(
self.config.out_of_order_tolerance,
self.config.maximum_forward_distance,
);
let mls_group_config = MlsGroupJoinConfig::builder()
.wire_format_policy(MIXED_CIPHERTEXT_WIRE_FORMAT_POLICY)
.use_ratchet_tree_extension(true)
.sender_ratchet_configuration(sender_ratchet_config)
.max_past_epochs(self.config.max_past_epochs)
.build();
let staged_welcome =
StagedWelcome::build_from_welcome(&self.provider, &mls_group_config, welcome)?
.replace_old_group()
.build()?;
let nostr_group_data =
NostrGroupDataExtension::from_group_context(staged_welcome.group_context())?;
Ok((staged_welcome, nostr_group_data))
}
fn preview_welcome(
&self,
wrapper_event_id: &EventId,
welcome_event: &UnsignedEvent,
) -> Result<WelcomePreview, Error> {
let encoding = match ContentEncoding::from_tags(welcome_event.tags.iter()) {
Some(enc) => enc,
None => {
let error_string = "Missing required encoding tag".to_string();
let processed_welcome = welcome_types::ProcessedWelcome {
wrapper_event_id: *wrapper_event_id,
welcome_event_id: welcome_event.id,
processed_at: Timestamp::now(),
state: welcome_types::ProcessedWelcomeState::Failed,
failure_reason: Some(error_string.clone()),
};
self.storage()
.save_processed_welcome(processed_welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
tracing::error!(
target: "mdk_core::welcomes::process_welcome",
"Error processing welcome: {}",
error_string
);
return Err(Error::Welcome(error_string));
}
};
let decoded_content = match decode_content(&welcome_event.content, encoding, "welcome") {
Ok((content, format)) => {
tracing::debug!(
target: "mdk_core::welcomes",
"Decoded welcome using {}", format
);
content
}
Err(e) => {
let error_string = format!(
"Error decoding welcome event content ({}): {:?}",
encoding.as_tag_value(),
e
);
let processed_welcome = welcome_types::ProcessedWelcome {
wrapper_event_id: *wrapper_event_id,
welcome_event_id: welcome_event.id,
processed_at: Timestamp::now(),
state: welcome_types::ProcessedWelcomeState::Failed,
failure_reason: Some(error_string.clone()),
};
self.storage()
.save_processed_welcome(processed_welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
tracing::error!(target: "mdk_core::welcomes::process_welcome", "Error processing welcome: {}", error_string);
return Err(Error::Welcome(error_string));
}
};
let welcome_preview = match self.parse_serialized_welcome(&decoded_content) {
Ok((staged_welcome, nostr_group_data)) => WelcomePreview {
staged_welcome,
nostr_group_data,
},
Err(e) => {
let error_string = format!("Error previewing welcome: {:?}", e);
let processed_welcome = welcome_types::ProcessedWelcome {
wrapper_event_id: *wrapper_event_id,
welcome_event_id: welcome_event.id,
processed_at: Timestamp::now(),
state: welcome_types::ProcessedWelcomeState::Failed,
failure_reason: Some(error_string.clone()),
};
self.storage()
.save_processed_welcome(processed_welcome)
.map_err(|e| Error::Welcome(e.to_string()))?;
tracing::error!(target: "mdk_core::welcomes::process_welcome", "Error processing welcome: {}", error_string);
return Err(Error::Welcome(error_string));
}
};
Ok(welcome_preview)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::*;
use crate::tests::create_test_mdk;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use nostr::{Keys, Kind, TagKind};
#[test]
fn test_welcome_event_structure_mip02_compliance() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let create_result = mdk
.create_group(
&creator.public_key(),
vec![
create_key_package_event(&mdk, &members[0]),
create_key_package_event(&mdk, &members[1]),
],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
assert_eq!(
create_result.welcome_rumors.len(),
2,
"Should have welcome rumors for both members"
);
for welcome_rumor in &create_result.welcome_rumors {
assert_eq!(
welcome_rumor.kind,
Kind::MlsWelcome,
"Welcome event must have kind 444 (MlsWelcome)"
);
let decoded_content = BASE64
.decode(&welcome_rumor.content)
.expect("Welcome content must be valid base64-encoded data");
assert!(
decoded_content.len() > 50,
"Welcome content should be substantial (typically > 50 bytes), got {} bytes",
decoded_content.len()
);
assert!(
welcome_rumor.tags.len() >= 3,
"Welcome event must have at least 3 tags (relays, e, encoding)"
);
let tags_vec: Vec<&nostr::Tag> = welcome_rumor.tags.iter().collect();
let relays_tag = tags_vec[0];
assert_eq!(
relays_tag.kind(),
TagKind::Relays,
"First tag must be 'relays' tag"
);
assert!(
!relays_tag.as_slice().is_empty(),
"Relays tag should contain relay URLs"
);
let event_ref_tag = tags_vec[1];
assert_eq!(
event_ref_tag.kind(),
TagKind::e(),
"Second tag must be 'e' (event reference) tag"
);
assert!(
event_ref_tag.content().is_some(),
"Event reference tag must have content (KeyPackage event ID)"
);
let encoding_tag = tags_vec
.iter()
.find(|t| matches!(t.kind(), TagKind::Custom(name) if name.as_ref() == "encoding"))
.expect("Welcome event must have an 'encoding' tag");
let encoding_value = encoding_tag
.content()
.expect("Encoding tag must have a value");
assert!(
!encoding_value.is_empty(),
"Encoding tag value must be non-empty"
);
assert_eq!(
encoding_value, "base64",
"Encoding tag value must be 'base64'"
);
assert!(
welcome_rumor.id.is_some(),
"Welcome rumor should have ID computed"
);
}
}
#[test]
fn test_welcome_validation_rejects_invalid_events() {
use nostr::RelayUrl;
let mdk = create_test_mdk();
let wrapper_event_id = EventId::all_zeros();
let mut tags1 = nostr::Tags::new();
tags1.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags1.push(nostr::Tag::event(EventId::all_zeros()));
tags1.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let wrong_kind_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::TextNote, tags: tags1,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &wrong_kind_event);
assert!(result.is_err(), "Should reject wrong kind");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags2 = nostr::Tags::new();
tags2.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
let missing_tags_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags2, content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &missing_tags_event);
assert!(result.is_err(), "Should reject missing tags");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags3 = nostr::Tags::new();
tags3.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags3.push(nostr::Tag::event(EventId::all_zeros()));
tags3.push(nostr::Tag::client("mdk".to_string()));
let missing_encoding_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags3,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &missing_encoding_event);
assert!(result.is_err(), "Should reject missing encoding tag");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags4 = nostr::Tags::new();
tags4.push(nostr::Tag::relays(vec![])); tags4.push(nostr::Tag::event(EventId::all_zeros()));
tags4.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let empty_relays_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags4,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &empty_relays_event);
assert!(result.is_err(), "Should reject empty relays tag");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags5 = nostr::Tags::new();
tags5.push(
nostr::Tag::parse(&["relays".to_string(), "http://invalid.com".to_string()]).unwrap(),
); tags5.push(nostr::Tag::event(EventId::all_zeros()));
tags5.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let invalid_relay_url_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags5,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &invalid_relay_url_event);
assert!(result.is_err(), "Should reject invalid relay URL format");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags6 = nostr::Tags::new();
tags6.push(nostr::Tag::parse(&["relays".to_string(), "wss://".to_string()]).unwrap()); tags6.push(nostr::Tag::event(EventId::all_zeros()));
tags6.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let incomplete_relay_url_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags6,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &incomplete_relay_url_event);
assert!(result.is_err(), "Should reject incomplete relay URL");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags7 = nostr::Tags::new();
tags7.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags7.push(nostr::Tag::parse(&["e".to_string(), "".to_string()]).unwrap()); tags7.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let empty_e_tag_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags7,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &empty_e_tag_event);
assert!(result.is_err(), "Should reject empty e tag content");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags8 = nostr::Tags::new();
tags8.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags8.push(nostr::Tag::event(EventId::all_zeros()));
tags8.push(nostr::Tag::parse(&["encoding".to_string(), "invalid".to_string()]).unwrap());
let invalid_encoding_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags8,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &invalid_encoding_event);
assert!(result.is_err(), "Should reject invalid encoding value");
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
let mut tags9 = nostr::Tags::new();
tags9.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags9.push(nostr::Tag::event(EventId::all_zeros()));
tags9.push(nostr::Tag::parse(&["client".to_string(), "".to_string()]).unwrap()); tags9.push(nostr::Tag::parse(&["encoding".to_string(), "base64".to_string()]).unwrap());
let empty_client_tag_event = UnsignedEvent {
id: None,
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags: tags9,
content: "test".to_string(),
};
let result = mdk.process_welcome(&wrapper_event_id, &empty_client_tag_event);
assert!(
result.is_err(),
"Should reject empty client tag content when present"
);
assert!(matches!(result.unwrap_err(), Error::InvalidWelcomeMessage));
}
#[test]
fn test_welcome_validation_accepts_missing_client_tag() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let member_kp_event = create_key_package_event(&mdk, &members[0]);
let create_result = mdk
.create_group(
&creator.public_key(),
vec![member_kp_event],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let original_welcome = &create_result.welcome_rumors[0];
let mut tags_without_client = nostr::Tags::new();
for tag in original_welcome.tags.iter() {
if tag.kind() != TagKind::Client {
tags_without_client.push(tag.clone());
}
}
assert!(
tags_without_client.len() < original_welcome.tags.len(),
"Should have fewer tags after removing client tag"
);
let welcome_without_client = UnsignedEvent {
id: original_welcome.id,
pubkey: original_welcome.pubkey,
created_at: original_welcome.created_at,
kind: original_welcome.kind,
tags: tags_without_client,
content: original_welcome.content.clone(),
};
let wrapper_event_id = EventId::all_zeros();
let result = mdk.process_welcome(&wrapper_event_id, &welcome_without_client);
assert!(
result.is_ok(),
"Welcome without client tag should be accepted, got: {:?}",
result.unwrap_err()
);
let welcome = result.unwrap();
assert_eq!(welcome.state, welcome_types::WelcomeState::Pending);
}
#[test]
fn test_welcome_content_validation_mip02() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let create_result = mdk
.create_group(
&creator.public_key(),
vec![create_key_package_event(&mdk, &members[0])],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let welcome_rumor = &create_result.welcome_rumors[0];
let decoded_content = BASE64
.decode(&welcome_rumor.content)
.expect("Welcome content should be valid base64");
assert!(
decoded_content.len() > 50,
"MLS Welcome messages should be substantial in size"
);
assert!(
!decoded_content.is_empty(),
"Decoded welcome should not be empty"
);
}
#[test]
fn test_welcome_references_correct_keypackage() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let kp_event1 = create_key_package_event(&mdk, &members[0]);
let kp_event2 = create_key_package_event(&mdk, &members[1]);
let kp1_id = kp_event1.id;
let kp2_id = kp_event2.id;
let create_result = mdk
.create_group(
&creator.public_key(),
vec![kp_event1, kp_event2],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
assert_eq!(
create_result.welcome_rumors.len(),
2,
"Should have 2 welcome rumors"
);
let mut welcome_event_refs = Vec::new();
for welcome_rumor in &create_result.welcome_rumors {
let event_ref_tag = welcome_rumor
.tags
.iter()
.find(|t| t.kind() == TagKind::e())
.expect("Welcome should have e tag");
let event_id_hex = event_ref_tag.content().expect("e tag should have content");
welcome_event_refs.push(event_id_hex.to_string());
}
assert!(
welcome_event_refs.contains(&kp1_id.to_hex()),
"Welcome should reference first KeyPackage event"
);
assert!(
welcome_event_refs.contains(&kp2_id.to_hex()),
"Welcome should reference second KeyPackage event"
);
}
#[test]
fn test_multiple_welcomes_for_multiple_members() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let member3 = Keys::generate();
let members_vec = vec![
create_key_package_event(&mdk, &members[0]),
create_key_package_event(&mdk, &members[1]),
create_key_package_event(&mdk, &member3),
];
let create_result = mdk
.create_group(
&creator.public_key(),
members_vec,
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
assert_eq!(
create_result.welcome_rumors.len(),
3,
"Should have welcome rumors for all 3 members"
);
for welcome_rumor in &create_result.welcome_rumors {
assert_eq!(welcome_rumor.kind, Kind::MlsWelcome);
assert!(welcome_rumor.tags.len() >= 3);
assert!(
BASE64.decode(&welcome_rumor.content).is_ok(),
"Welcome content should be valid base64"
);
}
}
#[test]
fn test_welcome_relays_tag_content() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let create_result = mdk
.create_group(
&creator.public_key(),
vec![create_key_package_event(&mdk, &members[0])],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let welcome_rumor = &create_result.welcome_rumors[0];
let relays_tag = welcome_rumor
.tags
.iter()
.find(|t| t.kind() == TagKind::Relays)
.expect("Welcome should have relays tag");
let relays_slice = relays_tag.as_slice();
assert!(
relays_slice.len() > 1,
"Relays tag should have at least tag name and one relay"
);
assert_eq!(
relays_slice[0], "relays",
"First element should be 'relays'"
);
for relay in relays_slice.iter().skip(1) {
assert!(
relay.starts_with("wss://") || relay.starts_with("ws://"),
"Relay URLs should start with wss:// or ws://, got: {}",
relay
);
}
}
#[test]
fn test_welcome_processing_flow() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let member_kp_event = create_key_package_event(&mdk, &members[0]);
let create_result = mdk
.create_group(
&creator.public_key(),
vec![member_kp_event],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let welcome_rumor = &create_result.welcome_rumors[0];
let wrapper_event_id = EventId::all_zeros();
let welcome = mdk
.process_welcome(&wrapper_event_id, welcome_rumor)
.expect("Failed to process welcome");
assert_eq!(welcome.state, welcome_types::WelcomeState::Pending);
assert_eq!(welcome.wrapper_event_id, wrapper_event_id);
assert!(
welcome.member_count >= 2,
"Group should have at least 2 members (creator + member)"
);
assert_eq!(
welcome_rumor.kind,
Kind::MlsWelcome,
"Welcome should be kind 444"
);
assert!(
welcome_rumor.tags.len() >= 3,
"Welcome should have at least 3 required tags"
);
}
#[test]
fn test_welcome_structure_consistency() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let create_result = mdk
.create_group(
&creator.public_key(),
vec![create_key_package_event(&mdk, &members[0])],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
let first_welcome = &create_result.welcome_rumors[0];
mdk.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let member3 = Keys::generate();
let add_result = mdk
.add_members(&group_id, &[create_key_package_event(&mdk, &member3)])
.expect("Failed to add member");
let second_welcome = &add_result
.welcome_rumors
.as_ref()
.expect("Should have welcome rumors")[0];
assert_eq!(first_welcome.kind, second_welcome.kind);
assert_eq!(first_welcome.tags.len(), second_welcome.tags.len());
let first_tags: Vec<&nostr::Tag> = first_welcome.tags.iter().collect();
let second_tags: Vec<&nostr::Tag> = second_welcome.tags.iter().collect();
assert_eq!(first_tags[0].kind(), second_tags[0].kind());
assert_eq!(first_tags[1].kind(), second_tags[1].kind());
assert!(
BASE64.decode(&first_welcome.content).is_ok(),
"First welcome should be valid base64"
);
assert!(
BASE64.decode(&second_welcome.content).is_ok(),
"Second welcome should be valid base64"
);
}
#[test]
fn test_welcome_processing_error_recovery() {
use crate::test_util::{create_key_package_event, create_nostr_group_config_data};
use nostr::Keys;
let alice_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_keys = Keys::generate();
let bob_device_a = create_test_mdk(); let bob_device_b = create_test_mdk();
let bob_key_package_event = create_key_package_event(&bob_device_a, &bob_keys);
let group_config = create_nostr_group_config_data(vec![alice_keys.public_key()]);
let group_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package_event.clone()],
group_config,
)
.expect("Failed to create group");
alice_mdk
.merge_pending_commit(&group_result.group.mls_group_id)
.expect("Failed to merge pending commit");
let welcome = &group_result.welcome_rumors[0];
let result = bob_device_b.process_welcome(&nostr::EventId::all_zeros(), welcome);
let error_msg = result
.expect_err("Processing welcome without signing key should fail")
.to_string();
assert!(
error_msg.contains("key") || error_msg.contains("Key") || error_msg.contains("storage"),
"Error message should mention key/storage issue: {}",
error_msg
);
let mut modified_welcome = welcome.clone();
let fake_event_id = nostr::EventId::all_zeros();
let mut new_tags = nostr::Tags::new();
new_tags.push(nostr::Tag::relays(vec![
nostr::RelayUrl::parse("wss://test.relay").unwrap(),
]));
new_tags.push(nostr::Tag::event(fake_event_id));
new_tags.push(nostr::Tag::custom(
nostr::TagKind::Custom("encoding".into()),
["base64"],
));
modified_welcome.tags = new_tags;
let result = bob_device_a.process_welcome(&nostr::EventId::all_zeros(), &modified_welcome);
if let Err(error) = result {
let error_msg = error.to_string();
assert!(!error_msg.is_empty(), "Error message should not be empty");
}
let result = bob_device_a.process_welcome(&nostr::EventId::all_zeros(), welcome);
assert!(
result.is_ok(),
"Processing welcome with correct signing key should succeed"
);
let pending_welcomes = bob_device_a
.get_pending_welcomes(None)
.expect("Failed to get pending welcomes");
assert!(
!pending_welcomes.is_empty(),
"Should have pending welcomes after successful processing"
);
bob_device_a
.accept_welcome(&pending_welcomes[0])
.expect("Failed to accept welcome");
let bob_groups = bob_device_a
.get_groups()
.expect("Failed to get Bob's groups");
assert_eq!(
bob_groups.len(),
1,
"Bob should have joined the group after successful welcome processing"
);
}
#[test]
fn test_large_group_welcome_size_limits() {
use crate::test_util::{create_key_package_event, create_nostr_group_config_data};
use nostr::Keys;
let alice_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let test_sizes = vec![5, 10, 20];
for group_size in test_sizes {
let mut members = Vec::new();
let mut key_package_events = Vec::new();
for _ in 0..group_size {
let member_keys = Keys::generate();
let key_package_event = create_key_package_event(&alice_mdk, &member_keys);
members.push(member_keys);
key_package_events.push(key_package_event);
}
let group_config = create_nostr_group_config_data(vec![alice_keys.public_key()]);
let group_result = alice_mdk
.create_group(&alice_keys.public_key(), key_package_events, group_config)
.unwrap_or_else(|_| panic!("Failed to create group with {} members", group_size));
assert_eq!(
group_result.welcome_rumors.len(),
group_size,
"Should have one welcome per member"
);
let welcome = &group_result.welcome_rumors[0];
let decoded_bytes: Vec<u8> = BASE64
.decode(&welcome.content)
.expect("Welcome content should be valid base64");
let binary_size = decoded_bytes.len();
let size_kb = binary_size as f64 / 1024.0;
println!(
"Group size: {} members, Welcome size: {} bytes ({:.2} KB)",
group_size, binary_size, size_kb
);
assert!(
BASE64.decode(&welcome.content).is_ok(),
"Welcome content should be valid base64"
);
if group_size <= 20 {
assert!(
size_kb < 100.0,
"Welcome for {} members should be under 100KB, got {:.2} KB",
group_size,
size_kb
);
}
assert_eq!(welcome.kind, Kind::MlsWelcome);
assert!(
welcome.tags.len() >= 3,
"Welcome should have at least 3 required tags"
);
}
}
#[test]
fn test_process_welcome_invalid_message() {
let mdk = create_test_mdk();
let invalid_welcome = nostr::UnsignedEvent {
id: Some(nostr::EventId::all_zeros()),
pubkey: Keys::generate().public_key(),
created_at: nostr::Timestamp::now(),
kind: Kind::MlsWelcome,
tags: nostr::Tags::new(),
content: "invalid_base64_content!!!".to_string(), };
let result = mdk.process_welcome(&nostr::EventId::all_zeros(), &invalid_welcome);
assert!(
result.is_err(),
"Should fail when welcome content is invalid base64"
);
}
#[test]
fn test_process_welcome_missing_rumor_id() {
let mdk = create_test_mdk();
let (creator, members, admins) = create_test_group_members();
let member_kp_event = create_key_package_event(&mdk, &members[0]);
let create_result = mdk
.create_group(
&creator.public_key(),
vec![member_kp_event],
create_nostr_group_config_data(admins),
)
.expect("Failed to create group");
let mut welcome_without_id = create_result.welcome_rumors[0].clone();
welcome_without_id.id = None;
let result = mdk.process_welcome(&nostr::EventId::all_zeros(), &welcome_without_id);
assert!(
result.is_err(),
"Should return error when rumor event ID is missing"
);
let error = result.unwrap_err();
assert_eq!(
error,
crate::error::Error::MissingRumorEventId,
"Error should be MissingRumorEventId"
);
}
#[test]
fn test_get_pending_welcomes_empty() {
let mdk = create_test_mdk();
let welcomes = mdk.get_pending_welcomes(None).expect("Should succeed");
assert_eq!(
welcomes.len(),
0,
"Should have no pending welcomes initially"
);
}
#[test]
fn test_accept_nonexistent_welcome() {
use std::collections::BTreeSet;
let mdk = create_test_mdk();
let fake_welcome = welcome_types::Welcome {
id: nostr::EventId::all_zeros(),
event: nostr::UnsignedEvent {
id: Some(nostr::EventId::all_zeros()),
pubkey: Keys::generate().public_key(),
created_at: nostr::Timestamp::now(),
kind: Kind::MlsWelcome,
tags: nostr::Tags::new(),
content: "fake".to_string(),
},
mls_group_id: crate::GroupId::from_slice(&[1, 2, 3, 4]),
nostr_group_id: [0u8; 32],
group_name: "Fake Group".to_string(),
group_description: "Fake Description".to_string(),
group_image_hash: None,
group_image_key: None,
group_image_nonce: None,
group_admin_pubkeys: BTreeSet::new(),
group_relays: BTreeSet::new(),
welcomer: Keys::generate().public_key(),
member_count: 2,
state: welcome_types::WelcomeState::Pending,
wrapper_event_id: nostr::EventId::all_zeros(),
};
let result = mdk.accept_welcome(&fake_welcome);
assert!(
result.is_err(),
"Should fail when accepting non-existent welcome"
);
}
#[test]
fn test_leave_group() {
use crate::test_util::{create_test_group, create_test_group_members};
let (creator, members, admins) = create_test_group_members();
let creator_mdk = create_test_mdk();
let group_id = create_test_group(&creator_mdk, &creator, &members, &admins);
let non_member_mdk = create_test_mdk();
let result = non_member_mdk.leave_group(&group_id);
assert!(
result.is_err(),
"Should fail when leaving a group you haven't joined"
);
}
#[test]
fn test_get_pending_welcomes_with_pagination() {
use crate::test_util::{create_key_package_event, create_nostr_group_config_data};
use nostr::Keys;
let mdk = create_test_mdk();
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let bob_kp = create_key_package_event(&mdk, &bob_keys);
let group_config = create_nostr_group_config_data(vec![alice_keys.public_key()]);
let result = mdk
.create_group(&alice_keys.public_key(), vec![bob_kp], group_config)
.expect("Failed to create group");
mdk.merge_pending_commit(&result.group.mls_group_id)
.expect("Failed to merge pending commit");
let welcome_rumor = &result.welcome_rumors[0];
mdk.process_welcome(&nostr::EventId::all_zeros(), welcome_rumor)
.expect("Failed to process welcome");
let default_welcomes = mdk
.get_pending_welcomes(None)
.expect("Failed to get welcomes");
assert_eq!(default_welcomes.len(), 1, "Should have 1 pending welcome");
let paginated_welcomes = mdk
.get_pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
.expect("Failed to get paginated welcomes");
assert_eq!(
paginated_welcomes.len(),
1,
"Should have 1 welcome with pagination"
);
let empty_page = mdk
.get_pending_welcomes(Some(Pagination::new(Some(10), Some(100))))
.expect("Failed to get empty page");
assert_eq!(
empty_page.len(),
0,
"Should return empty when offset is beyond available welcomes"
);
let limited = mdk
.get_pending_welcomes(Some(Pagination::new(Some(1), Some(0))))
.expect("Failed to get limited welcomes");
assert_eq!(
limited.len(),
1,
"Should return exactly 1 welcome with limit 1"
);
}
#[test]
fn test_failed_welcome_retry_returns_original_error() {
use nostr::RelayUrl;
let mdk = create_test_mdk();
let wrapper_event_id = EventId::from_slice(&[1u8; 32]).unwrap();
let mut tags = nostr::Tags::new();
tags.push(nostr::Tag::relays(vec![
RelayUrl::parse("wss://relay.example.com").unwrap(),
]));
tags.push(nostr::Tag::event(EventId::all_zeros()));
tags.push(nostr::Tag::client("mdk".to_string()));
tags.push(nostr::Tag::custom(
nostr::TagKind::Custom("encoding".into()),
["base64"],
));
let invalid_welcome = UnsignedEvent {
id: Some(EventId::all_zeros()),
pubkey: Keys::generate().public_key(),
created_at: Timestamp::now(),
kind: Kind::MlsWelcome,
tags,
content: "not_valid_base64!!!".to_string(),
};
let first_result = mdk.process_welcome(&wrapper_event_id, &invalid_welcome);
assert!(first_result.is_err(), "First attempt should fail");
let first_error = first_result.unwrap_err();
assert!(
matches!(first_error, Error::Welcome(ref msg) if msg.contains("decoding")),
"First error should be about decoding, got: {:?}",
first_error
);
let second_result = mdk.process_welcome(&wrapper_event_id, &invalid_welcome);
assert!(second_result.is_err(), "Second attempt should also fail");
let second_error = second_result.unwrap_err();
match second_error {
Error::WelcomePreviouslyFailed(reason) => {
assert!(
reason.contains("decoding"),
"Failure reason should contain original error about decoding, got: {}",
reason
);
}
other => {
panic!("Expected WelcomePreviouslyFailed error, got: {:?}", other);
}
}
}
}