use std::error::Error as StdError;
use std::marker::PhantomData;
use futures_channel::mpsc;
use iroh::endpoint::{Connection, VarInt};
use p2panda_core::Topic;
use p2panda_sync::protocols::{
TopicHandshakeEvent, TopicHandshakeInitiator, TopicHandshakeMessage,
};
use p2panda_sync::traits::Protocol;
use ractor::thread_local::ThreadLocalActor;
use ractor::{ActorProcessingErr, ActorRef};
use serde::{Deserialize, Serialize};
use tracing::Instrument;
use crate::cbor::{into_cbor_sink, into_cbor_stream};
use crate::iroh_endpoint::Endpoint;
use crate::utils::ShortFormat;
use crate::{NodeId, ProtocolId};
pub type SyncSessionId = u64;
pub enum SyncSessionMessage<P> {
Initiate {
node_id: NodeId,
topic: Topic,
session_id: u64,
protocol: P,
protocol_id: ProtocolId,
},
Accept {
connection: Connection,
topic: Topic,
session_id: u64,
protocol: P,
},
}
pub struct SyncSession<P> {
_marker: PhantomData<P>,
}
impl<P> Default for SyncSession<P> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<P> ThreadLocalActor for SyncSession<P>
where
P: Protocol + Send + 'static,
P::Error: StdError + Send + Sync + 'static,
for<'a> P::Message: Serialize + Deserialize<'a>,
{
type State = (Endpoint,);
type Msg = SyncSessionMessage<P>;
type Arguments = (Endpoint,);
async fn pre_start(
&self,
_myself: ActorRef<Self::Msg>,
args: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
Ok(args)
}
async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
SyncSessionMessage::Initiate {
node_id,
topic,
session_id,
protocol,
protocol_id,
} => {
let connection = state.0.connect(node_id, protocol_id).await?;
let (tx, rx) = connection.open_bi().await?;
let mut tx = into_cbor_sink::<TopicHandshakeMessage<Topic>, _>(tx);
let mut rx = into_cbor_stream::<TopicHandshakeMessage<Topic>, _>(rx);
let (event_tx, _event_rx) = mpsc::channel::<TopicHandshakeEvent<Topic>>(128);
let topic_handshake = TopicHandshakeInitiator::new(topic, event_tx);
topic_handshake.run(&mut tx, &mut rx).await?;
let span = tracing::debug_span!("sync", responder = %node_id.fmt_short(), topic = %topic.fmt_short(), session_id);
let (tx, rx) = connection.open_bi().await?;
let mut tx = into_cbor_sink::<P::Message, _>(tx);
let mut rx = into_cbor_stream::<P::Message, _>(rx);
protocol
.run(&mut tx, &mut rx)
.instrument(span.clone())
.await?;
connection.close(VarInt::from_u32(0), b"sync protocol initiate completed");
}
SyncSessionMessage::Accept {
connection,
topic,
session_id,
protocol,
} => {
let (tx, rx) = connection.accept_bi().await?;
let mut tx = into_cbor_sink::<P::Message, _>(tx);
let mut rx = into_cbor_stream::<P::Message, _>(rx);
let remote = connection.remote_id();
let span = tracing::debug_span!(parent: None, "sync", requester = %remote.fmt_short(), topic = %topic.fmt_short(), session_id);
protocol.run(&mut tx, &mut rx).instrument(span).await?;
connection.close(VarInt::from_u32(0), b"sync protocol accept completed");
}
}
Ok(())
}
}