richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{Engine, engine::general_purpose::STANDARD as base64_engine},
4    human_size::Size,
5    regex::Regex,
6    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
7    serde::{
8        Deserialize,
9        de::{self, DeserializeOwned, Deserializer},
10    },
11    solana_sdk::{pubkey::Pubkey, signature::Signature},
12    std::{
13        collections::HashSet,
14        fmt::Display,
15        fs, io,
16        path::{Path, PathBuf},
17        str::FromStr,
18        sync::atomic::{AtomicU64, Ordering},
19    },
20    thiserror::Error,
21};
22
23#[derive(Debug, Error)]
24pub enum ConfigLoadError {
25    #[error("failed to read config: {0}")]
26    Read(#[from] io::Error),
27    #[error("failed to parse YAML: {0}")]
28    Yaml(#[from] serde_yaml::Error),
29    #[error("failed to parse TOML: {0}")]
30    Toml(#[from] toml::de::Error),
31    #[error("failed to parse JSON: {0}")]
32    Json(#[from] json5::Error),
33}
34
35pub fn load_from_file<P, C>(file: P) -> Result<C, ConfigLoadError>
36where
37    P: AsRef<Path>,
38    C: DeserializeOwned,
39{
40    let config = fs::read_to_string(&file)?;
41    match file.as_ref().extension().and_then(|e| e.to_str()) {
42        Some("yml") | Some("yaml") => serde_yaml::from_str(&config).map_err(Into::into),
43        Some("toml") => toml::from_str(&config).map_err(Into::into),
44        _ => json5::from_str(&config).map_err(Into::into),
45    }
46}
47
48#[derive(Debug, Clone, Default, Deserialize)]
49#[serde(deny_unknown_fields, default)]
50pub struct ConfigTokio {
51    /// Number of worker threads in Tokio runtime
52    pub worker_threads: Option<usize>,
53    /// Threads affinity
54    #[serde(deserialize_with = "deserialize_affinity")]
55    pub affinity: Option<Vec<usize>>,
56}
57
58impl ConfigTokio {
59    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
60    where
61        T: AsRef<str> + Send + Sync + 'static,
62    {
63        let mut builder = tokio::runtime::Builder::new_multi_thread();
64        if let Some(worker_threads) = self.worker_threads {
65            builder.worker_threads(worker_threads);
66        }
67        if let Some(cpus) = self.affinity.clone() {
68            builder.on_thread_start(move || {
69                affinity_linux::set_thread_affinity(cpus.iter().copied())
70                    .expect("failed to set affinity")
71            });
72        }
73        builder
74            .thread_name_fn(move || {
75                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
76                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
77                format!("{}{id:02}", thread_name_prefix.as_ref())
78            })
79            .enable_all()
80            .build()
81    }
82}
83
84#[derive(Deserialize)]
85#[serde(untagged)]
86enum ValueNumStr<'a, T> {
87    Num(T),
88    Str(&'a str),
89}
90
91impl<T> ValueNumStr<'_, T>
92where
93    T: FromStr,
94    <T as FromStr>::Err: Display,
95{
96    fn parse(self) -> Result<T, String> {
97        match self {
98            Self::Num(value) => Ok(value),
99            Self::Str(value) => value
100                .replace('_', "")
101                .parse::<T>()
102                .map_err(|x| x.to_string()),
103        }
104    }
105}
106
107pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
108where
109    D: Deserializer<'de>,
110    T: Deserialize<'de> + FromStr,
111    <T as FromStr>::Err: Display,
112{
113    ValueNumStr::deserialize(deserializer)?
114        .parse()
115        .map_err(de::Error::custom)
116}
117
118pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
119where
120    D: Deserializer<'de>,
121    T: Deserialize<'de> + FromStr,
122    <T as FromStr>::Err: Display,
123{
124    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
125        Some(value) => Ok(Some(value.parse().map_err(de::Error::custom)?)),
126        None => Ok(None),
127    }
128}
129
130pub fn deserialize_humansize<'de, D>(deserializer: D) -> Result<u64, D::Error>
131where
132    D: Deserializer<'de>,
133{
134    let size: &str = Deserialize::deserialize(deserializer)?;
135
136    Size::from_str(size)
137        .map(|size| size.to_bytes())
138        .map_err(|error| de::Error::custom(format!("failed to parse size {size:?}: {error}")))
139}
140
141pub fn deserialize_humansize_usize<'de, D>(deserializer: D) -> Result<usize, D::Error>
142where
143    D: Deserializer<'de>,
144{
145    deserialize_humansize(deserializer)?
146        .try_into()
147        .map_err(|_| de::Error::custom("size value exceeds usize maximum"))
148}
149
150#[derive(Debug, Error)]
151enum DecodeXTokenError {
152    #[error(transparent)]
153    Base64(#[from] base64::DecodeError),
154    #[error(transparent)]
155    Base58(#[from] bs58::decode::Error),
156}
157
158fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
159    Ok(match &x_token[0..7] {
160        "base64:" => base64_engine.decode(x_token)?,
161        "base58:" => bs58::decode(x_token).into_vec()?,
162        _ => x_token.as_bytes().to_vec(),
163    })
164}
165
166pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
167where
168    D: Deserializer<'de>,
169{
170    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
171    x_token
172        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
173        .transpose()
174}
175
176pub fn deserialize_x_tokens_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
177where
178    D: Deserializer<'de>,
179{
180    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
181        vec.into_iter()
182            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
183            .collect::<Result<_, _>>()
184    })
185}
186
187pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
188where
189    D: Deserializer<'de>,
190{
191    Vec::<&str>::deserialize(deserializer)?
192        .into_iter()
193        .map(|value| {
194            pubkey_decode(value)
195                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
196        })
197        .collect::<Result<_, _>>()
198}
199
200pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
201where
202    D: Deserializer<'de>,
203{
204    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
205}
206
207pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
208where
209    D: Deserializer<'de>,
210{
211    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
212    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
213        .transpose()
214}
215
216#[derive(Debug, Deserialize)]
217#[serde(deny_unknown_fields, untagged)]
218enum RustlsServerConfigSignedSelfSigned<'a> {
219    Signed { cert: &'a str, key: &'a str },
220    SelfSigned { self_signed_alt_names: Vec<String> },
221}
222
223impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
224    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
225    where
226        D: Deserializer<'a>,
227    {
228        let (certs, key) = match self {
229            Self::Signed { cert, key } => {
230                let cert_path = PathBuf::from(cert);
231                let cert_bytes = fs::read(&cert_path).map_err(|error| {
232                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
233                })?;
234                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
235                    vec![CertificateDer::from(cert_bytes)]
236                } else {
237                    rustls_pemfile::certs(&mut &*cert_bytes)
238                        .collect::<Result<_, _>>()
239                        .map_err(|error| {
240                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
241                        })?
242                };
243
244                let key_path = PathBuf::from(key);
245                let key_bytes = fs::read(&key_path).map_err(|error| {
246                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
247                })?;
248                let key = if key_path.extension().is_some_and(|x| x == "der") {
249                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
250                } else {
251                    rustls_pemfile::private_key(&mut &*key_bytes)
252                        .map_err(|error| {
253                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
254                        })?
255                        .ok_or_else(|| de::Error::custom("no private keys found"))?
256                };
257
258                (cert_chain, key)
259            }
260            Self::SelfSigned {
261                self_signed_alt_names,
262            } => {
263                let cert =
264                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
265                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
266                    })?;
267                let cert_der = CertificateDer::from(cert.cert);
268                let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
269                (vec![cert_der], priv_key.into())
270            }
271        };
272
273        rustls::ServerConfig::builder()
274            .with_no_client_auth()
275            .with_single_cert(certs, key)
276            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
277    }
278}
279
280pub fn deserialize_maybe_rustls_server_config<'de, D>(
281    deserializer: D,
282) -> Result<Option<rustls::ServerConfig>, D::Error>
283where
284    D: Deserializer<'de>,
285{
286    let config: Option<RustlsServerConfigSignedSelfSigned> =
287        Deserialize::deserialize(deserializer)?;
288    if let Some(config) = config {
289        config.parse::<D>().map(Some)
290    } else {
291        Ok(None)
292    }
293}
294
295pub fn deserialize_rustls_server_config<'de, D>(
296    deserializer: D,
297) -> Result<rustls::ServerConfig, D::Error>
298where
299    D: Deserializer<'de>,
300{
301    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
302    config.parse::<D>()
303}
304
305pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
306where
307    D: Deserializer<'de>,
308{
309    match Option::<&str>::deserialize(deserializer)? {
310        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
311        None => Ok(None),
312    }
313}
314
315pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
316    let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
317    let mut set = HashSet::new();
318    for cpulist in taskset.split(',') {
319        let Some(caps) = re.captures(cpulist) else {
320            return Err(format!("invalid cpulist: {cpulist}"));
321        };
322
323        let start = caps
324            .get(1)
325            .and_then(|m| m.as_str().parse().ok())
326            .expect("valid regex");
327        let end = caps
328            .get(2)
329            .and_then(|m| m.as_str().parse().ok())
330            .unwrap_or(start);
331        let step = caps
332            .get(3)
333            .and_then(|m| m.as_str().parse().ok())
334            .unwrap_or(1);
335
336        for cpu in (start..=end).step_by(step) {
337            set.insert(cpu);
338        }
339    }
340
341    let mut vec = set.into_iter().collect::<Vec<usize>>();
342    vec.sort();
343
344    if !vec.is_empty() {
345        if let Some(cores) = affinity_linux::get_thread_affinity()
346            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
347        {
348            let mut cores = cores.into_iter().collect::<Vec<_>>();
349            cores.sort();
350
351            for core in vec.iter_mut() {
352                if let Some(actual_core) = cores.get(*core) {
353                    *core = *actual_core;
354                } else {
355                    return Err(format!(
356                        "we don't have core {core}, available cores: {:?}",
357                        (0..cores.len()).collect::<Vec<_>>()
358                    ));
359                }
360            }
361        }
362    }
363
364    Ok(vec)
365}