ecksport_net/
handshake.rs

1//! Subroutines for performing the connection setup, including managing any
2//! challenges.
3
4use std::io;
5use std::time;
6
7use thiserror::Error;
8use tracing::*;
9
10use ecksport_core::frame::{FrameBody, FrameType, GoodbyeData, ServerDumpData};
11use ecksport_core::peer::PeerData;
12use ecksport_core::state_mach::{self, Action, ClientMeta, ReadyData, ServerMeta, ServerState};
13use ecksport_core::traits::AuthConfig;
14use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame};
15use ecksport_core::{errors::*, topic};
16
17#[derive(Clone, Debug)]
18pub struct HandshakeOptions {
19    timeout: time::Duration,
20}
21
22impl HandshakeOptions {
23    pub fn new(timeout: time::Duration) -> Self {
24        Self { timeout }
25    }
26}
27
28#[derive(Debug, Error)]
29pub enum HandshakeError {
30    #[error("timed out")]
31    Timeout,
32
33    #[error("we closed connection: {0:?}")]
34    Abort(GoodbyeData),
35
36    #[error("connection closed on us: {0:?}")]
37    Closed(GoodbyeData),
38
39    #[error("we're happy and we don't know why")]
40    ClosedUnclear,
41
42    #[error("unexpected frame type {0:?}")]
43    UnexpectedFrameType(FrameType),
44
45    #[error("conn: {0}")]
46    Conn(#[from] ConnError),
47
48    #[error("io: {0}")]
49    Io(#[from] io::Error),
50}
51
52/// Contains data about how a handshake completed, which resulted in a ready connection.
53#[derive(Clone, Debug)]
54pub struct HandshakeStatus {
55    ready_data: ReadyData,
56    peer_data: PeerData,
57}
58
59impl HandshakeStatus {
60    fn new(ready_data: ReadyData, peer_data: PeerData) -> Self {
61        Self {
62            ready_data,
63            peer_data,
64        }
65    }
66
67    pub fn ready(&self) -> &ReadyData {
68        &self.ready_data
69    }
70
71    pub fn peer(&self) -> &PeerData {
72        &self.peer_data
73    }
74
75    pub fn into_peer(self) -> PeerData {
76        self.peer_data
77    }
78}
79
80pub async fn do_client_handshake_async<
81    T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin,
82    A: AuthConfig,
83>(
84    conn: &mut T,
85    protocol: topic::Topic,
86    meta: &ClientMeta,
87    opts: &HandshakeOptions,
88    auth: A,
89    mut peer: PeerData,
90) -> Result<HandshakeStatus, HandshakeError> {
91    let (mut cstate, mut actions) = state_mach::exec_client_init(protocol, auth, meta);
92
93    loop {
94        // Perform the last action, then wait for an input.
95        match perform_actions(&actions, conn, &mut peer).await? {
96            ActionStatus::FinishReady(rd) => {
97                let hs = HandshakeStatus::new(rd, peer);
98                return Ok(hs);
99            }
100            ActionStatus::FinishExit => return Err(HandshakeError::ClosedUnclear),
101            _ => {}
102        }
103
104        // Recv message or timeout.
105        let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
106            Ok(v) => v?,
107            Err(_) => return Err(HandshakeError::Timeout),
108        };
109
110        // Update the state and decide what to do with it.
111        (cstate, actions) = state_mach::exec_client_inp(&cstate, &m, meta);
112    }
113}
114
115pub async fn do_server_handshake_async<
116    T: AsyncSendFrame + AsyncRecvFrame + Sync + Send + Unpin,
117    A: AuthConfig,
118>(
119    conn: &mut T,
120    meta: &ServerMeta,
121    opts: &HandshakeOptions,
122    auth: A,
123    mut peer: PeerData,
124) -> Result<Option<HandshakeStatus>, HandshakeError> {
125    let mut sstate = ServerState::init(auth);
126
127    loop {
128        // Recv message or timeout.
129        let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
130            Ok(v) => v?,
131            Err(_) => return Err(HandshakeError::Timeout),
132        };
133
134        let (new_state, actions) = state_mach::exec_server_inp(&sstate, &m, meta);
135        sstate = new_state;
136
137        // Perform the action we've decided to do.
138        match perform_actions(&actions, conn, &mut peer).await? {
139            ActionStatus::FinishReady(rd) => {
140                let hs = HandshakeStatus::new(rd, peer);
141                return Ok(Some(hs));
142            }
143            ActionStatus::FinishExit => return Ok(None),
144            _ => {}
145        }
146    }
147}
148
149/// Describes how we should proceed from an action being performed.  This is
150/// interpreted differently depending on if we're the client or the server.
151pub enum ActionStatus {
152    Continue,
153    FinishReady(ReadyData),
154    FinishExit,
155}
156
157async fn perform_actions<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
158    actions: impl IntoIterator<Item = &Action>,
159    iow: &mut T,
160    peer: &mut PeerData,
161) -> Result<ActionStatus, HandshakeError> {
162    let mut status = ActionStatus::Continue;
163
164    for a in actions {
165        let res = perform_action(a, iow, peer).await?;
166        if !matches!(res, ActionStatus::Continue) {
167            status = res;
168        }
169    }
170
171    Ok(status)
172}
173
174async fn perform_action<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
175    action: &Action,
176    iow: &mut T,
177    peer: &mut PeerData,
178) -> Result<ActionStatus, HandshakeError> {
179    match action {
180        Action::SendFrame(frame) => {
181            iow.send_frame_async(&frame).await?;
182            Ok(ActionStatus::Continue)
183        }
184
185        Action::Ready(rd) => {
186            trace!("readying connection");
187            return Ok(ActionStatus::FinishReady(rd.clone()));
188        }
189
190        Action::Abort(gd) => {
191            let fb = FrameBody::Goodbye(gd.clone());
192            iow.send_frame_async(&fb).await?;
193            Err(HandshakeError::Abort(gd.clone()))
194        }
195
196        Action::Exit(gd) => Err(HandshakeError::Closed(gd.clone())),
197
198        Action::SetRemoteAgent(agent) => {
199            peer.set_agent(agent.clone());
200            Ok(ActionStatus::Continue)
201        }
202
203        Action::SetRemoteIdent(ident) => {
204            peer.set_identity(ident.clone());
205            Ok(ActionStatus::Continue)
206        }
207
208        Action::HappyClose => Ok(ActionStatus::FinishExit),
209    }
210}
211
212pub async fn do_query_async<T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin>(
213    conn: &mut T,
214    meta: &ClientMeta,
215    opts: &HandshakeOptions,
216) -> Result<ServerDumpData, HandshakeError> {
217    // First send a client query frame.
218    let frame = meta.create_query_frame();
219    conn.send_frame_async(&frame).await?;
220
221    // Recv message or timeout.
222    let m = match tokio::time::timeout(opts.timeout, conn.recv_frame_async()).await {
223        Ok(v) => v?,
224        Err(_) => return Err(HandshakeError::Timeout),
225    };
226
227    // If it's what we expected then we can return it, otherwise error.
228    match m {
229        FrameBody::ServerDump(sdump) => Ok(sdump),
230        _ => Err(HandshakeError::UnexpectedFrameType(m.ty())),
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    use std::net;
239
240    use ecksport_core::{stream_framing, topic};
241    use tokio::task::JoinHandle;
242
243    const DUMMY_TOPIC: topic::Topic = topic::Topic::from_const_str("dummytop");
244
245    #[tokio::test]
246    async fn test_handshake() {
247        let cmeta = ClientMeta::new("/test/client/".to_owned());
248        let smeta = ServerMeta::new("/test/server/".to_owned(), vec![DUMMY_TOPIC]);
249        let opts = HandshakeOptions {
250            timeout: time::Duration::from_millis(1000),
251        };
252        let opts2 = opts.clone();
253
254        // Random port that we picked.
255        let addr = "127.0.0.1:31827"
256            .parse::<net::SocketAddr>()
257            .expect("test: parse socketaddr");
258
259        let sock = tokio::net::TcpListener::bind(addr)
260            .await
261            .expect("test: tcp listen");
262
263        let shj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
264            let (conn, _) = sock.accept().await?;
265            let mut framer = stream_framing::StreamFramer::new(conn);
266            let pd = PeerData::default();
267            do_server_handshake_async(&mut framer, &smeta, &opts, (), pd).await?;
268            Ok(())
269        });
270
271        let chj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
272            let conn = tokio::net::TcpStream::connect(addr).await?;
273            let mut framer = stream_framing::StreamFramer::new(conn);
274            let pd = PeerData::default();
275            do_client_handshake_async(&mut framer, DUMMY_TOPIC, &cmeta, &opts2, (), pd).await?;
276            Ok(())
277        });
278
279        shj.await
280            .expect("test: server wait")
281            .expect("test: server connect");
282
283        chj.await
284            .expect("test: client wait")
285            .expect("test: client connect");
286    }
287
288    #[tokio::test]
289    async fn test_query() {
290        let cmeta = ClientMeta::new("/test/client/".to_owned());
291        let smeta = ServerMeta::new("/test/server/".to_owned(), vec![DUMMY_TOPIC]);
292        let opts = HandshakeOptions {
293            timeout: time::Duration::from_millis(1000),
294        };
295        let opts2 = opts.clone();
296
297        // Random port that we picked.
298        let addr = "127.0.0.1:31828"
299            .parse::<net::SocketAddr>()
300            .expect("test: parse socketaddr");
301
302        let sock = tokio::net::TcpListener::bind(addr)
303            .await
304            .expect("test: tcp listen");
305
306        let shj: JoinHandle<Result<(), HandshakeError>> = tokio::spawn(async move {
307            let (conn, _) = sock.accept().await?;
308            let mut framer = stream_framing::StreamFramer::new(conn);
309            let pd = PeerData::default();
310            do_server_handshake_async(&mut framer, &smeta, &opts, (), pd).await?;
311            Ok(())
312        });
313
314        let chj: JoinHandle<Result<ServerDumpData, HandshakeError>> = tokio::spawn(async move {
315            let conn = tokio::net::TcpStream::connect(addr).await?;
316            let mut framer = stream_framing::StreamFramer::new(conn);
317
318            let sdump = do_query_async(&mut framer, &cmeta, &opts2).await?;
319            Ok(sdump)
320        });
321
322        let _ = shj.await.expect("test: server wait");
323
324        let sdump = chj
325            .await
326            .expect("test: client wait")
327            .expect("test: client connect");
328
329        assert_eq!(sdump.agent(), "/test/server/");
330        assert_eq!(sdump.protocols(), &[DUMMY_TOPIC]);
331    }
332}