1use std::error;
2use std::fmt::{Debug, Display};
3use std::io;
4use std::io::{Read, Write};
5use std::net::{Shutdown, SocketAddr};
6
7use cyphernet::encrypt::noise::NoiseState;
8use cyphernet::proxy::socks5;
9
10use mio::event::Source;
11use mio::net::TcpStream;
12use mio::{Interest, Registry, Token};
13
14pub type NoiseSession<E, D, S> = Protocol<NoiseState<E, D>, S>;
15pub type Socks5Session<S> = Protocol<socks5::Socks5, S>;
16
17pub trait Session: Send + Read + Write {
18 type Inner: Session;
19 type Artifact: Display;
20
21 fn is_established(&self) -> bool {
22 self.artifact().is_some()
23 }
24
25 fn run_handshake(&mut self) -> io::Result<()> {
26 Ok(())
27 }
28
29 fn display(&self) -> String {
30 self.artifact()
31 .map(|artifact| artifact.to_string())
32 .unwrap_or_else(|| "<no-id>".to_string())
33 }
34
35 fn artifact(&self) -> Option<Self::Artifact>;
36
37 fn stream(&mut self) -> &mut TcpStream;
38
39 fn disconnect(self) -> io::Result<()>;
40}
41
42pub trait StateMachine: Sized + Send {
43 const NAME: &'static str;
44
45 type Artifact;
46
47 type Error: error::Error + Send + Sync + 'static;
48
49 fn next_read_len(&self) -> usize;
50
51 fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error>;
52
53 fn artifact(&self) -> Option<Self::Artifact>;
54
55 fn run_handshake<RW>(&mut self, stream: &mut RW) -> io::Result<()>
57 where
58 RW: Read + Write,
59 {
60 let mut input = vec![];
61 while !self.is_complete() {
62 let act = self.advance(&input).map_err(|err| {
63 log::error!(target: Self::NAME, "Handshake failure: {err}");
64 io::Error::other(err)
65 })?;
66 if !act.is_empty() {
67 log::trace!(target: Self::NAME, "Sending handshake act {act:02x?}");
68
69 stream.write_all(&act)?;
70 }
71 if !self.is_complete() {
72 input = vec![0u8; self.next_read_len()];
73 stream.read_exact(&mut input)?;
74
75 log::trace!(target: Self::NAME, "Receiving handshake act {input:02x?}");
76 }
77 }
78
79 log::debug!(target: Self::NAME, "Handshake protocol {} successfully completed", Self::NAME);
80 Ok(())
81 }
82
83 fn is_complete(&self) -> bool {
84 self.artifact().is_some()
85 }
86}
87
88#[derive(Clone, Eq, PartialEq, Hash, Debug)]
89pub struct ProtocolArtifact<M: StateMachine, S: Session> {
90 pub(crate) session: S::Artifact,
91 pub(crate) state: M::Artifact,
92}
93
94impl<M: StateMachine, S: Session> Display for ProtocolArtifact<M, S> {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("ProtocolArtifact")
97 .field("session", &"<omitted>")
98 .field("state", &"<omitted>")
99 .finish()
100 }
101}
102
103#[derive(Copy, Clone, Eq, PartialEq, Debug)]
104pub struct Protocol<M: StateMachine, S: Session> {
105 pub(crate) state: M,
106 pub(crate) session: S,
107}
108
109impl<M: StateMachine, S: Session> Protocol<M, S> {
110 pub fn new(session: S, state_machine: M) -> Self {
111 Self {
112 state: state_machine,
113 session,
114 }
115 }
116}
117
118impl<M: StateMachine, S: Session> io::Read for Protocol<M, S> {
119 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
120 log::trace!(target: M::NAME, "Reading event");
121
122 if self.state.is_complete() || !self.session.is_established() {
123 log::trace!(target: M::NAME, "Passing reading to inner not yet established session");
124 return self.session.read(buf);
125 }
126
127 let len = self.state.next_read_len();
128 let mut input = vec![0u8; len];
129 self.session.read_exact(&mut input)?;
130
131 log::trace!(target: M::NAME, "Received handshake act: {input:02x?}");
132
133 if !input.is_empty() {
134 let output = self.state.advance(&input).map_err(|err| {
135 log::error!(target: M::NAME, "Handshake failure: {err}");
136 io::Error::other(err)
137 })?;
138
139 if !output.is_empty() {
140 log::trace!(target: M::NAME, "Sending handshake act on read: {output:02x?}");
141 self.session.write_all(&output)?;
142 }
143 }
144
145 Ok(0)
146 }
147}
148
149impl<M: StateMachine, S: Session> Write for Protocol<M, S> {
150 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
151 log::trace!(target: M::NAME, "Writing event (state_complete={}, session_established={})", self.state.is_complete(), self.session.is_established());
152
153 if self.state.is_complete() || !self.session.is_established() {
154 log::trace!(target: M::NAME, "Passing writing to inner session");
155 return self.session.write(buf);
156 }
157
158 if self.state.next_read_len() == 0 {
159 log::trace!(target: M::NAME, "Starting handshake protocol");
160
161 let act = self.state.advance(&[]).map_err(|err| {
162 log::error!(target: M::NAME, "Handshake failure: {err}");
163 io::Error::other(err)
164 })?;
165
166 if !act.is_empty() {
167 log::trace!(target: M::NAME, "Sending handshake act on write: {act:02x?}");
168 self.session.write_all(&act)?;
169 } else {
170 log::trace!(target: M::NAME, "Handshake complete, passing data to inner session");
171 return self.session.write(buf);
172 }
173 }
174
175 if buf.is_empty() {
176 Ok(0)
177 } else {
178 Err(io::ErrorKind::Interrupted.into())
179 }
180 }
181
182 fn flush(&mut self) -> io::Result<()> {
183 self.session.flush()
184 }
185}
186
187impl<M: StateMachine, S: Session> Session for Protocol<M, S> {
188 type Inner = S;
189 type Artifact = ProtocolArtifact<M, S>;
190
191 fn run_handshake(&mut self) -> io::Result<()> {
192 log::debug!(target: M::NAME, "Starting handshake protocol {}", M::NAME);
193
194 if !self.session.is_established() {
195 self.session.run_handshake()?;
196 }
197
198 self.state.run_handshake(self.session.stream())
199 }
200
201 fn artifact(&self) -> Option<Self::Artifact> {
202 Some(ProtocolArtifact {
203 session: self.session.artifact()?,
204 state: self.state.artifact()?,
205 })
206 }
207
208 fn stream(&mut self) -> &mut TcpStream {
209 self.session.stream()
210 }
211
212 fn disconnect(self) -> io::Result<()> {
213 self.session.disconnect()
214 }
215}
216
217impl<M: StateMachine, S: Session + Source> Source for Protocol<M, S> {
218 fn register(
219 &mut self,
220 registry: &Registry,
221 token: Token,
222 interests: Interest,
223 ) -> io::Result<()> {
224 self.session.register(registry, token, interests)
225 }
226
227 fn reregister(
228 &mut self,
229 registry: &Registry,
230 token: Token,
231 interests: Interest,
232 ) -> io::Result<()> {
233 self.session.reregister(registry, token, interests)
234 }
235
236 fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
237 self.session.deregister(registry)
238 }
239}
240
241impl Session for TcpStream {
242 type Inner = Self;
243 type Artifact = SocketAddr;
244
245 fn artifact(&self) -> Option<Self::Artifact> {
246 self.peer_addr().ok()
247 }
248
249 fn stream(&mut self) -> &mut TcpStream {
250 self
251 }
252
253 fn disconnect(self) -> io::Result<()> {
254 self.shutdown(Shutdown::Both)
255 }
256}
257
258mod impl_noise {
259 use cyphernet::encrypt::noise::{error::NoiseError as Error, NoiseState as Noise};
260 use cyphernet::{Digest, Ecdh};
261
262 use super::*;
263
264 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
265 pub struct NoiseArtifact<E: Ecdh, D: Digest> {
266 pub handshake_hash: D::Output,
267 pub remote_static_key: Option<E::Pk>,
268 }
269
270 impl<E: Ecdh, D: Digest> StateMachine for Noise<E, D> {
271 const NAME: &'static str = "noise";
272 type Artifact = NoiseArtifact<E, D>;
273 type Error = Error;
274
275 fn next_read_len(&self) -> usize {
276 self.next_read_len()
277 }
278
279 fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
280 self.advance(input)
281 }
282
283 fn artifact(&self) -> Option<Self::Artifact> {
284 self.get_handshake_hash().map(|hh| NoiseArtifact {
285 handshake_hash: hh,
286 remote_static_key: self.get_remote_static_key(),
287 })
288 }
289 }
290}
291
292mod impl_socks5 {
293 use cyphernet::addr::{Host as _, HostName, NetAddr};
294 use cyphernet::proxy::socks5::{Error, Socks5};
295
296 use super::*;
297
298 impl StateMachine for Socks5 {
299 const NAME: &'static str = "socks5";
300
301 type Artifact = NetAddr<HostName>;
302 type Error = Error;
303
304 fn next_read_len(&self) -> usize {
305 self.next_read_len()
306 }
307
308 fn advance(&mut self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
309 self.advance(input)
310 }
311
312 fn artifact(&self) -> Option<Self::Artifact> {
313 match self {
314 Socks5::Initial(addr, false) if !addr.requires_proxy() => Some(addr.clone()),
315 Socks5::Active(addr) => Some(addr.clone()),
316 _ => None,
317 }
318 }
319 }
320}