sosistab/client/
inner.rs

1use crate::{buffer::Buff, crypt};
2use crate::{protocol, runtime, Backhaul, Session, SessionConfig, StatsGatherer};
3
4use probability::distribution::{Binomial, Distribution};
5use smallvec::SmallVec;
6use smol::{prelude::*, Task};
7use std::{
8    collections::VecDeque,
9    net::SocketAddr,
10    sync::Arc,
11    time::{Duration, Instant},
12};
13
14use super::worker::ClientWorker;
15
16/// Configures the client.
17#[derive(Clone)]
18pub(crate) struct LowlevelClientConfig {
19    pub server_addr: SocketAddr,
20    pub server_pubkey: x25519_dalek::PublicKey,
21    pub backhaul_gen: Arc<dyn Fn() -> Arc<dyn Backhaul> + 'static + Send + Sync>,
22    pub num_shards: usize,
23    pub reset_interval: Option<Duration>,
24    pub gather: Arc<StatsGatherer>,
25}
26
27/// Connects to a remote server, given a closure that generates socket addresses.
28pub(crate) async fn connect_custom(cfg: LowlevelClientConfig) -> std::io::Result<Session> {
29    let my_long_sk = x25519_dalek::StaticSecret::new(rand::thread_rng());
30    let my_eph_sk = x25519_dalek::StaticSecret::new(rand::thread_rng());
31    // do the handshake
32    let cookie = crypt::Cookie::new(cfg.server_pubkey);
33    let init_hello = protocol::HandshakeFrame::ClientHello {
34        long_pk: (&my_long_sk).into(),
35        eph_pk: (&my_eph_sk).into(),
36        version: VERSION,
37    };
38    for timeout_factor in (0u32..).map(|x| 2u64.pow(x.min(10))) {
39        let backhaul = (cfg.backhaul_gen)();
40        // send hello
41        let init_hello = crypt::LegacyAead::new(&cookie.generate_c2s().next().unwrap())
42            .pad_encrypt_v1(std::slice::from_ref(&init_hello), 1000);
43        backhaul.send_to(init_hello, cfg.server_addr).await?;
44        tracing::trace!("sent client hello");
45        // wait for response
46        let res = backhaul
47            .recv_from()
48            .or(async {
49                smol::Timer::after(Duration::from_secs(timeout_factor.min(10))).await;
50                Err(std::io::Error::new(
51                    std::io::ErrorKind::TimedOut,
52                    "timed out",
53                ))
54            })
55            .await;
56        match res {
57            Ok((buf, _)) => {
58                for possible_key in cookie.generate_s2c() {
59                    let decrypter = crypt::LegacyAead::new(&possible_key);
60                    let response = decrypter.pad_decrypt_v1(&buf);
61                    for response in response.unwrap_or_default() {
62                        if let protocol::HandshakeFrame::ServerHello {
63                            long_pk,
64                            eph_pk,
65                            resume_token,
66                        } = response
67                        {
68                            tracing::trace!("obtained response from server");
69                            if long_pk.as_bytes() != cfg.server_pubkey.as_bytes() {
70                                return Err(std::io::Error::new(
71                                    std::io::ErrorKind::ConnectionRefused,
72                                    "bad pubkey",
73                                ));
74                            }
75                            let shared_sec =
76                                crypt::triple_ecdh(&my_long_sk, &my_eph_sk, &long_pk, &eph_pk);
77                            return Ok(init_session(cookie, resume_token, shared_sec, cfg.clone()));
78                        }
79                    }
80                }
81            }
82            Err(err) => {
83                if err.kind() == std::io::ErrorKind::TimedOut {
84                    tracing::trace!(
85                        "timed out to {} with {}s timeout; trying again",
86                        cfg.server_addr,
87                        timeout_factor
88                    );
89                    continue;
90                }
91                return Err(err);
92            }
93        }
94    }
95    unimplemented!()
96}
97const VERSION: u64 = 3;
98
99fn init_session(
100    cookie: crypt::Cookie,
101    resume_token: Buff,
102    shared_sec: blake3::Hash,
103    cfg: LowlevelClientConfig,
104) -> Session {
105    let (mut session, back) = Session::new(SessionConfig {
106        version: VERSION,
107        gather: cfg.gather.clone(),
108        session_key: shared_sec.as_bytes().to_vec(),
109        role: crate::Role::Client,
110    });
111    let back = Arc::new(back);
112    let uploader: Task<anyhow::Result<()>> = runtime::spawn(async move {
113        let mut workers: Vec<ClientWorker> = (0..cfg.num_shards)
114            .map(|shard_id| {
115                ClientWorker::start(
116                    cookie.clone(),
117                    resume_token.clone(),
118                    back.clone(),
119                    shard_id as u8,
120                    cfg.clone(),
121                )
122            })
123            .collect();
124        let mut fired_workers: VecDeque<ClientWorker> = VecDeque::new();
125        let mut last_reset = Instant::now();
126        let mut just_respawned = false;
127        for ctr in (0..).cycle() {
128            let to_upload = back.next_outgoing().await?;
129            let random_worker = ctr % workers.len();
130            workers[random_worker].send_upload(to_upload).await;
131            if cfg
132                .reset_interval
133                .map(|dur| last_reset.elapsed() > dur)
134                .unwrap_or_default()
135            {
136                tracing::debug!("reset timer expired!");
137                last_reset = Instant::now();
138                if just_respawned {
139                    for worker in workers.iter() {
140                        worker.reset_received_count();
141                    }
142                    just_respawned = false;
143                } else {
144                    // check: are we even that bad?
145                    let worker_packet_count: SmallVec<[usize; 16]> =
146                        workers.iter().map(|w| w.get_received_count()).collect();
147                    let p_value = uniform_pvalue(&worker_packet_count);
148                    tracing::debug!("p-value = {}; {:?}", p_value, worker_packet_count);
149                    if p_value < 0.01 {
150                        // find the worst worker and fire it
151                        let worst_worker_id = workers
152                            .iter()
153                            .enumerate()
154                            .min_by_key(|(worker_id, worker)| {
155                                let count = worker.get_received_count();
156                                tracing::debug!("worker {} has {}", worker_id, count);
157                                count
158                            })
159                            .map(|x| x.0)
160                            .expect("must have a worst worker");
161                        tracing::debug!("replacing worst worker {}", worst_worker_id);
162                        let new_worker = ClientWorker::start(
163                            cookie.clone(),
164                            resume_token.clone(),
165                            back.clone(),
166                            worst_worker_id as u8,
167                            cfg.clone(),
168                        );
169                        let worst_worker =
170                            std::mem::replace(&mut workers[worst_worker_id], new_worker);
171                        fired_workers.push_back(worst_worker);
172                        if fired_workers.len() > workers.len() {
173                            fired_workers.pop_front();
174                        }
175                        just_respawned = true;
176                    }
177                }
178            }
179        }
180        unreachable!()
181    });
182    session.on_drop(move || {
183        drop(uploader);
184    });
185    session
186}
187
188// guess whether the given slice is uniformly distributed
189fn uniform_pvalue(vals: &[usize]) -> f64 {
190    if vals.is_empty() {
191        return 0.0;
192    }
193    if vals.iter().all(|a| *a == vals[0]) {
194        return 0.0;
195    }
196    let total_count = vals.iter().sum::<usize>();
197    let min = vals.iter().min().copied().unwrap();
198    let distro = Binomial::new(total_count, 1.0 / (vals.len() as f64));
199    distro.distribution(min as f64)
200}