1use std::io;
5use std::time;
6
7use thiserror::Error;
8use tracing::*;
9
10use ecksport_core::frame::{FrameBody, FrameType, GoodbyeData, ServerDumpData};
11use ecksport_core::peer::PeerData;
12use ecksport_core::state_mach::{self, Action, ClientMeta, ReadyData, ServerMeta, ServerState};
13use ecksport_core::traits::AuthConfig;
14use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame};
15use ecksport_core::{errors::*, topic};
16
17#[derive(Clone, Debug)]
18pub struct HandshakeOptions {
19 timeout: time::Duration,
20}
21
22impl HandshakeOptions {
23 pub fn new(timeout: time::Duration) -> Self {
24 Self { timeout }
25 }
26}
27
28#[derive(Debug, Error)]
29pub enum HandshakeError {
30 #[error("timed out")]
31 Timeout,
32
33 #[error("we closed connection: {0:?}")]
34 Abort(GoodbyeData),
35
36 #[error("connection closed on us: {0:?}")]
37 Closed(GoodbyeData),
38
39 #[error("we're happy and we don't know why")]
40 ClosedUnclear,
41
42 #[error("unexpected frame type {0:?}")]
43 UnexpectedFrameType(FrameType),
44
45 #[error("conn: {0}")]
46 Conn(#[from] ConnError),
47
48 #[error("io: {0}")]
49 Io(#[from] io::Error),
50}
51
52#[derive(Clone, Debug)]
54pub struct HandshakeStatus {
55 ready_data: ReadyData,
56 peer_data: PeerData,
57}
58
59impl HandshakeStatus {
60 fn new(ready_data: ReadyData, peer_data: PeerData) -> Self {
61 Self {
62 ready_data,
63 peer_data,
64 }
65 }
66
67 pub fn ready(&self) -> &ReadyData {
68 &self.ready_data
69 }
70
71 pub fn peer(&self) -> &PeerData {
72 &self.peer_data
73 }
74
75 pub fn into_peer(self) -> PeerData {
76 self.peer_data
77 }
78}
79
80pub async fn do_client_handshake_async<
81 T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin,
82 A: AuthConfig,
83>(
84 conn: &mut T,
85 protocol: topic::Topic,
86 meta: &ClientMeta,
87 opts: &HandshakeOptions,
88 auth: A,
89 mut peer: PeerData,
90) -> Result<HandshakeStatus, HandshakeError> {
91 let (mut cstate, mut actions) = state_mach::exec_client_init(protocol, auth, meta);
92
93 loop {
94 match perform_actions(&actions, conn, &mut peer).await? {
96 ActionStatus::FinishReady(rd) => {
97 let hs = HandshakeStatus::new(rd, peer);
98 return Ok(hs);
99 }
100 ActionStatus::FinishExit => return Err(HandshakeError::ClosedUnclear),
101 _ => {}
102 }
103
104 let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
106 Ok(v) => v?,
107 Err(_) => return Err(HandshakeError::Timeout),
108 };
109
110 (cstate, actions) = state_mach::exec_client_inp(&cstate, &m, meta);
112 }
113}
114
115pub async fn do_server_handshake_async<
116 T: AsyncSendFrame + AsyncRecvFrame + Sync + Send + Unpin,
117 A: AuthConfig,
118>(
119 conn: &mut T,
120 meta: &ServerMeta,
121 opts: &HandshakeOptions,
122 auth: A,
123 mut peer: PeerData,
124) -> Result<Option<HandshakeStatus>, HandshakeError> {
125 let mut sstate = ServerState::init(auth);
126
127 loop {
128 let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
130 Ok(v) => v?,
131 Err(_) => return Err(HandshakeError::Timeout),
132 };
133
134 let (new_state, actions) = state_mach::exec_server_inp(&sstate, &m, meta);
135 sstate = new_state;
136
137 match perform_actions(&actions, conn, &mut peer).await? {
139 ActionStatus::FinishReady(rd) => {
140 let hs = HandshakeStatus::new(rd, peer);
141 return Ok(Some(hs));
142 }
143 ActionStatus::FinishExit => return Ok(None),
144 _ => {}
145 }
146 }
147}
148
149pub enum ActionStatus {
152 Continue,
153 FinishReady(ReadyData),
154 FinishExit,
155}
156
157async fn perform_actions<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
158 actions: impl IntoIterator<Item = &Action>,
159 iow: &mut T,
160 peer: &mut PeerData,
161) -> Result<ActionStatus, HandshakeError> {
162 let mut status = ActionStatus::Continue;
163
164 for a in actions {
165 let res = perform_action(a, iow, peer).await?;
166 if !matches!(res, ActionStatus::Continue) {
167 status = res;
168 }
169 }
170
171 Ok(status)
172}
173
174async fn perform_action<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
175 action: &Action,
176 iow: &mut T,
177 peer: &mut PeerData,
178) -> Result<ActionStatus, HandshakeError> {
179 match action {
180 Action::SendFrame(frame) => {
181 iow.send_frame_async(&frame).await?;
182 Ok(ActionStatus::Continue)
183 }
184
185 Action::Ready(rd) => {
186 trace!("readying connection");
187 return Ok(ActionStatus::FinishReady(rd.clone()));
188 }
189
190 Action::Abort(gd) => {
191 let fb = FrameBody::Goodbye(gd.clone());
192 iow.send_frame_async(&fb).await?;
193 Err(HandshakeError::Abort(gd.clone()))
194 }
195
196 Action::Exit(gd) => Err(HandshakeError::Closed(gd.clone())),
197
198 Action::SetRemoteAgent(agent) => {
199 peer.set_agent(agent.clone());
200 Ok(ActionStatus::Continue)
201 }
202
203 Action::SetRemoteIdent(ident) => {
204 peer.set_identity(ident.clone());
205 Ok(ActionStatus::Continue)
206 }
207
208 Action::HappyClose => Ok(ActionStatus::FinishExit),
209 }
210}
211
212pub async fn do_query_async<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
213 conn: &mut T,
214 meta: &ClientMeta,
215 opts: &HandshakeOptions,
216) -> Result<ServerDumpData, HandshakeError> {
217 let frame = meta.create_query_frame();
219 conn.send_frame_async(&frame).await?;
220
221 let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
223 Ok(v) => v?,
224 Err(_) => return Err(HandshakeError::Timeout),
225 };
226
227 match m {
229 FrameBody::ServerDump(sdump) => Ok(sdump),
230 _ => Err(HandshakeError::UnexpectedFrameType(m.ty())),
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 use std::net;
239
240 use ecksport_core::{stream_framing, topic};
241 use tokio::task::JoinHandle;
242
243 const DUMMY_TOPIC: topic::Topic = topic::Topic::from_const_str("dummytop");
244
245 #[tokio::test]
246 async fn test_handshake() {
247 let cmeta = ClientMeta::new("/test/client/".to_owned());
248 let smeta = ServerMeta::new("/test/server/".to_owned(), vec![DUMMY_TOPIC]);
249 let opts = HandshakeOptions {
250 timeout: time::Duration::from_millis(1000),
251 };
252 let opts2 = opts.clone();
253
254 let addr = "127.0.0.1:31827"
256 .parse::<net::SocketAddr>()
257 .expect("test: parse socketaddr");
258
259 let sock = tokio::net::TcpListener::bind(addr)
260 .await
261 .expect("test: tcp listen");
262
263 let shj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
264 let (conn, _) = sock.accept().await?;
265 let mut framer = stream_framing::StreamFramer::new(conn);
266 let pd = PeerData::default();
267 do_server_handshake_async(&mut framer, &smeta, &opts, (), pd).await?;
268 Ok(())
269 });
270
271 let chj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
272 let conn = tokio::net::TcpStream::connect(addr).await?;
273 let mut framer = stream_framing::StreamFramer::new(conn);
274 let pd = PeerData::default();
275 do_client_handshake_async(&mut framer, DUMMY_TOPIC, &cmeta, &opts2, (), pd).await?;
276 Ok(())
277 });
278
279 shj.await
280 .expect("test: server wait")
281 .expect("test: server connect");
282
283 chj.await
284 .expect("test: client wait")
285 .expect("test: client connect");
286 }
287
288 #[tokio::test]
289 async fn test_query() {
290 let cmeta = ClientMeta::new("/test/client/".to_owned());
291 let smeta = ServerMeta::new("/test/server/".to_owned(), vec![DUMMY_TOPIC]);
292 let opts = HandshakeOptions {
293 timeout: time::Duration::from_millis(1000),
294 };
295 let opts2 = opts.clone();
296
297 let addr = "127.0.0.1:31828"
299 .parse::<net::SocketAddr>()
300 .expect("test: parse socketaddr");
301
302 let sock = tokio::net::TcpListener::bind(addr)
303 .await
304 .expect("test: tcp listen");
305
306 let shj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
307 let (conn, _) = sock.accept().await?;
308 let mut framer = stream_framing::StreamFramer::new(conn);
309 let pd = PeerData::default();
310 do_server_handshake_async(&mut framer, &smeta, &opts, (), pd).await?;
311 Ok(())
312 });
313
314 let chj: JoinHandle<Result<ServerDumpData, HandshakeError>> = tokio::spawn(async move {
315 let conn = tokio::net::TcpStream::connect(addr).await?;
316 let mut framer = stream_framing::StreamFramer::new(conn);
317
318 let sdump = do_query_async(&mut framer, &cmeta, &opts2).await?;
319 Ok(sdump)
320 });
321
322 let _ = shj.await.expect("test: server wait");
323
324 let sdump = chj
325 .await
326 .expect("test: client wait")
327 .expect("test: client connect");
328
329 assert_eq!(sdump.agent(), "/test/server/");
330 assert_eq!(sdump.protocols(), &[DUMMY_TOPIC]);
331 }
332}