use super::{
super::{super::read_write::ReadWrite, noise, yamux},
Config, Event, SubstreamId, SubstreamIdInner,
substream::{self, RespondInRequestError},
};
use alloc::{boxed::Box, string::String, vec::Vec};
use core::{
fmt,
num::NonZero,
ops::{Add, Index, IndexMut, Sub},
time::Duration,
};
use rand_chacha::rand_core::{RngCore as _, SeedableRng as _};
pub use substream::InboundTy;
pub struct SingleStream<TNow, TSubUd> {
encryption: noise::Noise,
inner: Box<Inner<TNow, TSubUd>>,
}
struct Inner<TNow, TSubUd> {
yamux: yamux::Yamux<TNow, Option<(substream::Substream<TNow>, Option<TSubUd>)>>,
outgoing_pings: yamux::SubstreamId,
next_ping: TNow,
ping_payload_randomness: rand_chacha::ChaCha20Rng,
max_inbound_substreams: usize,
max_protocol_name_len: usize,
ping_interval: Duration,
ping_timeout: Duration,
}
impl<TNow, TSubUd> SingleStream<TNow, TSubUd>
where
TNow: Clone + Add<Duration, Output = TNow> + Sub<TNow, Output = Duration> + Ord,
{
pub fn read_write(
mut self,
read_write: &mut ReadWrite<TNow>,
) -> Result<(SingleStream<TNow, TSubUd>, Option<Event<TSubUd>>), Error> {
if read_write.now >= self.inner.next_ping {
self.inner.next_ping = read_write.now.clone() + self.inner.ping_interval;
if self.inner.yamux.has_substream(self.inner.outgoing_pings) {
let mut payload = [0u8; 32];
self.inner.ping_payload_randomness.fill_bytes(&mut payload);
self.inner.yamux[self.inner.outgoing_pings]
.as_mut()
.unwrap()
.0
.queue_ping(&payload, read_write.now.clone(), self.inner.ping_timeout);
self.inner
.yamux
.mark_substream_write_ready(self.inner.outgoing_pings);
} else {
return Ok((self, Some(Event::PingOutFailed)));
}
}
read_write.wake_up_after(&self.inner.next_ping);
if (self.inner.yamux.len()
== if self.inner.yamux.has_substream(self.inner.outgoing_pings) {
1
} else {
0
})
&& self.inner.yamux.goaway_sent()
&& self.inner.yamux.received_goaway().is_some()
{
read_write.close_write();
}
let mut decrypted_read_write = self
.encryption
.read_write(read_write)
.map_err(Error::Noise)?;
let yamux_rw_outcome = self
.inner
.yamux
.read_write(&mut decrypted_read_write)
.map_err(Error::Yamux)?;
match yamux_rw_outcome {
yamux::ReadWriteOutcome::Idle { yamux } => {
self.inner.yamux = yamux;
drop(decrypted_read_write);
return Ok((self, None));
}
yamux::ReadWriteOutcome::IncomingSubstream { mut yamux } => {
debug_assert!(!yamux.goaway_queued_or_sent());
if yamux.num_inbound() >= self.inner.max_inbound_substreams {
yamux
.reject_pending_substream()
.unwrap_or_else(|_| panic!());
} else {
yamux
.accept_pending_substream(Some((
substream::Substream::ingoing(self.inner.max_protocol_name_len),
None,
)))
.unwrap_or_else(|_| panic!());
}
self.inner.yamux = yamux;
drop(decrypted_read_write);
return Ok((self, None));
}
yamux::ReadWriteOutcome::ProcessSubstream {
mut substream_read_write,
} => {
let (state_machine, mut substream_user_data) =
substream_read_write.user_data_mut().take().unwrap();
let (state_machine_update, event) =
state_machine.read_write(substream_read_write.read_write());
let event_to_yield = event.map(|ev| {
Self::pass_through_substream_event(
substream_read_write.substream_id(),
&mut substream_user_data,
ev,
)
});
match state_machine_update {
Some(s) => {
*substream_read_write.user_data_mut() = Some((s, substream_user_data));
self.inner.yamux = substream_read_write.finish();
}
None => {
self.inner.yamux = substream_read_write.reset();
}
}
if let Some(event_to_yield) = event_to_yield {
drop(decrypted_read_write);
return Ok((self, Some(event_to_yield)));
}
}
yamux::ReadWriteOutcome::StreamReset { yamux, .. } => {
self.inner.yamux = yamux;
decrypted_read_write.wake_up_asap();
}
yamux::ReadWriteOutcome::GoAway { yamux, .. } => {
self.inner.yamux = yamux;
drop(decrypted_read_write);
return Ok((self, Some(Event::NewOutboundSubstreamsForbidden)));
}
yamux::ReadWriteOutcome::PingResponse { .. } => {
unreachable!()
}
}
drop(decrypted_read_write);
let dead_substream_ids = self
.inner
.yamux
.dead_substreams()
.map(|(id, death_ty, _)| (id, death_ty))
.collect::<Vec<_>>();
for (dead_substream_id, death_ty) in dead_substream_ids {
match death_ty {
yamux::DeadSubstreamTy::Reset => {
if let Some((state_machine, mut user_data)) =
self.inner.yamux.remove_dead_substream(dead_substream_id)
{
if let Some(event) = state_machine.reset() {
return Ok((
self,
Some(Self::pass_through_substream_event(
dead_substream_id,
&mut user_data,
event,
)),
));
}
};
read_write.wake_up_asap();
}
yamux::DeadSubstreamTy::ClosedGracefully => {
self.inner.yamux.remove_dead_substream(dead_substream_id);
}
}
}
Ok((self, None))
}
fn pass_through_substream_event(
substream_id: yamux::SubstreamId,
substream_user_data: &mut Option<TSubUd>,
event: substream::Event,
) -> Event<TSubUd> {
match event {
substream::Event::InboundError {
error,
was_accepted: false,
} => Event::InboundError(error),
substream::Event::InboundError {
was_accepted: true, ..
} => Event::InboundAcceptedCancel {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
user_data: substream_user_data.take().unwrap(),
},
substream::Event::InboundNegotiated(protocol_name) => Event::InboundNegotiated {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
protocol_name,
},
substream::Event::InboundNegotiatedCancel => Event::InboundNegotiatedCancel {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
},
substream::Event::RequestIn { request } => Event::RequestIn {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
request,
},
substream::Event::Response { response } => Event::Response {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
response,
user_data: substream_user_data.take().unwrap(),
},
substream::Event::NotificationsInOpen { handshake } => Event::NotificationsInOpen {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
handshake,
},
substream::Event::NotificationsInOpenCancel => Event::NotificationsInOpenCancel {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
},
substream::Event::NotificationIn { notification } => Event::NotificationIn {
notification,
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
},
substream::Event::NotificationsInClose { outcome } => Event::NotificationsInClose {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
outcome,
user_data: substream_user_data.take().unwrap(),
},
substream::Event::NotificationsOutResult { result } => Event::NotificationsOutResult {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
result: match result {
Ok(r) => Ok(r),
Err(err) => Err((err, substream_user_data.take().unwrap())),
},
},
substream::Event::NotificationsOutCloseDemanded => {
Event::NotificationsOutCloseDemanded {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
}
}
substream::Event::NotificationsOutReset => Event::NotificationsOutReset {
id: SubstreamId(SubstreamIdInner::SingleStream(substream_id)),
user_data: substream_user_data.take().unwrap(),
},
substream::Event::PingOutSuccess { ping_time } => Event::PingOutSuccess { ping_time },
substream::Event::PingOutError { .. } => {
Event::PingOutFailed
}
}
}
pub fn deny_new_incoming_substreams(&mut self) {
self.inner
.yamux
.send_goaway(yamux::GoAwayErrorCode::NormalTermination)
.unwrap()
}
pub fn set_max_protocol_name_len(&mut self, new_value: usize) {
self.inner.max_protocol_name_len = new_value;
}
pub fn add_request(
&mut self,
protocol_name: String,
request: Option<Vec<u8>>,
timeout: TNow,
max_response_size: usize,
user_data: TSubUd,
) -> SubstreamId {
let substream_id = self
.inner
.yamux
.open_substream(Some((
substream::Substream::request_out(
protocol_name,
timeout,
request,
max_response_size,
),
Some(user_data),
)))
.unwrap();
self.inner.yamux.add_remote_window_saturating(
substream_id,
u64::try_from(max_response_size)
.unwrap_or(u64::MAX)
.saturating_add(64)
.saturating_sub(yamux::NEW_SUBSTREAMS_FRAME_SIZE),
);
SubstreamId(SubstreamIdInner::SingleStream(substream_id))
}
pub fn open_notifications_substream(
&mut self,
protocol_name: String,
handshake: Vec<u8>,
max_handshake_size: usize,
timeout: TNow,
user_data: TSubUd,
) -> SubstreamId {
let substream = self
.inner
.yamux
.open_substream(Some((
substream::Substream::notifications_out(
timeout,
protocol_name,
handshake,
max_handshake_size,
),
Some(user_data),
)))
.unwrap();
SubstreamId(SubstreamIdInner::SingleStream(substream))
}
pub fn accept_inbound(&mut self, substream_id: SubstreamId, ty: InboundTy, user_data: TSubUd) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
let (substream, ud) = self.inner.yamux[substream_id].as_mut().unwrap();
substream.accept_inbound(ty);
debug_assert!(ud.is_none());
*ud = Some(user_data);
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn reject_inbound(&mut self, substream_id: SubstreamId) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
let (substream, ud) = self.inner.yamux[substream_id].as_mut().unwrap();
substream.reject_inbound();
debug_assert!(ud.is_none());
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn accept_in_notifications_substream(
&mut self,
substream_id: SubstreamId,
handshake: Vec<u8>,
max_notification_size: usize,
) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.accept_in_notifications_substream(handshake, max_notification_size);
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn reject_in_notifications_substream(&mut self, substream_id: SubstreamId) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.reject_in_notifications_substream();
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn write_notification_unbounded(
&mut self,
substream_id: SubstreamId,
notification: Vec<u8>,
) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.write_notification_unbounded(notification);
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn notification_substream_queued_bytes(&self, substream_id: SubstreamId) -> usize {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_ref()
.unwrap()
.0
.notification_substream_queued_bytes()
}
pub fn close_out_notifications_substream(&mut self, substream_id: SubstreamId) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
if !self.inner.yamux.has_substream(substream_id) {
panic!()
}
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.close_out_notifications_substream();
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn close_in_notifications_substream(&mut self, substream_id: SubstreamId, timeout: TNow) {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
if !self.inner.yamux.has_substream(substream_id) {
panic!()
}
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.close_in_notifications_substream(timeout);
self.inner.yamux.mark_substream_write_ready(substream_id);
}
pub fn respond_in_request(
&mut self,
substream_id: SubstreamId,
response: Result<Vec<u8>, ()>,
) -> Result<(), RespondInRequestError> {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => return Err(RespondInRequestError::SubstreamClosed),
};
if !self.inner.yamux.has_substream(substream_id) {
return Err(RespondInRequestError::SubstreamClosed);
}
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.0
.respond_in_request(response)?;
self.inner.yamux.mark_substream_write_ready(substream_id);
Ok(())
}
}
impl<TNow, TSubUd> Index<SubstreamId> for SingleStream<TNow, TSubUd> {
type Output = TSubUd;
fn index(&self, substream_id: SubstreamId) -> &Self::Output {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_ref()
.unwrap()
.1
.as_ref()
.unwrap()
}
}
impl<TNow, TSubUd> IndexMut<SubstreamId> for SingleStream<TNow, TSubUd> {
fn index_mut(&mut self, substream_id: SubstreamId) -> &mut Self::Output {
let substream_id = match substream_id.0 {
SubstreamIdInner::SingleStream(id) => id,
_ => panic!(),
};
self.inner.yamux[substream_id]
.as_mut()
.unwrap()
.1
.as_mut()
.unwrap()
}
}
impl<TNow, TSubUd> fmt::Debug for SingleStream<TNow, TSubUd>
where
TSubUd: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_map()
.entries(self.inner.yamux.user_datas())
.finish()
}
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum Error {
#[display("Noise error: {_0}")]
Noise(noise::CipherError),
#[display("{_0}")]
NoiseEncrypt(noise::EncryptError),
#[display("Yamux error: {_0}")]
Yamux(yamux::Error),
}
pub struct ConnectionPrototype {
encryption: noise::Noise,
}
impl ConnectionPrototype {
pub(crate) fn from_noise_yamux(encryption: noise::Noise) -> Self {
ConnectionPrototype { encryption }
}
pub fn into_noise_state_machine(self) -> noise::Noise {
self.encryption
}
pub fn into_connection<TNow, TSubUd>(self, config: Config<TNow>) -> SingleStream<TNow, TSubUd>
where
TNow: Clone + Add<Duration, Output = TNow> + Sub<TNow, Output = Duration> + Ord,
{
let mut randomness = rand_chacha::ChaCha20Rng::from_seed(config.randomness_seed);
let mut yamux = yamux::Yamux::new(yamux::Config {
is_initiator: self.encryption.is_initiator(),
capacity: config.substreams_capacity,
randomness_seed: {
let mut seed = [0; 32];
randomness.fill_bytes(&mut seed);
seed
},
max_out_data_frame_size: NonZero::<u32>::new(8192).unwrap(), max_simultaneous_queued_pongs: NonZero::<usize>::new(4).unwrap(),
max_simultaneous_rst_substreams: NonZero::<usize>::new(1024).unwrap(),
});
let outgoing_pings = yamux
.open_substream(Some((
substream::Substream::ping_out(config.ping_protocol.clone()),
None,
)))
.unwrap_or_else(|_| panic!());
SingleStream {
encryption: self.encryption,
inner: Box::new(Inner {
yamux,
outgoing_pings,
next_ping: config.first_out_ping,
ping_payload_randomness: randomness,
max_inbound_substreams: config.max_inbound_substreams,
max_protocol_name_len: config.max_protocol_name_len,
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
}),
}
}
}
impl fmt::Debug for ConnectionPrototype {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ConnectionPrototype").finish()
}
}