1use std::fmt;
2use std::ops::{Deref, DerefMut};
3use std::sync::mpsc::{self, SyncSender};
4use std::sync::{Arc, Mutex};
5use std::thread;
6
7use log::*;
8
9use crate::auth_service::AuthService;
10use crate::byte_stream::ByteStream;
11use crate::connector::Connector;
12use crate::model::dao::*;
13use crate::model::model::*;
14use crate::model::Error;
15use crate::relay::{self, RelayHandle};
16use crate::rw_socks_stream::ReadWriteStream;
17use crate::server_command::ServerCommand;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct SessionId(pub u32);
21
22impl From<u32> for SessionId {
23 fn from(id: u32) -> Self {
24 Self(id)
25 }
26}
27
28impl fmt::Display for SessionId {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 write!(f, "SessionId({})", self.0)
31 }
32}
33
34#[derive(Debug)]
35pub struct SessionHandle {
36 addr: SocketAddr,
38 handle: thread::JoinHandle<Result<RelayHandle, Error>>,
40 tx: SyncSender<()>,
42}
43
44impl SessionHandle {
45 pub fn new(
46 addr: SocketAddr,
47 handle: thread::JoinHandle<Result<RelayHandle, Error>>,
48 tx: SyncSender<()>,
49 ) -> Self {
50 Self { addr, handle, tx }
51 }
52
53 pub fn client_addr(&self) -> SocketAddr {
54 self.addr
55 }
56
57 pub fn stop(&self) {
58 trace!("stop session: {}", self.addr);
59 if self.tx.send(()).is_ok() {
62 self.tx.send(()).ok();
64 }
65 }
66
67 pub fn join(self) -> thread::Result<Result<(), Error>> {
68 trace!("join session: {}", self.addr);
69 match self.handle.join()? {
70 Ok(relay) => relay.join(),
71 Err(err) => Ok(Err(err)),
72 }
73 }
74}
75
76#[derive(Debug)]
77pub struct Session<D, A, S> {
78 pub id: SessionId,
79 pub version: ProtocolVersion,
80 pub dst_connector: D,
81 pub authorizer: A,
82 pub server_addr: SocketAddr,
83 pub conn_rule: ConnectRule,
84 rx: Arc<Mutex<mpsc::Receiver<()>>>,
86 guard: Arc<Mutex<DisconnectGuard<S>>>,
89}
90
91impl<D, A, S> Session<D, A, S>
92where
93 D: Connector,
94 A: AuthService,
95 S: Send + 'static,
96{
97 pub fn new(
99 id: SessionId,
100 version: ProtocolVersion,
101 dst_connector: D,
102 authorizer: A,
103 server_addr: SocketAddr,
104 conn_rule: ConnectRule,
105 tx_cmd: mpsc::Sender<ServerCommand<S>>,
106 ) -> (Self, mpsc::SyncSender<()>) {
107 let (tx, rx) = mpsc::sync_channel(2);
108 (
109 Self {
110 id,
111 version,
112 dst_connector,
113 authorizer,
114 server_addr,
115 conn_rule,
116 rx: Arc::new(Mutex::new(rx)),
117 guard: Arc::new(Mutex::new(DisconnectGuard::new(id, tx_cmd))),
118 },
119 tx,
120 )
121 }
122
123 fn connect_reply(&self, connect_result: Result<(), ConnectError>) -> ConnectReply {
124 ConnectReply {
125 version: self.version,
126 connect_result,
127 server_addr: self.server_addr.into(),
128 }
129 }
130
131 fn make_session<'a>(
132 &self,
133 src_addr: SocketAddr,
134 mut src_conn: impl ByteStream + 'a,
135 ) -> Result<RelayHandle, Error> {
136 let mut socks = ReadWriteStream::new(&mut src_conn);
137
138 let select = negotiate_auth_method(self.version, &self.authorizer, &mut socks)?;
139 debug!("auth method: {:?}", select);
140 let mut socks = ReadWriteStream::new(self.authorizer.authorize(select.method, src_conn)?);
141
142 let req = socks.recv_connect_request()?;
143 debug!("connect request: {:?}", req);
144
145 let (conn, dst_addr) = match perform_command(
146 req.command,
147 &self.dst_connector,
148 &self.conn_rule,
149 req.connect_to.clone(),
150 ) {
151 Ok((conn, dst_addr)) => {
152 info!("connected: {}: {}", req.connect_to, dst_addr);
153 socks.send_connect_reply(self.connect_reply(Ok(())))?;
154 (conn, dst_addr)
155 }
156 Err(err) => {
157 error!("command error: {}", err);
158 trace!("command error: {:?}", err);
159 socks.send_connect_reply(self.connect_reply(Err(err.cerr())))?;
161 return Err(err);
162 }
163 };
164
165 relay::spawn_relay(
166 src_addr,
167 dst_addr,
168 socks.into_inner(),
169 conn,
170 self.rx.clone(),
171 self.guard.clone(),
172 )
173 }
174
175 pub fn start<'a>(
176 self,
177 src_addr: SocketAddr,
178 src_conn: impl ByteStream + 'a,
179 ) -> Result<RelayHandle, Error> {
180 self.make_session(src_addr, src_conn)
181 }
182}
183
184fn perform_command(
185 cmd: Command,
186 connector: impl Deref<Target = impl Connector>,
187 rule: &ConnectRule,
188 connect_to: Address,
189) -> Result<(impl ByteStream, SocketAddr), Error> {
190 match cmd {
191 Command::Connect => {}
192 cmd @ Command::Bind | cmd @ Command::UdpAssociate => {
193 return Err(Error::command_not_supported(cmd));
194 }
195 };
196 check_rule(rule, connect_to.clone(), L4Protocol::Tcp)?;
198 connector.connect_byte_stream(connect_to)
199}
200
201fn negotiate_auth_method(
202 version: ProtocolVersion,
203 auth: impl Deref<Target = impl AuthService>,
204 mut socks: impl DerefMut<Target = impl SocksStream>,
205) -> Result<MethodSelection, Error> {
206 let candidates = socks.recv_method_candidates()?;
207 trace!("candidates: {:?}", candidates);
208
209 let selection = auth.select(&candidates.method)?;
210 trace!("selection: {:?}", selection);
211
212 let method_sel = MethodSelection {
213 version,
214 method: selection.unwrap_or(Method::NoMethods),
215 };
216 socks.send_method_selection(method_sel)?;
217 match method_sel.method {
218 Method::NoMethods => Err(Error::NoAcceptableMethod),
219 _ => Ok(method_sel),
220 }
221}
222
223fn check_rule(rule: &ConnectRule, addr: Address, proto: L4Protocol) -> Result<(), Error> {
224 if rule.check(addr.clone(), proto) {
225 Ok(())
226 } else {
227 Err(Error::connection_not_allowed(addr, proto))
228 }
229}
230
231#[derive(Debug, Clone)]
232pub struct DisconnectGuard<S> {
233 id: SessionId,
234 tx: mpsc::Sender<ServerCommand<S>>,
235}
236
237impl<S> DisconnectGuard<S> {
238 pub fn new(id: SessionId, tx: mpsc::Sender<ServerCommand<S>>) -> Self {
239 Self { id, tx }
240 }
241}
242
243impl<S> Drop for DisconnectGuard<S> {
244 fn drop(&mut self) {
245 debug!("DisconnectGuard: {}", self.id);
246 self.tx.send(ServerCommand::Disconnect(self.id)).unwrap()
247 }
248}
249
250#[cfg(test)]
251mod test {
252 use super::*;
253 use crate::auth_service::test::RejectService;
254 use crate::byte_stream::test::BufferStream;
255 use crate::connector::test::BufferConnector;
256 use crate::rw_socks_stream as socks;
257 use std::io;
258 use std::iter::FromIterator;
259 use std::str::FromStr;
260
261 #[test]
262 fn no_acceptable_method() {
263 let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
264 let (session, _) = Session::new(
265 0.into(),
266 5.into(),
267 BufferConnector::from_iter(vec![(
268 "192.168.0.1:5123".parse().unwrap(),
269 Ok(BufferStream::new()),
270 )]),
271 RejectService,
272 "0.0.0.0:1080".parse().unwrap(),
273 ConnectRule::any(),
274 tx,
275 );
276 println!("session: {:?}", session);
277 let src = BufferStream::with_buffer(vec![5, 1, 0].into(), vec![].into());
278 assert!(matches!(
279 session
280 .make_session("192.168.0.2:12345".parse().unwrap(), src)
281 .unwrap_err(),
282 Error::NoAcceptableMethod
283 ));
284 }
285
286 #[test]
287 fn command_not_supported() {
288 use crate::auth_service::NoAuthService;
289 let mcand = MethodCandidates::new(&[Method::NoAuth]);
290 let req = ConnectRequest::udp_associate(Address::from_str("192.168.0.1:5123").unwrap());
292 let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
293 let (session, _) = Session::new(
294 1.into(),
295 5.into(),
296 BufferConnector::from_iter(vec![(req.connect_to.clone(), Ok(BufferStream::new()))]),
297 NoAuthService::new(),
298 "0.0.0.0:1080".parse().unwrap(),
299 ConnectRule::any(),
300 tx,
301 );
302 println!("session: {:?}", session);
303
304 let buff = {
305 let mut cursor = io::Cursor::new(vec![]);
306 socks::test::write_method_candidates(&mut cursor, mcand).unwrap();
307 socks::test::write_connect_request(&mut cursor, req).unwrap();
308 cursor.into_inner()
309 };
310 let src = BufferStream::with_buffer(buff.into(), vec![].into());
311 assert!(matches!(
312 session
313 .make_session("192.168.1.1:34567".parse().unwrap(), src)
314 .unwrap_err(),
315 Error::CommandNotSupported {
316 cmd: Command::UdpAssociate
317 }
318 ));
319 }
320
321 #[test]
322 fn connect_not_allowed() {
323 use crate::auth_service::NoAuthService;
324 let version: ProtocolVersion = 5.into();
325 let connect_to = Address::from_str("192.168.0.1:5123").unwrap();
326 let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
327 let (session, _) = Session::new(
328 2.into(),
329 version,
330 BufferConnector::from_iter(vec![(connect_to.clone(), Ok(BufferStream::new()))]),
331 NoAuthService::new(),
332 "0.0.0.0:1080".parse().unwrap(),
333 ConnectRule::none(),
334 tx,
335 );
336 println!("session: {:?}", session);
337
338 let buff = {
339 let mut cursor = io::Cursor::new(vec![]);
340 socks::test::write_method_candidates(
341 &mut cursor,
342 MethodCandidates::new(&[Method::NoAuth]),
343 )
344 .unwrap();
345 socks::test::write_connect_request(
346 &mut cursor,
347 ConnectRequest::connect_to(connect_to.clone()),
348 )
349 .unwrap();
350 cursor.into_inner()
351 };
352 let src = BufferStream::with_buffer(buff.into(), vec![].into());
353 assert!(matches!(
354 session
355 .make_session("192.168.1.1:34567".parse().unwrap(), src)
356 .unwrap_err(),
357 Error::ConnectionNotAllowed { addr, protocol: L4Protocol::Tcp } if addr == connect_to
358 ));
359 }
360
361 #[test]
362 fn connection_refused() {
363 use crate::auth_service::NoAuthService;
364 let version: ProtocolVersion = 5.into();
365 let connect_to = Address::from_str("192.168.0.1:5123").unwrap();
366 let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
367 let (session, _) = Session::new(
368 3.into(),
369 version,
370 BufferConnector::<BufferStream>::from_iter(vec![(
371 connect_to.clone(),
372 Err(ConnectError::ConnectionRefused),
373 )]),
374 NoAuthService::new(),
375 "0.0.0.0:1080".parse().unwrap(),
376 ConnectRule::any(),
377 tx,
378 );
379 println!("session: {:?}", session);
380
381 let buff = {
382 let mut cursor = io::Cursor::new(vec![]);
383 socks::test::write_method_candidates(
384 &mut cursor,
385 MethodCandidates::new(&[Method::NoAuth]),
386 )
387 .unwrap();
388 socks::test::write_connect_request(
389 &mut cursor,
390 ConnectRequest::connect_to(connect_to.clone()),
391 )
392 .unwrap();
393 cursor.into_inner()
394 };
395 let src = BufferStream::with_buffer(buff.into(), vec![].into());
396 assert!(matches!(
397 session
398 .make_session("192.168.1.1:34567".parse().unwrap(), src)
399 .unwrap_err(),
400 Error::ConnectionRefused { addr, protocol: L4Protocol::Tcp } if addr == connect_to
401 ));
402 }
403
404 fn gen_random_vec(size: usize) -> Vec<u8> {
405 use rand::distributions::Standard;
406 use rand::{thread_rng, Rng};
407 let rng = thread_rng();
408 rng.sample_iter(Standard).take(size).collect()
409 }
410
411 fn vec_from_read<T: io::Read>(mut reader: T) -> Vec<u8> {
412 let mut buff = vec![];
413 reader.read_to_end(&mut buff).unwrap();
414 buff
415 }
416
417 #[test]
418 fn relay_contents() {
419 use crate::auth_service::NoAuthService;
420 use io::Write;
421
422 let version: ProtocolVersion = 5.into();
423 let connect_to = Address::Domain("example.com".into(), 5123);
424 let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
425 let (session, _tx_session_term) = Session::new(
426 4.into(),
427 version,
428 BufferConnector::from_iter(vec![(
429 connect_to.clone(),
430 Ok(BufferStream::with_buffer(
431 gen_random_vec(8200).into(),
432 vec![].into(),
433 )),
434 )]),
435 NoAuthService::new(),
436 "0.0.0.0:1080".parse().unwrap(),
437 ConnectRule::any(),
438 tx,
439 );
440
441 let input_stream_pos;
443 let src = {
444 let mut cursor = io::Cursor::new(vec![]);
446 socks::test::write_method_candidates(
447 &mut cursor,
448 MethodCandidates::new(&[Method::NoAuth]),
449 )
450 .unwrap();
451 socks::test::write_connect_request(
452 &mut cursor,
453 ConnectRequest::connect_to(connect_to.clone()),
454 )
455 .unwrap();
456 input_stream_pos = cursor.position();
457 cursor.write_all(&gen_random_vec(8200)).unwrap();
459 BufferStream::with_buffer(cursor.into_inner().into(), vec![].into())
460 };
461 let dst_connector = session.dst_connector.clone();
462 let relay = session
464 .make_session("192.168.1.2:33333".parse().unwrap(), src.clone())
465 .unwrap();
466 assert!(relay.join().is_ok());
467
468 {
470 src.wr_buff().set_position(0);
472 assert_eq!(
473 socks::test::read_method_selection(&mut *src.wr_buff()).unwrap(),
474 MethodSelection {
475 version,
476 method: Method::NoAuth
477 }
478 );
479 assert_eq!(
480 socks::test::read_connect_reply(&mut *src.wr_buff()).unwrap(),
481 ConnectReply {
482 version,
483 connect_result: Ok(()),
484 server_addr: Address::IpAddr("0.0.0.0".parse().unwrap(), 1080),
485 }
486 );
487 }
488
489 assert_eq!(vec_from_read(&mut *src.wr_buff()), {
492 let mut rd_buff = dst_connector.stream(&connect_to).rd_buff();
493 rd_buff.set_position(0);
494 vec_from_read(&mut *rd_buff)
495 });
496 assert_eq!(
498 {
499 let mut rd_buff = src.rd_buff();
500 rd_buff.set_position(input_stream_pos);
501 vec_from_read(&mut *rd_buff)
502 },
503 {
504 let mut wr_buff = dst_connector.stream(&connect_to).wr_buff();
505 wr_buff.set_position(0);
506 vec_from_read(&mut *wr_buff)
507 }
508 );
509 }
510}