use {
super::{Criteria, StreamId, Streams, producer::Sinks},
crate::{
NetworkId,
discovery::Discovery,
network::{
UnknownPeer,
error::DifferentNetwork,
link::{Link, Protocol},
},
primitives::Short,
streams::StreamNotFound,
},
core::fmt,
iroh::{
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler},
},
n0_error::Meta,
serde::{Deserialize, Serialize},
std::sync::Arc,
};
pub(super) struct Acceptor {
sinks: Arc<Sinks>,
discovery: Discovery,
}
impl Acceptor {
pub(super) fn new(streams: &Streams) -> Self {
let sinks = Arc::clone(&streams.sinks);
let discovery = streams.discovery.clone();
Self { sinks, discovery }
}
}
impl fmt::Debug for Acceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
unsafe { write!(f, "Streams({})", str::from_utf8_unchecked(Streams::ALPN)) }
}
}
impl ProtocolHandler for Acceptor {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
let cancel = self.sinks.local.termination().clone();
let mut link = Link::accept_with_cancel(connection, cancel).await?;
let remote_peer_id = link.remote_id();
let catalog = self.discovery.catalog();
let Some(peer) = catalog.get(&remote_peer_id) else {
tracing::trace!(
peer_id = %Short(&remote_peer_id),
"rejecting unidentified consumer",
);
link
.close(UnknownPeer)
.await
.map_err(|e| AcceptError::from_err(e))?;
return Err(AcceptError::NotAllowed {
meta: Meta::default(),
});
};
tracing::trace!(
consumer_id = %Short(peer.id()),
consumer_info = %Short(peer),
"new consumer connection",
);
let handshake: ConsumerHandshake = link
.recv()
.await
.inspect_err(|e| {
tracing::debug!(
consumer_id = %Short(peer.id()),
error = %e,
"Failed to receive consumer handshake",
);
})
.map_err(AcceptError::from_err)?;
if handshake.network_id != self.sinks.local.network_id() {
tracing::debug!(
consumer_id = %Short(peer.id()),
stream_id = %Short(handshake.stream_id),
expected_network = %Short(self.sinks.local.network_id()),
received_network = %Short(handshake.network_id),
"Consumer connected to wrong network",
);
link
.close(DifferentNetwork)
.await
.map_err(AcceptError::from_err)?;
return Err(AcceptError::NotAllowed {
meta: Meta::default(),
});
}
let Some(sink) = self.sinks.open(handshake.stream_id) else {
tracing::debug!(
consumer_id = %Short(peer.id()),
stream_id = %handshake.stream_id,
"Consumer requesting unavailable stream",
);
link
.close(StreamNotFound)
.await
.map_err(AcceptError::from_err)?;
return Err(AcceptError::NotAllowed {
meta: Meta::default(),
});
};
if let Err(link) = sink.accept(link, handshake.criteria, peer.clone()) {
tracing::debug!(
consumer_id = %Short(peer.id()),
stream_id = %handshake.stream_id,
"Sink terminated before accepting new consumer",
);
return link
.close(StreamNotFound)
.await
.map_err(AcceptError::from_err);
}
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct ConsumerHandshake {
network_id: NetworkId,
stream_id: StreamId,
criteria: Criteria,
}
impl ConsumerHandshake {
pub const fn new(
network_id: NetworkId,
stream_id: StreamId,
criteria: Criteria,
) -> Self {
Self {
network_id,
stream_id,
criteria,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct StartStream(pub NetworkId, pub StreamId);
impl StartStream {
pub const fn stream_id(&self) -> &StreamId {
&self.1
}
pub const fn network_id(&self) -> &NetworkId {
&self.0
}
}