use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use futures::lock::Mutex as FuturesMutex;
use rings_transport::core::callback::TransportCallback;
use rings_transport::core::transport::WebrtcConnectionState;
use crate::chunk::ChunkList;
use crate::chunk::ChunkManager;
use crate::consts::TRANSPORT_MTU;
use crate::dht::Did;
use crate::message::HandleMsg;
use crate::message::Message;
use crate::message::MessageHandler;
use crate::message::MessagePayload;
use crate::message::MessageVerificationExt;
use crate::swarm::transport::SwarmTransport;
type CallbackError = Box<dyn std::error::Error>;
#[cfg(feature = "wasm")]
pub type SharedSwarmCallback = Arc<dyn SwarmCallback>;
#[cfg(not(feature = "wasm"))]
pub type SharedSwarmCallback = Arc<dyn SwarmCallback + Send + Sync>;
#[derive(Debug)]
#[non_exhaustive]
pub enum SwarmEvent {
ConnectionStateChange {
peer: Did,
state: WebrtcConnectionState,
},
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
pub trait SwarmCallback {
async fn on_validate(&self, _payload: &MessagePayload) -> Result<(), CallbackError> {
Ok(())
}
async fn on_inbound(&self, _payload: &MessagePayload) -> Result<(), CallbackError> {
Ok(())
}
async fn on_event(&self, _event: &SwarmEvent) -> Result<(), CallbackError> {
Ok(())
}
}
pub struct InnerSwarmCallback {
transport: Arc<SwarmTransport>,
message_handler: MessageHandler,
callback: SharedSwarmCallback,
chunk_list: FuturesMutex<ChunkList<TRANSPORT_MTU>>,
}
impl InnerSwarmCallback {
pub fn new(transport: Arc<SwarmTransport>, callback: SharedSwarmCallback) -> Self {
let message_handler = MessageHandler::new(transport.clone(), callback.clone());
Self {
transport,
message_handler,
callback,
chunk_list: Default::default(),
}
}
async fn handle_payload(
&self,
cid: &str,
payload: &MessagePayload,
) -> Result<(), CallbackError> {
let message: Message = payload.transaction.data()?;
match &message {
Message::ConnectNodeSend(ref msg) => self.message_handler.handle(payload, msg).await,
Message::ConnectNodeReport(ref msg) => self.message_handler.handle(payload, msg).await,
Message::FindSuccessorSend(ref msg) => self.message_handler.handle(payload, msg).await,
Message::FindSuccessorReport(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::NotifyPredecessorSend(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::NotifyPredecessorReport(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::SearchVNode(ref msg) => self.message_handler.handle(payload, msg).await,
Message::FoundVNode(ref msg) => self.message_handler.handle(payload, msg).await,
Message::SyncVNodeWithSuccessor(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::OperateVNode(ref msg) => self.message_handler.handle(payload, msg).await,
Message::CustomMessage(ref msg) => self.message_handler.handle(payload, msg).await,
Message::QueryForTopoInfoSend(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::QueryForTopoInfoReport(ref msg) => {
self.message_handler.handle(payload, msg).await
}
Message::Chunk(ref msg) => {
if let Some(data) = self.chunk_list.lock().await.handle(msg.clone()) {
return self.on_message(cid, &data).await;
}
Ok(())
}
}
.unwrap_or_else(|e| {
tracing::error!("Failed to handle_payload: {:?}", e);
});
if payload.transaction.destination == self.transport.dht.did {
self.callback.on_inbound(payload).await?;
}
Ok(())
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl TransportCallback for InnerSwarmCallback {
async fn on_message(&self, cid: &str, msg: &[u8]) -> Result<(), CallbackError> {
let payload = MessagePayload::from_bincode(msg)?;
if !(payload.verify() && payload.transaction.verify()) {
tracing::error!("Cannot verify msg or it's expired: {:?}", payload);
return Err("Cannot verify msg or it's expired".into());
}
self.callback.on_validate(&payload).await?;
self.handle_payload(cid, &payload).await
}
async fn on_peer_connection_state_change(
&self,
cid: &str,
s: WebrtcConnectionState,
) -> Result<(), CallbackError> {
let Ok(did) = Did::from_str(cid) else {
tracing::warn!("on_peer_connection_state_change parse did failed: {}", cid);
return Ok(());
};
match s {
WebrtcConnectionState::Failed
| WebrtcConnectionState::Disconnected
| WebrtcConnectionState::Closed => {
self.message_handler.leave_dht(did).await?;
}
_ => {}
};
if s != WebrtcConnectionState::Connected {
self.callback
.on_event(&SwarmEvent::ConnectionStateChange {
peer: did,
state: s,
})
.await?
}
Ok(())
}
async fn on_data_channel_open(&self, cid: &str) -> Result<(), CallbackError> {
let Ok(did) = Did::from_str(cid) else {
tracing::warn!("on_data_channel_open parse did failed: {}", cid);
return Ok(());
};
self.message_handler.join_dht(did).await?;
self.callback
.on_event(&SwarmEvent::ConnectionStateChange {
peer: self.transport.dht.did,
state: WebrtcConnectionState::Connected,
})
.await
}
}