use futures::AsyncWriteExt;
use prost::Message;
use std::{error::Error, fmt, io};
use crate::{rpc_proto, Topic, FLOOD_SUB_ID};
use async_trait::async_trait;
use futures::{channel::mpsc, SinkExt};
use libp2prs_core::upgrade::UpgradeInfo;
use libp2prs_core::{PeerId, ProtocolId, ReadEx};
use libp2prs_swarm::protocol_handler::Notifiee;
use libp2prs_swarm::{
connection::Connection,
protocol_handler::{IProtocolHandler, ProtocolHandler},
substream::Substream,
};
pub(crate) enum PeerEvent {
NewPeer(PeerId),
DeadPeer(PeerId),
}
#[derive(Clone)]
pub struct Handler {
incoming_tx: mpsc::UnboundedSender<RPC>,
peer_tx: mpsc::UnboundedSender<PeerEvent>,
}
impl Handler {
pub(crate) fn new(incoming_tx: mpsc::UnboundedSender<RPC>, peer_tx: mpsc::UnboundedSender<PeerEvent>) -> Self {
Handler { incoming_tx, peer_tx }
}
}
impl UpgradeInfo for Handler {
type Info = ProtocolId;
fn protocol_info(&self) -> Vec<Self::Info> {
vec![FLOOD_SUB_ID.into()]
}
}
impl Notifiee for Handler {
fn connected(&mut self, conn: &mut Connection) {
let peer_id = conn.remote_peer();
let _ = self.peer_tx.unbounded_send(PeerEvent::NewPeer(peer_id));
}
fn disconnected(&mut self, conn: &mut Connection) {
let peer_id = conn.remote_peer();
let _ = self.peer_tx.unbounded_send(PeerEvent::DeadPeer(peer_id));
}
}
#[async_trait]
impl ProtocolHandler for Handler {
async fn handle(&mut self, mut stream: Substream, _info: <Self as UpgradeInfo>::Info) -> Result<(), Box<dyn Error>> {
log::trace!("Handle stream from {}", stream.remote_peer());
loop {
let packet = match stream.read_one(2048).await {
Ok(p) => p,
Err(e) => {
if e.kind() == io::ErrorKind::UnexpectedEof {
stream.close().await?;
}
return Err(Box::new(e));
}
};
let rpc = rpc_proto::Rpc::decode(&packet[..])?;
log::trace!("recv rpc msg: {:?}", rpc);
let mut messages = Vec::with_capacity(rpc.publish.len());
for publish in rpc.publish.into_iter() {
messages.push(FloodsubMessage {
source: PeerId::from_bytes(&publish.from.unwrap_or_default()).map_err(|_| FloodsubDecodeError::InvalidPeerId)?,
data: publish.data.unwrap_or_default(),
sequence_number: publish.seqno.unwrap_or_default(),
topics: publish.topic_ids.into_iter().map(Topic::new).collect(),
});
}
let rpc = RPC {
rpc: FloodsubRpc {
messages,
subscriptions: rpc
.subscriptions
.into_iter()
.map(|sub| FloodsubSubscription {
action: if Some(true) == sub.subscribe {
FloodsubSubscriptionAction::Subscribe
} else {
FloodsubSubscriptionAction::Unsubscribe
},
topic: Topic::new(sub.topic_id.unwrap_or_default()),
})
.collect(),
},
from: stream.remote_peer(),
};
self.incoming_tx.send(rpc).await.map_err(|_| FloodsubDecodeError::ProtocolExit)?;
}
}
fn box_clone(&self) -> IProtocolHandler {
Box::new(self.clone())
}
}
#[derive(Debug)]
pub enum FloodsubDecodeError {
ReadError(io::Error),
ProtobufError(prost::DecodeError),
InvalidPeerId,
ProtocolExit,
}
impl From<prost::DecodeError> for FloodsubDecodeError {
fn from(err: prost::DecodeError) -> Self {
FloodsubDecodeError::ProtobufError(err)
}
}
impl fmt::Display for FloodsubDecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
FloodsubDecodeError::ReadError(ref err) => write!(f, "Error while reading from socket: {}", err),
FloodsubDecodeError::ProtobufError(ref err) => write!(f, "Error while decoding protobuf: {}", err),
FloodsubDecodeError::InvalidPeerId => write!(f, "Error while decoding PeerId from message"),
FloodsubDecodeError::ProtocolExit => write!(f, "Error while send message to message process mainloop"),
}
}
}
impl Error for FloodsubDecodeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
FloodsubDecodeError::ReadError(ref err) => Some(err),
FloodsubDecodeError::ProtobufError(ref err) => Some(err),
FloodsubDecodeError::InvalidPeerId => None,
FloodsubDecodeError::ProtocolExit => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RPC {
pub rpc: FloodsubRpc,
pub from: PeerId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FloodsubRpc {
pub messages: Vec<FloodsubMessage>,
pub subscriptions: Vec<FloodsubSubscription>,
}
impl FloodsubRpc {
pub fn into_bytes(self) -> Vec<u8> {
let rpc = rpc_proto::Rpc {
publish: self
.messages
.into_iter()
.map(|msg| rpc_proto::Message {
from: Some(msg.source.to_bytes()),
data: Some(msg.data),
seqno: Some(msg.sequence_number),
topic_ids: msg.topics.into_iter().map(|topic| topic.into()).collect(),
})
.collect(),
subscriptions: self
.subscriptions
.into_iter()
.map(|topic| rpc_proto::rpc::SubOpts {
subscribe: Some(topic.action == FloodsubSubscriptionAction::Subscribe),
topic_id: Some(topic.topic.into()),
})
.collect(),
};
let mut buf = Vec::with_capacity(rpc.encoded_len());
rpc.encode(&mut buf).expect("Vec<u8> provides capacity as needed");
buf
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FloodsubMessage {
pub source: PeerId,
pub data: Vec<u8>,
pub sequence_number: Vec<u8>,
pub topics: Vec<Topic>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FloodsubSubscription {
pub action: FloodsubSubscriptionAction,
pub topic: Topic,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FloodsubSubscriptionAction {
Subscribe,
Unsubscribe,
}