1use {
2 crate::five8::{pubkey_decode, signature_decode},
3 base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4 human_size::Size,
5 regex::Regex,
6 rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
7 serde::{
8 de::{self, Deserializer},
9 Deserialize,
10 },
11 solana_sdk::{pubkey::Pubkey, signature::Signature},
12 std::{
13 collections::HashSet,
14 fmt::Display,
15 fs, io,
16 path::PathBuf,
17 str::FromStr,
18 sync::atomic::{AtomicU64, Ordering},
19 },
20 thiserror::Error,
21};
22
23#[derive(Debug, Clone, Default, Deserialize)]
24#[serde(deny_unknown_fields, default)]
25pub struct ConfigTokio {
26 pub worker_threads: Option<usize>,
28 #[serde(deserialize_with = "deserialize_affinity")]
30 pub affinity: Option<Vec<usize>>,
31}
32
33impl ConfigTokio {
34 pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
35 where
36 T: AsRef<str> + Send + Sync + 'static,
37 {
38 let mut builder = tokio::runtime::Builder::new_multi_thread();
39 if let Some(worker_threads) = self.worker_threads {
40 builder.worker_threads(worker_threads);
41 }
42 if let Some(cpus) = self.affinity.clone() {
43 builder.on_thread_start(move || {
44 affinity_linux::set_thread_affinity(cpus.iter().copied())
45 .expect("failed to set affinity")
46 });
47 }
48 builder
49 .thread_name_fn(move || {
50 static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
51 let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
52 format!("{}{id:02}", thread_name_prefix.as_ref())
53 })
54 .enable_all()
55 .build()
56 }
57}
58
59#[derive(Deserialize)]
60#[serde(untagged)]
61enum ValueNumStr<'a, T> {
62 Num(T),
63 Str(&'a str),
64}
65
66impl<T> ValueNumStr<'_, T>
67where
68 T: FromStr,
69 <T as FromStr>::Err: Display,
70{
71 fn parse(self) -> Result<T, String> {
72 match self {
73 Self::Num(value) => Ok(value),
74 Self::Str(value) => value
75 .replace('_', "")
76 .parse::<T>()
77 .map_err(|x| x.to_string()),
78 }
79 }
80}
81
82pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
83where
84 D: Deserializer<'de>,
85 T: Deserialize<'de> + FromStr,
86 <T as FromStr>::Err: Display,
87{
88 ValueNumStr::deserialize(deserializer)?
89 .parse()
90 .map_err(de::Error::custom)
91}
92
93pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
94where
95 D: Deserializer<'de>,
96 T: Deserialize<'de> + FromStr,
97 <T as FromStr>::Err: Display,
98{
99 match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
100 Some(value) => Ok(Some(value.parse().map_err(de::Error::custom)?)),
101 None => Ok(None),
102 }
103}
104
105pub fn deserialize_humansize<'de, D>(deserializer: D) -> Result<u64, D::Error>
106where
107 D: Deserializer<'de>,
108{
109 let size: &str = Deserialize::deserialize(deserializer)?;
110
111 Size::from_str(size)
112 .map(|size| size.to_bytes())
113 .map_err(|error| de::Error::custom(format!("failed to parse size {size:?}: {error}")))
114}
115
116pub fn deserialize_humansize_usize<'de, D>(deserializer: D) -> Result<usize, D::Error>
117where
118 D: Deserializer<'de>,
119{
120 deserialize_humansize(deserializer)?
121 .try_into()
122 .map_err(|_| de::Error::custom("size value exceeds usize maximum"))
123}
124
125#[derive(Debug, Error)]
126enum DecodeXTokenError {
127 #[error(transparent)]
128 Base64(#[from] base64::DecodeError),
129 #[error(transparent)]
130 Base58(#[from] bs58::decode::Error),
131}
132
133fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
134 Ok(match &x_token[0..7] {
135 "base64:" => base64_engine.decode(x_token)?,
136 "base58:" => bs58::decode(x_token).into_vec()?,
137 _ => x_token.as_bytes().to_vec(),
138 })
139}
140
141pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
142where
143 D: Deserializer<'de>,
144{
145 let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
146 x_token
147 .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
148 .transpose()
149}
150
151pub fn deserialize_x_tokens_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
152where
153 D: Deserializer<'de>,
154{
155 Vec::<&str>::deserialize(deserializer).and_then(|vec| {
156 vec.into_iter()
157 .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
158 .collect::<Result<_, _>>()
159 })
160}
161
162pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
163where
164 D: Deserializer<'de>,
165{
166 Vec::<&str>::deserialize(deserializer)?
167 .into_iter()
168 .map(|value| {
169 pubkey_decode(value)
170 .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
171 })
172 .collect::<Result<_, _>>()
173}
174
175pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
176where
177 D: Deserializer<'de>,
178{
179 deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
180}
181
182pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
183where
184 D: Deserializer<'de>,
185{
186 let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
187 sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
188 .transpose()
189}
190
191#[derive(Debug, Deserialize)]
192#[serde(deny_unknown_fields, untagged)]
193enum RustlsServerConfigSignedSelfSigned<'a> {
194 Signed { cert: &'a str, key: &'a str },
195 SelfSigned { self_signed_alt_names: Vec<String> },
196}
197
198impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
199 fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
200 where
201 D: Deserializer<'a>,
202 {
203 let (certs, key) = match self {
204 Self::Signed { cert, key } => {
205 let cert_path = PathBuf::from(cert);
206 let cert_bytes = fs::read(&cert_path).map_err(|error| {
207 de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
208 })?;
209 let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
210 vec![CertificateDer::from(cert_bytes)]
211 } else {
212 rustls_pemfile::certs(&mut &*cert_bytes)
213 .collect::<Result<_, _>>()
214 .map_err(|error| {
215 de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
216 })?
217 };
218
219 let key_path = PathBuf::from(key);
220 let key_bytes = fs::read(&key_path).map_err(|error| {
221 de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
222 })?;
223 let key = if key_path.extension().is_some_and(|x| x == "der") {
224 PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
225 } else {
226 rustls_pemfile::private_key(&mut &*key_bytes)
227 .map_err(|error| {
228 de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
229 })?
230 .ok_or_else(|| de::Error::custom("no private keys found"))?
231 };
232
233 (cert_chain, key)
234 }
235 Self::SelfSigned {
236 self_signed_alt_names,
237 } => {
238 let cert =
239 rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
240 de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
241 })?;
242 let cert_der = CertificateDer::from(cert.cert);
243 let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
244 (vec![cert_der], priv_key.into())
245 }
246 };
247
248 rustls::ServerConfig::builder()
249 .with_no_client_auth()
250 .with_single_cert(certs, key)
251 .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
252 }
253}
254
255pub fn deserialize_maybe_rustls_server_config<'de, D>(
256 deserializer: D,
257) -> Result<Option<rustls::ServerConfig>, D::Error>
258where
259 D: Deserializer<'de>,
260{
261 let config: Option<RustlsServerConfigSignedSelfSigned> =
262 Deserialize::deserialize(deserializer)?;
263 if let Some(config) = config {
264 config.parse::<D>().map(Some)
265 } else {
266 Ok(None)
267 }
268}
269
270pub fn deserialize_rustls_server_config<'de, D>(
271 deserializer: D,
272) -> Result<rustls::ServerConfig, D::Error>
273where
274 D: Deserializer<'de>,
275{
276 let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
277 config.parse::<D>()
278}
279
280pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
281where
282 D: Deserializer<'de>,
283{
284 match Option::<&str>::deserialize(deserializer)? {
285 Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
286 None => Ok(None),
287 }
288}
289
290pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
291 let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
292 let mut set = HashSet::new();
293 for cpulist in taskset.split(',') {
294 let Some(caps) = re.captures(cpulist) else {
295 return Err(format!("invalid cpulist: {cpulist}"));
296 };
297
298 let start = caps
299 .get(1)
300 .and_then(|m| m.as_str().parse().ok())
301 .expect("valid regex");
302 let end = caps
303 .get(2)
304 .and_then(|m| m.as_str().parse().ok())
305 .unwrap_or(start);
306 let step = caps
307 .get(3)
308 .and_then(|m| m.as_str().parse().ok())
309 .unwrap_or(1);
310
311 for cpu in (start..=end).step_by(step) {
312 set.insert(cpu);
313 }
314 }
315
316 let mut vec = set.into_iter().collect::<Vec<usize>>();
317 vec.sort();
318
319 if !vec.is_empty() {
320 if let Some(cores) = affinity_linux::get_thread_affinity()
321 .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
322 {
323 let mut cores = cores.into_iter().collect::<Vec<_>>();
324 cores.sort();
325
326 for core in vec.iter_mut() {
327 if let Some(actual_core) = cores.get(*core) {
328 *core = *actual_core;
329 } else {
330 return Err(format!(
331 "we don't have core {core}, available cores: {:?}",
332 (0..cores.len()).collect::<Vec<_>>()
333 ));
334 }
335 }
336 }
337 }
338
339 Ok(vec)
340}