Skip to main content

richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{Engine, engine::general_purpose::STANDARD as base64_engine},
4    human_size::Size,
5    regex::Regex,
6    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
7    serde::{
8        Deserialize,
9        de::{self, DeserializeOwned, Deserializer},
10    },
11    solana_pubkey::Pubkey,
12    solana_signature::Signature,
13    std::{
14        collections::HashSet,
15        fmt::Display,
16        fs, io,
17        path::{Path, PathBuf},
18        str::FromStr,
19        sync::atomic::{AtomicU64, Ordering},
20    },
21    thiserror::Error,
22};
23
24#[derive(Debug, Error)]
25pub enum ConfigLoadError {
26    #[error("failed to read config: {0}")]
27    Read(#[from] io::Error),
28    #[error("failed to parse YAML: {0}")]
29    Yaml(#[from] serde_yaml::Error),
30    #[error("failed to parse TOML: {0}")]
31    Toml(#[from] toml::de::Error),
32    #[error("failed to parse JSON: {0}")]
33    Json(#[from] json5::Error),
34}
35
36pub fn load_from_file_sync<P, C>(file: P) -> Result<C, ConfigLoadError>
37where
38    P: AsRef<Path>,
39    C: DeserializeOwned,
40{
41    let config = fs::read_to_string(&file)?;
42    parse_config(file.as_ref(), &config)
43}
44
45pub async fn load_from_file<P, C>(file: P) -> Result<C, ConfigLoadError>
46where
47    P: AsRef<Path>,
48    C: DeserializeOwned,
49{
50    let config = tokio::fs::read_to_string(&file).await?;
51    parse_config(file.as_ref(), &config)
52}
53
54fn parse_config<C>(file: &Path, config: &str) -> Result<C, ConfigLoadError>
55where
56    C: DeserializeOwned,
57{
58    match file.extension().and_then(|e| e.to_str()) {
59        Some("yml") | Some("yaml") => serde_yaml::from_str(config).map_err(Into::into),
60        Some("toml") => toml::from_str(config).map_err(Into::into),
61        _ => json5::from_str(config).map_err(Into::into),
62    }
63}
64
65#[derive(Debug, Clone, Default, Deserialize)]
66#[serde(deny_unknown_fields, default)]
67pub struct ConfigTokio {
68    /// Number of worker threads in Tokio runtime
69    pub worker_threads: Option<usize>,
70    /// Threads affinity
71    #[serde(deserialize_with = "deserialize_affinity")]
72    pub affinity: Option<Vec<usize>>,
73}
74
75impl ConfigTokio {
76    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
77    where
78        T: AsRef<str> + Send + Sync + 'static,
79    {
80        let mut builder = tokio::runtime::Builder::new_multi_thread();
81        if let Some(worker_threads) = self.worker_threads {
82            builder.worker_threads(worker_threads);
83        }
84        if let Some(cpus) = self.affinity.clone() {
85            builder.on_thread_start(move || {
86                affinity_linux::set_thread_affinity(cpus.iter().copied())
87                    .expect("failed to set affinity")
88            });
89        }
90        builder
91            .thread_name_fn(move || {
92                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
93                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
94                format!("{}{id:02}", thread_name_prefix.as_ref())
95            })
96            .enable_all()
97            .build()
98    }
99}
100
101#[derive(Deserialize)]
102#[serde(untagged)]
103enum ValueNumStr<'a, T> {
104    Num(T),
105    Str(&'a str),
106}
107
108impl<T> ValueNumStr<'_, T>
109where
110    T: FromStr,
111    <T as FromStr>::Err: Display,
112{
113    fn parse(self) -> Result<T, String> {
114        match self {
115            Self::Num(value) => Ok(value),
116            Self::Str(value) => value
117                .replace('_', "")
118                .parse::<T>()
119                .map_err(|x| x.to_string()),
120        }
121    }
122}
123
124pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
125where
126    D: Deserializer<'de>,
127    T: Deserialize<'de> + FromStr,
128    <T as FromStr>::Err: Display,
129{
130    ValueNumStr::deserialize(deserializer)?
131        .parse()
132        .map_err(de::Error::custom)
133}
134
135pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
136where
137    D: Deserializer<'de>,
138    T: Deserialize<'de> + FromStr,
139    <T as FromStr>::Err: Display,
140{
141    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
142        Some(value) => Ok(Some(value.parse().map_err(de::Error::custom)?)),
143        None => Ok(None),
144    }
145}
146
147pub fn deserialize_humansize<'de, D>(deserializer: D) -> Result<u64, D::Error>
148where
149    D: Deserializer<'de>,
150{
151    let size: &str = Deserialize::deserialize(deserializer)?;
152
153    Size::from_str(size)
154        .map(|size| size.to_bytes())
155        .map_err(|error| de::Error::custom(format!("failed to parse size {size:?}: {error}")))
156}
157
158pub fn deserialize_humansize_usize<'de, D>(deserializer: D) -> Result<usize, D::Error>
159where
160    D: Deserializer<'de>,
161{
162    deserialize_humansize(deserializer)?
163        .try_into()
164        .map_err(|_| de::Error::custom("size value exceeds usize maximum"))
165}
166
167#[derive(Debug, Error)]
168enum DecodeXTokenError {
169    #[error(transparent)]
170    Base64(#[from] base64::DecodeError),
171    #[error(transparent)]
172    Base58(#[from] bs58::decode::Error),
173}
174
175fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
176    Ok(match &x_token[0..7] {
177        "base64:" => base64_engine.decode(x_token)?,
178        "base58:" => bs58::decode(x_token).into_vec()?,
179        _ => x_token.as_bytes().to_vec(),
180    })
181}
182
183pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
184where
185    D: Deserializer<'de>,
186{
187    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
188    x_token
189        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
190        .transpose()
191}
192
193pub fn deserialize_x_tokens_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
194where
195    D: Deserializer<'de>,
196{
197    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
198        vec.into_iter()
199            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
200            .collect::<Result<_, _>>()
201    })
202}
203
204pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
205where
206    D: Deserializer<'de>,
207{
208    Vec::<&str>::deserialize(deserializer)?
209        .into_iter()
210        .map(|value| {
211            pubkey_decode(value)
212                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
213        })
214        .collect::<Result<_, _>>()
215}
216
217pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
218where
219    D: Deserializer<'de>,
220{
221    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
222}
223
224pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
225where
226    D: Deserializer<'de>,
227{
228    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
229    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
230        .transpose()
231}
232
233#[derive(Debug, Deserialize)]
234#[serde(deny_unknown_fields, untagged)]
235enum RustlsServerConfigSignedSelfSigned<'a> {
236    Signed { cert: &'a str, key: &'a str },
237    SelfSigned { self_signed_alt_names: Vec<String> },
238}
239
240impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
241    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
242    where
243        D: Deserializer<'a>,
244    {
245        let (certs, key) = match self {
246            Self::Signed { cert, key } => {
247                let cert_path = PathBuf::from(cert);
248                let cert_bytes = fs::read(&cert_path).map_err(|error| {
249                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
250                })?;
251                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
252                    vec![CertificateDer::from(cert_bytes)]
253                } else {
254                    rustls_pemfile::certs(&mut &*cert_bytes)
255                        .collect::<Result<_, _>>()
256                        .map_err(|error| {
257                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
258                        })?
259                };
260
261                let key_path = PathBuf::from(key);
262                let key_bytes = fs::read(&key_path).map_err(|error| {
263                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
264                })?;
265                let key = if key_path.extension().is_some_and(|x| x == "der") {
266                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
267                } else {
268                    rustls_pemfile::private_key(&mut &*key_bytes)
269                        .map_err(|error| {
270                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
271                        })?
272                        .ok_or_else(|| de::Error::custom("no private keys found"))?
273                };
274
275                (cert_chain, key)
276            }
277            Self::SelfSigned {
278                self_signed_alt_names,
279            } => {
280                let cert =
281                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
282                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
283                    })?;
284                let cert_der = CertificateDer::from(cert.cert);
285                let priv_key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
286                (vec![cert_der], priv_key.into())
287            }
288        };
289
290        rustls::ServerConfig::builder()
291            .with_no_client_auth()
292            .with_single_cert(certs, key)
293            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
294    }
295}
296
297pub fn deserialize_maybe_rustls_server_config<'de, D>(
298    deserializer: D,
299) -> Result<Option<rustls::ServerConfig>, D::Error>
300where
301    D: Deserializer<'de>,
302{
303    let config: Option<RustlsServerConfigSignedSelfSigned> =
304        Deserialize::deserialize(deserializer)?;
305    if let Some(config) = config {
306        config.parse::<D>().map(Some)
307    } else {
308        Ok(None)
309    }
310}
311
312pub fn deserialize_rustls_server_config<'de, D>(
313    deserializer: D,
314) -> Result<rustls::ServerConfig, D::Error>
315where
316    D: Deserializer<'de>,
317{
318    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
319    config.parse::<D>()
320}
321
322pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
323where
324    D: Deserializer<'de>,
325{
326    match Option::<&str>::deserialize(deserializer)? {
327        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
328        None => Ok(None),
329    }
330}
331
332pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
333    let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
334    let mut set = HashSet::new();
335    for cpulist in taskset.split(',') {
336        let Some(caps) = re.captures(cpulist) else {
337            return Err(format!("invalid cpulist: {cpulist}"));
338        };
339
340        let start = caps
341            .get(1)
342            .and_then(|m| m.as_str().parse().ok())
343            .expect("valid regex");
344        let end = caps
345            .get(2)
346            .and_then(|m| m.as_str().parse().ok())
347            .unwrap_or(start);
348        let step = caps
349            .get(3)
350            .and_then(|m| m.as_str().parse().ok())
351            .unwrap_or(1);
352
353        for cpu in (start..=end).step_by(step) {
354            set.insert(cpu);
355        }
356    }
357
358    let mut vec = set.into_iter().collect::<Vec<usize>>();
359    vec.sort();
360
361    if !vec.is_empty() {
362        if let Some(cores) = affinity_linux::get_thread_affinity()
363            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
364        {
365            let mut cores = cores.into_iter().collect::<Vec<_>>();
366            cores.sort();
367
368            for core in vec.iter_mut() {
369                if let Some(actual_core) = cores.get(*core) {
370                    *core = *actual_core;
371                } else {
372                    return Err(format!(
373                        "we don't have core {core}, available cores: {:?}",
374                        (0..cores.len()).collect::<Vec<_>>()
375                    ));
376                }
377            }
378        }
379    }
380
381    Ok(vec)
382}