use std::{
collections::{BTreeMap, HashMap},
task::Poll,
};
use futures::{
future::{self, Ready},
task::Context,
};
use tower::Service;
use tracing::{debug, trace, warn};
use crate::services::messages::{multipart::buffer::MultipartMessageBuffer, ServiceError};
use xaynet_core::{
crypto::{PublicEncryptKey, PublicSigningKey},
message::{Chunk, DecodeError, FromBytes, Message, Payload, Sum, Sum2, Tag, Update},
};
#[derive(Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct MessageBuilder {
participant_pk: PublicSigningKey,
coordinator_pk: PublicEncryptKey,
tag: Tag,
last_chunk_id: Option<u16>,
data: BTreeMap<u16, Vec<u8>>,
}
impl MessageBuilder {
fn new(tag: Tag, participant_pk: PublicSigningKey, coordinator_pk: PublicEncryptKey) -> Self {
MessageBuilder {
tag,
participant_pk,
coordinator_pk,
data: BTreeMap::new(),
last_chunk_id: None,
}
}
fn has_all_chunks(&self) -> bool {
self.last_chunk_id
.map(|last_chunk_id| {
self.data.len() >= (last_chunk_id as usize + 1)
})
.unwrap_or(false)
}
fn add_chunk(&mut self, chunk: Chunk) {
let Chunk { id, last, data, .. } = chunk;
if last {
self.last_chunk_id = Some(id);
}
self.data.insert(id, data);
}
fn into_message(self) -> Result<Message, DecodeError> {
let mut bytes = MultipartMessageBuffer::from(self.data);
let payload = match self.tag {
Tag::Sum => Sum::from_byte_stream(&mut bytes).map(Into::into)?,
Tag::Update => Update::from_byte_stream(&mut bytes).map(Into::into)?,
Tag::Sum2 => Sum2::from_byte_stream(&mut bytes).map(Into::into)?,
};
let message = Message {
signature: None,
participant_pk: self.participant_pk,
coordinator_pk: self.coordinator_pk,
tag: self.tag,
is_multipart: false,
payload,
};
Ok(message)
}
}
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
pub struct MessageId {
message_id: u16,
participant_pk: PublicSigningKey,
}
pub struct MultipartHandler {
message_builders: HashMap<MessageId, MessageBuilder>,
}
impl MultipartHandler {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
message_builders: HashMap::new(),
}
}
}
impl Service<Message> for MultipartHandler {
type Response = Option<Message>;
type Error = ServiceError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, message: Message) -> Self::Future {
if !message.is_multipart {
trace!("message is not multipart, nothing to do");
return ready_ok(Some(message));
}
debug!("handling multipart message");
if let Message {
tag,
participant_pk,
coordinator_pk,
payload: Payload::Chunk(chunk),
..
} = message
{
let id = MessageId {
message_id: chunk.message_id,
participant_pk,
};
let mp_message = self.message_builders.entry(id.clone()).or_insert_with(|| {
debug!("new multipart message (id = {})", id.message_id);
MessageBuilder::new(tag, participant_pk, coordinator_pk)
});
mp_message.add_chunk(chunk);
if mp_message.has_all_chunks() {
debug!("received the final message chunk, now parsing the full message");
match self.message_builders.remove(&id).unwrap().into_message() {
Ok(message) => {
debug!("multipart message succesfully parsed");
ready_ok(Some(message))
}
Err(e) => {
warn!("invalid multipart message: {}", e);
ready_err(ServiceError::Parsing(e))
}
}
} else {
ready_ok(None)
}
} else {
panic!("multipart flag is set but payload is not a multipart message");
}
}
}
fn ready_ok<T, E>(t: T) -> Ready<Result<T, E>> {
future::ready(Ok(t))
}
fn ready_err<T, E>(e: E) -> Ready<Result<T, E>> {
future::ready(Err(e))
}
#[cfg(test)]
mod tests {
use std::iter;
use tokio_test::assert_ready;
use tower_test::mock::Spawn;
use xaynet_core::crypto::{ByteObject, PublicEncryptKey, Signature};
use super::*;
fn spawn_svc() -> Spawn<MultipartHandler> {
Spawn::new(MultipartHandler::new())
}
fn sum() -> (Vec<u8>, Sum) {
let mut start_byte: u8 = 0xff;
let f = move || {
start_byte = start_byte.wrapping_add(1);
Some(start_byte)
};
let bytes: Vec<u8> = iter::from_fn(f)
.take(PublicEncryptKey::LENGTH + Signature::LENGTH)
.collect();
let sum = Sum {
sum_signature: Signature::from_slice(&bytes[..Signature::LENGTH]).unwrap(),
ephm_pk: PublicEncryptKey::from_slice(&bytes[Signature::LENGTH..]).unwrap(),
};
(bytes, sum)
}
fn message_builder() -> MessageBuilder {
let participant_pk = PublicSigningKey::zeroed();
let coordinator_pk = PublicEncryptKey::zeroed();
let tag = Tag::Sum;
MessageBuilder::new(tag, participant_pk, coordinator_pk)
}
fn chunks(mut data: Vec<u8>) -> (Chunk, Chunk, Chunk, Chunk, Chunk) {
assert_eq!(data.len(), 96);
let data5 = data.split_off(10);
assert_eq!(data5.len(), 86);
assert_eq!(data.len(), 10);
let data4 = data.split_off(6);
assert_eq!(data4.len(), 4);
assert_eq!(data.len(), 6);
let data3 = data.split_off(3);
assert_eq!(data3.len(), 3);
assert_eq!(data.len(), 3);
let data2 = data.split_off(1);
assert_eq!(data2.len(), 2);
assert_eq!(data.len(), 1);
let chunk1 = Chunk {
id: 0,
message_id: 1234,
last: false,
data,
};
let chunk2 = Chunk {
id: 1,
message_id: 1234,
last: false,
data: data2,
};
let chunk3 = Chunk {
id: 2,
message_id: 1234,
last: false,
data: data3,
};
let chunk4 = Chunk {
id: 3,
message_id: 1234,
last: false,
data: data4,
};
let chunk5 = Chunk {
id: 4,
message_id: 1234,
last: true,
data: data5,
};
(chunk1, chunk2, chunk3, chunk4, chunk5)
}
#[test]
fn test_message_builder_in_order() {
let mut msg = message_builder();
let (data, sum) = sum();
let (c1, c2, c3, c4, c5) = chunks(data);
assert!(msg.data.is_empty());
assert!(msg.last_chunk_id.is_none());
msg.add_chunk(c1);
assert_eq!(msg.data.len(), 1);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c2);
assert_eq!(msg.data.len(), 2);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c3);
assert_eq!(msg.data.len(), 3);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c4);
assert_eq!(msg.data.len(), 4);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c5);
assert_eq!(msg.data.len(), 5);
assert_eq!(msg.last_chunk_id, Some(4));
assert!(msg.has_all_chunks());
let actual = msg.into_message().unwrap();
let expected =
Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum);
assert_eq!(actual, expected);
}
#[test]
fn test_message_builder_out_of_order() {
let mut msg = message_builder();
let (data, sum) = sum();
let (c1, c2, c3, c4, c5) = chunks(data);
assert!(msg.data.is_empty());
assert!(msg.last_chunk_id.is_none());
msg.add_chunk(c3);
assert_eq!(msg.data.len(), 1);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c1);
assert_eq!(msg.data.len(), 2);
assert!(msg.last_chunk_id.is_none());
assert!(!msg.has_all_chunks());
msg.add_chunk(c5);
assert_eq!(msg.data.len(), 3);
assert_eq!(msg.last_chunk_id, Some(4));
assert!(!msg.has_all_chunks());
msg.add_chunk(c2);
assert_eq!(msg.data.len(), 4);
assert_eq!(msg.last_chunk_id, Some(4));
assert!(!msg.has_all_chunks());
msg.add_chunk(c4);
assert_eq!(msg.data.len(), 5);
assert_eq!(msg.last_chunk_id, Some(4));
assert!(msg.has_all_chunks());
let actual = msg.into_message().unwrap();
let expected =
Message::new_sum(PublicSigningKey::zeroed(), PublicEncryptKey::zeroed(), sum);
assert_eq!(actual, expected);
}
#[tokio::test]
async fn message_handler() {
let mut task = spawn_svc();
assert_ready!(task.poll_ready()).unwrap();
let coordinator_pk =
PublicEncryptKey::from_slice(&[0x00; PublicSigningKey::LENGTH]).unwrap();
let (data, sum) = sum();
let (c1, c2, c3, c4, c5) = chunks(data.clone());
let pk1 = PublicSigningKey::from_slice(&[0x11; PublicSigningKey::LENGTH]).unwrap();
let message_id1 = MessageId {
message_id: 1234,
participant_pk: pk1,
};
let make_message1 =
|chunk: &Chunk| Message::new_multipart(pk1, coordinator_pk, chunk.clone(), Tag::Sum);
let pk2 = PublicSigningKey::from_slice(&[0x22; PublicSigningKey::LENGTH]).unwrap();
let message_id2 = MessageId {
message_id: 1234,
participant_pk: pk2,
};
let make_message2 =
|chunk: &Chunk| Message::new_multipart(pk2, coordinator_pk, chunk.clone(), Tag::Sum);
assert!(task.call(make_message1(&c3)).await.unwrap().is_none());
assert_eq!(task.get_ref().message_builders.len(), 1);
let builder = task.get_ref().message_builders.get(&message_id1).unwrap();
assert_eq!(builder.data.len(), 1);
assert!(task.call(make_message2(&c3)).await.unwrap().is_none());
assert_eq!(task.get_ref().message_builders.len(), 2);
let builder = task.get_ref().message_builders.get(&message_id2).unwrap();
assert_eq!(builder.data.len(), 1);
assert!(task.call(make_message1(&c5)).await.unwrap().is_none());
assert!(task.call(make_message2(&c5)).await.unwrap().is_none());
assert!(task.call(make_message1(&c1)).await.unwrap().is_none());
assert!(task.call(make_message2(&c1)).await.unwrap().is_none());
assert!(task.call(make_message1(&c4)).await.unwrap().is_none());
assert!(task.call(make_message2(&c4)).await.unwrap().is_none());
let builder = task.get_ref().message_builders.get(&message_id1).unwrap();
assert_eq!(builder.data.len(), 4);
let builder = task.get_ref().message_builders.get(&message_id2).unwrap();
assert_eq!(builder.data.len(), 4);
let res1 = task.call(make_message1(&c2)).await.unwrap().unwrap();
let res2 = task.call(make_message2(&c2)).await.unwrap().unwrap();
assert!(task.get_ref().message_builders.get(&message_id1).is_none());
assert!(task.get_ref().message_builders.get(&message_id2).is_none());
assert_eq!(res1, Message::new_sum(pk1, coordinator_pk, sum.clone()));
assert_eq!(res2, Message::new_sum(pk2, coordinator_pk, sum.clone()));
}
}