enigma_node_registry/
config.rs

1use std::fs;
2use std::path::Path;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use serde::Deserialize;
6
7use crate::error::{RegistryError, RegistryResult};
8
9#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "lowercase")]
11pub enum ServerMode {
12    Http,
13    Tls,
14}
15
16impl Default for ServerMode {
17    fn default() -> Self {
18        ServerMode::Tls
19    }
20}
21
22#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
23#[serde(deny_unknown_fields)]
24pub struct RegistryConfig {
25    #[serde(default = "default_address")]
26    pub address: String,
27    #[serde(default)]
28    pub mode: ServerMode,
29    #[serde(default)]
30    pub trusted_proxies: Vec<String>,
31    #[serde(default)]
32    pub rate_limit: RateLimitConfig,
33    pub envelope: EnvelopeConfig,
34    #[serde(default)]
35    pub tls: Option<TlsConfig>,
36    #[serde(default)]
37    pub storage: StorageConfig,
38    #[serde(default)]
39    pub presence: PresenceConfig,
40    #[serde(default)]
41    pub pow: PowConfig,
42    #[serde(default = "default_allow_sync")]
43    pub allow_sync: bool,
44    #[serde(default = "default_max_nodes")]
45    pub max_nodes: usize,
46}
47
48impl RegistryConfig {
49    pub fn load_from_path(path: impl AsRef<Path>) -> RegistryResult<Self> {
50        let content = fs::read_to_string(path)?;
51        let parsed: RegistryConfig = toml::from_str(&content)
52            .map_err(|err| RegistryError::Config(format!("failed to parse config: {}", err)))?;
53        parsed.validate()?;
54        Ok(parsed)
55    }
56
57    pub fn validate(&self) -> RegistryResult<()> {
58        match self.mode {
59            ServerMode::Http => {
60                if !cfg!(feature = "http") {
61                    return Err(RegistryError::FeatureDisabled("http".to_string()));
62                }
63            }
64            ServerMode::Tls => {
65                if !cfg!(feature = "tls") {
66                    return Err(RegistryError::FeatureDisabled("tls".to_string()));
67                }
68                if self.tls.is_none() {
69                    return Err(RegistryError::Config(
70                        "tls configuration is required for tls mode".to_string(),
71                    ));
72                }
73            }
74        }
75        if self.address.trim().is_empty() {
76            return Err(RegistryError::Config("address cannot be empty".to_string()));
77        }
78        self.rate_limit.validate()?;
79        self.envelope.validate()?;
80        if let Some(tls) = &self.tls {
81            tls.validate(self.mode.clone())?;
82        }
83        self.storage.validate()?;
84        self.presence.validate()?;
85        self.pow.validate()?;
86        Ok(())
87    }
88
89    pub fn pepper_bytes(&self) -> [u8; 32] {
90        self.envelope.pepper_bytes()
91    }
92}
93
94fn default_address() -> String {
95    "0.0.0.0:8443".to_string()
96}
97
98fn default_allow_sync() -> bool {
99    true
100}
101
102fn default_max_nodes() -> usize {
103    2048
104}
105
106#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
107#[serde(deny_unknown_fields)]
108pub struct RateLimitConfig {
109    #[serde(default = "default_enabled")]
110    pub enabled: bool,
111    #[serde(default = "default_per_ip_rps")]
112    pub per_ip_rps: u32,
113    #[serde(default = "default_burst")]
114    pub burst: u32,
115    #[serde(default = "default_ban_seconds")]
116    pub ban_seconds: u64,
117    #[serde(default)]
118    pub endpoints: RateLimitEndpoints,
119}
120
121impl RateLimitConfig {
122    pub fn validate(&self) -> RegistryResult<()> {
123        if self.per_ip_rps == 0 {
124            return Err(RegistryError::Config(
125                "per_ip_rps must be positive".to_string(),
126            ));
127        }
128        if self.burst == 0 {
129            return Err(RegistryError::Config("burst must be positive".to_string()));
130        }
131        self.endpoints.validate()
132    }
133}
134
135impl Default for RateLimitConfig {
136    fn default() -> Self {
137        RateLimitConfig {
138            enabled: true,
139            per_ip_rps: 5,
140            burst: 10,
141            ban_seconds: 300,
142            endpoints: RateLimitEndpoints::default(),
143        }
144    }
145}
146
147fn default_enabled() -> bool {
148    true
149}
150
151fn default_per_ip_rps() -> u32 {
152    5
153}
154
155fn default_burst() -> u32 {
156    10
157}
158
159fn default_ban_seconds() -> u64 {
160    300
161}
162
163#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
164#[serde(deny_unknown_fields)]
165pub struct RateLimitEndpoints {
166    #[serde(default = "default_register_rps")]
167    pub register_rps: u32,
168    #[serde(default = "default_resolve_rps")]
169    pub resolve_rps: u32,
170    #[serde(default = "default_check_user_rps")]
171    pub check_user_rps: u32,
172}
173
174impl RateLimitEndpoints {
175    pub fn validate(&self) -> RegistryResult<()> {
176        if self.register_rps == 0 {
177            return Err(RegistryError::Config(
178                "register_rps must be positive".to_string(),
179            ));
180        }
181        if self.resolve_rps == 0 {
182            return Err(RegistryError::Config(
183                "resolve_rps must be positive".to_string(),
184            ));
185        }
186        if self.check_user_rps == 0 {
187            return Err(RegistryError::Config(
188                "check_user_rps must be positive".to_string(),
189            ));
190        }
191        Ok(())
192    }
193}
194
195impl Default for RateLimitEndpoints {
196    fn default() -> Self {
197        RateLimitEndpoints {
198            register_rps: 1,
199            resolve_rps: 3,
200            check_user_rps: 10,
201        }
202    }
203}
204
205fn default_register_rps() -> u32 {
206    1
207}
208
209fn default_resolve_rps() -> u32 {
210    3
211}
212
213fn default_check_user_rps() -> u32 {
214    10
215}
216
217#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
218#[serde(deny_unknown_fields)]
219pub struct EnvelopeConfig {
220    pub pepper_hex: String,
221    pub keys: Vec<EnvelopeKeyConfig>,
222}
223
224impl EnvelopeConfig {
225    pub fn validate(&self) -> RegistryResult<()> {
226        if self.keys.is_empty() {
227            return Err(RegistryError::Config(
228                "at least one envelope key required".to_string(),
229            ));
230        }
231        if self.pepper_hex.len() != 64 {
232            return Err(RegistryError::Config(
233                "pepper_hex must be 32 bytes hex".to_string(),
234            ));
235        }
236        if hex::decode(&self.pepper_hex)
237            .map_err(|_| RegistryError::Config("invalid pepper_hex".to_string()))?
238            .len()
239            != 32
240        {
241            return Err(RegistryError::Config(
242                "pepper_hex must decode to 32 bytes".to_string(),
243            ));
244        }
245        let mut seen = std::collections::HashSet::new();
246        let mut active = 0usize;
247        for key in &self.keys {
248            key.validate()?;
249            if !seen.insert(key.kid_hex.clone()) {
250                return Err(RegistryError::Config("duplicate kid_hex".to_string()));
251            }
252            if key.active {
253                active = active.saturating_add(1);
254            }
255        }
256        if active == 0 {
257            return Err(RegistryError::Config(
258                "one active envelope key required".to_string(),
259            ));
260        }
261        Ok(())
262    }
263
264    pub fn pepper_bytes(&self) -> [u8; 32] {
265        let mut out = [0u8; 32];
266        if let Ok(bytes) = hex::decode(&self.pepper_hex) {
267            let len = bytes.len().min(32);
268            out[..len].copy_from_slice(&bytes[..len]);
269        }
270        out
271    }
272}
273
274#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
275#[serde(deny_unknown_fields)]
276pub struct EnvelopeKeyConfig {
277    pub kid_hex: String,
278    pub x25519_private_key_hex: String,
279    #[serde(default)]
280    pub active: bool,
281    #[serde(default)]
282    pub not_after_epoch_ms: Option<u64>,
283}
284
285impl EnvelopeKeyConfig {
286    pub fn validate(&self) -> RegistryResult<()> {
287        if self.kid_hex.len() != 16 {
288            return Err(RegistryError::Config(
289                "kid_hex must be 8 bytes hex".to_string(),
290            ));
291        }
292        if hex::decode(&self.kid_hex)
293            .map_err(|_| RegistryError::Config("invalid kid_hex".to_string()))?
294            .len()
295            != 8
296        {
297            return Err(RegistryError::Config(
298                "kid_hex must decode to 8 bytes".to_string(),
299            ));
300        }
301        if self.x25519_private_key_hex.len() != 64 {
302            return Err(RegistryError::Config(
303                "x25519_private_key_hex must be 32 bytes hex".to_string(),
304            ));
305        }
306        if hex::decode(&self.x25519_private_key_hex)
307            .map_err(|_| RegistryError::Config("invalid x25519_private_key_hex".to_string()))?
308            .len()
309            != 32
310        {
311            return Err(RegistryError::Config(
312                "x25519_private_key_hex must decode to 32 bytes".to_string(),
313            ));
314        }
315        if let Some(not_after) = self.not_after_epoch_ms {
316            if not_after <= current_time_ms() {
317                return Err(RegistryError::Config(
318                    "not_after_epoch_ms is in the past".to_string(),
319                ));
320            }
321        }
322        Ok(())
323    }
324}
325
326#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
327#[serde(deny_unknown_fields)]
328pub struct TlsConfig {
329    pub cert_pem_path: String,
330    pub key_pem_path: String,
331    #[serde(default)]
332    pub client_ca_pem_path: Option<String>,
333}
334
335impl TlsConfig {
336    pub fn validate(&self, mode: ServerMode) -> RegistryResult<()> {
337        if mode == ServerMode::Tls
338            && (self.cert_pem_path.is_empty() || self.key_pem_path.is_empty())
339        {
340            return Err(RegistryError::Config(
341                "cert_pem_path and key_pem_path are required for tls mode".to_string(),
342            ));
343        }
344        if self.client_ca_pem_path.is_some() && !cfg!(feature = "mtls") {
345            return Err(RegistryError::FeatureDisabled("mtls".to_string()));
346        }
347        Ok(())
348    }
349}
350
351#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Default)]
352#[serde(deny_unknown_fields)]
353pub struct StorageConfig {
354    #[serde(default = "default_storage_kind")]
355    pub kind: String,
356    #[serde(default = "default_storage_path")]
357    pub path: String,
358}
359
360impl StorageConfig {
361    pub fn validate(&self) -> RegistryResult<()> {
362        if self.kind != "sled" && self.kind != "memory" {
363            return Err(RegistryError::Config(
364                "storage.kind must be \"sled\" or \"memory\"".to_string(),
365            ));
366        }
367        if self.kind == "sled" {
368            if self.path.trim().is_empty() {
369                return Err(RegistryError::Config(
370                    "storage.path cannot be empty".to_string(),
371                ));
372            }
373            if !cfg!(feature = "persistence") {
374                return Err(RegistryError::FeatureDisabled("persistence".to_string()));
375            }
376        }
377        Ok(())
378    }
379}
380
381fn default_storage_kind() -> String {
382    "sled".to_string()
383}
384
385fn default_storage_path() -> String {
386    "./registry_db".to_string()
387}
388
389#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
390#[serde(deny_unknown_fields)]
391pub struct PresenceConfig {
392    #[serde(default = "default_ttl_seconds")]
393    pub ttl_seconds: u64,
394    #[serde(default = "default_gc_interval_seconds")]
395    pub gc_interval_seconds: u64,
396}
397
398impl PresenceConfig {
399    pub fn validate(&self) -> RegistryResult<()> {
400        if self.ttl_seconds == 0 {
401            return Err(RegistryError::Config(
402                "ttl_seconds must be positive".to_string(),
403            ));
404        }
405        if self.gc_interval_seconds == 0 {
406            return Err(RegistryError::Config(
407                "gc_interval_seconds must be positive".to_string(),
408            ));
409        }
410        Ok(())
411    }
412}
413
414impl Default for PresenceConfig {
415    fn default() -> Self {
416        PresenceConfig {
417            ttl_seconds: 300,
418            gc_interval_seconds: 60,
419        }
420    }
421}
422
423fn default_ttl_seconds() -> u64 {
424    300
425}
426
427fn default_gc_interval_seconds() -> u64 {
428    60
429}
430
431#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
432#[serde(deny_unknown_fields)]
433pub struct PowConfig {
434    #[serde(default)]
435    pub enabled: bool,
436    #[serde(default = "default_pow_difficulty")]
437    pub difficulty: u8,
438    #[serde(default = "default_pow_ttl_seconds")]
439    pub ttl_seconds: u64,
440}
441
442impl PowConfig {
443    pub fn validate(&self) -> RegistryResult<()> {
444        if !self.enabled {
445            return Ok(());
446        }
447        if self.difficulty == 0 || self.difficulty > 30 {
448            return Err(RegistryError::Config(
449                "difficulty must be between 1 and 30".to_string(),
450            ));
451        }
452        if self.ttl_seconds == 0 {
453            return Err(RegistryError::Config(
454                "ttl_seconds must be positive".to_string(),
455            ));
456        }
457        Ok(())
458    }
459}
460
461impl Default for PowConfig {
462    fn default() -> Self {
463        PowConfig {
464            enabled: false,
465            difficulty: default_pow_difficulty(),
466            ttl_seconds: default_pow_ttl_seconds(),
467        }
468    }
469}
470
471fn default_pow_difficulty() -> u8 {
472    18
473}
474
475fn default_pow_ttl_seconds() -> u64 {
476    120
477}
478
479fn current_time_ms() -> u64 {
480    match SystemTime::now().duration_since(UNIX_EPOCH) {
481        Ok(duration) => duration.as_millis() as u64,
482        Err(_) => 0,
483    }
484}