richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
5    serde::{
6        de::{self, Deserializer},
7        Deserialize,
8    },
9    solana_sdk::{pubkey::Pubkey, signature::Signature},
10    std::{
11        collections::HashSet,
12        fmt::Display,
13        fs, io,
14        net::{IpAddr, Ipv4Addr, SocketAddr},
15        path::PathBuf,
16        str::FromStr,
17        sync::atomic::{AtomicU64, Ordering},
18    },
19    thiserror::Error,
20};
21
22#[derive(Debug, Clone, Default, Deserialize)]
23#[serde(deny_unknown_fields, default)]
24pub struct ConfigTokio {
25    /// Number of worker threads in Tokio runtime
26    pub worker_threads: Option<usize>,
27    /// Threads affinity
28    #[serde(deserialize_with = "deserialize_affinity")]
29    pub affinity: Option<Vec<usize>>,
30}
31
32impl ConfigTokio {
33    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
34    where
35        T: AsRef<str> + Send + Sync + 'static,
36    {
37        let mut builder = tokio::runtime::Builder::new_multi_thread();
38        if let Some(worker_threads) = self.worker_threads {
39            builder.worker_threads(worker_threads);
40        }
41        if let Some(cpus) = self.affinity.clone() {
42            builder.on_thread_start(move || {
43                affinity::set_thread_affinity(&cpus).expect("failed to set affinity")
44            });
45        }
46        builder
47            .thread_name_fn(move || {
48                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
49                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
50                format!("{}{id:02}", thread_name_prefix.as_ref())
51            })
52            .enable_all()
53            .build()
54    }
55}
56
57#[derive(Debug, Clone, Copy, Deserialize)]
58#[serde(deny_unknown_fields, default)]
59pub struct ConfigMetrics {
60    /// Endpoint of Prometheus service
61    pub endpoint: SocketAddr,
62}
63
64impl Default for ConfigMetrics {
65    fn default() -> Self {
66        Self {
67            endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10123),
68        }
69    }
70}
71
72#[derive(Deserialize)]
73#[serde(untagged)]
74enum ValueNumStr<'a, T> {
75    Num(T),
76    Str(&'a str),
77}
78
79pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
80where
81    D: Deserializer<'de>,
82    T: Deserialize<'de> + FromStr,
83    <T as FromStr>::Err: Display,
84{
85    match ValueNumStr::deserialize(deserializer)? {
86        ValueNumStr::Num(value) => Ok(value),
87        ValueNumStr::Str(value) => value
88            .replace('_', "")
89            .parse::<T>()
90            .map_err(de::Error::custom),
91    }
92}
93
94pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
95where
96    D: Deserializer<'de>,
97    T: Deserialize<'de> + FromStr,
98    <T as FromStr>::Err: Display,
99{
100    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
101        Some(ValueNumStr::Num(value)) => Ok(Some(value)),
102        Some(ValueNumStr::Str(value)) => value
103            .replace('_', "")
104            .parse::<T>()
105            .map_err(de::Error::custom)
106            .map(Some),
107        None => Ok(None),
108    }
109}
110
111#[derive(Debug, Error)]
112enum DecodeXTokenError {
113    #[error(transparent)]
114    Base64(#[from] base64::DecodeError),
115    #[error(transparent)]
116    Base58(#[from] bs58::decode::Error),
117}
118
119fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
120    Ok(match &x_token[0..7] {
121        "base64:" => base64_engine.decode(x_token)?,
122        "base58:" => bs58::decode(x_token).into_vec()?,
123        _ => x_token.as_bytes().to_vec(),
124    })
125}
126
127pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
128where
129    D: Deserializer<'de>,
130{
131    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
132    x_token
133        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
134        .transpose()
135}
136
137pub fn deserialize_x_token_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
138where
139    D: Deserializer<'de>,
140{
141    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
142        vec.into_iter()
143            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
144            .collect::<Result<_, _>>()
145    })
146}
147
148pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
149where
150    D: Deserializer<'de>,
151{
152    Vec::<&str>::deserialize(deserializer)?
153        .into_iter()
154        .map(|value| {
155            pubkey_decode(value)
156                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
157        })
158        .collect::<Result<_, _>>()
159}
160
161pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
162where
163    D: Deserializer<'de>,
164{
165    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
166}
167
168pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
169where
170    D: Deserializer<'de>,
171{
172    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
173    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
174        .transpose()
175}
176
177#[derive(Debug, Deserialize)]
178#[serde(deny_unknown_fields, untagged)]
179enum RustlsServerConfigSignedSelfSigned<'a> {
180    Signed { cert: &'a str, key: &'a str },
181    SelfSigned { self_signed_alt_names: Vec<String> },
182}
183
184impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
185    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
186    where
187        D: Deserializer<'a>,
188    {
189        let (certs, key) = match self {
190            Self::Signed { cert, key } => {
191                let cert_path = PathBuf::from(cert);
192                let cert_bytes = fs::read(&cert_path).map_err(|error| {
193                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
194                })?;
195                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
196                    vec![CertificateDer::from(cert_bytes)]
197                } else {
198                    rustls_pemfile::certs(&mut &*cert_bytes)
199                        .collect::<Result<_, _>>()
200                        .map_err(|error| {
201                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
202                        })?
203                };
204
205                let key_path = PathBuf::from(key);
206                let key_bytes = fs::read(&key_path).map_err(|error| {
207                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
208                })?;
209                let key = if key_path.extension().is_some_and(|x| x == "der") {
210                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
211                } else {
212                    rustls_pemfile::private_key(&mut &*key_bytes)
213                        .map_err(|error| {
214                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
215                        })?
216                        .ok_or_else(|| de::Error::custom("no private keys found"))?
217                };
218
219                (cert_chain, key)
220            }
221            Self::SelfSigned {
222                self_signed_alt_names,
223            } => {
224                let cert =
225                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
226                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
227                    })?;
228                let cert_der = CertificateDer::from(cert.cert);
229                let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
230                (vec![cert_der], priv_key.into())
231            }
232        };
233
234        rustls::ServerConfig::builder()
235            .with_no_client_auth()
236            .with_single_cert(certs, key)
237            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
238    }
239}
240
241pub fn deserialize_maybe_rustls_server_config<'de, D>(
242    deserializer: D,
243) -> Result<Option<rustls::ServerConfig>, D::Error>
244where
245    D: Deserializer<'de>,
246{
247    let config: Option<RustlsServerConfigSignedSelfSigned> =
248        Deserialize::deserialize(deserializer)?;
249    if let Some(config) = config {
250        config.parse::<D>().map(Some)
251    } else {
252        Ok(None)
253    }
254}
255
256pub fn deserialize_rustls_server_config<'de, D>(
257    deserializer: D,
258) -> Result<rustls::ServerConfig, D::Error>
259where
260    D: Deserializer<'de>,
261{
262    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
263    config.parse::<D>()
264}
265
266pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
267where
268    D: Deserializer<'de>,
269{
270    match Option::<&str>::deserialize(deserializer)? {
271        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
272        None => Ok(None),
273    }
274}
275
276pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
277    let mut set = HashSet::new();
278    for taskset2 in taskset.split(',') {
279        match taskset2.split_once('-') {
280            Some((start, end)) => {
281                let start: usize = start
282                    .parse()
283                    .map_err(|_error| format!("failed to parse {start:?} from {taskset:?}"))?;
284                let end: usize = end
285                    .parse()
286                    .map_err(|_error| format!("failed to parse {end:?} from {taskset:?}"))?;
287                if start > end {
288                    return Err(format!("invalid interval {taskset2:?} in {taskset:?}"));
289                }
290                for idx in start..=end {
291                    set.insert(idx);
292                }
293            }
294            None => {
295                set.insert(
296                    taskset2.parse().map_err(|_error| {
297                        format!("failed to parse {taskset2:?} from {taskset:?}")
298                    })?,
299                );
300            }
301        }
302    }
303
304    let mut vec = set.into_iter().collect::<Vec<usize>>();
305    vec.sort();
306
307    if let Some(set_max_index) = vec.last().copied() {
308        let max_index = affinity::get_thread_affinity()
309            .map_err(|_err| "failed to get affinity".to_owned())?
310            .into_iter()
311            .max()
312            .unwrap_or(0);
313
314        if set_max_index > max_index {
315            return Err(format!("core index must be in the range [0, {max_index}]"));
316        }
317    }
318
319    Ok(vec)
320}