richat_shared/
config.rs

1use {
2    crate::five8::{pubkey_decode, signature_decode},
3    base64::{engine::general_purpose::STANDARD as base64_engine, Engine},
4    regex::Regex,
5    rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
6    serde::{
7        de::{self, Deserializer},
8        Deserialize,
9    },
10    solana_sdk::{pubkey::Pubkey, signature::Signature},
11    std::{
12        collections::HashSet,
13        fmt::Display,
14        fs, io,
15        path::PathBuf,
16        str::FromStr,
17        sync::atomic::{AtomicU64, Ordering},
18    },
19    thiserror::Error,
20};
21
22#[derive(Debug, Clone, Default, Deserialize)]
23#[serde(deny_unknown_fields, default)]
24pub struct ConfigTokio {
25    /// Number of worker threads in Tokio runtime
26    pub worker_threads: Option<usize>,
27    /// Threads affinity
28    #[serde(deserialize_with = "deserialize_affinity")]
29    pub affinity: Option<Vec<usize>>,
30}
31
32impl ConfigTokio {
33    pub fn build_runtime<T>(self, thread_name_prefix: T) -> io::Result<tokio::runtime::Runtime>
34    where
35        T: AsRef<str> + Send + Sync + 'static,
36    {
37        let mut builder = tokio::runtime::Builder::new_multi_thread();
38        if let Some(worker_threads) = self.worker_threads {
39            builder.worker_threads(worker_threads);
40        }
41        if let Some(cpus) = self.affinity.clone() {
42            builder.on_thread_start(move || {
43                affinity_linux::set_thread_affinity(cpus.iter().copied())
44                    .expect("failed to set affinity")
45            });
46        }
47        builder
48            .thread_name_fn(move || {
49                static ATOMIC_ID: AtomicU64 = AtomicU64::new(0);
50                let id = ATOMIC_ID.fetch_add(1, Ordering::Relaxed);
51                format!("{}{id:02}", thread_name_prefix.as_ref())
52            })
53            .enable_all()
54            .build()
55    }
56}
57
58#[derive(Deserialize)]
59#[serde(untagged)]
60enum ValueNumStr<'a, T> {
61    Num(T),
62    Str(&'a str),
63}
64
65pub fn deserialize_num_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
66where
67    D: Deserializer<'de>,
68    T: Deserialize<'de> + FromStr,
69    <T as FromStr>::Err: Display,
70{
71    match ValueNumStr::deserialize(deserializer)? {
72        ValueNumStr::Num(value) => Ok(value),
73        ValueNumStr::Str(value) => value
74            .replace('_', "")
75            .parse::<T>()
76            .map_err(de::Error::custom),
77    }
78}
79
80pub fn deserialize_maybe_num_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
81where
82    D: Deserializer<'de>,
83    T: Deserialize<'de> + FromStr,
84    <T as FromStr>::Err: Display,
85{
86    match Option::<ValueNumStr<T>>::deserialize(deserializer)? {
87        Some(ValueNumStr::Num(value)) => Ok(Some(value)),
88        Some(ValueNumStr::Str(value)) => value
89            .replace('_', "")
90            .parse::<T>()
91            .map_err(de::Error::custom)
92            .map(Some),
93        None => Ok(None),
94    }
95}
96
97#[derive(Debug, Error)]
98enum DecodeXTokenError {
99    #[error(transparent)]
100    Base64(#[from] base64::DecodeError),
101    #[error(transparent)]
102    Base58(#[from] bs58::decode::Error),
103}
104
105fn decode_x_token(x_token: &str) -> Result<Vec<u8>, DecodeXTokenError> {
106    Ok(match &x_token[0..7] {
107        "base64:" => base64_engine.decode(x_token)?,
108        "base58:" => bs58::decode(x_token).into_vec()?,
109        _ => x_token.as_bytes().to_vec(),
110    })
111}
112
113pub fn deserialize_maybe_x_token<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
114where
115    D: Deserializer<'de>,
116{
117    let x_token: Option<&str> = Deserialize::deserialize(deserializer)?;
118    x_token
119        .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
120        .transpose()
121}
122
123pub fn deserialize_x_token_set<'de, D>(deserializer: D) -> Result<HashSet<Vec<u8>>, D::Error>
124where
125    D: Deserializer<'de>,
126{
127    Vec::<&str>::deserialize(deserializer).and_then(|vec| {
128        vec.into_iter()
129            .map(|x_token| decode_x_token(x_token).map_err(de::Error::custom))
130            .collect::<Result<_, _>>()
131    })
132}
133
134pub fn deserialize_pubkey_set<'de, D>(deserializer: D) -> Result<HashSet<Pubkey>, D::Error>
135where
136    D: Deserializer<'de>,
137{
138    Vec::<&str>::deserialize(deserializer)?
139        .into_iter()
140        .map(|value| {
141            pubkey_decode(value)
142                .map_err(|error| de::Error::custom(format!("Invalid pubkey: {value} ({error:?})")))
143        })
144        .collect::<Result<_, _>>()
145}
146
147pub fn deserialize_pubkey_vec<'de, D>(deserializer: D) -> Result<Vec<Pubkey>, D::Error>
148where
149    D: Deserializer<'de>,
150{
151    deserialize_pubkey_set(deserializer).map(|set| set.into_iter().collect())
152}
153
154pub fn deserialize_maybe_signature<'de, D>(deserializer: D) -> Result<Option<Signature>, D::Error>
155where
156    D: Deserializer<'de>,
157{
158    let sig: Option<&str> = Deserialize::deserialize(deserializer)?;
159    sig.map(|sig| signature_decode(sig).map_err(de::Error::custom))
160        .transpose()
161}
162
163#[derive(Debug, Deserialize)]
164#[serde(deny_unknown_fields, untagged)]
165enum RustlsServerConfigSignedSelfSigned<'a> {
166    Signed { cert: &'a str, key: &'a str },
167    SelfSigned { self_signed_alt_names: Vec<String> },
168}
169
170impl<'a> RustlsServerConfigSignedSelfSigned<'a> {
171    fn parse<D>(self) -> Result<rustls::ServerConfig, D::Error>
172    where
173        D: Deserializer<'a>,
174    {
175        let (certs, key) = match self {
176            Self::Signed { cert, key } => {
177                let cert_path = PathBuf::from(cert);
178                let cert_bytes = fs::read(&cert_path).map_err(|error| {
179                    de::Error::custom(format!("failed to read cert {cert_path:?}: {error:?}"))
180                })?;
181                let cert_chain = if cert_path.extension().is_some_and(|x| x == "der") {
182                    vec![CertificateDer::from(cert_bytes)]
183                } else {
184                    rustls_pemfile::certs(&mut &*cert_bytes)
185                        .collect::<Result<_, _>>()
186                        .map_err(|error| {
187                            de::Error::custom(format!("invalid PEM-encoded certificate: {error:?}"))
188                        })?
189                };
190
191                let key_path = PathBuf::from(key);
192                let key_bytes = fs::read(&key_path).map_err(|error| {
193                    de::Error::custom(format!("failed to read key {key_path:?}: {error:?}"))
194                })?;
195                let key = if key_path.extension().is_some_and(|x| x == "der") {
196                    PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_bytes))
197                } else {
198                    rustls_pemfile::private_key(&mut &*key_bytes)
199                        .map_err(|error| {
200                            de::Error::custom(format!("malformed PKCS #1 private key: {error:?}"))
201                        })?
202                        .ok_or_else(|| de::Error::custom("no private keys found"))?
203                };
204
205                (cert_chain, key)
206            }
207            Self::SelfSigned {
208                self_signed_alt_names,
209            } => {
210                let cert =
211                    rcgen::generate_simple_self_signed(self_signed_alt_names).map_err(|error| {
212                        de::Error::custom(format!("failed to generate self-signed cert: {error:?}"))
213                    })?;
214                let cert_der = CertificateDer::from(cert.cert);
215                let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
216                (vec![cert_der], priv_key.into())
217            }
218        };
219
220        rustls::ServerConfig::builder()
221            .with_no_client_auth()
222            .with_single_cert(certs, key)
223            .map_err(|error| de::Error::custom(format!("failed to use cert: {error:?}")))
224    }
225}
226
227pub fn deserialize_maybe_rustls_server_config<'de, D>(
228    deserializer: D,
229) -> Result<Option<rustls::ServerConfig>, D::Error>
230where
231    D: Deserializer<'de>,
232{
233    let config: Option<RustlsServerConfigSignedSelfSigned> =
234        Deserialize::deserialize(deserializer)?;
235    if let Some(config) = config {
236        config.parse::<D>().map(Some)
237    } else {
238        Ok(None)
239    }
240}
241
242pub fn deserialize_rustls_server_config<'de, D>(
243    deserializer: D,
244) -> Result<rustls::ServerConfig, D::Error>
245where
246    D: Deserializer<'de>,
247{
248    let config: RustlsServerConfigSignedSelfSigned = Deserialize::deserialize(deserializer)?;
249    config.parse::<D>()
250}
251
252pub fn deserialize_affinity<'de, D>(deserializer: D) -> Result<Option<Vec<usize>>, D::Error>
253where
254    D: Deserializer<'de>,
255{
256    match Option::<&str>::deserialize(deserializer)? {
257        Some(taskset) => parse_taskset(taskset).map(Some).map_err(de::Error::custom),
258        None => Ok(None),
259    }
260}
261
262pub fn parse_taskset(taskset: &str) -> Result<Vec<usize>, String> {
263    let re = Regex::new(r"^(\d+)(?:-(\d+)(?::(\d+))?)?$").expect("valid regex");
264    let mut set = HashSet::new();
265    for cpulist in taskset.split(',') {
266        let Some(caps) = re.captures(cpulist) else {
267            return Err(format!("invalid cpulist: {cpulist}"));
268        };
269
270        let start = caps
271            .get(1)
272            .and_then(|m| m.as_str().parse().ok())
273            .expect("valid regex");
274        let end = caps
275            .get(2)
276            .and_then(|m| m.as_str().parse().ok())
277            .unwrap_or(start);
278        let step = caps
279            .get(3)
280            .and_then(|m| m.as_str().parse().ok())
281            .unwrap_or(1);
282
283        for cpu in (start..=end).step_by(step) {
284            set.insert(cpu);
285        }
286    }
287
288    let mut vec = set.into_iter().collect::<Vec<usize>>();
289    vec.sort();
290
291    if !vec.is_empty() {
292        if let Some(cores) = affinity_linux::get_thread_affinity()
293            .map_err(|error| format!("failed to get allowed cpus: {error:?}"))?
294        {
295            let mut cores = cores.into_iter().collect::<Vec<_>>();
296            cores.sort();
297
298            for core in vec.iter_mut() {
299                if let Some(actual_core) = cores.get(*core) {
300                    *core = *actual_core;
301                } else {
302                    return Err(format!(
303                        "we don't have core {core}, available cores: {:?}",
304                        (0..cores.len()).collect::<Vec<_>>()
305                    ));
306                }
307            }
308        }
309    }
310
311    Ok(vec)
312}