Skip to main content

radicle_node/reactor/
session.rs

1use std::error;
2use std::fmt::{Debug, Display};
3use std::io;
4use std::io::{Read, Write};
5use std::net::{Shutdown, SocketAddr};
6
7use cyphernet::encrypt::noise::NoiseState;
8use cyphernet::proxy::socks5;
9
10use mio::event::Source;
11use mio::net::TcpStream;
12use mio::{Interest, Registry, Token};
13
14pub type NoiseSession<E, D, S> = Protocol<NoiseState<E, D>, S>;
15pub type Socks5Session<S> = Protocol<socks5::Socks5, S>;
16
17pub trait Session: Send + Read + Write {
18    type Inner: Session;
19    type Artifact: Display;
20
21    fn is_established(&self) -> bool {
22        self.artifact().is_some()
23    }
24
25    fn run_handshake(&mut self) -> io::Result<()> {
26        Ok(())
27    }
28
29    fn display(&self) -> String {
30        self.artifact()
31            .map(|artifact| artifact.to_string())
32            .unwrap_or_else(|| "<no-id>".to_string())
33    }
34
35    fn artifact(&self) -> Option<Self::Artifact>;
36
37    fn stream(&mut self) -> &mut TcpStream;
38
39    fn disconnect(self) -> io::Result<()>;
40}
41
42pub trait StateMachine: Sized + Send {
43    const NAME: &'static str;
44
45    type Artifact;
46
47    type Error: error::Error + Send + Sync + 'static;
48
49    fn next_read_len(&self) -> usize;
50
51    fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error>;
52
53    fn artifact(&self) -> Option<Self::Artifact>;
54
55    // Blocking
56    fn run_handshake<RW>(&mut self, stream: &mut RW) -> io::Result<()>
57    where
58        RW: Read + Write,
59    {
60        let mut input = vec![];
61        while !self.is_complete() {
62            let act = self.advance(&input).map_err(|err| {
63                log::error!(target: Self::NAME, "Handshake failure: {err}");
64                io::Error::other(err)
65            })?;
66            if !act.is_empty() {
67                log::trace!(target: Self::NAME, "Sending handshake act {act:02x?}");
68
69                stream.write_all(&act)?;
70            }
71            if !self.is_complete() {
72                input = vec![0u8; self.next_read_len()];
73                stream.read_exact(&mut input)?;
74
75                log::trace!(target: Self::NAME, "Receiving handshake act {input:02x?}");
76            }
77        }
78
79        log::debug!(target: Self::NAME, "Handshake protocol {} successfully completed", Self::NAME);
80        Ok(())
81    }
82
83    fn is_complete(&self) -> bool {
84        self.artifact().is_some()
85    }
86}
87
88#[derive(Clone, Eq, PartialEq, Hash, Debug)]
89pub struct ProtocolArtifact<M: StateMachine, S: Session> {
90    pub(crate) session: S::Artifact,
91    pub(crate) state: M::Artifact,
92}
93
94impl<M: StateMachine, S: Session> Display for ProtocolArtifact<M, S> {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("ProtocolArtifact")
97            .field("session", &"<omitted>")
98            .field("state", &"<omitted>")
99            .finish()
100    }
101}
102
103#[derive(Copy, Clone, Eq, PartialEq, Debug)]
104pub struct Protocol<M: StateMachine, S: Session> {
105    pub(crate) state: M,
106    pub(crate) session: S,
107}
108
109impl<M: StateMachine, S: Session> Protocol<M, S> {
110    pub fn new(session: S, state_machine: M) -> Self {
111        Self {
112            state: state_machine,
113            session,
114        }
115    }
116}
117
118impl<M: StateMachine, S: Session> io::Read for Protocol<M, S> {
119    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
120        log::trace!(target: M::NAME, "Reading event");
121
122        if self.state.is_complete() || !self.session.is_established() {
123            log::trace!(target: M::NAME, "Passing reading to inner not yet established session");
124            return self.session.read(buf);
125        }
126
127        let len = self.state.next_read_len();
128        let mut input = vec![0u8; len];
129        self.session.read_exact(&mut input)?;
130
131        log::trace!(target: M::NAME, "Received handshake act: {input:02x?}");
132
133        if !input.is_empty() {
134            let output = self.state.advance(&input).map_err(|err| {
135                log::error!(target: M::NAME, "Handshake failure: {err}");
136                io::Error::other(err)
137            })?;
138
139            if !output.is_empty() {
140                log::trace!(target: M::NAME, "Sending handshake act on read: {output:02x?}");
141                self.session.write_all(&output)?;
142            }
143        }
144
145        Ok(0)
146    }
147}
148
149impl<M: StateMachine, S: Session> Write for Protocol<M, S> {
150    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
151        log::trace!(target: M::NAME, "Writing event (state_complete={}, session_established={})", self.state.is_complete(), self.session.is_established());
152
153        if self.state.is_complete() || !self.session.is_established() {
154            log::trace!(target: M::NAME, "Passing writing to inner session");
155            return self.session.write(buf);
156        }
157
158        if self.state.next_read_len() == 0 {
159            log::trace!(target: M::NAME, "Starting handshake protocol");
160
161            let act = self.state.advance(&[]).map_err(|err| {
162                log::error!(target: M::NAME, "Handshake failure: {err}");
163                io::Error::other(err)
164            })?;
165
166            if !act.is_empty() {
167                log::trace!(target: M::NAME, "Sending handshake act on write: {act:02x?}");
168                self.session.write_all(&act)?;
169            } else {
170                log::trace!(target: M::NAME, "Handshake complete, passing data to inner session");
171                return self.session.write(buf);
172            }
173        }
174
175        if buf.is_empty() {
176            Ok(0)
177        } else {
178            Err(io::ErrorKind::Interrupted.into())
179        }
180    }
181
182    fn flush(&mut self) -> io::Result<()> {
183        self.session.flush()
184    }
185}
186
187impl<M: StateMachine, S: Session> Session for Protocol<M, S> {
188    type Inner = S;
189    type Artifact = ProtocolArtifact<M, S>;
190
191    fn run_handshake(&mut self) -> io::Result<()> {
192        log::debug!(target: M::NAME, "Starting handshake protocol {}", M::NAME);
193
194        if !self.session.is_established() {
195            self.session.run_handshake()?;
196        }
197
198        self.state.run_handshake(self.session.stream())
199    }
200
201    fn artifact(&self) -> Option<Self::Artifact> {
202        Some(ProtocolArtifact {
203            session: self.session.artifact()?,
204            state: self.state.artifact()?,
205        })
206    }
207
208    fn stream(&mut self) -> &mut TcpStream {
209        self.session.stream()
210    }
211
212    fn disconnect(self) -> io::Result<()> {
213        self.session.disconnect()
214    }
215}
216
217impl<M: StateMachine, S: Session + Source> Source for Protocol<M, S> {
218    fn register(
219        &mut self,
220        registry: &Registry,
221        token: Token,
222        interests: Interest,
223    ) -> io::Result<()> {
224        self.session.register(registry, token, interests)
225    }
226
227    fn reregister(
228        &mut self,
229        registry: &Registry,
230        token: Token,
231        interests: Interest,
232    ) -> io::Result<()> {
233        self.session.reregister(registry, token, interests)
234    }
235
236    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
237        self.session.deregister(registry)
238    }
239}
240
241impl Session for TcpStream {
242    type Inner = Self;
243    type Artifact = SocketAddr;
244
245    fn artifact(&self) -> Option<Self::Artifact> {
246        self.peer_addr().ok()
247    }
248
249    fn stream(&mut self) -> &mut TcpStream {
250        self
251    }
252
253    fn disconnect(self) -> io::Result<()> {
254        self.shutdown(Shutdown::Both)
255    }
256}
257
258mod impl_noise {
259    use cyphernet::encrypt::noise::{error::NoiseError as Error, NoiseState as Noise};
260    use cyphernet::{Digest, Ecdh};
261
262    use super::*;
263
264    #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
265    pub struct NoiseArtifact<E: Ecdh, D: Digest> {
266        pub handshake_hash: D::Output,
267        pub remote_static_key: Option<E::Pk>,
268    }
269
270    impl<E: Ecdh, D: Digest> StateMachine for Noise<E, D> {
271        const NAME: &'static str = "noise";
272        type Artifact = NoiseArtifact<E, D>;
273        type Error = Error;
274
275        fn next_read_len(&self) -> usize {
276            self.next_read_len()
277        }
278
279        fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
280            self.advance(input)
281        }
282
283        fn artifact(&self) -> Option<Self::Artifact> {
284            self.get_handshake_hash().map(|hh| NoiseArtifact {
285                handshake_hash: hh,
286                remote_static_key: self.get_remote_static_key(),
287            })
288        }
289    }
290}
291
292mod impl_socks5 {
293    use cyphernet::addr::{Host as _, HostName, NetAddr};
294    use cyphernet::proxy::socks5::{Error, Socks5};
295
296    use super::*;
297
298    impl StateMachine for Socks5 {
299        const NAME: &'static str = "socks5";
300
301        type Artifact = NetAddr<HostName>;
302        type Error = Error;
303
304        fn next_read_len(&self) -> usize {
305            self.next_read_len()
306        }
307
308        fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
309            self.advance(input)
310        }
311
312        fn artifact(&self) -> Option<Self::Artifact> {
313            match self {
314                Socks5::Initial(addr, false) if !addr.requires_proxy() => Some(addr.clone()),
315                Socks5::Active(addr) => Some(addr.clone()),
316                _ => None,
317            }
318        }
319    }
320}