1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::str::FromStr;
6use std::time::Duration;
7
8use clap::Parser;
9use config::{Config, ConfigError, Environment, File, Value, ValueKind};
10use directories::ProjectDirs;
11use libp2p::Multiaddr;
12use serde::Deserialize;
13use sierradb::bucket::{BucketId, PartitionId};
14use sierradb::cache::BLOCK_SIZE;
15use thiserror::Error;
16
17#[derive(Parser, Debug)]
19#[command(version, about, long_about = None)]
20pub struct Args {
21 #[arg(long, short = 'd')]
23 pub dir: Option<String>,
24
25 #[arg(long)]
27 pub client_address: Option<String>,
28
29 #[arg(long)]
31 pub cluster_address: Option<String>,
32
33 #[arg(short = 'c', long)]
35 pub config: Option<PathBuf>,
36
37 #[arg(short = 'l', long)]
39 pub log: Option<String>,
40
41 #[arg(short = 'n', long)]
43 pub node_count: Option<u32>,
44
45 #[arg(short = 'i', long)]
47 pub node_index: Option<u32>,
48
49 #[arg(long)]
51 pub mdns: Option<bool>,
52}
53
54#[derive(Debug, Deserialize)]
55pub struct AppConfig {
56 pub bucket: BucketConfig,
57 pub cache: CacheConfig,
58 pub dir: PathBuf,
59 pub heartbeat: HeartbeatConfig,
60 pub network: NetworkConfig,
61 pub node: NodeConfig,
62 pub partition: PartitionConfig,
63 pub replication: ReplicationConfig,
64 pub segment: SegmentConfig,
65 pub sync: SyncConfig,
66 #[serde(default)]
67 pub threads: Threads,
68
69 pub nodes: Option<Vec<Value>>,
70}
71
72#[derive(Debug, Deserialize)]
73pub struct BucketConfig {
74 pub count: u16,
75 pub ids: Option<Vec<BucketId>>,
76}
77
78#[derive(Debug, Deserialize)]
79pub struct CacheConfig {
80 pub capacity_bytes: usize,
81}
82
83#[derive(Debug, Deserialize)]
84pub struct HeartbeatConfig {
85 pub interval_ms: u64,
86 pub timeout_ms: u64,
87}
88
89#[derive(Debug, Deserialize)]
90pub struct NetworkConfig {
91 pub cluster_enabled: bool,
92 pub cluster_address: Multiaddr, pub client_address: String, pub mdns: bool,
95}
96
97#[derive(Debug, Deserialize)]
98pub struct NodeConfig {
99 pub count: Option<u32>,
100 pub index: u32,
101}
102
103#[derive(Debug, Deserialize)]
104pub struct PartitionConfig {
105 pub count: u16,
106 pub ids: Option<Vec<PartitionId>>,
107}
108
109#[derive(Debug, Deserialize)]
110pub struct ReplicationConfig {
111 pub buffer_size: usize,
112 pub buffer_timeout_ms: u64,
113 pub catchup_timeout_ms: u64,
114 pub factor: u8,
115}
116
117#[derive(Debug, Deserialize)]
118pub struct SegmentConfig {
119 pub size_bytes: usize,
120 pub compression: bool,
121}
122
123#[derive(Debug, Deserialize)]
124pub struct SyncConfig {
125 pub interval_ms: u64,
127 pub idle_interval_ms: Option<u64>,
129 pub max_batch_size: usize,
131 pub min_bytes: usize,
133}
134
135#[derive(Debug, Default, Deserialize)]
136pub struct Threads {
137 pub read: Option<u16>,
138 pub write: Option<u16>,
139}
140
141impl AppConfig {
142 pub fn load(args: Args) -> Result<Self, ConfigError> {
145 let project_dirs = ProjectDirs::from("io", "sierradb", "sierradb");
146
147 let mut builder = Config::builder();
148
149 if let Some(dirs) = &project_dirs {
150 builder = builder.set_default(
157 "dir",
158 dirs.data_dir().join("db").to_string_lossy().into_owned(),
159 )?
160 }
161
162 if let Some(config_path) = args.config {
164 builder = builder.add_source(File::from(config_path));
165 } else {
166 builder = builder.add_source(File::with_name("sierra").required(false));
168 if let Some(dirs) = &project_dirs {
169 builder = builder
178 .add_source(File::from(dirs.config_dir().join("sierra")).required(false));
179 }
180 }
181
182 let overrides = builder.build_cloned()?;
183
184 builder = builder
185 .set_default("bucket.count", 4)?
186 .set_default("cache.capacity_bytes", 256 * 1024 * 1024)?
187 .set_default("heartbeat.interval_ms", 1000)?
188 .set_default("heartbeat.timeout_ms", 6000)?
189 .set_default("network.cluster_enabled", true)?
190 .set_default("network.cluster_address", "/ip4/0.0.0.0/tcp/0")?
191 .set_default("network.client_address", "0.0.0.0:9090")?
192 .set_default("network.mdns", false)?
193 .set_default("partition.count", 32)?
194 .set_default("replication.buffer_size", 1000)?
195 .set_default("replication.buffer_timeout_ms", 8000)?
196 .set_default("replication.catchup_timeout_ms", 2000)?
197 .set_default("segment.size_bytes", 256 * 1024 * 1024)?
198 .set_default("segment.compression", true)?
199 .set_default("sync.interval_ms", 5)?
200 .set_default("sync.max_batch_size", 50)?
201 .set_default("sync.min_bytes", 4096)?;
202
203 {
205 let mut nodes = overrides.get_array("nodes").ok().unwrap_or_default();
206 let nodes_count = if nodes.is_empty() {
207 1
208 } else {
209 nodes.len() as u32
210 };
211 builder = builder.set_default("node.count", nodes_count)?;
212
213 let node_index = args
214 .node_index
215 .or_else(|| overrides.get_int("node.index").map(|n| n as u32).ok());
216 if let Some(node_index) = node_index
217 && (node_index as usize) < nodes.len()
218 {
219 let overrides = nodes.remove(node_index as usize);
220 for (key, value) in flatten_value(overrides) {
221 builder = builder.set_override(key, value)?;
222 }
223 }
224 }
225
226 builder = builder.add_source(Environment::with_prefix("SIERRA"));
228
229 builder = builder
231 .set_override_option("dir", args.dir)?
232 .set_override_option("network.cluster_address", args.cluster_address)?
233 .set_override_option("network.client_address", args.client_address)?
234 .set_override_option("network.mdns", args.mdns)?
235 .set_override_option("node.index", args.node_index)?
236 .set_override_option("node.count", args.node_count)?;
237
238 {
239 let temp_config = builder.build_cloned()?;
241 let node_count = temp_config.get::<u32>("node.count").unwrap_or(1);
242 builder = builder.set_default("replication.factor", node_count.clamp(1, 3))?;
243
244 let node_index_set =
246 args.node_index.is_some() || temp_config.get::<u32>("node.index").is_ok();
247
248 if node_count == 1 && !node_index_set {
249 builder = builder.set_override("node.index", 0)?;
251 } else if node_count > 1 && !node_index_set {
252 return Err(ConfigError::Message(
254 "node.index is required when node.count > 1 (use --node-index or set in config file)".to_string(),
255 ));
256 }
257 }
258
259 let config: AppConfig = builder.build()?.try_deserialize()?;
260
261 Ok(config)
262 }
263
264 pub fn validate(&self) -> Result<Vec<ValidationError>, ConfigError> {
265 let mut errs = Vec::new();
266
267 if self.bucket.count == 0 {
269 errs.push(ValidationError::BucketCountZero);
270 }
271 if let Some(ids) = &self.bucket.ids {
272 if ids.len() != self.bucket.count as usize {
273 errs.push(ValidationError::BucketIdCountMismatch {
274 expected: self.bucket.count,
275 actual: ids.len(),
276 });
277 }
278 let unique_ids: HashSet<_> = ids.iter().collect();
279 if unique_ids.len() != ids.len() {
280 errs.push(ValidationError::DuplicateBucketIds);
281 }
282 }
283
284 if self.heartbeat.interval_ms == 0 {
286 errs.push(ValidationError::HeartbeatIntervalZero);
287 }
288 if self.heartbeat.timeout_ms == 0 {
289 errs.push(ValidationError::HeartbeatTimeoutZero);
290 }
291 if self.heartbeat.timeout_ms <= self.heartbeat.interval_ms {
292 errs.push(ValidationError::HeartbeatTimeoutTooShort {
293 interval: self.heartbeat.interval_ms,
294 timeout: self.heartbeat.timeout_ms,
295 });
296 }
297
298 if SocketAddr::from_str(&self.network.client_address).is_err() {
300 errs.push(ValidationError::InvalidClientAddress {
301 address: self.network.client_address.clone(),
302 });
303 }
304
305 if let Some(count) = self.node.count {
307 if count == 0 {
308 errs.push(ValidationError::NodeCountZero);
309 }
310 if self.node.index >= count {
311 errs.push(ValidationError::NodeIndexOutOfBounds {
312 index: self.node.index,
313 count,
314 });
315 }
316 }
317
318 if !self.network.cluster_enabled {
320 if let Some(count) = self.node.count
322 && count > 1
323 {
324 errs.push(ValidationError::MultipleNodesWithoutCluster { count });
325 }
326 }
327
328 if self.partition.count == 0 {
330 errs.push(ValidationError::PartitionCountZero);
331 }
332 if let Some(ids) = &self.partition.ids {
333 if ids.len() != self.partition.count as usize {
334 errs.push(ValidationError::PartitionIdCountMismatch {
335 expected: self.partition.count,
336 actual: ids.len(),
337 });
338 }
339 let unique_ids: HashSet<_> = ids.iter().collect();
340 if unique_ids.len() != ids.len() {
341 errs.push(ValidationError::DuplicatePartitionIds);
342 }
343 }
344
345 if self.replication.factor == 0 {
347 errs.push(ValidationError::ReplicationFactorZero);
348 }
349 if self.replication.buffer_size == 0 {
350 errs.push(ValidationError::ReplicationBufferSizeZero);
351 }
352 if self.replication.buffer_timeout_ms == 0 {
353 errs.push(ValidationError::ReplicationBufferTimeoutZero);
354 }
355 if self.replication.catchup_timeout_ms == 0 {
356 errs.push(ValidationError::ReplicationCatchupTimeoutZero);
357 }
358
359 let node_count = self.node_count()?;
361 if self.replication.factor as usize > node_count {
362 errs.push(ValidationError::ReplicationFactorExceedsNodeCount {
363 factor: self.replication.factor,
364 node_count,
365 });
366 }
367
368 if self.segment.size_bytes == 0 {
370 errs.push(ValidationError::SegmentSizeZero);
371 }
372 const MIN_SEGMENT_SIZE: usize = BLOCK_SIZE * 2; const MAX_SEGMENT_SIZE: usize = 1024 * 1024 * 1024 * 10; if self.segment.size_bytes < MIN_SEGMENT_SIZE {
375 errs.push(ValidationError::SegmentSizeTooSmall {
376 size: self.segment.size_bytes,
377 min: MIN_SEGMENT_SIZE,
378 });
379 }
380 if self.segment.size_bytes > MAX_SEGMENT_SIZE {
381 errs.push(ValidationError::SegmentSizeTooLarge {
382 size: self.segment.size_bytes,
383 max: MAX_SEGMENT_SIZE,
384 });
385 }
386
387 if let Some(idle_interval_ms) = self.sync.idle_interval_ms
389 && self.sync.interval_ms <= idle_interval_ms
390 {
391 errs.push(ValidationError::SyncIdleIntervalTooSmall {
392 interval_ms: self.sync.interval_ms,
393 idle_interval_ms,
394 });
395 }
396
397 if let Some(read_threads) = self.threads.read {
399 if read_threads == 0 {
400 errs.push(ValidationError::ReadThreadsZero);
401 }
402 const MAX_THREADS: u16 = 1024;
403 if read_threads > MAX_THREADS {
404 errs.push(ValidationError::TooManyReadThreads {
405 count: read_threads,
406 max: MAX_THREADS,
407 });
408 }
409 }
410 if let Some(write_threads) = self.threads.write {
411 if write_threads == 0 {
412 errs.push(ValidationError::WriteThreadsZero);
413 }
414 const MAX_THREADS: u16 = 1024;
415 if write_threads > MAX_THREADS {
416 errs.push(ValidationError::TooManyWriteThreads {
417 count: write_threads,
418 max: MAX_THREADS,
419 });
420 }
421 }
422
423 if (self.partition.count as usize) < node_count {
425 errs.push(ValidationError::TooFewPartitionsForNodes {
426 partitions: self.partition.count,
427 nodes: node_count,
428 });
429 }
430 if self.partition.count < self.bucket.count {
431 errs.push(ValidationError::TooFewPartitionsForBuckets {
432 buckets: self.bucket.count,
433 partitions: self.partition.count,
434 });
435 }
436
437 Ok(errs)
438 }
439
440 pub fn assigned_buckets(&self) -> Result<HashSet<BucketId>, ConfigError> {
441 match &self.bucket.ids {
442 Some(ids) => Ok(ids.iter().copied().collect()),
443 None => {
444 let node_count = self.node_count()?;
445 let effective_replication_factor =
446 (self.replication.factor as usize).min(node_count);
447 let mut assigned = HashSet::new();
448
449 let buckets_per_node = self.bucket.count as usize / node_count;
450 let extra_buckets = self.bucket.count as usize % node_count;
451
452 for replica_offset in 0..effective_replication_factor {
454 let primary_node =
456 (self.node.index as usize + node_count - replica_offset) % node_count;
457
458 let start = primary_node * buckets_per_node + primary_node.min(extra_buckets);
460 let extra = if primary_node < extra_buckets { 1 } else { 0 };
461 let count = buckets_per_node + extra;
462
463 for bucket_id in start..(start + count) {
465 assigned.insert(bucket_id.try_into().unwrap());
466 }
467 }
468
469 Ok(assigned)
470 }
471 }
472 }
473
474 pub fn assigned_partitions(&self, bucket_ids: &HashSet<BucketId>) -> HashSet<PartitionId> {
475 match &self.partition.ids {
476 Some(ids) => ids.iter().copied().collect(),
477 None => (0..self.partition.count)
478 .filter(|p| bucket_ids.contains(&(p % self.bucket.count)))
479 .collect(),
480 }
481 }
482
483 pub fn node_count(&self) -> Result<usize, ConfigError> {
484 self.node
485 .count
486 .map(|node_count| node_count as usize)
487 .or(self.nodes.as_ref().map(|nodes| nodes.len()))
488 .ok_or_else(|| ConfigError::Message("node.count not specified".to_string()))
489 }
490
491 pub fn effective_idle_interval(&self) -> Duration {
492 let active_interval = Duration::from_millis(self.sync.interval_ms);
493 self.sync
494 .idle_interval_ms
495 .map(Duration::from_millis)
496 .unwrap_or_else(|| {
497 (active_interval * 10).min(Duration::from_secs(1))
499 })
500 }
501}
502
503impl fmt::Display for AppConfig {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 writeln!(f, "bucket.count = {}", self.bucket.count)?;
506 match &self.bucket.ids {
507 Some(ids) => writeln!(
508 f,
509 "bucket.ids = [{}]",
510 ids.iter()
511 .map(|id| id.to_string())
512 .collect::<Vec<_>>()
513 .join(", ")
514 )?,
515 None => writeln!(f, "bucket.ids = <none>")?,
516 }
517
518 writeln!(f, "cache.capacity_bytes = {}", self.cache.capacity_bytes)?;
519
520 writeln!(f, "dir = {}", self.dir.to_string_lossy())?;
521
522 writeln!(f, "heartbeat.interval_ms = {}", self.heartbeat.interval_ms)?;
523 writeln!(f, "heartbeat.timeout_ms = {}", self.heartbeat.timeout_ms)?;
524
525 writeln!(
526 f,
527 "network.cluster_enabled = {}",
528 self.network.cluster_enabled
529 )?;
530 writeln!(
531 f,
532 "network.cluster_address = {}",
533 self.network.cluster_address
534 )?;
535 writeln!(
536 f,
537 "network.client_address = {}",
538 self.network.client_address
539 )?;
540 writeln!(f, "network.mdns = {}", self.network.mdns)?;
541
542 match self.node_count() {
543 Ok(count) => writeln!(f, "node.count = {count}")?,
544 Err(_) => writeln!(f, "node.count = <none>")?,
545 }
546 writeln!(f, "node.index = {}", self.node.index)?;
547
548 writeln!(f, "partition.count = {}", self.partition.count)?;
549 match &self.partition.ids {
550 Some(ids) => writeln!(
551 f,
552 "partition.ids = [{}]",
553 ids.iter()
554 .map(|id| id.to_string())
555 .collect::<Vec<_>>()
556 .join(", ")
557 )?,
558 None => writeln!(f, "partition.ids = <none>")?,
559 }
560
561 writeln!(
562 f,
563 "replication.buffer_size = {}",
564 self.replication.buffer_size
565 )?;
566 writeln!(
567 f,
568 "replication.buffer_timeout_ms = {}",
569 self.replication.buffer_timeout_ms
570 )?;
571 writeln!(f, "replication.factor = {}", self.replication.factor)?;
572
573 writeln!(f, "segment.size_bytes = {}", self.segment.size_bytes)?;
574 writeln!(f, "segment.compression = {}", self.segment.compression)?;
575
576 writeln!(f, "sync.interval_ms = {}", self.sync.interval_ms)?;
577 writeln!(
578 f,
579 "sync.idle_interval_ms = {}",
580 self.effective_idle_interval().as_millis()
581 )?;
582 writeln!(f, "sync.max_batch_size = {}", self.sync.max_batch_size)?;
583 writeln!(f, "sync.min_bytes = {}", self.sync.min_bytes)?;
584
585 match self.threads.read {
586 Some(count) => writeln!(f, "threads.read = {count}")?,
587 None => writeln!(f, "threads.read = <none>")?,
588 }
589 match self.threads.write {
590 Some(count) => write!(f, "threads.write = {count}")?,
591 None => write!(f, "threads.write = <none>")?,
592 }
593
594 Ok(())
595 }
596}
597
598fn flatten_value(value: Value) -> HashMap<String, Value> {
599 let mut result = HashMap::new();
600 flatten_value_recursive(value, "", &mut result);
601 result
602}
603
604fn flatten_value_recursive(value: Value, prefix: &str, result: &mut HashMap<String, Value>) {
605 match value.kind {
606 ValueKind::Table(table) => {
607 for (key, val) in table {
608 let new_prefix = if prefix.is_empty() {
609 key
610 } else {
611 format!("{prefix}.{key}")
612 };
613
614 match val.kind {
615 ValueKind::Table(_) => {
616 flatten_value_recursive(val, &new_prefix, result);
617 }
618 _ => {
619 result.insert(new_prefix, val);
620 }
621 }
622 }
623 }
624 _ => {
625 if !prefix.is_empty() {
626 result.insert(prefix.to_string(), value);
627 }
628 }
629 }
630}
631
632#[derive(Clone, Debug, Error)]
633pub enum ValidationError {
634 #[error("bucket count cannot be zero")]
636 BucketCountZero,
637 #[error("bucket ID count mismatch: expected {expected}, got {actual}")]
638 BucketIdCountMismatch { expected: u16, actual: usize },
639 #[error("duplicate bucket IDs found")]
640 DuplicateBucketIds,
641
642 #[error("heartbeat interval cannot be zero")]
644 HeartbeatIntervalZero,
645 #[error("heartbeat timeout cannot be zero")]
646 HeartbeatTimeoutZero,
647 #[error("heartbeat timeout ({timeout}ms) must be greater than interval ({interval}ms)")]
648 HeartbeatTimeoutTooShort { interval: u64, timeout: u64 },
649
650 #[error("invalid client address: {address}")]
652 InvalidClientAddress { address: String },
653
654 #[error("node count cannot be zero")]
656 NodeCountZero,
657 #[error("node index {index} is out of bounds for count {count}")]
658 NodeIndexOutOfBounds { index: u32, count: u32 },
659 #[error("multiple nodes ({count}) configured but cluster is disabled")]
660 MultipleNodesWithoutCluster { count: u32 },
661
662 #[error("partition count cannot be zero")]
664 PartitionCountZero,
665 #[error("partition ID count mismatch: expected {expected}, got {actual}")]
666 PartitionIdCountMismatch { expected: u16, actual: usize },
667 #[error("duplicate partition IDs found")]
668 DuplicatePartitionIds,
669
670 #[error("replication factor cannot be zero")]
672 ReplicationFactorZero,
673 #[error("replication factor {factor} exceeds node count {node_count}")]
674 ReplicationFactorExceedsNodeCount { factor: u8, node_count: usize },
675 #[error("replication buffer size cannot be zero")]
676 ReplicationBufferSizeZero,
677 #[error("replication buffer timeout cannot be zero")]
678 ReplicationBufferTimeoutZero,
679 #[error("replication catchup timeout cannot be zero")]
680 ReplicationCatchupTimeoutZero,
681
682 #[error("segment size cannot be zero")]
684 SegmentSizeZero,
685 #[error("segment size {size} is too small (minimum: {min} bytes)")]
686 SegmentSizeTooSmall { size: usize, min: usize },
687 #[error("segment size {size} is too large (maximum: {max} bytes)")]
688 SegmentSizeTooLarge { size: usize, max: usize },
689
690 #[error(
692 "idle sync interval {idle_interval_ms} cannot be less than sync interval {interval_ms}"
693 )]
694 SyncIdleIntervalTooSmall {
695 interval_ms: u64,
696 idle_interval_ms: u64,
697 },
698
699 #[error("read thread count cannot be zero")]
701 ReadThreadsZero,
702 #[error("write thread count cannot be zero")]
703 WriteThreadsZero,
704 #[error("too many read threads: {count} (maximum: {max})")]
705 TooManyReadThreads { count: u16, max: u16 },
706 #[error("too many write threads: {count} (maximum: {max})")]
707 TooManyWriteThreads { count: u16, max: u16 },
708
709 #[error("too few partitions ({partitions}) for {nodes} nodes")]
711 TooFewPartitionsForNodes { partitions: u16, nodes: usize },
712 #[error("too few partitions ({partitions}) for {buckets} buckets")]
713 TooFewPartitionsForBuckets { buckets: u16, partitions: u16 },
714}