use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use asynchronous_codec::FramedRead;
use cid::CidGeneric;
use fnv::FnvHashMap;
use futures_util::future::{BoxFuture, Fuse, FusedFuture, FutureExt};
use futures_util::stream::StreamExt;
use tracing::debug;
use crate::cid_prefix::CidPrefix;
use crate::message::Codec;
use crate::multihasher::{MultihasherError, MultihasherTable};
use crate::proto::message::mod_Message::BlockPresenceType;
use crate::proto::message::mod_Message::Wantlist;
use crate::proto::message::Message;
pub(crate) struct IncomingStream<const S: usize> {
multihasher: Arc<MultihasherTable<S>>,
stream: FramedRead<libp2p_swarm::Stream, Codec>,
processing: Fuse<BoxFuture<'static, Option<IncomingMessage<S>>>>,
}
#[derive(Debug, Default)]
pub(crate) struct ClientMessage<const S: usize> {
pub(crate) block_presences: FnvHashMap<CidGeneric<S>, BlockPresenceType>,
pub(crate) blocks: FnvHashMap<CidGeneric<S>, Vec<u8>>,
}
#[derive(Debug, Default)]
pub(crate) struct ServerMessage {
pub(crate) wantlist: Wantlist,
}
#[derive(Debug, Default)]
#[doc(hidden)]
pub struct IncomingMessage<const S: usize> {
pub(crate) client: Option<ClientMessage<S>>,
pub(crate) server: Option<ServerMessage>,
}
impl<const S: usize> IncomingStream<S> {
pub(crate) fn new(stream: libp2p_swarm::Stream, multihasher: Arc<MultihasherTable<S>>) -> Self {
IncomingStream {
multihasher,
stream: FramedRead::new(stream, Codec),
processing: Fuse::terminated(),
}
}
}
impl<const S: usize> fmt::Debug for IncomingStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("IncomingStream { .. }")
}
}
impl<const S: usize> futures_core::stream::Stream for IncomingStream<S> {
type Item = IncomingMessage<S>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<IncomingMessage<S>>> {
loop {
if !self.processing.is_terminated() {
match self.processing.poll_unpin(cx) {
Poll::Ready(Some(msg)) => {
if msg.client.is_some() || msg.server.is_some() {
return Poll::Ready(Some(msg));
}
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
let msg = match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(msg))) => msg,
Poll::Ready(Some(Err(e))) => {
debug!("Message decoding failed: {e}");
return Poll::Ready(None);
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
self.processing = process_message(self.multihasher.clone(), msg)
.boxed()
.fuse();
}
}
}
async fn process_message<const S: usize>(
multihasher: Arc<MultihasherTable<S>>,
mut msg: Message,
) -> Option<IncomingMessage<S>> {
let mut incoming_msg = IncomingMessage::default();
for block_presence in msg.blockPresences {
let cid = match CidGeneric::try_from(&block_presence.cid[..]) {
Ok(cid) => cid,
Err(e) => {
debug!("Invalid CID bytes: {}: {:?}", e, block_presence.cid);
return None;
}
};
incoming_msg
.client
.get_or_insert_with(ClientMessage::default)
.block_presences
.insert(cid, block_presence.type_pb);
}
for payload in msg.payload {
let Some(cid_prefix) = CidPrefix::from_bytes(&payload.prefix) else {
debug!("block.prefix not decodable: {:?}", payload.prefix);
return None;
};
let cid = match cid_prefix.to_cid(&multihasher, &payload.data).await {
Ok(cid) => cid,
Err(MultihasherError::UnknownMultihashCode) => {
debug!(
"Multihasher non-fatal error: Unknown multihash code: {}",
cid_prefix.multihash_code()
);
continue;
}
Err(MultihasherError::Custom(e)) => {
debug!("Multihasher non-fatal error: {e}");
continue;
}
Err(e) => {
debug!("Multihasher fatal error: {e}");
return None;
}
};
incoming_msg
.client
.get_or_insert_with(ClientMessage::default)
.blocks
.insert(cid, payload.data);
}
if let Some(wantlist) = msg.wantlist.take() {
if wantlist.full || !wantlist.entries.is_empty() {
incoming_msg
.server
.get_or_insert_with(ServerMessage::default)
.wantlist = wantlist;
}
}
Some(incoming_msg)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::message::mod_Message::mod_Wantlist::{Entry, WantType};
use crate::proto::message::mod_Message::{Block, BlockPresence};
use crate::test_utils::cid_of_data;
#[tokio::test]
async fn parse_client_message() {
let cid1 = cid_of_data(b"x1");
let cid2 = cid_of_data(b"x2");
let cid3 = cid_of_data(b"x3");
let cid4 = cid_of_data(b"x4");
let proto_msg = Message {
payload: vec![
Block {
prefix: CidPrefix::from_cid(&cid1).to_bytes(),
data: b"x1".to_vec(),
},
Block {
prefix: CidPrefix::from_cid(&cid2).to_bytes(),
data: b"x2".to_vec(),
},
],
blockPresences: vec![
BlockPresence {
cid: cid3.to_bytes(),
type_pb: BlockPresenceType::Have,
},
BlockPresence {
cid: cid4.to_bytes(),
type_pb: BlockPresenceType::DontHave,
},
],
..Default::default()
};
let multihasher = Arc::new(MultihasherTable::new());
let msg = process_message(multihasher.clone(), proto_msg)
.await
.unwrap();
let client_msg = msg.client.unwrap();
assert!(msg.server.is_none());
assert_eq!(client_msg.blocks.get(&cid1).unwrap(), b"x1");
assert_eq!(client_msg.blocks.get(&cid2).unwrap(), b"x2");
assert_eq!(
client_msg.block_presences.get(&cid3).unwrap(),
&BlockPresenceType::Have
);
assert_eq!(
client_msg.block_presences.get(&cid4).unwrap(),
&BlockPresenceType::DontHave
);
}
#[tokio::test]
async fn parse_server_message() {
let cid1 = cid_of_data(b"x1");
let cid2 = cid_of_data(b"x2");
let proto_msg = Message {
wantlist: Some(Wantlist {
entries: vec![
Entry {
block: cid1.to_bytes(),
priority: 1,
cancel: false,
wantType: WantType::Have,
sendDontHave: true,
},
Entry {
block: cid2.to_bytes(),
priority: 1,
cancel: false,
wantType: WantType::Have,
sendDontHave: true,
},
],
full: false,
}),
..Default::default()
};
let multihasher = Arc::new(MultihasherTable::<64>::new());
let msg = process_message(multihasher, proto_msg.clone())
.await
.unwrap();
let server_msg = msg.server.unwrap();
assert!(msg.client.is_none());
assert_eq!(proto_msg.wantlist.unwrap(), server_msg.wantlist);
}
#[tokio::test]
async fn parse_server_message_without_entries() {
let proto_msg = Message {
wantlist: Some(Wantlist {
entries: Vec::new(),
full: false,
}),
..Default::default()
};
let multihasher = Arc::new(MultihasherTable::<64>::new());
let msg = process_message(multihasher, proto_msg).await.unwrap();
assert!(msg.server.is_none());
assert!(msg.client.is_none());
}
#[tokio::test]
async fn parse_server_message_full_without_entries() {
let proto_msg = Message {
wantlist: Some(Wantlist {
entries: Vec::new(),
full: true,
}),
..Default::default()
};
let multihasher = Arc::new(MultihasherTable::<64>::new());
let msg = process_message(multihasher, proto_msg.clone())
.await
.unwrap();
let server_msg = msg.server.unwrap();
assert!(msg.client.is_none());
assert_eq!(proto_msg.wantlist.unwrap(), server_msg.wantlist);
}
}