mod common;
use std::{cmp, collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
use bytes::{Bytes, BytesMut};
use common::{noise, yamux};
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use libp2p::swarm::{keep_alive, NetworkBehaviour, Swarm, SwarmEvent};
use libp2p::{core::multiaddr::Protocol, identity, ping, PeerId, Transport};
use parking_lot::{Mutex, RwLock};
use pea2pea::{
protocols::{Disconnect, Handshake, Reading, Writing},
Connection, ConnectionSide, Node, Pea2Pea,
};
use prost::Message;
use tokio::{sync::oneshot, time::sleep};
use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
use tracing::*;
use tracing_subscriber::filter::LevelFilter;
use unsigned_varint::codec::UviBytes;
const PROTOCOL_PING: &[u8] = b"\x13/multistream/1.0.0\n\x11/ipfs/ping/1.0.0\n";
#[derive(Clone)]
struct Libp2pNode {
node: Node,
keypair: identity::Keypair,
#[allow(dead_code)]
peer_id: PeerId,
noise_states: Arc<Mutex<HashMap<SocketAddr, noise::State>>>,
peer_states: Arc<RwLock<HashMap<SocketAddr, PeerState>>>,
}
impl Pea2Pea for Libp2pNode {
fn node(&self) -> &Node {
&self.node
}
}
impl Libp2pNode {
fn new() -> Self {
let keypair = identity::Keypair::generate_ed25519();
let peer_id = keypair.public().to_peer_id();
let node = Node::new(Default::default());
info!(parent: node.span(), "started a node with PeerId {}", peer_id);
Self {
node,
keypair,
peer_id,
noise_states: Default::default(),
peer_states: Default::default(),
}
}
async fn process_event(&self, event: Event, source: SocketAddr) -> io::Result<()> {
let reply = match event {
Event::NewStream(stream_id, protocol_info) => {
Some(yamux::Frame::data(
stream_id,
vec![yamux::Flag::Ack],
Some(protocol_info),
))
}
Event::StreamHalfClosed(stream_id) => {
Some(yamux::Frame::data(stream_id, vec![yamux::Flag::Fin], None))
}
Event::ReceivedPing(stream_id, payload) => {
Some(yamux::Frame::data(stream_id, vec![], Some(payload)))
}
_ => None,
};
if let Some(reply_msg) = reply {
info!(parent: self.node().span(), " sending a {:?}", &reply_msg);
let _ = self.unicast(source, reply_msg)?.await;
}
Ok(())
}
}
pub type Streams = HashMap<yamux::StreamId, Bytes>;
struct PeerState {
#[allow(dead_code)]
id: PeerId,
streams: Streams,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Event {
NewStream(yamux::StreamId, Bytes),
StreamHalfClosed(yamux::StreamId),
StreamTerminated(yamux::StreamId),
ReceivedPing(yamux::StreamId, Bytes),
Unknown(yamux::Frame),
}
struct Codec {
noise: noise::Codec,
yamux: yamux::Codec,
}
impl Codec {
fn new(noise: noise::Codec, yamux: yamux::Codec) -> Self {
Self { noise, yamux }
}
}
impl Decoder for Codec {
type Item = yamux::Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut bytes = if let Some(bytes) = self.noise.decode(src)? {
bytes
} else {
return Ok(None);
};
self.yamux.decode(&mut bytes)
}
}
impl Encoder<yamux::Frame> for Codec {
type Error = io::Error;
fn encode(&mut self, msg: yamux::Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.yamux.encode(msg, dst)?;
let mut bytes = dst.split().freeze();
while !bytes.is_empty() {
let chunk = bytes.split_to(cmp::min(bytes.len(), noise::MAX_MESSAGE_LEN));
self.noise.encode(chunk, dst)?;
}
Ok(())
}
}
#[derive(Clone, PartialEq, Eq, ::prost::Message)]
pub struct NoiseHandshakePayload {
#[prost(bytes = "vec", tag = "1")]
pub identity_key: ::prost::alloc::vec::Vec<u8>,
#[prost(bytes = "vec", tag = "2")]
pub identity_sig: ::prost::alloc::vec::Vec<u8>,
#[prost(bytes = "vec", tag = "3")]
pub data: ::prost::alloc::vec::Vec<u8>,
}
#[async_trait::async_trait]
impl Handshake for Libp2pNode {
const TIMEOUT_MS: u64 = 5_000;
async fn perform_handshake(&self, mut conn: Connection) -> io::Result<Connection> {
let node_conn_side = !conn.side();
let addr = conn.addr();
let mut negotiation_codec = Framed::new(self.borrow_stream(&mut conn), UviBytes::default());
match node_conn_side {
ConnectionSide::Initiator => {
negotiation_codec
.send(Bytes::from("/multistream/1.0.0\n"))
.await?;
debug!(parent: self.node().span(), "sent protocol params (1/2)");
let protocol_info = negotiation_codec
.try_next()
.await?
.ok_or(io::ErrorKind::InvalidData)?;
debug!(parent: self.node().span(), "received protocol params (1/2): {:?}", protocol_info);
negotiation_codec.send(Bytes::from("/noise\n")).await?;
debug!(parent: self.node().span(), "sent protocol params (2/2)");
let protocol_info = negotiation_codec
.try_next()
.await?
.ok_or(io::ErrorKind::InvalidData)?;
debug!(parent: self.node().span(), "received protocol params (2/2): {:?}", protocol_info);
}
ConnectionSide::Responder => {
let _protocol_info = negotiation_codec.try_next().await?;
debug!(parent: self.node().span(), "received protocol params (1/2)");
negotiation_codec
.send(Bytes::from("/multistream/1.0.0\n"))
.await?;
debug!(parent: self.node().span(), "sent protocol params (1/2)");
let _protocol_info = negotiation_codec.try_next().await?;
debug!(parent: self.node().span(), "received protocol params (2/2)");
negotiation_codec.send(Bytes::from("/noise\n")).await?;
debug!(parent: self.node().span(), "sent protocol params (2/2)");
}
};
let noise_builder = snow::Builder::new("Noise_XX_25519_ChaChaPoly_SHA256".parse().unwrap());
let noise_keypair = noise_builder.generate_keypair().unwrap();
let noise_builder = noise_builder.local_private_key(&noise_keypair.private);
let noise_payload = {
let protobuf_payload = NoiseHandshakePayload {
identity_key: self.keypair.public().to_protobuf_encoding(),
identity_sig: self
.keypair
.sign(&[&b"noise-libp2p-static-key:"[..], &noise_keypair.public].concat())
.unwrap(),
data: vec![],
};
let mut bytes = Vec::with_capacity(protobuf_payload.encoded_len());
protobuf_payload.encode(&mut bytes).unwrap();
bytes
};
let (noise_state, secure_payload) =
noise::handshake_xx(self, &mut conn, noise_builder, noise_payload.into()).await?;
let secure_payload = NoiseHandshakePayload::decode(&secure_payload[..])?;
let peer_key = identity::PublicKey::from_protobuf_encoding(&secure_payload.identity_key)
.map_err(|_| io::ErrorKind::InvalidData)?;
let peer_id = PeerId::from(peer_key);
info!(parent: self.node().span(), "the PeerId of {} is {}", addr, &peer_id);
let mut framed = Framed::new(
self.borrow_stream(&mut conn),
noise::Codec::new(
2,
u16::MAX as usize,
noise_state,
self.node().span().clone(),
),
);
match node_conn_side {
ConnectionSide::Initiator => {
framed
.send(Bytes::from(&b"\x13/multistream/1.0.0\n"[..]))
.await?;
debug!(parent: self.node().span(), "sent protocol params (1/2)");
let protocol_info = framed.try_next().await?.ok_or(io::ErrorKind::InvalidData)?;
debug!(parent: self.node().span(), "received protocol params (1/2): {:?}", protocol_info);
framed.send(Bytes::from(&b"\r/yamux/1.0.0\n"[..])).await?;
debug!(parent: self.node().span(), "sent protocol params (2/2)");
let protocol_info = framed.try_next().await?.ok_or(io::ErrorKind::InvalidData)?;
debug!(parent: self.node().span(), "received protocol params (2/2): {:?}", protocol_info);
}
ConnectionSide::Responder => {
let protocol_info = framed.try_next().await?.ok_or(io::ErrorKind::InvalidData)?;
debug!(parent: self.node().span(), "received protocol params: {:?}", protocol_info);
framed.send(protocol_info.freeze()).await?;
debug!(parent: self.node().span(), "echoed the protocol params back to the sender");
}
}
let FramedParts {
codec, read_buf, ..
} = framed.into_parts();
let noise::Codec { mut noise, .. } = codec;
noise.save_buffer(read_buf);
self.noise_states.lock().insert(conn.addr(), noise);
self.peer_states.write().insert(
addr,
PeerState {
id: peer_id,
streams: Default::default(),
},
);
Ok(conn)
}
}
macro_rules! get_streams_mut {
($self:expr, $addr:expr) => {
$self
.peer_states
.write()
.get_mut(&$addr)
.ok_or(io::ErrorKind::BrokenPipe)?
.streams
};
}
#[async_trait::async_trait]
impl Reading for Libp2pNode {
type Message = yamux::Frame;
type Codec = Codec;
fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec {
let noise_state = self.noise_states.lock().get(&addr).cloned().unwrap();
Self::Codec::new(
noise::Codec::new(
2,
u16::MAX as usize,
noise_state,
self.node().span().clone(),
),
yamux::Codec::new(side, self.node().span().clone()),
)
}
async fn process_message(&self, source: SocketAddr, message: Self::Message) -> io::Result<()> {
info!(parent: self.node().span(), "received a {:?}", message);
let yamux::Header {
stream_id, flags, ..
} = &message.header;
let payload = message.payload;
let mut events = vec![];
match &flags[..] {
&[yamux::Flag::Syn] => {
if get_streams_mut!(self, source)
.insert(*stream_id, payload.clone())
.is_none()
{
events.push(Event::NewStream(*stream_id, payload.clone()));
} else {
error!(parent: self.node().span(), "yamux stream {} had already been registered", stream_id);
return Err(io::ErrorKind::InvalidData.into());
}
}
&[yamux::Flag::Rst] => {
if get_streams_mut!(self, source).remove(stream_id).is_some() {
events.push(Event::StreamTerminated(*stream_id));
} else {
error!(parent: self.node().span(), "yamux stream {} is unknown", stream_id);
return Err(io::ErrorKind::InvalidData.into());
}
}
&[yamux::Flag::Ack] => {
todo!(); }
&[yamux::Flag::Fin] => {
if get_streams_mut!(self, source).remove(stream_id).is_some() {
events.push(Event::StreamHalfClosed(*stream_id));
} else {
error!(parent: self.node().span(), "yamux stream {} is unknown", stream_id);
return Err(io::ErrorKind::InvalidData.into());
}
}
&[] => {
let protocol = if let Some(p) = self
.peer_states
.read()
.get(&source)
.ok_or(io::ErrorKind::BrokenPipe)?
.streams
.get(stream_id)
{
p.clone()
} else {
error!(parent: self.node().span(), "yamux stream {} is unknown", stream_id);
return Err(io::ErrorKind::InvalidData.into());
};
if protocol == PROTOCOL_PING {
events.push(Event::ReceivedPing(*stream_id, payload));
} else {
events.push(Event::Unknown(yamux::Frame {
header: message.header,
payload,
}));
}
}
flags => {
warn!(parent: self.node().span(), "unexpected combination of yamux flags: {:?}", flags);
}
}
for event in events {
self.process_event(event, source).await?;
}
Ok(())
}
}
impl Writing for Libp2pNode {
type Message = yamux::Frame;
type Codec = Codec;
fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec {
let noise_state = self.noise_states.lock().remove(&addr).unwrap();
Self::Codec::new(
noise::Codec::new(
2,
u16::MAX as usize,
noise_state,
self.node().span().clone(),
),
yamux::Codec::new(side, self.node().span().clone()),
)
}
}
#[async_trait::async_trait]
impl Disconnect for Libp2pNode {
async fn handle_disconnect(&self, addr: SocketAddr) {
self.peer_states.write().remove(&addr);
}
}
#[derive(NetworkBehaviour, Default)]
struct Behaviour {
keep_alive: keep_alive::Behaviour,
ping: ping::Behaviour,
}
#[tokio::main(flavor = "multi_thread")]
async fn main() {
common::start_logger(LevelFilter::DEBUG);
let pea2pea_node = Libp2pNode::new();
pea2pea_node.enable_handshake().await;
pea2pea_node.enable_reading().await;
pea2pea_node.enable_writing().await;
let swarm_keypair = identity::Keypair::generate_ed25519();
let swarm_peer_id = PeerId::from(swarm_keypair.public());
let transport = libp2p::tcp::tokio::Transport::new(libp2p::tcp::Config::new().nodelay(true));
let noise_keys = libp2p::noise::Keypair::<libp2p::noise::X25519Spec>::new()
.into_authentic(&swarm_keypair)
.unwrap();
let transport = transport
.upgrade(libp2p::core::upgrade::Version::V1)
.authenticate(libp2p::noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::YamuxConfig::default())
.boxed();
let mut swarm = Swarm::with_tokio_executor(transport, Behaviour::default(), swarm_peer_id);
swarm
.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.unwrap();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let mut tx = Some(tx);
loop {
let event = swarm.select_next_some().await;
debug!(" libp2p node: {:?}", event);
if let SwarmEvent::NewListenAddr { address, .. } = event {
tx.take().unwrap().send(address).unwrap();
}
}
});
let mut swarm_addr = rx.await.unwrap();
let swarm_port = if let Some(Protocol::Tcp(port)) = swarm_addr.pop() {
port
} else {
panic!("the libp2p swarm did not return a listening TCP port");
};
let swarm_addr = format!("127.0.0.1:{}", swarm_port).parse().unwrap();
pea2pea_node.node().connect(swarm_addr).await.unwrap();
sleep(Duration::from_secs(60)).await;
}