use std::{io, time::Duration};
use bytes::BytesMut;
use futures::{SinkExt, StreamExt};
use log::info;
use prost::{DecodeError, Message};
use tokio::{
io::{AsyncRead, AsyncWrite},
time,
};
use tracing::{Instrument, Level, debug, error, span, warn};
use crate::{framing::CanonicalFraming, message::MessageExt, proto, protocol::rpc::error::HandshakeRejectReason};
const LOG_TARGET: &str = "comms::rpc::handshake";
pub(super) const SUPPORTED_RPC_VERSIONS: &[u32] = &[0];
#[derive(Debug, thiserror::Error)]
pub enum RpcHandshakeError {
#[error("Failed to decode message: {0}")]
DecodeError(#[from] DecodeError),
#[error("IO Error: {0}")]
Io(#[from] io::Error),
#[error("The client does not support any RPC protocol version supported by this node")]
ClientNoSupportedVersion,
#[error("Remote peer unexpectedly closed the RPC connection")]
ServerClosedRequest,
#[error("RPC handshake timed out")]
TimedOut,
#[error("RPC handshake was explicitly rejected: {0}")]
Rejected(#[from] HandshakeRejectReason),
#[error("The client connection is closed")]
ClientClosed,
}
pub struct Handshake<'a, T> {
framed: &'a mut CanonicalFraming<T>,
timeout: Option<Duration>,
}
impl<'a, T> Handshake<'a, T>
where T: AsyncRead + AsyncWrite + Unpin
{
pub fn new(framed: &'a mut CanonicalFraming<T>) -> Self {
Self { framed, timeout: None }
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub async fn perform_server_handshake(&mut self) -> Result<u32, RpcHandshakeError> {
match self.recv_next_frame().await {
Ok(Some(Ok(msg))) => {
let msg = proto::rpc::RpcSession::decode(&mut msg.freeze())?;
let version = SUPPORTED_RPC_VERSIONS
.iter()
.find(|v| msg.supported_versions.contains(v));
if let Some(version) = version {
debug!(target: LOG_TARGET, "Local server accepted version: {}", version);
let reply = proto::rpc::RpcSessionReply {
session_result: Some(proto::rpc::rpc_session_reply::SessionResult::AcceptedVersion(*version)),
..Default::default()
};
let span = span!(Level::INFO, "rpc::server::handshake::send_accept_version_reply");
self.framed
.send(reply.to_encoded_bytes().into())
.instrument(span)
.await?;
return Ok(*version);
}
let span = span!(Level::INFO, "rpc::server::handshake::send_rejection");
self.reject_with_reason(HandshakeRejectReason::UnsupportedVersion)
.instrument(span)
.await?;
Err(RpcHandshakeError::ClientNoSupportedVersion)
},
Ok(Some(Err(err))) => {
info!(target: LOG_TARGET, "Error during handshake: {err}");
Err(err.into())
},
Ok(None) => {
info!(target: LOG_TARGET, "Error during handshake, client closed connection");
Err(RpcHandshakeError::ClientClosed)
},
Err(_) => {
info!(target: LOG_TARGET, "Error during handshake, timed out");
Err(RpcHandshakeError::TimedOut)
},
}
}
pub async fn reject_with_reason(&mut self, reject_reason: HandshakeRejectReason) -> Result<(), RpcHandshakeError> {
warn!(target: LOG_TARGET, "Rejecting handshake because {}", reject_reason);
let reply = proto::rpc::RpcSessionReply {
session_result: Some(proto::rpc::rpc_session_reply::SessionResult::Rejected(true)),
reject_reason: reject_reason.as_i32(),
};
self.framed.send(reply.to_encoded_bytes().into()).await?;
self.framed.close().await?;
Ok(())
}
pub async fn perform_client_handshake(&mut self) -> Result<(), RpcHandshakeError> {
let msg = proto::rpc::RpcSession {
supported_versions: SUPPORTED_RPC_VERSIONS.to_vec(),
};
let payload = msg.to_encoded_bytes();
debug!(target: LOG_TARGET, "Sending client handshake ({} bytes)", payload.len());
if let Err(err) = self.framed.send(payload.into()).await {
warn!(
target: LOG_TARGET,
"IO error when sending new session handshake to peer: {}", err
);
}
self.framed.flush().await?;
match self.recv_next_frame().await {
Ok(Some(Ok(msg))) => {
let msg = proto::rpc::RpcSessionReply::decode(&mut msg.freeze())?;
let version = msg.result()?;
debug!(target: LOG_TARGET, "Remote server accepted version {}", version);
Ok(())
},
Ok(Some(Err(err))) => {
error!(target: LOG_TARGET, "Error during handshake: {}", err);
Err(err.into())
},
Ok(None) => {
warn!(target: LOG_TARGET, "Error during handshake, server closed connection");
Err(RpcHandshakeError::ServerClosedRequest)
},
Err(_) => {
error!(target: LOG_TARGET, "Error during handshake, timed out");
Err(RpcHandshakeError::TimedOut)
},
}
}
async fn recv_next_frame(&mut self) -> Result<Option<Result<BytesMut, io::Error>>, time::error::Elapsed> {
match self.timeout {
Some(timeout) => time::timeout(timeout, self.framed.next()).await,
None => Ok(self.framed.next().await),
}
}
}