use crate::{protocol::*, *};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, Weak};
type SpaceMap = Arc<Mutex<HashMap<SpaceId, DynTxSpaceHandler>>>;
type ModMap = Arc<Mutex<HashMap<(SpaceId, String), DynTxModuleHandler>>>;
type MessageBlocksMap =
Arc<Mutex<HashMap<Url, HashMap<SpaceId, MessageBlockCount>>>>;
pub struct TxImpHnd {
handler: DynTxHandler,
space_map: SpaceMap,
mod_map: ModMap,
blocked_message_counts: MessageBlocksMap,
}
impl TxImpHnd {
pub fn new(handler: DynTxHandler) -> Arc<Self> {
Arc::new(Self {
handler,
space_map: Arc::new(Mutex::new(HashMap::new())),
mod_map: Arc::new(Mutex::new(HashMap::new())),
blocked_message_counts: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn new_listening_address(&self, this_url: Url) -> BoxFut<'static, ()> {
let handler = self.handler.clone();
let space_map = self
.space_map
.clone()
.lock()
.unwrap()
.values()
.cloned()
.collect::<Vec<_>>();
Box::pin(async move {
handler.new_listening_address(this_url.clone()).await;
for s in space_map {
s.new_listening_address(this_url.clone()).await;
}
})
}
pub fn peer_connect(
&self,
peer: Url,
) -> BoxFut<'_, K2Result<bytes::Bytes>> {
Box::pin(async {
for mod_handler in self.mod_map.lock().unwrap().values() {
mod_handler.peer_connect(peer.clone())?;
}
for space_handler in self.space_map.lock().unwrap().values() {
space_handler.peer_connect(peer.clone())?;
}
self.handler.peer_connect(peer.clone())?;
let space_handlers: Vec<_> =
self.space_map.lock().unwrap().values().cloned().collect();
if !space_handlers.is_empty() {
let mut has_any_local_agents = false;
for handler in &space_handlers {
if handler.has_local_agents().await? {
has_any_local_agents = true;
break;
}
}
if !has_any_local_agents {
return Err(K2Error::NoLocalAgentsDuringPreflight);
}
}
let preflight =
self.handler.preflight_gather_outgoing(peer).await?;
let enc = (K2Proto {
ty: K2WireType::Preflight as i32,
data: preflight,
space_id: None,
module_id: None,
})
.encode()?;
Ok(enc)
})
}
pub fn peer_disconnect(&self, peer: Url, reason: Option<String>) {
for h in self.mod_map.lock().unwrap().values() {
h.peer_disconnect(peer.clone(), reason.clone());
}
for h in self.space_map.lock().unwrap().values() {
h.peer_disconnect(peer.clone(), reason.clone());
}
self.handler.peer_disconnect(peer, reason);
}
pub fn recv_data(
&self,
peer: Url,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>> {
Box::pin(async move {
let data = K2Proto::decode(&data)?;
let message_type = data.ty();
let K2Proto {
space_id,
module_id,
data,
..
} = data;
let start = std::time::Instant::now();
if !self.check_message_permitted(
&peer,
&space_id,
&module_id,
&message_type,
)? {
let elapsed = start.elapsed();
tracing::debug!(
?peer,
"Checked message not permitted in {:?}",
elapsed
);
return Ok(());
}
let elapsed = start.elapsed();
tracing::debug!(
?peer,
"Checked message permitted in {:?}",
elapsed
);
match message_type {
K2WireType::Unspecified => Ok(()),
K2WireType::Preflight => {
self.handler.preflight_validate_incoming(peer, data).await
}
K2WireType::Notify => {
if let Some(space_id) = space_id {
let space_id = SpaceId::from(space_id);
if let Some(h) =
self.space_map.lock().unwrap().get(&space_id)
{
h.recv_space_notify(peer, space_id, data)?;
}
}
Ok(())
}
K2WireType::Module => {
if let (Some(space_id), Some(module)) =
(space_id, module_id)
{
let space_id = SpaceId::from(space_id);
if let Some(h) = self
.mod_map
.lock()
.unwrap()
.get(&(space_id.clone(), module.clone()))
{
h.recv_module_msg(peer, space_id, module.clone(), data).inspect_err(|e| {
tracing::warn!(?module, "Error in recv_module_msg, peer connection will be closed: {e}");
})?;
}
}
Ok(())
}
K2WireType::Disconnect => {
let reason = String::from_utf8_lossy(&data).to_string();
Err(K2Error::other(format!("Remote Disconnect: {reason}")))
}
}
})
}
pub fn set_unresponsive(
&self,
peer: Url,
when: Timestamp,
) -> BoxFut<'_, K2Result<()>> {
let space_map = self.space_map.lock().unwrap().clone();
Box::pin(async move {
for (space_id, space_handler) in space_map.iter() {
if let Err(e) =
space_handler.set_unresponsive(peer.clone(), when).await
{
tracing::error!(
"Failed to mark peer with url {peer} as unresponsive in space {space_id}: {e}"
);
};
}
Ok(())
})
}
pub fn check_message_permitted(
&self,
peer_url: &Url,
space_id: &Option<Bytes>,
module_id: &Option<String>,
message_type: &K2WireType,
) -> K2Result<bool> {
if matches!(
message_type,
K2WireType::Preflight
| K2WireType::Unspecified
| K2WireType::Disconnect
) {
return Ok(true);
}
let space_id = match space_id {
None => {
tracing::warn!(
?peer_url,
"Received a message of type {:?} without space id which is violating the protocol. Dropping the message and closing the connection.",
message_type
);
return Err(K2Error::other(
"Received a message without space id.",
));
}
Some(id) => SpaceId::from(id.clone()),
};
let is_blocked = is_peer_blocked(
self.space_map.clone(),
self.blocked_message_counts.clone(),
peer_url,
&space_id,
module_id,
false,
)?;
Ok(!is_blocked)
}
}
impl std::fmt::Debug for TxImpHnd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TxImpHnd {{ handler: {:?}, space_map: [{} entries], mod_map: [{} entries] }}",
self.handler,
self.space_map.lock().unwrap().len(),
self.mod_map.lock().unwrap().len()
)
}
}
fn is_peer_blocked(
space_map: SpaceMap,
message_blocks_map: MessageBlocksMap,
peer_url: &Url,
space_id: &SpaceId,
module_id: &Option<String>,
outgoing: bool,
) -> K2Result<bool> {
let space_handler =
space_map.lock().expect("poisoned").get(space_id).cloned();
match space_handler {
Some(space_handler) => {
let all_blocked = space_handler.is_any_agent_at_url_blocked(peer_url).inspect_err(|e| tracing::warn!(?space_id, ?peer_url, ?module_id, "Failed to check whether any agent is blocked, peer connection will be closed: {e}"))?;
if all_blocked {
tracing::debug!(
?space_id,
?peer_url,
?module_id,
"At least one agent at peer is blocked, message will be dropped."
);
if outgoing {
incr_blocked_message_count_outgoing(
message_blocks_map,
peer_url.clone(),
space_id,
);
} else {
incr_blocked_message_count_incoming(
message_blocks_map,
peer_url.clone(),
space_id,
);
}
return Ok(true);
}
Ok(false)
}
None => {
tracing::error!(
?space_id,
?peer_url,
?module_id,
"No space handler found. Message will be dropped."
);
Ok(true)
}
}
}
fn incr_blocked_message_count_incoming(
message_blocks_map: MessageBlocksMap,
peer_url: Url,
space_id: &SpaceId,
) {
let mut blocked_message_counts =
message_blocks_map.lock().expect("poisoned");
blocked_message_counts
.entry(peer_url)
.and_modify(|space_counts| {
space_counts
.entry(space_id.clone())
.and_modify(|c| c.incoming += 1)
.or_insert(MessageBlockCount {
incoming: 1,
outgoing: 0,
});
})
.or_insert(
[(
space_id.clone(),
MessageBlockCount {
incoming: 1,
outgoing: 0,
},
)]
.into(),
);
}
fn incr_blocked_message_count_outgoing(
message_blocks_map: MessageBlocksMap,
peer_url: Url,
space_id: &SpaceId,
) {
let mut blocked_message_counts =
message_blocks_map.lock().expect("poisoned");
blocked_message_counts
.entry(peer_url)
.and_modify(|space_counts| {
space_counts
.entry(space_id.clone())
.and_modify(|c| c.outgoing += 1)
.or_insert(MessageBlockCount {
incoming: 0,
outgoing: 1,
});
})
.or_insert(
[(
space_id.clone(),
MessageBlockCount {
incoming: 0,
outgoing: 1,
},
)]
.into(),
);
}
pub trait TxImp: 'static + Send + Sync + std::fmt::Debug {
fn url(&self) -> Option<Url>;
fn disconnect(
&self,
peer: Url,
payload: Option<(String, bytes::Bytes)>,
) -> BoxFut<'_, ()>;
fn send(&self, peer: Url, data: bytes::Bytes) -> BoxFut<'_, K2Result<()>>;
fn get_connected_peers(&self) -> BoxFut<'_, K2Result<Vec<Url>>>;
fn dump_network_stats(&self) -> BoxFut<'_, K2Result<TransportStats>>;
}
pub type DynTxImp = Arc<dyn TxImp>;
#[cfg_attr(any(test, feature = "mockall"), mockall::automock)]
pub trait Transport: 'static + Send + Sync + std::fmt::Debug {
fn register_space_handler(
&self,
space_id: SpaceId,
handler: DynTxSpaceHandler,
) -> Option<Url>;
fn register_module_handler(
&self,
space_id: SpaceId,
module: String,
handler: DynTxModuleHandler,
);
fn disconnect(&self, peer: Url, reason: Option<String>) -> BoxFut<'_, ()>;
fn send_space_notify(
&self,
peer: Url,
space_id: SpaceId,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>>;
fn send_module(
&self,
peer: Url,
space_id: SpaceId,
module: String,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>>;
fn get_connected_peers(&self) -> BoxFut<'_, K2Result<Vec<Url>>>;
fn unregister_space(&self, space_id: SpaceId) -> BoxFut<'_, ()>;
fn dump_network_stats(&self) -> BoxFut<'_, K2Result<ApiTransportStats>>;
}
pub type DynTransport = Arc<dyn Transport>;
pub type WeakDynTransport = Weak<dyn Transport>;
#[derive(Clone, Debug)]
pub struct DefaultTransport {
imp: DynTxImp,
space_map: SpaceMap,
mod_map: ModMap,
blocked_message_counts: MessageBlocksMap,
}
impl DefaultTransport {
pub fn create(hnd: &TxImpHnd, imp: DynTxImp) -> DynTransport {
let out: DynTransport = Arc::new(DefaultTransport {
imp,
space_map: hnd.space_map.clone(),
mod_map: hnd.mod_map.clone(),
blocked_message_counts: hnd.blocked_message_counts.clone(),
});
out
}
async fn error_if_no_local_agents(
&self,
space_id: SpaceId,
) -> K2Result<()> {
let space_handler =
self.space_map.lock().unwrap().get(&space_id).cloned();
if let Some(handler) = space_handler
&& !handler.has_local_agents().await?
{
return Err(K2Error::NoLocalAgentsDuringPreflight);
}
Ok(())
}
}
impl Transport for DefaultTransport {
fn register_space_handler(
&self,
space_id: SpaceId,
handler: DynTxSpaceHandler,
) -> Option<Url> {
let mut lock = self.space_map.lock().unwrap();
if lock.insert(space_id.clone(), handler).is_some() {
panic!("Attempted to register duplicate space handler! {space_id}");
}
self.imp.url()
}
fn register_module_handler(
&self,
space_id: SpaceId,
module: String,
handler: DynTxModuleHandler,
) {
if self
.mod_map
.lock()
.unwrap()
.insert((space_id.clone(), module.clone()), handler)
.is_some()
{
panic!(
"Attempted to register duplicate module handler! {space_id} {module}"
);
}
}
fn disconnect(&self, peer: Url, reason: Option<String>) -> BoxFut<'_, ()> {
Box::pin(async move {
let payload = match reason {
None => None,
Some(reason) => match (K2Proto {
ty: K2WireType::Disconnect as i32,
data: bytes::Bytes::copy_from_slice(reason.as_bytes()),
space_id: None,
module_id: None,
})
.encode()
{
Ok(payload) => Some((reason, payload)),
Err(_) => None,
},
};
self.imp.disconnect(peer, payload).await;
})
}
fn send_space_notify(
&self,
peer_url: Url,
space_id: SpaceId,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>> {
Box::pin(async move {
self.error_if_no_local_agents(space_id.clone()).await?;
if is_peer_blocked(
self.space_map.clone(),
self.blocked_message_counts.clone(),
&peer_url,
&space_id,
&None,
true,
)? {
tracing::warn!(
?peer_url,
?space_id,
"Attempted to send space notify message to a peer that is blocked in that space. Dropping message."
);
return Ok(());
}
let enc = (K2Proto {
ty: K2WireType::Notify as i32,
data,
space_id: Some(space_id.into()),
module_id: None,
})
.encode()?;
self.imp.send(peer_url, enc).await
})
}
fn send_module(
&self,
peer_url: Url,
space_id: SpaceId,
module: String,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>> {
Box::pin(async move {
self.error_if_no_local_agents(space_id.clone()).await?;
if is_peer_blocked(
self.space_map.clone(),
self.blocked_message_counts.clone(),
&peer_url,
&space_id,
&None,
true,
)? {
tracing::warn!(
?peer_url,
?space_id,
?module,
"Attempted to send module message to a peer that is blocked in the associated space. Dropping message."
);
return Ok(());
}
let enc = (K2Proto {
ty: K2WireType::Module as i32,
data,
space_id: Some(space_id.into()),
module_id: Some(module),
})
.encode()?;
self.imp.send(peer_url, enc).await
})
}
fn get_connected_peers(&self) -> BoxFut<'_, K2Result<Vec<Url>>> {
self.imp.get_connected_peers()
}
fn unregister_space(&self, space_id: SpaceId) -> BoxFut<'_, ()> {
Box::pin(async move {
self.space_map.lock().unwrap().remove(&space_id);
self.mod_map
.lock()
.unwrap()
.retain(|(s, _), _| s != &space_id);
})
}
fn dump_network_stats(&self) -> BoxFut<'_, K2Result<ApiTransportStats>> {
Box::pin(async {
let low_level_stats = self.imp.dump_network_stats().await?;
let blocked_message_counts =
self.blocked_message_counts.lock().expect("poisoned");
Ok(ApiTransportStats {
transport_stats: low_level_stats,
blocked_message_counts: blocked_message_counts.clone(),
})
})
}
}
pub trait TxBaseHandler: 'static + Send + Sync + std::fmt::Debug {
fn new_listening_address(&self, this_url: Url) -> BoxFut<'static, ()> {
drop(this_url);
Box::pin(async move {})
}
fn peer_connect(&self, peer: Url) -> K2Result<()> {
drop(peer);
Ok(())
}
fn peer_disconnect(&self, peer: Url, reason: Option<String>) {
drop((peer, reason));
}
}
pub trait TxHandler: TxBaseHandler {
fn preflight_gather_outgoing(
&self,
peer_url: Url,
) -> BoxFut<'_, K2Result<bytes::Bytes>> {
drop(peer_url);
Box::pin(async { Ok(bytes::Bytes::new()) })
}
fn preflight_validate_incoming(
&self,
peer_url: Url,
data: bytes::Bytes,
) -> BoxFut<'_, K2Result<()>> {
drop((peer_url, data));
Box::pin(async { Ok(()) })
}
fn has_local_agents(&self) -> BoxFut<'_, K2Result<bool>> {
Box::pin(async { Ok(true) })
}
}
pub type DynTxHandler = Arc<dyn TxHandler>;
pub trait TxSpaceHandler: TxBaseHandler {
fn recv_space_notify(
&self,
peer: Url,
space_id: SpaceId,
data: bytes::Bytes,
) -> K2Result<()> {
drop((peer, space_id, data));
Ok(())
}
fn set_unresponsive(
&self,
peer: Url,
when: Timestamp,
) -> BoxFut<'_, K2Result<()>> {
drop((peer, when));
Box::pin(async move { Ok(()) })
}
fn is_any_agent_at_url_blocked(&self, peer_url: &Url) -> K2Result<bool>;
fn has_local_agents(&self) -> BoxFut<'_, K2Result<bool>> {
Box::pin(async { Ok(true) })
}
}
pub type DynTxSpaceHandler = Arc<dyn TxSpaceHandler>;
pub trait TxModuleHandler: TxBaseHandler {
fn recv_module_msg(
&self,
peer: Url,
space_id: SpaceId,
module: String,
data: bytes::Bytes,
) -> K2Result<()> {
drop((peer, space_id, module, data));
Ok(())
}
}
pub type DynTxModuleHandler = Arc<dyn TxModuleHandler>;
pub trait TransportFactory: 'static + Send + Sync + std::fmt::Debug {
fn default_config(&self, config: &mut config::Config) -> K2Result<()>;
fn validate_config(&self, config: &config::Config) -> K2Result<()>;
fn create(
&self,
builder: Arc<builder::Builder>,
handler: DynTxHandler,
) -> BoxFut<'static, K2Result<DynTransport>>;
}
pub type DynTransportFactory = Arc<dyn TransportFactory>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MessageBlockCount {
pub incoming: u32,
pub outgoing: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiTransportStats {
pub transport_stats: TransportStats,
pub blocked_message_counts:
HashMap<Url, HashMap<SpaceId, MessageBlockCount>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportStats {
pub backend: String,
pub peer_urls: Vec<Url>,
pub connections: Vec<TransportConnectionStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportConnectionStats {
pub pub_key: String,
pub send_message_count: u64,
pub send_bytes: u64,
pub recv_message_count: u64,
pub recv_bytes: u64,
pub opened_at_s: u64,
pub is_direct: bool,
}