richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
5    serde::{
6        de::{self, Deserializer},
7        Deserialize,
8    },
9    solana_sdk::{pubkey::Pubkey, signature::Signature},
10    std::{
11        collections::HashSet,
12        fmt::Display,
13        fs, io,
14        net::{IpAddr, Ipv4Addr, SocketAddr},
15        path::PathBuf,
16        str::FromStr,
17        sync::atomic::{AtomicU64, Ordering},
18    },
19    thiserror::Error,
20};
21
22#[derive(Debug, Clone, Default, Deserialize)]
23#[serde(deny_unknown_fields, default)]
24pub struct ConfigTokio {
25    /// Number of worker threads in Tokio runtime
26    pub worker_threads: Option<usize>,
27    /// Threads affinity
28    #[serde(deserialize_with = "deserialize_affinity")]
29    pub affinity: Option<Vec<usize>>,
30}
31
32impl ConfigTokio {
33    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
34    where
35        T: AsRef<str> + Send + Sync + 'static,
36    {
37        let mut builder = tokio::runtime::Builder::new_multi_thread();
38        if let Some(worker_threads) = self.worker_threads {
39            builder.worker_threads(worker_threads);
40        }
41        if let Some(cpus) = self.affinity.clone() {
42            builder.on_thread_start(move || {
43                affinity_linux::set_thread_affinity(cpus.iter().copied())
44                    .expect("failed to set affinity")
45            });
46        }
47        builder
48            .thread_name_fn(move || {
49                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
50                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
51                format!("{}{id:02}", thread_name_prefix.as_ref())
52            })
53            .enable_all()
54            .build()
55    }
56}
57
58#[derive(Debug, Clone, Copy, Deserialize)]
59#[serde(deny_unknown_fields, default)]
60pub struct ConfigMetrics {
61    /// Endpoint of Prometheus service
62    pub endpoint: SocketAddr,
63}
64
65impl Default for ConfigMetrics {
66    fn default() -> Self {
67        Self {
68            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10123),
69        }
70    }
71}
72
73#[derive(Deserialize)]
74#[serde(untagged)]
75enum ValueNumStr<'a, T> {
76    Num(T),
77    Str(&'a str),
78}
79
80pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
81where
82    D: Deserializer<'de>,
83    T: Deserialize<'de> + FromStr,
84    <T as FromStr>::Err: Display,
85{
86    match ValueNumStr::deserialize(deserializer)? {
87        ValueNumStr::Num(value) => Ok(value),
88        ValueNumStr::Str(value) => value
89            .replace('_', "")
90            .parse::<T>()
91            .map_err(de::Error::custom),
92    }
93}
94
95pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
96where
97    D: Deserializer<'de>,
98    T: Deserialize<'de> + FromStr,
99    <T as FromStr>::Err: Display,
100{
101    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
102        Some(ValueNumStr::Num(value)) => Ok(Some(value)),
103        Some(ValueNumStr::Str(value)) => value
104            .replace('_', "")
105            .parse::<T>()
106            .map_err(de::Error::custom)
107            .map(Some),
108        None => Ok(None),
109    }
110}
111
112#[derive(Debug, Error)]
113enum DecodeXTokenError {
114    #[error(transparent)]
115    Base64(#[from] base64::DecodeError),
116    #[error(transparent)]
117    Base58(#[from] bs58::decode::Error),
118}
119
120fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
121    Ok(match &x_token[0..7] {
122        "base64:" => base64_engine.decode(x_token)?,
123        "base58:" => bs58::decode(x_token).into_vec()?,
124        _ => x_token.as_bytes().to_vec(),
125    })
126}
127
128pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
129where
130    D: Deserializer<'de>,
131{
132    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
133    x_token
134        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
135        .transpose()
136}
137
138pub fn deserialize_x_token_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
139where
140    D: Deserializer<'de>,
141{
142    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
143        vec.into_iter()
144            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
145            .collect::<Result<_, _>>()
146    })
147}
148
149pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
150where
151    D: Deserializer<'de>,
152{
153    Vec::<&str>::deserialize(deserializer)?
154        .into_iter()
155        .map(|value| {
156            pubkey_decode(value)
157                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
158        })
159        .collect::<Result<_, _>>()
160}
161
162pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
163where
164    D: Deserializer<'de>,
165{
166    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
167}
168
169pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
170where
171    D: Deserializer<'de>,
172{
173    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
174    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
175        .transpose()
176}
177
178#[derive(Debug, Deserialize)]
179#[serde(deny_unknown_fields, untagged)]
180enum RustlsServerConfigSignedSelfSigned<'a> {
181    Signed { cert: &'a str, key: &'a str },
182    SelfSigned { self_signed_alt_names: Vec<String> },
183}
184
185impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
186    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
187    where
188        D: Deserializer<'a>,
189    {
190        let (certs, key) = match self {
191            Self::Signed { cert, key } => {
192                let cert_path = PathBuf::from(cert);
193                let cert_bytes = fs::read(&cert_path).map_err(|error| {
194                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
195                })?;
196                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
197                    vec![CertificateDer::from(cert_bytes)]
198                } else {
199                    rustls_pemfile::certs(&mut &*cert_bytes)
200                        .collect::<Result<_, _>>()
201                        .map_err(|error| {
202                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
203                        })?
204                };
205
206                let key_path = PathBuf::from(key);
207                let key_bytes = fs::read(&key_path).map_err(|error| {
208                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
209                })?;
210                let key = if key_path.extension().is_some_and(|x| x == "der") {
211                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
212                } else {
213                    rustls_pemfile::private_key(&mut &*key_bytes)
214                        .map_err(|error| {
215                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
216                        })?
217                        .ok_or_else(|| de::Error::custom("no private keys found"))?
218                };
219
220                (cert_chain, key)
221            }
222            Self::SelfSigned {
223                self_signed_alt_names,
224            } => {
225                let cert =
226                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
227                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
228                    })?;
229                let cert_der = CertificateDer::from(cert.cert);
230                let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
231                (vec![cert_der], priv_key.into())
232            }
233        };
234
235        rustls::ServerConfig::builder()
236            .with_no_client_auth()
237            .with_single_cert(certs, key)
238            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
239    }
240}
241
242pub fn deserialize_maybe_rustls_server_config<'de, D>(
243    deserializer: D,
244) -> Result<Option<rustls::ServerConfig>, D::Error>
245where
246    D: Deserializer<'de>,
247{
248    let config: Option<RustlsServerConfigSignedSelfSigned> =
249        Deserialize::deserialize(deserializer)?;
250    if let Some(config) = config {
251        config.parse::<D>().map(Some)
252    } else {
253        Ok(None)
254    }
255}
256
257pub fn deserialize_rustls_server_config<'de, D>(
258    deserializer: D,
259) -> Result<rustls::ServerConfig, D::Error>
260where
261    D: Deserializer<'de>,
262{
263    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
264    config.parse::<D>()
265}
266
267pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
268where
269    D: Deserializer<'de>,
270{
271    match Option::<&str>::deserialize(deserializer)? {
272        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
273        None => Ok(None),
274    }
275}
276
277pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
278    let mut set = HashSet::new();
279    for taskset2 in taskset.split(',') {
280        match taskset2.split_once('-') {
281            Some((start, end)) => {
282                let start: usize = start
283                    .parse()
284                    .map_err(|_error| format!("failed to parse {start:?} from {taskset:?}"))?;
285                let end: usize = end
286                    .parse()
287                    .map_err(|_error| format!("failed to parse {end:?} from {taskset:?}"))?;
288                if start > end {
289                    return Err(format!("invalid interval {taskset2:?} in {taskset:?}"));
290                }
291                for idx in start..=end {
292                    set.insert(idx);
293                }
294            }
295            None => {
296                set.insert(
297                    taskset2.parse().map_err(|_error| {
298                        format!("failed to parse {taskset2:?} from {taskset:?}")
299                    })?,
300                );
301            }
302        }
303    }
304
305    if !set.is_empty() {
306        if let Some(core_ids) = affinity_linux::get_thread_affinity()
307            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
308        {
309            if !set.is_subset(&core_ids) {
310                return Err(format!(
311                    "not allowed core index, should be set of: {core_ids:?}"
312                ));
313            }
314        }
315    }
316
317    let mut vec = set.into_iter().collect::<Vec<usize>>();
318    vec.sort();
319
320    Ok(vec)
321}