richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4    regex::Regex,
5    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
6    serde::{
7        de::{self, Deserializer},
8        Deserialize,
9    },
10    solana_sdk::{pubkey::Pubkey, signature::Signature},
11    std::{
12        collections::HashSet,
13        fmt::Display,
14        fs, io,
15        net::{IpAddr, Ipv4Addr, SocketAddr},
16        path::PathBuf,
17        str::FromStr,
18        sync::atomic::{AtomicU64, Ordering},
19    },
20    thiserror::Error,
21};
22
23#[derive(Debug, Clone, Default, Deserialize)]
24#[serde(deny_unknown_fields, default)]
25pub struct ConfigTokio {
26    /// Number of worker threads in Tokio runtime
27    pub worker_threads: Option<usize>,
28    /// Threads affinity
29    #[serde(deserialize_with = "deserialize_affinity")]
30    pub affinity: Option<Vec<usize>>,
31}
32
33impl ConfigTokio {
34    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
35    where
36        T: AsRef<str> + Send + Sync + 'static,
37    {
38        let mut builder = tokio::runtime::Builder::new_multi_thread();
39        if let Some(worker_threads) = self.worker_threads {
40            builder.worker_threads(worker_threads);
41        }
42        if let Some(cpus) = self.affinity.clone() {
43            builder.on_thread_start(move || {
44                affinity_linux::set_thread_affinity(cpus.iter().copied())
45                    .expect("failed to set affinity")
46            });
47        }
48        builder
49            .thread_name_fn(move || {
50                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
51                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
52                format!("{}{id:02}", thread_name_prefix.as_ref())
53            })
54            .enable_all()
55            .build()
56    }
57}
58
59#[derive(Debug, Clone, Copy, Deserialize)]
60#[serde(deny_unknown_fields, default)]
61pub struct ConfigMetrics {
62    /// Endpoint of Prometheus service
63    pub endpoint: SocketAddr,
64}
65
66impl Default for ConfigMetrics {
67    fn default() -> Self {
68        Self {
69            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10123),
70        }
71    }
72}
73
74#[derive(Deserialize)]
75#[serde(untagged)]
76enum ValueNumStr<'a, T> {
77    Num(T),
78    Str(&'a str),
79}
80
81pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
82where
83    D: Deserializer<'de>,
84    T: Deserialize<'de> + FromStr,
85    <T as FromStr>::Err: Display,
86{
87    match ValueNumStr::deserialize(deserializer)? {
88        ValueNumStr::Num(value) => Ok(value),
89        ValueNumStr::Str(value) => value
90            .replace('_', "")
91            .parse::<T>()
92            .map_err(de::Error::custom),
93    }
94}
95
96pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
97where
98    D: Deserializer<'de>,
99    T: Deserialize<'de> + FromStr,
100    <T as FromStr>::Err: Display,
101{
102    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
103        Some(ValueNumStr::Num(value)) => Ok(Some(value)),
104        Some(ValueNumStr::Str(value)) => value
105            .replace('_', "")
106            .parse::<T>()
107            .map_err(de::Error::custom)
108            .map(Some),
109        None => Ok(None),
110    }
111}
112
113#[derive(Debug, Error)]
114enum DecodeXTokenError {
115    #[error(transparent)]
116    Base64(#[from] base64::DecodeError),
117    #[error(transparent)]
118    Base58(#[from] bs58::decode::Error),
119}
120
121fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
122    Ok(match &x_token[0..7] {
123        "base64:" => base64_engine.decode(x_token)?,
124        "base58:" => bs58::decode(x_token).into_vec()?,
125        _ => x_token.as_bytes().to_vec(),
126    })
127}
128
129pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
130where
131    D: Deserializer<'de>,
132{
133    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
134    x_token
135        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
136        .transpose()
137}
138
139pub fn deserialize_x_token_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
140where
141    D: Deserializer<'de>,
142{
143    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
144        vec.into_iter()
145            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
146            .collect::<Result<_, _>>()
147    })
148}
149
150pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
151where
152    D: Deserializer<'de>,
153{
154    Vec::<&str>::deserialize(deserializer)?
155        .into_iter()
156        .map(|value| {
157            pubkey_decode(value)
158                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
159        })
160        .collect::<Result<_, _>>()
161}
162
163pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
164where
165    D: Deserializer<'de>,
166{
167    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
168}
169
170pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
171where
172    D: Deserializer<'de>,
173{
174    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
175    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
176        .transpose()
177}
178
179#[derive(Debug, Deserialize)]
180#[serde(deny_unknown_fields, untagged)]
181enum RustlsServerConfigSignedSelfSigned<'a> {
182    Signed { cert: &'a str, key: &'a str },
183    SelfSigned { self_signed_alt_names: Vec<String> },
184}
185
186impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
187    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
188    where
189        D: Deserializer<'a>,
190    {
191        let (certs, key) = match self {
192            Self::Signed { cert, key } => {
193                let cert_path = PathBuf::from(cert);
194                let cert_bytes = fs::read(&cert_path).map_err(|error| {
195                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
196                })?;
197                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
198                    vec![CertificateDer::from(cert_bytes)]
199                } else {
200                    rustls_pemfile::certs(&mut &*cert_bytes)
201                        .collect::<Result<_, _>>()
202                        .map_err(|error| {
203                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
204                        })?
205                };
206
207                let key_path = PathBuf::from(key);
208                let key_bytes = fs::read(&key_path).map_err(|error| {
209                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
210                })?;
211                let key = if key_path.extension().is_some_and(|x| x == "der") {
212                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
213                } else {
214                    rustls_pemfile::private_key(&mut &*key_bytes)
215                        .map_err(|error| {
216                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
217                        })?
218                        .ok_or_else(|| de::Error::custom("no private keys found"))?
219                };
220
221                (cert_chain, key)
222            }
223            Self::SelfSigned {
224                self_signed_alt_names,
225            } => {
226                let cert =
227                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
228                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
229                    })?;
230                let cert_der = CertificateDer::from(cert.cert);
231                let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
232                (vec![cert_der], priv_key.into())
233            }
234        };
235
236        rustls::ServerConfig::builder()
237            .with_no_client_auth()
238            .with_single_cert(certs, key)
239            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
240    }
241}
242
243pub fn deserialize_maybe_rustls_server_config<'de, D>(
244    deserializer: D,
245) -> Result<Option<rustls::ServerConfig>, D::Error>
246where
247    D: Deserializer<'de>,
248{
249    let config: Option<RustlsServerConfigSignedSelfSigned> =
250        Deserialize::deserialize(deserializer)?;
251    if let Some(config) = config {
252        config.parse::<D>().map(Some)
253    } else {
254        Ok(None)
255    }
256}
257
258pub fn deserialize_rustls_server_config<'de, D>(
259    deserializer: D,
260) -> Result<rustls::ServerConfig, D::Error>
261where
262    D: Deserializer<'de>,
263{
264    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
265    config.parse::<D>()
266}
267
268pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
269where
270    D: Deserializer<'de>,
271{
272    match Option::<&str>::deserialize(deserializer)? {
273        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
274        None => Ok(None),
275    }
276}
277
278pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
279    let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
280    let mut set = HashSet::new();
281    for cpulist in taskset.split(',') {
282        let Some(caps) = re.captures(cpulist) else {
283            return Err(format!("invalid cpulist: {cpulist}"));
284        };
285
286        let start = caps
287            .get(1)
288            .and_then(|m| m.as_str().parse().ok())
289            .expect("valid regex");
290        let end = caps
291            .get(2)
292            .and_then(|m| m.as_str().parse().ok())
293            .unwrap_or(start);
294        let step = caps
295            .get(3)
296            .and_then(|m| m.as_str().parse().ok())
297            .unwrap_or(1);
298
299        for cpu in (start..=end).step_by(step) {
300            set.insert(cpu);
301        }
302    }
303
304    let mut vec = set.into_iter().collect::<Vec<usize>>();
305    vec.sort();
306
307    if !vec.is_empty() {
308        if let Some(cores) = affinity_linux::get_thread_affinity()
309            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
310        {
311            let mut cores = cores.into_iter().collect::<Vec<_>>();
312            cores.sort();
313
314            for core in vec.iter_mut() {
315                if let Some(actual_core) = cores.get(*core) {
316                    *core = *actual_core;
317                } else {
318                    return Err(format!(
319                        "we don't have core {core}, available cores: {:?}",
320                        (0..cores.len()).collect::<Vec<_>>()
321                    ));
322                }
323            }
324        }
325    }
326
327    Ok(vec)
328}