sosistab2 0.7.0

An obfuscated datagram transport for horrible networks
Documentation
use bytes::Bytes;
use dashmap::DashMap;

use rand::prelude::*;

use smol::channel::{Receiver, Sender};
use smol::prelude::*;
use smol_str::SmolStr;
use std::time::{SystemTime, UNIX_EPOCH};
use std::{sync::Arc, time::Duration};

use crate::multiplex::structs::Message;

use crate::{
    crypt::{triple_ecdh, NonObfsAead},
    utilities::ReplayFilter,
};

use super::{
    stream::{StreamBack, StreamState},
    structs::{OuterMessage, PipePool, RelKind},
    MuxStream,
};

pub async fn multiplex(
    pipe_pool: Arc<PipePool>,
    conn_open_recv: Receiver<(SmolStr, Sender<MuxStream>)>,
    conn_accept_send: Sender<MuxStream>,
    my_long_sk: x25519_dalek::StaticSecret,
    real_their_long_pk: Option<x25519_dalek::PublicKey>,
) -> anyhow::Result<()> {
    // encryption parameters
    let my_eph_sk_send = x25519_dalek::StaticSecret::new(rand::thread_rng());
    let my_eph_sk_recv = x25519_dalek::StaticSecret::new(rand::thread_rng());
    let mut send_aead: Option<NonObfsAead> = None;
    let mut recv_aead: Option<NonObfsAead> = None;
    let mut replay_filter = ReplayFilter::default();

    let conn_tab = Arc::new(ConnTable::default());
    let (glob_send, glob_recv) = smol::channel::unbounded();
    let (dead_send, dead_recv) = smol::channel::unbounded();

    // Reap death
    let reap_dead = {
        let dead_send = dead_send.clone();
        move |id: u16| {
            log::debug!("reaper received {}", id);
            smolscale::spawn(async move {
                smol::Timer::after(Duration::from_secs(30)).await;
                log::debug!("reaper executed {}", id);
                let _ = dead_send.try_send(id);
            })
            .detach()
        }
    };

    // enum of possible events
    enum Event {
        RecvMsg(OuterMessage),
        SendMsg(Message),
        ConnOpen(SmolStr, Sender<MuxStream>),
        Dead(u16),
    }

    loop {
        // fires on receiving messages
        let recv_msg = async {
            let raw_msg = pipe_pool.recv().await?;
            // decrypt!
            let msg = stdcode::deserialize(&raw_msg);
            if let Ok(msg) = msg {
                Ok::<_, anyhow::Error>(Event::RecvMsg(msg))
            } else {
                log::trace!("unrecognizable message from sess: {:?}", raw_msg);
                smol::future::pending().await
            }
        };
        // fires on sending messages
        let send_msg = async {
            let to_send = glob_recv.recv().await?;
            Ok::<_, anyhow::Error>(Event::SendMsg(to_send))
        };
        // fires on stream open events
        let conn_open = async {
            let (additional_data, result_chan) = conn_open_recv.recv().await?;
            Ok::<_, anyhow::Error>(Event::ConnOpen(additional_data, result_chan))
        };
        // fires on death
        let death = async {
            let res = dead_recv.recv().await?;
            Ok::<_, anyhow::Error>(Event::Dead(res))
        };
        // match on the event
        match recv_msg.or(send_msg.or(conn_open.or(death))).await? {
            Event::Dead(id) => conn_tab.del_stream(id),
            Event::ConnOpen(additional_data, result_chan) => {
                let conn_tab = conn_tab.clone();
                let glob_send = glob_send.clone();
                let reap_dead = reap_dead.clone();
                smolscale::spawn(async move {
                    let stream_id = {
                        let stream_id = conn_tab.find_id();
                        if let Some(stream_id) = stream_id {
                            let (send_sig, recv_sig) = smol::channel::bounded(1);
                            let (conn, conn_back) = MuxStream::new(
                                StreamState::SynSent {
                                    stream_id,
                                    tries: 0,
                                    result: send_sig,
                                },
                                glob_send.clone(),
                                move || reap_dead(stream_id),
                                additional_data.clone(),
                            );
                            smolscale::spawn(async move {
                                recv_sig.recv().await.ok()?;
                                result_chan.send(conn).await.ok()?;
                                Some(())
                            })
                            .detach();
                            conn_tab.set_stream(stream_id, conn_back);
                            stream_id
                        } else {
                            return;
                        }
                    };
                    log::trace!("conn open send {}", stream_id);
                    drop({
                        glob_send
                            .send(Message::Rel {
                                kind: RelKind::Syn,
                                stream_id,
                                seqno: 0,
                                payload: Bytes::copy_from_slice(additional_data.as_bytes()),
                            })
                            .await
                    });
                })
                .detach();
            }
            Event::SendMsg(msg) => {
                // if outgoing_key is available, encrypt and send off; else drop msg & send ClientHello
                if let Some(send_aead) = send_aead.as_mut() {
                    let (_, msg) = send_aead.encrypt(&stdcode::serialize(&msg).unwrap());
                    let outer = OuterMessage::EncryptedMsg { inner: msg };
                    pipe_pool
                        .send(stdcode::serialize(&outer).unwrap().into())
                        .await;
                } else {
                    log::debug!("no send_aead available, so we send a client hello");
                    let to_send = OuterMessage::ClientHello {
                        long_pk: (&my_long_sk).into(),
                        eph_pk: (&my_eph_sk_send).into(),
                        version: 1,
                        timestamp: SystemTime::now()
                            .duration_since(UNIX_EPOCH)
                            .unwrap()
                            .as_secs(),
                    };
                    pipe_pool
                        .send(stdcode::serialize(&to_send).unwrap().into())
                        .await;
                }
            }
            Event::RecvMsg(msg) => {
                match msg {
                    OuterMessage::ClientHello {
                        long_pk: their_long_pk,
                        eph_pk: their_eph_pk,
                        version: _,
                        timestamp: _,
                    } => {
                        if let Some(real) = real_their_long_pk {
                            if real != their_long_pk {
                                log::warn!("dropping invalid ClientHello");
                                continue;
                            }
                        }
                        let recv_secret = triple_ecdh(
                            &my_long_sk,
                            &my_eph_sk_recv,
                            &their_long_pk,
                            &their_eph_pk,
                        );
                        log::debug!("recv_secret {:?}", recv_secret);
                        recv_aead = Some(NonObfsAead::new(recv_secret.as_bytes()));
                        log::debug!("recv_aead registered since we received a clienthello");
                        // respond with a serverhello
                        let our_hello = OuterMessage::ServerHello {
                            long_pk: (&my_long_sk).into(),
                            eph_pk: (&my_eph_sk_recv).into(),
                        };
                        pipe_pool
                            .send(stdcode::serialize(&our_hello).unwrap().into())
                            .await;
                    }
                    OuterMessage::ServerHello {
                        long_pk: their_long_pk,
                        eph_pk: their_eph_pk,
                    } => {
                        if let Some(real) = real_their_long_pk {
                            if real != their_long_pk {
                                log::warn!("dropping invalid ServerHello");
                                continue;
                            }
                        }
                        let send_secret = triple_ecdh(
                            &my_long_sk,
                            &my_eph_sk_send,
                            &their_long_pk,
                            &their_eph_pk,
                        );
                        log::debug!("send_secret {:?}", send_secret);
                        send_aead = Some(NonObfsAead::new(send_secret.as_bytes()));
                        log::debug!("send_aead registered since we received a clienthello");
                    }

                    OuterMessage::EncryptedMsg { inner } => {
                        if let Some(recv_aead) = recv_aead.as_ref() {
                            match recv_aead.decrypt(&inner) {
                                Err(err) => {
                                    log::warn!("failed decrypting {} bytes: {:?}", inner.len(), err)
                                }
                                Ok((nonce, plain)) => {
                                    log::trace!("decrypted {} bytes, nonce {nonce}", plain.len());
                                    if replay_filter.add(nonce) {
                                        if let Ok(msg) = stdcode::deserialize::<Message>(&plain) {
                                            match msg {
                                                Message::Urel { stream_id, payload } => {
                                                    if let Some(val) =
                                                        conn_tab.get_stream(stream_id)
                                                    {
                                                        val.process(Message::Urel {
                                                            stream_id,
                                                            payload,
                                                        })
                                                        .await;
                                                    }
                                                }
                                                Message::Rel {
                                                    kind: RelKind::Syn,
                                                    stream_id,
                                                    payload,
                                                    ..
                                                } => {
                                                    if conn_tab.get_stream(stream_id).is_some() {
                                                        log::trace!(
                                                            "syn recv {} REACCEPT",
                                                            stream_id
                                                        );
                                                        let msg = Message::Rel {
                                                            kind: RelKind::SynAck,
                                                            stream_id,
                                                            seqno: 0,
                                                            payload: Bytes::copy_from_slice(&[]),
                                                        };
                                                        let _ = glob_send.try_send(msg);
                                                    } else {
                                                        log::trace!(
                                                            "syn recv {} ACCEPT",
                                                            stream_id
                                                        );
                                                        let lala =
                                                            String::from_utf8_lossy(&payload)
                                                                .to_string();
                                                        let additional_info = lala.into();
                                                        let reap_dead = reap_dead.clone();
                                                        let (new_conn, new_conn_back) =
                                                            MuxStream::new(
                                                                StreamState::SynReceived {
                                                                    stream_id,
                                                                },
                                                                glob_send.clone(),
                                                                move || {
                                                                    reap_dead(stream_id);
                                                                },
                                                                additional_info,
                                                            );
                                                        // the Stream itself is responsible for sending the SynAck. Here we just store the connection into the table, accept it, and be done with it.
                                                        conn_tab
                                                            .set_stream(stream_id, new_conn_back);
                                                        drop(conn_accept_send.send(new_conn).await);
                                                    }
                                                }
                                                // associated with existing connection
                                                Message::Rel {
                                                    stream_id, kind, ..
                                                } => {
                                                    if let Some(handle) =
                                                        conn_tab.get_stream(stream_id)
                                                    {
                                                        // log::trace!("handing over {:?} to {}", kind, stream_id);
                                                        handle.process(msg).await;
                                                    } else {
                                                        log::trace!(
                                                            "discarding {:?} to nonexistent {}",
                                                            kind,
                                                            stream_id
                                                        );
                                                        if kind != RelKind::Rst {
                                                            let msg = Message::Rel {
                                                                kind: RelKind::Rst,
                                                                stream_id,
                                                                seqno: 0,
                                                                payload: Bytes::copy_from_slice(&[]),
                                                            };
                                                            let _ = glob_send.send(msg).await;
                                                        }
                                                    }
                                                }
                                                Message::Empty => {}
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            log::debug!(
                                "received {} bytes but we don't have recv_aead yet",
                                inner.len()
                            );
                        }
                    } // // connection opening
                }
            }
        }
    }
}

#[derive(Default)]
struct ConnTable {
    /// Maps IDs to Stream back handles.
    sid_to_stream: DashMap<u16, StreamBack>,
}

impl ConnTable {
    fn get_stream(&self, sid: u16) -> Option<StreamBack> {
        let x = self.sid_to_stream.get(&sid)?;
        Some(x.clone())
    }

    fn set_stream(&self, id: u16, handle: StreamBack) {
        self.sid_to_stream.insert(id, handle);
    }

    fn del_stream(&self, id: u16) {
        self.sid_to_stream.remove(&id);
    }

    fn find_id(&self) -> Option<u16> {
        loop {
            if self.sid_to_stream.len() >= 50000 {
                log::warn!("ran out of descriptors ({})", self.sid_to_stream.len());
                return None;
            }
            let possible_id: u16 = rand::thread_rng().gen();
            if self.sid_to_stream.get(&possible_id).is_none() {
                log::debug!(
                    "found id {} out of {}",
                    possible_id,
                    self.sid_to_stream.len()
                );
                break Some(possible_id);
            }
        }
    }
}