1use crate::{buffer::Buff, crypt};
2use crate::{protocol, runtime, Backhaul, Session, SessionConfig, StatsGatherer};
3
4use probability::distribution::{Binomial, Distribution};
5use smallvec::SmallVec;
6use smol::{prelude::*, Task};
7use std::{
8 collections::VecDeque,
9 net::SocketAddr,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13
14use super::worker::ClientWorker;
15
16#[derive(Clone)]
18pub(crate) struct LowlevelClientConfig {
19 pub server_addr: SocketAddr,
20 pub server_pubkey: x25519_dalek::PublicKey,
21 pub backhaul_gen: Arc<dyn Fn() -> Arc<dyn Backhaul> + 'static + Send + Sync>,
22 pub num_shards: usize,
23 pub reset_interval: Option<Duration>,
24 pub gather: Arc<StatsGatherer>,
25}
26
27pub(crate) async fn connect_custom(cfg: LowlevelClientConfig) -> std::io::Result<Session> {
29 let my_long_sk = x25519_dalek::StaticSecret::new(rand::thread_rng());
30 let my_eph_sk = x25519_dalek::StaticSecret::new(rand::thread_rng());
31 let cookie = crypt::Cookie::new(cfg.server_pubkey);
33 let init_hello = protocol::HandshakeFrame::ClientHello {
34 long_pk: (&my_long_sk).into(),
35 eph_pk: (&my_eph_sk).into(),
36 version: VERSION,
37 };
38 for timeout_factor in (0u32..).map(|x| 2u64.pow(x.min(10))) {
39 let backhaul = (cfg.backhaul_gen)();
40 let init_hello = crypt::LegacyAead::new(&cookie.generate_c2s().next().unwrap())
42 .pad_encrypt_v1(std::slice::from_ref(&init_hello), 1000);
43 backhaul.send_to(init_hello, cfg.server_addr).await?;
44 tracing::trace!("sent client hello");
45 let res = backhaul
47 .recv_from()
48 .or(async {
49 smol::Timer::after(Duration::from_secs(timeout_factor.min(10))).await;
50 Err(std::io::Error::new(
51 std::io::ErrorKind::TimedOut,
52 "timed out",
53 ))
54 })
55 .await;
56 match res {
57 Ok((buf, _)) => {
58 for possible_key in cookie.generate_s2c() {
59 let decrypter = crypt::LegacyAead::new(&possible_key);
60 let response = decrypter.pad_decrypt_v1(&buf);
61 for response in response.unwrap_or_default() {
62 if let protocol::HandshakeFrame::ServerHello {
63 long_pk,
64 eph_pk,
65 resume_token,
66 } = response
67 {
68 tracing::trace!("obtained response from server");
69 if long_pk.as_bytes() != cfg.server_pubkey.as_bytes() {
70 return Err(std::io::Error::new(
71 std::io::ErrorKind::ConnectionRefused,
72 "bad pubkey",
73 ));
74 }
75 let shared_sec =
76 crypt::triple_ecdh(&my_long_sk, &my_eph_sk, &long_pk, &eph_pk);
77 return Ok(init_session(cookie, resume_token, shared_sec, cfg.clone()));
78 }
79 }
80 }
81 }
82 Err(err) => {
83 if err.kind() == std::io::ErrorKind::TimedOut {
84 tracing::trace!(
85 "timed out to {} with {}s timeout; trying again",
86 cfg.server_addr,
87 timeout_factor
88 );
89 continue;
90 }
91 return Err(err);
92 }
93 }
94 }
95 unimplemented!()
96}
97const VERSION: u64 = 3;
98
99fn init_session(
100 cookie: crypt::Cookie,
101 resume_token: Buff,
102 shared_sec: blake3::Hash,
103 cfg: LowlevelClientConfig,
104) -> Session {
105 let (mut session, back) = Session::new(SessionConfig {
106 version: VERSION,
107 gather: cfg.gather.clone(),
108 session_key: shared_sec.as_bytes().to_vec(),
109 role: crate::Role::Client,
110 });
111 let back = Arc::new(back);
112 let uploader: Task<anyhow::Result<()>> = runtime::spawn(async move {
113 let mut workers: Vec<ClientWorker> = (0..cfg.num_shards)
114 .map(|shard_id| {
115 ClientWorker::start(
116 cookie.clone(),
117 resume_token.clone(),
118 back.clone(),
119 shard_id as u8,
120 cfg.clone(),
121 )
122 })
123 .collect();
124 let mut fired_workers: VecDeque<ClientWorker> = VecDeque::new();
125 let mut last_reset = Instant::now();
126 let mut just_respawned = false;
127 for ctr in (0..).cycle() {
128 let to_upload = back.next_outgoing().await?;
129 let random_worker = ctr % workers.len();
130 workers[random_worker].send_upload(to_upload).await;
131 if cfg
132 .reset_interval
133 .map(|dur| last_reset.elapsed() > dur)
134 .unwrap_or_default()
135 {
136 tracing::debug!("reset timer expired!");
137 last_reset = Instant::now();
138 if just_respawned {
139 for worker in workers.iter() {
140 worker.reset_received_count();
141 }
142 just_respawned = false;
143 } else {
144 let worker_packet_count: SmallVec<[usize; 16]> =
146 workers.iter().map(|w| w.get_received_count()).collect();
147 let p_value = uniform_pvalue(&worker_packet_count);
148 tracing::debug!("p-value = {}; {:?}", p_value, worker_packet_count);
149 if p_value < 0.01 {
150 let worst_worker_id = workers
152 .iter()
153 .enumerate()
154 .min_by_key(|(worker_id, worker)| {
155 let count = worker.get_received_count();
156 tracing::debug!("worker {} has {}", worker_id, count);
157 count
158 })
159 .map(|x| x.0)
160 .expect("must have a worst worker");
161 tracing::debug!("replacing worst worker {}", worst_worker_id);
162 let new_worker = ClientWorker::start(
163 cookie.clone(),
164 resume_token.clone(),
165 back.clone(),
166 worst_worker_id as u8,
167 cfg.clone(),
168 );
169 let worst_worker =
170 std::mem::replace(&mut workers[worst_worker_id], new_worker);
171 fired_workers.push_back(worst_worker);
172 if fired_workers.len() > workers.len() {
173 fired_workers.pop_front();
174 }
175 just_respawned = true;
176 }
177 }
178 }
179 }
180 unreachable!()
181 });
182 session.on_drop(move || {
183 drop(uploader);
184 });
185 session
186}
187
188fn uniform_pvalue(vals: &[usize]) -> f64 {
190 if vals.is_empty() {
191 return 0.0;
192 }
193 if vals.iter().all(|a| *a == vals[0]) {
194 return 0.0;
195 }
196 let total_count = vals.iter().sum::<usize>();
197 let min = vals.iter().min().copied().unwrap();
198 let distro = Binomial::new(total_count, 1.0 / (vals.len() as f64));
199 distro.distribution(min as f64)
200}