1use {
2 crate::five8::{pubkey_decode, signature_decode},
3 base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4 regex::Regex,
5 rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
6 serde::{
7 de::{self, Deserializer},
8 Deserialize,
9 },
10 solana_sdk::{pubkey::Pubkey, signature::Signature},
11 std::{
12 collections::HashSet,
13 fmt::Display,
14 fs, io,
15 path::PathBuf,
16 str::FromStr,
17 sync::atomic::{AtomicU64, Ordering},
18 },
19 thiserror::Error,
20};
21
22#[derive(Debug, Clone, Default, Deserialize)]
23#[serde(deny_unknown_fields, default)]
24pub struct ConfigTokio {
25 pub worker_threads: Option<usize>,
27 #[serde(deserialize_with = "deserialize_affinity")]
29 pub affinity: Option<Vec<usize>>,
30}
31
32impl ConfigTokio {
33 pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
34 where
35 T: AsRef<str> + Send + Sync + 'static,
36 {
37 let mut builder = tokio::runtime::Builder::new_multi_thread();
38 if let Some(worker_threads) = self.worker_threads {
39 builder.worker_threads(worker_threads);
40 }
41 if let Some(cpus) = self.affinity.clone() {
42 builder.on_thread_start(move || {
43 affinity_linux::set_thread_affinity(cpus.iter().copied())
44 .expect("failed to set affinity")
45 });
46 }
47 builder
48 .thread_name_fn(move || {
49 static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
50 let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
51 format!("{}{id:02}", thread_name_prefix.as_ref())
52 })
53 .enable_all()
54 .build()
55 }
56}
57
58#[derive(Deserialize)]
59#[serde(untagged)]
60enum ValueNumStr<'a, T> {
61 Num(T),
62 Str(&'a str),
63}
64
65pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
66where
67 D: Deserializer<'de>,
68 T: Deserialize<'de> + FromStr,
69 <T as FromStr>::Err: Display,
70{
71 match ValueNumStr::deserialize(deserializer)? {
72 ValueNumStr::Num(value) => Ok(value),
73 ValueNumStr::Str(value) => value
74 .replace('_', "")
75 .parse::<T>()
76 .map_err(de::Error::custom),
77 }
78}
79
80pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
81where
82 D: Deserializer<'de>,
83 T: Deserialize<'de> + FromStr,
84 <T as FromStr>::Err: Display,
85{
86 match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
87 Some(ValueNumStr::Num(value)) => Ok(Some(value)),
88 Some(ValueNumStr::Str(value)) => value
89 .replace('_', "")
90 .parse::<T>()
91 .map_err(de::Error::custom)
92 .map(Some),
93 None => Ok(None),
94 }
95}
96
97#[derive(Debug, Error)]
98enum DecodeXTokenError {
99 #[error(transparent)]
100 Base64(#[from] base64::DecodeError),
101 #[error(transparent)]
102 Base58(#[from] bs58::decode::Error),
103}
104
105fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
106 Ok(match &x_token[0..7] {
107 "base64:" => base64_engine.decode(x_token)?,
108 "base58:" => bs58::decode(x_token).into_vec()?,
109 _ => x_token.as_bytes().to_vec(),
110 })
111}
112
113pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
114where
115 D: Deserializer<'de>,
116{
117 let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
118 x_token
119 .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
120 .transpose()
121}
122
123pub fn deserialize_x_token_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
124where
125 D: Deserializer<'de>,
126{
127 Vec::<&str>::deserialize(deserializer).and_then(|vec| {
128 vec.into_iter()
129 .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
130 .collect::<Result<_, _>>()
131 })
132}
133
134pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
135where
136 D: Deserializer<'de>,
137{
138 Vec::<&str>::deserialize(deserializer)?
139 .into_iter()
140 .map(|value| {
141 pubkey_decode(value)
142 .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
143 })
144 .collect::<Result<_, _>>()
145}
146
147pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
148where
149 D: Deserializer<'de>,
150{
151 deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
152}
153
154pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
155where
156 D: Deserializer<'de>,
157{
158 let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
159 sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
160 .transpose()
161}
162
163#[derive(Debug, Deserialize)]
164#[serde(deny_unknown_fields, untagged)]
165enum RustlsServerConfigSignedSelfSigned<'a> {
166 Signed { cert: &'a str, key: &'a str },
167 SelfSigned { self_signed_alt_names: Vec<String> },
168}
169
170impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
171 fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
172 where
173 D: Deserializer<'a>,
174 {
175 let (certs, key) = match self {
176 Self::Signed { cert, key } => {
177 let cert_path = PathBuf::from(cert);
178 let cert_bytes = fs::read(&cert_path).map_err(|error| {
179 de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
180 })?;
181 let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
182 vec![CertificateDer::from(cert_bytes)]
183 } else {
184 rustls_pemfile::certs(&mut &*cert_bytes)
185 .collect::<Result<_, _>>()
186 .map_err(|error| {
187 de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
188 })?
189 };
190
191 let key_path = PathBuf::from(key);
192 let key_bytes = fs::read(&key_path).map_err(|error| {
193 de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
194 })?;
195 let key = if key_path.extension().is_some_and(|x| x == "der") {
196 PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
197 } else {
198 rustls_pemfile::private_key(&mut &*key_bytes)
199 .map_err(|error| {
200 de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
201 })?
202 .ok_or_else(|| de::Error::custom("no private keys found"))?
203 };
204
205 (cert_chain, key)
206 }
207 Self::SelfSigned {
208 self_signed_alt_names,
209 } => {
210 let cert =
211 rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
212 de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
213 })?;
214 let cert_der = CertificateDer::from(cert.cert);
215 let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
216 (vec![cert_der], priv_key.into())
217 }
218 };
219
220 rustls::ServerConfig::builder()
221 .with_no_client_auth()
222 .with_single_cert(certs, key)
223 .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
224 }
225}
226
227pub fn deserialize_maybe_rustls_server_config<'de, D>(
228 deserializer: D,
229) -> Result<Option<rustls::ServerConfig>, D::Error>
230where
231 D: Deserializer<'de>,
232{
233 let config: Option<RustlsServerConfigSignedSelfSigned> =
234 Deserialize::deserialize(deserializer)?;
235 if let Some(config) = config {
236 config.parse::<D>().map(Some)
237 } else {
238 Ok(None)
239 }
240}
241
242pub fn deserialize_rustls_server_config<'de, D>(
243 deserializer: D,
244) -> Result<rustls::ServerConfig, D::Error>
245where
246 D: Deserializer<'de>,
247{
248 let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
249 config.parse::<D>()
250}
251
252pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
253where
254 D: Deserializer<'de>,
255{
256 match Option::<&str>::deserialize(deserializer)? {
257 Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
258 None => Ok(None),
259 }
260}
261
262pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
263 let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
264 let mut set = HashSet::new();
265 for cpulist in taskset.split(',') {
266 let Some(caps) = re.captures(cpulist) else {
267 return Err(format!("invalid cpulist: {cpulist}"));
268 };
269
270 let start = caps
271 .get(1)
272 .and_then(|m| m.as_str().parse().ok())
273 .expect("valid regex");
274 let end = caps
275 .get(2)
276 .and_then(|m| m.as_str().parse().ok())
277 .unwrap_or(start);
278 let step = caps
279 .get(3)
280 .and_then(|m| m.as_str().parse().ok())
281 .unwrap_or(1);
282
283 for cpu in (start..=end).step_by(step) {
284 set.insert(cpu);
285 }
286 }
287
288 let mut vec = set.into_iter().collect::<Vec<usize>>();
289 vec.sort();
290
291 if !vec.is_empty() {
292 if let Some(cores) = affinity_linux::get_thread_affinity()
293 .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
294 {
295 let mut cores = cores.into_iter().collect::<Vec<_>>();
296 cores.sort();
297
298 for core in vec.iter_mut() {
299 if let Some(actual_core) = cores.get(*core) {
300 *core = *actual_core;
301 } else {
302 return Err(format!(
303 "we don't have core {core}, available cores: {:?}",
304 (0..cores.len()).collect::<Vec<_>>()
305 ));
306 }
307 }
308 }
309 }
310
311 Ok(vec)
312}