p2panda_sync/protocols/
topic_handshake.rs1use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6
7use futures::channel::mpsc;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12use crate::traits::Protocol;
13
14pub struct TopicHandshakeInitiator<T, Evt> {
18 pub topic: T,
19 pub event_tx: mpsc::Sender<Evt>,
20}
21
22impl<T, Evt> TopicHandshakeInitiator<T, Evt>
23where
24 T: Clone + for<'de> Deserialize<'de> + Serialize,
25 Evt: From<TopicHandshakeEvent<T>>,
26{
27 pub fn new(topic: T, event_tx: mpsc::Sender<Evt>) -> Self {
28 Self { topic, event_tx }
29 }
30}
31
32impl<T, Evt> Protocol for TopicHandshakeInitiator<T, Evt>
33where
34 T: Clone + Debug + for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
35 Evt: From<TopicHandshakeEvent<T>>,
36{
37 type Error = TopicHandshakeError<T>;
38 type Output = ();
39 type Message = TopicHandshakeMessage<T>;
40
41 async fn run(
42 mut self,
43 sink: &mut (impl Sink<Self::Message, Error = impl Debug> + Unpin),
44 stream: &mut (impl Stream<Item = Result<Self::Message, impl Debug>> + Unpin),
45 ) -> Result<Self::Output, Self::Error> {
46 self.event_tx
48 .send(TopicHandshakeEvent::Initiate(self.topic.clone()).into())
49 .await?;
50
51 sink.send(TopicHandshakeMessage::Topic(self.topic.clone()))
53 .await
54 .map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
55
56 let Some(message) = stream.next().await else {
58 return Err(TopicHandshakeError::UnexpectedStreamClosure);
59 };
60 let message =
61 message.map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
62 let TopicHandshakeMessage::Done = message else {
63 return Err(TopicHandshakeError::UnexpectedMessage(message));
64 };
65
66 sink.send(TopicHandshakeMessage::Done)
68 .await
69 .map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
70
71 self.event_tx
73 .send(TopicHandshakeEvent::Done(self.topic).into())
74 .await?;
75
76 sink.flush()
77 .await
78 .map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
79 self.event_tx.flush().await?;
80
81 Ok(())
82 }
83}
84
85pub struct TopicHandshakeAcceptor<T, Evt> {
89 pub event_tx: mpsc::Sender<Evt>,
90 _phantom: PhantomData<T>,
91}
92
93impl<T, Evt> TopicHandshakeAcceptor<T, Evt>
94where
95 T: Clone + for<'de> Deserialize<'de> + Serialize,
96 Evt: From<TopicHandshakeEvent<T>>,
97{
98 pub fn new(event_tx: mpsc::Sender<Evt>) -> Self {
99 Self {
100 event_tx,
101 _phantom: PhantomData,
102 }
103 }
104}
105
106impl<T, Evt> Protocol for TopicHandshakeAcceptor<T, Evt>
107where
108 T: Clone + Debug + for<'de> Deserialize<'de> + Serialize + Send + Sync + 'static,
109 Evt: From<TopicHandshakeEvent<T>>,
110{
111 type Error = TopicHandshakeError<T>;
112 type Output = T;
113 type Message = TopicHandshakeMessage<T>;
114
115 async fn run(
116 mut self,
117 sink: &mut (impl Sink<Self::Message, Error = impl Debug> + Unpin),
118 stream: &mut (impl Stream<Item = Result<Self::Message, impl Debug>> + Unpin),
119 ) -> Result<Self::Output, Self::Error> {
120 self.event_tx
122 .send(TopicHandshakeEvent::Accept.into())
123 .await?;
124
125 let Some(message) = stream.next().await else {
127 return Err(TopicHandshakeError::UnexpectedStreamClosure);
128 };
129 let message =
130 message.map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
131 let TopicHandshakeMessage::Topic(topic) = message else {
132 return Err(TopicHandshakeError::UnexpectedMessage(message));
133 };
134
135 self.event_tx
137 .send(TopicHandshakeEvent::TopicReceived(topic.clone()).into())
138 .await?;
139
140 sink.send(TopicHandshakeMessage::Done)
142 .await
143 .map_err(|err| TopicHandshakeError::MessageStream(format!("{err:?}")))?;
144
145 let Some(message) = stream.next().await else {
147 return Err(TopicHandshakeError::UnexpectedStreamClosure);
148 };
149 let message =
150 message.map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
151 let TopicHandshakeMessage::Done = message else {
152 return Err(TopicHandshakeError::UnexpectedMessage(message));
153 };
154
155 self.event_tx
157 .send(TopicHandshakeEvent::Done(topic.clone()).into())
158 .await?;
159
160 sink.flush()
161 .await
162 .map_err(|err| TopicHandshakeError::MessageSink(format!("{err:?}")))?;
163 self.event_tx.flush().await?;
164
165 Ok(topic)
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
171#[serde(tag = "type", content = "value")]
172pub enum TopicHandshakeMessage<T> {
173 Topic(T),
174 Done,
175}
176
177#[derive(Clone, Debug, Error)]
179pub enum TopicHandshakeError<T> {
180 #[error("unexpected protocol message: {0}")]
181 UnexpectedMessage(TopicHandshakeMessage<T>),
182
183 #[error("stream ended before protocol completion")]
184 UnexpectedStreamClosure,
185
186 #[error("error sending on message sink: {0}")]
187 MessageSink(String),
188
189 #[error("error receiving from message stream: {0}")]
190 MessageStream(String),
191
192 #[error(transparent)]
193 MpscSend(#[from] mpsc::SendError),
194}
195
196impl<T: std::fmt::Debug> Display for TopicHandshakeError<T> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 write!(f, "{self:?}")
199 }
200}
201
202#[derive(Debug, Clone, PartialEq)]
204pub enum TopicHandshakeEvent<T> {
205 Initiate(T),
206 Accept,
207 TopicReceived(T),
208 Done(T),
209}