richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4    human_size::Size,
5    regex::Regex,
6    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
7    serde::{
8        de::{self, Deserializer},
9        Deserialize,
10    },
11    solana_sdk::{pubkey::Pubkey, signature::Signature},
12    std::{
13        collections::HashSet,
14        fmt::Display,
15        fs, io,
16        path::PathBuf,
17        str::FromStr,
18        sync::atomic::{AtomicU64, Ordering},
19    },
20    thiserror::Error,
21};
22
23#[derive(Debug, Clone, Default, Deserialize)]
24#[serde(deny_unknown_fields, default)]
25pub struct ConfigTokio {
26    /// Number of worker threads in Tokio runtime
27    pub worker_threads: Option<usize>,
28    /// Threads affinity
29    #[serde(deserialize_with = "deserialize_affinity")]
30    pub affinity: Option<Vec<usize>>,
31}
32
33impl ConfigTokio {
34    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
35    where
36        T: AsRef<str> + Send + Sync + 'static,
37    {
38        let mut builder = tokio::runtime::Builder::new_multi_thread();
39        if let Some(worker_threads) = self.worker_threads {
40            builder.worker_threads(worker_threads);
41        }
42        if let Some(cpus) = self.affinity.clone() {
43            builder.on_thread_start(move || {
44                affinity_linux::set_thread_affinity(cpus.iter().copied())
45                    .expect("failed to set affinity")
46            });
47        }
48        builder
49            .thread_name_fn(move || {
50                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
51                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
52                format!("{}{id:02}", thread_name_prefix.as_ref())
53            })
54            .enable_all()
55            .build()
56    }
57}
58
59#[derive(Deserialize)]
60#[serde(untagged)]
61enum ValueNumStr<'a, T> {
62    Num(T),
63    Str(&'a str),
64}
65
66impl<T> ValueNumStr<'_, T>
67where
68    T: FromStr,
69    <T as FromStr>::Err: Display,
70{
71    fn parse(self) -> Result<T, String> {
72        match self {
73            Self::Num(value) => Ok(value),
74            Self::Str(value) => value
75                .replace('_', "")
76                .parse::<T>()
77                .map_err(|x| x.to_string()),
78        }
79    }
80}
81
82pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
83where
84    D: Deserializer<'de>,
85    T: Deserialize<'de> + FromStr,
86    <T as FromStr>::Err: Display,
87{
88    ValueNumStr::deserialize(deserializer)?
89        .parse()
90        .map_err(de::Error::custom)
91}
92
93pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
94where
95    D: Deserializer<'de>,
96    T: Deserialize<'de> + FromStr,
97    <T as FromStr>::Err: Display,
98{
99    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
100        Some(value) => Ok(Some(value.parse().map_err(de::Error::custom)?)),
101        None => Ok(None),
102    }
103}
104
105pub fn deserialize_humansize<'de, D>(deserializer: D) -> Result<u64, D::Error>
106where
107    D: Deserializer<'de>,
108{
109    let size: &str = Deserialize::deserialize(deserializer)?;
110
111    Size::from_str(size)
112        .map(|size| size.to_bytes())
113        .map_err(|error| de::Error::custom(format!("failed to parse size {size:?}: {error}")))
114}
115
116pub fn deserialize_humansize_usize<'de, D>(deserializer: D) -> Result<usize, D::Error>
117where
118    D: Deserializer<'de>,
119{
120    deserialize_humansize(deserializer)?
121        .try_into()
122        .map_err(|_| de::Error::custom("size value exceeds usize maximum"))
123}
124
125#[derive(Debug, Error)]
126enum DecodeXTokenError {
127    #[error(transparent)]
128    Base64(#[from] base64::DecodeError),
129    #[error(transparent)]
130    Base58(#[from] bs58::decode::Error),
131}
132
133fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
134    Ok(match &x_token[0..7] {
135        "base64:" => base64_engine.decode(x_token)?,
136        "base58:" => bs58::decode(x_token).into_vec()?,
137        _ => x_token.as_bytes().to_vec(),
138    })
139}
140
141pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
142where
143    D: Deserializer<'de>,
144{
145    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
146    x_token
147        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
148        .transpose()
149}
150
151pub fn deserialize_x_tokens_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
152where
153    D: Deserializer<'de>,
154{
155    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
156        vec.into_iter()
157            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
158            .collect::<Result<_, _>>()
159    })
160}
161
162pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
163where
164    D: Deserializer<'de>,
165{
166    Vec::<&str>::deserialize(deserializer)?
167        .into_iter()
168        .map(|value| {
169            pubkey_decode(value)
170                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
171        })
172        .collect::<Result<_, _>>()
173}
174
175pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
176where
177    D: Deserializer<'de>,
178{
179    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
180}
181
182pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
183where
184    D: Deserializer<'de>,
185{
186    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
187    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
188        .transpose()
189}
190
191#[derive(Debug, Deserialize)]
192#[serde(deny_unknown_fields, untagged)]
193enum RustlsServerConfigSignedSelfSigned<'a> {
194    Signed { cert: &'a str, key: &'a str },
195    SelfSigned { self_signed_alt_names: Vec<String> },
196}
197
198impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
199    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
200    where
201        D: Deserializer<'a>,
202    {
203        let (certs, key) = match self {
204            Self::Signed { cert, key } => {
205                let cert_path = PathBuf::from(cert);
206                let cert_bytes = fs::read(&cert_path).map_err(|error| {
207                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
208                })?;
209                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
210                    vec![CertificateDer::from(cert_bytes)]
211                } else {
212                    rustls_pemfile::certs(&mut &*cert_bytes)
213                        .collect::<Result<_, _>>()
214                        .map_err(|error| {
215                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
216                        })?
217                };
218
219                let key_path = PathBuf::from(key);
220                let key_bytes = fs::read(&key_path).map_err(|error| {
221                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
222                })?;
223                let key = if key_path.extension().is_some_and(|x| x == "der") {
224                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
225                } else {
226                    rustls_pemfile::private_key(&mut &*key_bytes)
227                        .map_err(|error| {
228                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
229                        })?
230                        .ok_or_else(|| de::Error::custom("no private keys found"))?
231                };
232
233                (cert_chain, key)
234            }
235            Self::SelfSigned {
236                self_signed_alt_names,
237            } => {
238                let cert =
239                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
240                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
241                    })?;
242                let cert_der = CertificateDer::from(cert.cert);
243                let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
244                (vec![cert_der], priv_key.into())
245            }
246        };
247
248        rustls::ServerConfig::builder()
249            .with_no_client_auth()
250            .with_single_cert(certs, key)
251            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
252    }
253}
254
255pub fn deserialize_maybe_rustls_server_config<'de, D>(
256    deserializer: D,
257) -> Result<Option<rustls::ServerConfig>, D::Error>
258where
259    D: Deserializer<'de>,
260{
261    let config: Option<RustlsServerConfigSignedSelfSigned> =
262        Deserialize::deserialize(deserializer)?;
263    if let Some(config) = config {
264        config.parse::<D>().map(Some)
265    } else {
266        Ok(None)
267    }
268}
269
270pub fn deserialize_rustls_server_config<'de, D>(
271    deserializer: D,
272) -> Result<rustls::ServerConfig, D::Error>
273where
274    D: Deserializer<'de>,
275{
276    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
277    config.parse::<D>()
278}
279
280pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
281where
282    D: Deserializer<'de>,
283{
284    match Option::<&str>::deserialize(deserializer)? {
285        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
286        None => Ok(None),
287    }
288}
289
290pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
291    let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
292    let mut set = HashSet::new();
293    for cpulist in taskset.split(',') {
294        let Some(caps) = re.captures(cpulist) else {
295            return Err(format!("invalid cpulist: {cpulist}"));
296        };
297
298        let start = caps
299            .get(1)
300            .and_then(|m| m.as_str().parse().ok())
301            .expect("valid regex");
302        let end = caps
303            .get(2)
304            .and_then(|m| m.as_str().parse().ok())
305            .unwrap_or(start);
306        let step = caps
307            .get(3)
308            .and_then(|m| m.as_str().parse().ok())
309            .unwrap_or(1);
310
311        for cpu in (start..=end).step_by(step) {
312            set.insert(cpu);
313        }
314    }
315
316    let mut vec = set.into_iter().collect::<Vec<usize>>();
317    vec.sort();
318
319    if !vec.is_empty() {
320        if let Some(cores) = affinity_linux::get_thread_affinity()
321            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
322        {
323            let mut cores = cores.into_iter().collect::<Vec<_>>();
324            cores.sort();
325
326            for core in vec.iter_mut() {
327                if let Some(actual_core) = cores.get(*core) {
328                    *core = *actual_core;
329                } else {
330                    return Err(format!(
331                        "we don't have core {core}, available cores: {:?}",
332                        (0..cores.len()).collect::<Vec<_>>()
333                    ));
334                }
335            }
336        }
337    }
338
339    Ok(vec)
340}