gatekeeper/
session.rs

1use std::fmt;
2use std::ops::{Deref, DerefMut};
3use std::sync::mpsc::{self, SyncSender};
4use std::sync::{Arc, Mutex};
5use std::thread;
6
7use log::*;
8
9use crate::auth_service::AuthService;
10use crate::byte_stream::ByteStream;
11use crate::connector::Connector;
12use crate::model::dao::*;
13use crate::model::model::*;
14use crate::model::Error;
15use crate::relay::{self, RelayHandle};
16use crate::rw_socks_stream::ReadWriteStream;
17use crate::server_command::ServerCommand;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct SessionId(pub u32);
21
22impl From<u32> for SessionId {
23    fn from(id: u32) -> Self {
24        Self(id)
25    }
26}
27
28impl fmt::Display for SessionId {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        write!(f, "SessionId({})", self.0)
31    }
32}
33
34#[derive(Debug)]
35pub struct SessionHandle {
36    /// client address
37    addr: SocketAddr,
38    /// thread performs relay bytes
39    handle: thread::JoinHandle<Result<RelayHandle, Error>>,
40    /// Sender to send termination messages to relay threads
41    tx: SyncSender<()>,
42}
43
44impl SessionHandle {
45    pub fn new(
46        addr: SocketAddr,
47        handle: thread::JoinHandle<Result<RelayHandle, Error>>,
48        tx: SyncSender<()>,
49    ) -> Self {
50        Self { addr, handle, tx }
51    }
52
53    pub fn client_addr(&self) -> SocketAddr {
54        self.addr
55    }
56
57    pub fn stop(&self) {
58        trace!("stop session: {}", self.addr);
59        // ignore disconnected error. if the receiver is deallocated,
60        // relay threads should have been terminated.
61        if self.tx.send(()).is_ok() {
62            // send a message to another side relay
63            self.tx.send(()).ok();
64        }
65    }
66
67    pub fn join(self) -> thread::Result<Result<(), Error>> {
68        trace!("join session: {}", self.addr);
69        match self.handle.join()? {
70            Ok(relay) => relay.join(),
71            Err(err) => Ok(Err(err)),
72        }
73    }
74}
75
76#[derive(Debug)]
77pub struct Session<D, A, S> {
78    pub id: SessionId,
79    pub version: ProtocolVersion,
80    pub dst_connector: D,
81    pub authorizer: A,
82    pub server_addr: SocketAddr,
83    pub conn_rule: ConnectRule,
84    /// termination message receiver
85    rx: Arc<Mutex<mpsc::Receiver<()>>>,
86    /// Send `Disconnect` command to the main thread.
87    /// This guard is shared with 2 relays.
88    guard: Arc<Mutex<DisconnectGuard<S>>>,
89}
90
91impl<D, A, S> Session<D, A, S>
92where
93    D: Connector,
94    A: AuthService,
95    S: Send + 'static,
96{
97    /// Returns Self and termination message sender.
98    pub fn new(
99        id: SessionId,
100        version: ProtocolVersion,
101        dst_connector: D,
102        authorizer: A,
103        server_addr: SocketAddr,
104        conn_rule: ConnectRule,
105        tx_cmd: mpsc::Sender<ServerCommand<S>>,
106    ) -> (Self, mpsc::SyncSender<()>) {
107        let (tx, rx) = mpsc::sync_channel(2);
108        (
109            Self {
110                id,
111                version,
112                dst_connector,
113                authorizer,
114                server_addr,
115                conn_rule,
116                rx: Arc::new(Mutex::new(rx)),
117                guard: Arc::new(Mutex::new(DisconnectGuard::new(id, tx_cmd))),
118            },
119            tx,
120        )
121    }
122
123    fn connect_reply(&self, connect_result: Result<(), ConnectError>) -> ConnectReply {
124        ConnectReply {
125            version: self.version,
126            connect_result,
127            server_addr: self.server_addr.into(),
128        }
129    }
130
131    fn make_session<'a>(
132        &self,
133        src_addr: SocketAddr,
134        mut src_conn: impl ByteStream + 'a,
135    ) -> Result<RelayHandle, Error> {
136        let mut socks = ReadWriteStream::new(&mut src_conn);
137
138        let select = negotiate_auth_method(self.version, &self.authorizer, &mut socks)?;
139        debug!("auth method: {:?}", select);
140        let mut socks = ReadWriteStream::new(self.authorizer.authorize(select.method, src_conn)?);
141
142        let req = socks.recv_connect_request()?;
143        debug!("connect request: {:?}", req);
144
145        let (conn, dst_addr) = match perform_command(
146            req.command,
147            &self.dst_connector,
148            &self.conn_rule,
149            req.connect_to.clone(),
150        ) {
151            Ok((conn, dst_addr)) => {
152                info!("connected: {}: {}", req.connect_to, dst_addr);
153                socks.send_connect_reply(self.connect_reply(Ok(())))?;
154                (conn, dst_addr)
155            }
156            Err(err) => {
157                error!("command error: {}", err);
158                trace!("command error: {:?}", err);
159                // reply error
160                socks.send_connect_reply(self.connect_reply(Err(err.cerr())))?;
161                return Err(err);
162            }
163        };
164
165        relay::spawn_relay(
166            src_addr,
167            dst_addr,
168            socks.into_inner(),
169            conn,
170            self.rx.clone(),
171            self.guard.clone(),
172        )
173    }
174
175    pub fn start<'a>(
176        self,
177        src_addr: SocketAddr,
178        src_conn: impl ByteStream + 'a,
179    ) -> Result<RelayHandle, Error> {
180        self.make_session(src_addr, src_conn)
181    }
182}
183
184fn perform_command(
185    cmd: Command,
186    connector: impl Deref<Target = impl Connector>,
187    rule: &ConnectRule,
188    connect_to: Address,
189) -> Result<(impl ByteStream, SocketAddr), Error> {
190    match cmd {
191        Command::Connect => {}
192        cmd @ Command::Bind | cmd @ Command::UdpAssociate => {
193            return Err(Error::command_not_supported(cmd));
194        }
195    };
196    // filter out request not sufficies the connection rule
197    check_rule(rule, connect_to.clone(), L4Protocol::Tcp)?;
198    connector.connect_byte_stream(connect_to)
199}
200
201fn negotiate_auth_method(
202    version: ProtocolVersion,
203    auth: impl Deref<Target = impl AuthService>,
204    mut socks: impl DerefMut<Target = impl SocksStream>,
205) -> Result<MethodSelection, Error> {
206    let candidates = socks.recv_method_candidates()?;
207    trace!("candidates: {:?}", candidates);
208
209    let selection = auth.select(&candidates.method)?;
210    trace!("selection: {:?}", selection);
211
212    let method_sel = MethodSelection {
213        version,
214        method: selection.unwrap_or(Method::NoMethods),
215    };
216    socks.send_method_selection(method_sel)?;
217    match method_sel.method {
218        Method::NoMethods => Err(Error::NoAcceptableMethod),
219        _ => Ok(method_sel),
220    }
221}
222
223fn check_rule(rule: &ConnectRule, addr: Address, proto: L4Protocol) -> Result<(), Error> {
224    if rule.check(addr.clone(), proto) {
225        Ok(())
226    } else {
227        Err(Error::connection_not_allowed(addr, proto))
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct DisconnectGuard<S> {
233    id: SessionId,
234    tx: mpsc::Sender<ServerCommand<S>>,
235}
236
237impl<S> DisconnectGuard<S> {
238    pub fn new(id: SessionId, tx: mpsc::Sender<ServerCommand<S>>) -> Self {
239        Self { id, tx }
240    }
241}
242
243impl<S> Drop for DisconnectGuard<S> {
244    fn drop(&mut self) {
245        debug!("DisconnectGuard: {}", self.id);
246        self.tx.send(ServerCommand::Disconnect(self.id)).unwrap()
247    }
248}
249
250#[cfg(test)]
251mod test {
252    use super::*;
253    use crate::auth_service::test::RejectService;
254    use crate::byte_stream::test::BufferStream;
255    use crate::connector::test::BufferConnector;
256    use crate::rw_socks_stream as socks;
257    use std::io;
258    use std::iter::FromIterator;
259    use std::str::FromStr;
260
261    #[test]
262    fn no_acceptable_method() {
263        let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
264        let (session, _) = Session::new(
265            0.into(),
266            5.into(),
267            BufferConnector::from_iter(vec![(
268                "192.168.0.1:5123".parse().unwrap(),
269                Ok(BufferStream::new()),
270            )]),
271            RejectService,
272            "0.0.0.0:1080".parse().unwrap(),
273            ConnectRule::any(),
274            tx,
275        );
276        println!("session: {:?}", session);
277        let src = BufferStream::with_buffer(vec![5, 1, 0].into(), vec![].into());
278        assert!(matches!(
279            session
280                .make_session("192.168.0.2:12345".parse().unwrap(), src)
281                .unwrap_err(),
282            Error::NoAcceptableMethod
283        ));
284    }
285
286    #[test]
287    fn command_not_supported() {
288        use crate::auth_service::NoAuthService;
289        let mcand = MethodCandidates::new(&[Method::NoAuth]);
290        // udp is not unsupported
291        let req = ConnectRequest::udp_associate(Address::from_str("192.168.0.1:5123").unwrap());
292        let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
293        let (session, _) = Session::new(
294            1.into(),
295            5.into(),
296            BufferConnector::from_iter(vec![(req.connect_to.clone(), Ok(BufferStream::new()))]),
297            NoAuthService::new(),
298            "0.0.0.0:1080".parse().unwrap(),
299            ConnectRule::any(),
300            tx,
301        );
302        println!("session: {:?}", session);
303
304        let buff = {
305            let mut cursor = io::Cursor::new(vec![]);
306            socks::test::write_method_candidates(&mut cursor, mcand).unwrap();
307            socks::test::write_connect_request(&mut cursor, req).unwrap();
308            cursor.into_inner()
309        };
310        let src = BufferStream::with_buffer(buff.into(), vec![].into());
311        assert!(matches!(
312            session
313                .make_session("192.168.1.1:34567".parse().unwrap(), src)
314                .unwrap_err(),
315            Error::CommandNotSupported {
316                cmd: Command::UdpAssociate
317            }
318        ));
319    }
320
321    #[test]
322    fn connect_not_allowed() {
323        use crate::auth_service::NoAuthService;
324        let version: ProtocolVersion = 5.into();
325        let connect_to = Address::from_str("192.168.0.1:5123").unwrap();
326        let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
327        let (session, _) = Session::new(
328            2.into(),
329            version,
330            BufferConnector::from_iter(vec![(connect_to.clone(), Ok(BufferStream::new()))]),
331            NoAuthService::new(),
332            "0.0.0.0:1080".parse().unwrap(),
333            ConnectRule::none(),
334            tx,
335        );
336        println!("session: {:?}", session);
337
338        let buff = {
339            let mut cursor = io::Cursor::new(vec![]);
340            socks::test::write_method_candidates(
341                &mut cursor,
342                MethodCandidates::new(&[Method::NoAuth]),
343            )
344            .unwrap();
345            socks::test::write_connect_request(
346                &mut cursor,
347                ConnectRequest::connect_to(connect_to.clone()),
348            )
349            .unwrap();
350            cursor.into_inner()
351        };
352        let src = BufferStream::with_buffer(buff.into(), vec![].into());
353        assert!(matches!(
354            session
355                .make_session("192.168.1.1:34567".parse().unwrap(), src)
356                .unwrap_err(),
357            Error::ConnectionNotAllowed { addr, protocol: L4Protocol::Tcp } if addr == connect_to
358        ));
359    }
360
361    #[test]
362    fn connection_refused() {
363        use crate::auth_service::NoAuthService;
364        let version: ProtocolVersion = 5.into();
365        let connect_to = Address::from_str("192.168.0.1:5123").unwrap();
366        let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
367        let (session, _) = Session::new(
368            3.into(),
369            version,
370            BufferConnector::<BufferStream>::from_iter(vec![(
371                connect_to.clone(),
372                Err(ConnectError::ConnectionRefused),
373            )]),
374            NoAuthService::new(),
375            "0.0.0.0:1080".parse().unwrap(),
376            ConnectRule::any(),
377            tx,
378        );
379        println!("session: {:?}", session);
380
381        let buff = {
382            let mut cursor = io::Cursor::new(vec![]);
383            socks::test::write_method_candidates(
384                &mut cursor,
385                MethodCandidates::new(&[Method::NoAuth]),
386            )
387            .unwrap();
388            socks::test::write_connect_request(
389                &mut cursor,
390                ConnectRequest::connect_to(connect_to.clone()),
391            )
392            .unwrap();
393            cursor.into_inner()
394        };
395        let src = BufferStream::with_buffer(buff.into(), vec![].into());
396        assert!(matches!(
397            session
398                .make_session("192.168.1.1:34567".parse().unwrap(), src)
399                .unwrap_err(),
400            Error::ConnectionRefused { addr, protocol: L4Protocol::Tcp } if addr == connect_to
401        ));
402    }
403
404    fn gen_random_vec(size: usize) -> Vec<u8> {
405        use rand::distributions::Standard;
406        use rand::{thread_rng, Rng};
407        let rng = thread_rng();
408        rng.sample_iter(Standard).take(size).collect()
409    }
410
411    fn vec_from_read<T: io::Read>(mut reader: T) -> Vec<u8> {
412        let mut buff = vec![];
413        reader.read_to_end(&mut buff).unwrap();
414        buff
415    }
416
417    #[test]
418    fn relay_contents() {
419        use crate::auth_service::NoAuthService;
420        use io::Write;
421
422        let version: ProtocolVersion = 5.into();
423        let connect_to = Address::Domain("example.com".into(), 5123);
424        let (tx, _rx) = mpsc::channel::<ServerCommand<()>>();
425        let (session, _tx_session_term) = Session::new(
426            4.into(),
427            version,
428            BufferConnector::from_iter(vec![(
429                connect_to.clone(),
430                Ok(BufferStream::with_buffer(
431                    gen_random_vec(8200).into(),
432                    vec![].into(),
433                )),
434            )]),
435            NoAuthService::new(),
436            "0.0.0.0:1080".parse().unwrap(),
437            ConnectRule::any(),
438            tx,
439        );
440
441        // length of SOCKS message (len MethodCandidates + len ConnectRequest)
442        let input_stream_pos;
443        let src = {
444            // input from socks client
445            let mut cursor = io::Cursor::new(vec![]);
446            socks::test::write_method_candidates(
447                &mut cursor,
448                MethodCandidates::new(&[Method::NoAuth]),
449            )
450            .unwrap();
451            socks::test::write_connect_request(
452                &mut cursor,
453                ConnectRequest::connect_to(connect_to.clone()),
454            )
455            .unwrap();
456            input_stream_pos = cursor.position();
457            // binaries from client
458            cursor.write_all(&gen_random_vec(8200)).unwrap();
459            BufferStream::with_buffer(cursor.into_inner().into(), vec![].into())
460        };
461        let dst_connector = session.dst_connector.clone();
462        // start relay
463        let relay = session
464            .make_session("192.168.1.2:33333".parse().unwrap(), src.clone())
465            .unwrap();
466        assert!(relay.join().is_ok());
467
468        // check for replied command from Session to client
469        {
470            // read output buffer from pos(0)
471            src.wr_buff().set_position(0);
472            assert_eq!(
473                socks::test::read_method_selection(&mut *src.wr_buff()).unwrap(),
474                MethodSelection {
475                    version,
476                    method: Method::NoAuth
477                }
478            );
479            assert_eq!(
480                socks::test::read_connect_reply(&mut *src.wr_buff()).unwrap(),
481                ConnectReply {
482                    version,
483                    connect_result: Ok(()),
484                    server_addr: Address::IpAddr("0.0.0.0".parse().unwrap(), 1080),
485                }
486            );
487        }
488
489        // check for relayed contents
490        // client <-- target
491        assert_eq!(vec_from_read(&mut *src.wr_buff()), {
492            let mut rd_buff = dst_connector.stream(&connect_to).rd_buff();
493            rd_buff.set_position(0);
494            vec_from_read(&mut *rd_buff)
495        });
496        // client --> target
497        assert_eq!(
498            {
499                let mut rd_buff = src.rd_buff();
500                rd_buff.set_position(input_stream_pos);
501                vec_from_read(&mut *rd_buff)
502            },
503            {
504                let mut wr_buff = dst_connector.stream(&connect_to).wr_buff();
505                wr_buff.set_position(0);
506                vec_from_read(&mut *wr_buff)
507            }
508        );
509    }
510}