1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
4use std::path::{Path, PathBuf};
5
6use ipnet::{Ipv4Net, Ipv6Net};
7use secrecy::{ExposeSecret, SecretString};
8use serde::{Deserialize, Serialize};
9
10use crate::{ConfigError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ServerConfig {
15 pub server: ServerSettings,
17 pub network: NetworkSettings,
19 pub security: SecuritySettings,
21 #[serde(default)]
23 pub oauth: Option<OAuthSettings>,
24 #[serde(default)]
26 pub logging: LoggingSettings,
27 #[serde(default)]
29 pub admin: AdminSettings,
30 #[serde(default)]
32 pub audit: AuditSettings,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ServerSettings {
38 #[serde(default = "default_listen_addr")]
40 pub listen_addr: SocketAddr,
41 #[serde(default)]
43 pub tcp_listen_addr: Option<SocketAddr>,
44 pub public_host: String,
46 #[serde(default = "default_protocol")]
48 pub protocol: String,
49 #[serde(default = "default_max_clients")]
51 pub max_clients: u32,
52 #[serde(default = "default_data_dir")]
54 pub data_dir: PathBuf,
55}
56
57fn default_listen_addr() -> SocketAddr {
58 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 1194)
59}
60
61fn default_protocol() -> String {
62 "udp".to_string()
63}
64
65fn default_max_clients() -> u32 {
66 1000
67}
68
69fn default_data_dir() -> PathBuf {
70 PathBuf::from("/var/lib/corevpn")
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct NetworkSettings {
76 #[serde(default = "default_subnet")]
78 pub subnet: String,
79 #[serde(default)]
81 pub subnet_v6: Option<String>,
82 #[serde(default = "default_dns")]
84 pub dns: Vec<String>,
85 #[serde(default)]
87 pub dns_search: Vec<String>,
88 #[serde(default)]
90 pub push_routes: Vec<String>,
91 #[serde(default = "default_redirect_gateway")]
93 pub redirect_gateway: bool,
94 #[serde(default = "default_mtu")]
96 pub mtu: u16,
97}
98
99fn default_subnet() -> String {
100 "10.8.0.0/24".to_string()
101}
102
103fn default_dns() -> Vec<String> {
104 vec!["1.1.1.1".to_string(), "1.0.0.1".to_string()]
105}
106
107fn default_redirect_gateway() -> bool {
108 true
109}
110
111fn default_mtu() -> u16 {
112 1420
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SecuritySettings {
118 #[serde(default = "default_cipher")]
120 pub cipher: String,
121 #[serde(default = "default_tls_version")]
123 pub tls_min_version: String,
124 #[serde(default = "default_true")]
126 pub tls_auth: bool,
127 #[serde(default)]
129 pub tls_crypt: bool,
130 #[serde(default = "default_cert_lifetime")]
132 pub cert_lifetime_days: u32,
133 #[serde(default = "default_client_cert_lifetime")]
135 pub client_cert_lifetime_days: u32,
136 #[serde(default = "default_reneg_sec")]
138 pub reneg_sec: u32,
139 #[serde(default = "default_true")]
141 pub pfs: bool,
142}
143
144fn default_cipher() -> String {
145 "chacha20-poly1305".to_string()
146}
147
148fn default_tls_version() -> String {
149 "1.3".to_string()
150}
151
152fn default_true() -> bool {
153 true
154}
155
156fn default_cert_lifetime() -> u32 {
157 3650 }
159
160fn default_client_cert_lifetime() -> u32 {
161 90 }
163
164fn default_reneg_sec() -> u32 {
165 3600 }
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct OAuthSettings {
171 #[serde(default)]
173 pub enabled: bool,
174 pub provider: String,
176 pub client_id: String,
178 #[serde(skip_serializing, deserialize_with = "deserialize_secret_string")]
180 pub client_secret: SecretString,
181 #[serde(default)]
183 pub issuer_url: Option<String>,
184 #[serde(default)]
186 pub tenant_id: Option<String>,
187 #[serde(default)]
189 pub domain: Option<String>,
190 #[serde(default)]
192 pub allowed_domains: Vec<String>,
193 #[serde(default)]
195 pub required_groups: Vec<String>,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
200#[serde(rename_all = "snake_case")]
201pub enum ConnectionLogMode {
202 None,
205 #[default]
208 Memory,
209 File,
211 Database,
213 Both,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize, Default)]
219pub struct ConnectionLogEvents {
220 #[serde(default)]
222 pub attempts: bool,
223 #[serde(default = "default_true")]
225 pub connects: bool,
226 #[serde(default = "default_true")]
228 pub disconnects: bool,
229 #[serde(default)]
231 pub auth_events: bool,
232 #[serde(default)]
234 pub transfer_stats: bool,
235 #[serde(default)]
237 pub ip_changes: bool,
238 #[serde(default)]
240 pub renegotiations: bool,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize, Default)]
245pub struct ConnectionLogAnonymization {
246 #[serde(default)]
249 pub hash_client_ips: bool,
250 #[serde(default)]
252 pub truncate_client_ips: bool,
253 #[serde(default)]
255 pub hash_usernames: bool,
256 #[serde(default)]
258 pub round_timestamps: bool,
259 #[serde(default)]
261 pub aggregate_transfer_stats: bool,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ConnectionLogRetention {
267 #[serde(default = "default_retention_days")]
269 pub days: u32,
270 #[serde(default = "default_max_log_size_mb")]
272 pub max_file_size_mb: u32,
273 #[serde(default = "default_max_log_files")]
275 pub max_files: u32,
276 #[serde(default = "default_true")]
278 pub auto_purge: bool,
279 #[serde(default)]
281 pub secure_delete: bool,
282}
283
284fn default_retention_days() -> u32 {
285 7 }
287
288fn default_max_log_size_mb() -> u32 {
289 100 }
291
292fn default_max_log_files() -> u32 {
293 5
294}
295
296impl Default for ConnectionLogRetention {
297 fn default() -> Self {
298 Self {
299 days: default_retention_days(),
300 max_file_size_mb: default_max_log_size_mb(),
301 max_files: default_max_log_files(),
302 auto_purge: true,
303 secure_delete: false,
304 }
305 }
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct LoggingSettings {
311 #[serde(default = "default_log_level")]
313 pub level: String,
314 #[serde(default = "default_log_format")]
316 pub format: String,
317 #[serde(default)]
319 pub file: Option<PathBuf>,
320
321 #[serde(default)]
326 pub connection_mode: ConnectionLogMode,
327 #[serde(default)]
329 pub connection_log_file: Option<PathBuf>,
330 #[serde(default)]
332 pub connection_log_db: Option<PathBuf>,
333 #[serde(default)]
335 pub connection_events: ConnectionLogEvents,
336 #[serde(default)]
338 pub anonymization: ConnectionLogAnonymization,
339 #[serde(default)]
341 pub retention: ConnectionLogRetention,
342
343 #[serde(default = "default_true")]
346 #[deprecated(note = "Use connection_mode instead")]
347 pub log_connections: bool,
348}
349
350impl Default for LoggingSettings {
351 fn default() -> Self {
352 #[allow(deprecated)]
353 Self {
354 level: default_log_level(),
355 format: default_log_format(),
356 file: None,
357 connection_mode: ConnectionLogMode::default(),
358 connection_log_file: None,
359 connection_log_db: None,
360 connection_events: ConnectionLogEvents::default(),
361 anonymization: ConnectionLogAnonymization::default(),
362 retention: ConnectionLogRetention::default(),
363 log_connections: true,
364 }
365 }
366}
367
368fn default_log_level() -> String {
369 "info".to_string()
370}
371
372fn default_log_format() -> String {
373 "pretty".to_string()
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct AdminSettings {
379 #[serde(default)]
381 pub enabled: bool,
382 #[serde(default = "default_admin_addr")]
384 pub listen_addr: SocketAddr,
385 #[serde(skip_serializing, deserialize_with = "deserialize_optional_secret_string")]
387 pub api_key: Option<SecretString>,
388 #[serde(default)]
390 pub allowed_ips: Vec<String>,
391}
392
393fn default_admin_addr() -> SocketAddr {
394 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8443)
395}
396
397impl Default for AdminSettings {
398 fn default() -> Self {
399 Self {
400 enabled: false,
401 listen_addr: default_admin_addr(),
402 api_key: None,
403 allowed_ips: vec!["127.0.0.1".to_string()],
404 }
405 }
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct AuditSettings {
411 #[serde(default)]
413 pub enabled: bool,
414
415 #[serde(default = "default_audit_buffer")]
417 pub buffer_size: usize,
418
419 #[serde(default = "default_true")]
421 pub include_source_ip: bool,
422
423 #[serde(default = "default_true")]
425 pub include_user_identity: bool,
426
427 #[serde(default)]
429 pub hash_sensitive_fields: bool,
430
431 #[serde(default)]
433 pub sinks: Vec<AuditSinkConfig>,
434}
435
436fn default_audit_buffer() -> usize { 10000 }
437
438impl Default for AuditSettings {
439 fn default() -> Self {
440 Self {
441 enabled: false,
442 buffer_size: 10000,
443 include_source_ip: true,
444 include_user_identity: true,
445 hash_sensitive_fields: false,
446 sinks: Vec::new(),
447 }
448 }
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
453#[serde(tag = "type", rename_all = "snake_case")]
454pub enum AuditSinkConfig {
455 AwsCloudwatch {
457 region: String,
458 log_group: String,
459 #[serde(default = "default_cloudwatch_stream")]
460 log_stream: String,
461 #[serde(default)]
462 profile: Option<String>,
463 #[serde(default)]
464 role_arn: Option<String>,
465 },
466
467 AwsS3 {
469 region: String,
470 bucket: String,
471 #[serde(default = "default_s3_prefix")]
472 prefix: String,
473 #[serde(default)]
474 profile: Option<String>,
475 },
476
477 AwsSecurityHub {
479 region: String,
480 account_id: String,
481 #[serde(default)]
482 profile: Option<String>,
483 },
484
485 AwsEventBridge {
487 region: String,
488 #[serde(default = "default_event_bus")]
489 event_bus: String,
490 #[serde(default)]
491 profile: Option<String>,
492 },
493
494 AzureMonitor {
496 workspace_id: String,
497 shared_key: String,
498 #[serde(default = "default_azure_log_type")]
499 log_type: String,
500 },
501
502 AzureEventHub {
504 namespace: String,
505 event_hub: String,
506 policy_name: String,
507 policy_key: String,
508 },
509
510 AzureSentinel {
512 workspace_id: String,
513 shared_key: String,
514 #[serde(default = "default_sentinel_log_type")]
515 log_type: String,
516 },
517
518 OracleLogging {
520 region: String,
521 log_id: String,
522 tenancy_id: String,
523 user_id: String,
524 fingerprint: String,
525 private_key: String,
526 },
527
528 OracleStreaming {
530 region: String,
531 stream_id: String,
532 tenancy_id: String,
533 user_id: String,
534 fingerprint: String,
535 private_key: String,
536 },
537
538 Elasticsearch {
540 urls: Vec<String>,
541 #[serde(default = "default_es_index")]
542 index: String,
543 #[serde(default)]
544 username: Option<String>,
545 #[serde(default)]
546 password: Option<String>,
547 #[serde(default)]
548 api_key: Option<String>,
549 },
550
551 Splunk {
553 url: String,
554 token: String,
555 #[serde(default = "default_splunk_sourcetype")]
556 sourcetype: String,
557 #[serde(default)]
558 index: Option<String>,
559 },
560
561 Kafka {
563 brokers: Vec<String>,
564 topic: String,
565 #[serde(default)]
566 sasl_username: Option<String>,
567 #[serde(default)]
568 sasl_password: Option<String>,
569 },
570
571 Syslog {
573 address: String,
574 #[serde(default = "default_syslog_port")]
575 port: u16,
576 #[serde(default = "default_syslog_protocol")]
577 protocol: String,
578 #[serde(default)]
579 use_cef: bool,
580 #[serde(default)]
581 use_leef: bool,
582 },
583
584 Webhook {
586 url: String,
587 #[serde(default)]
588 headers: std::collections::HashMap<String, String>,
589 #[serde(default)]
590 bearer_token: Option<String>,
591 #[serde(default)]
592 api_key_header: Option<String>,
593 #[serde(default)]
594 api_key_value: Option<String>,
595 },
596
597 File {
599 path: String,
600 #[serde(default = "default_audit_format")]
601 format: String,
602 #[serde(default = "default_max_size")]
603 max_size_mb: u64,
604 #[serde(default = "default_max_files")]
605 max_files: u32,
606 },
607}
608
609fn default_cloudwatch_stream() -> String { "corevpn-audit-{date}".to_string() }
610fn default_s3_prefix() -> String { "audit-logs/{date}/{hour}/".to_string() }
611fn default_event_bus() -> String { "default".to_string() }
612fn default_azure_log_type() -> String { "CoreVPNAudit".to_string() }
613fn default_sentinel_log_type() -> String { "CoreVPNSecurity".to_string() }
614fn default_es_index() -> String { "corevpn-audit-{date}".to_string() }
615fn default_splunk_sourcetype() -> String { "corevpn:audit".to_string() }
616fn default_syslog_port() -> u16 { 514 }
617fn default_syslog_protocol() -> String { "udp".to_string() }
618fn default_audit_format() -> String { "json".to_string() }
619fn default_max_size() -> u64 { 100 }
620fn default_max_files() -> u32 { 10 }
621
622impl ServerConfig {
623 pub fn default_config(public_host: &str) -> Self {
625 Self {
626 server: ServerSettings {
627 listen_addr: default_listen_addr(),
628 tcp_listen_addr: None,
629 public_host: public_host.to_string(),
630 protocol: default_protocol(),
631 max_clients: default_max_clients(),
632 data_dir: default_data_dir(),
633 },
634 network: NetworkSettings {
635 subnet: default_subnet(),
636 subnet_v6: None,
637 dns: default_dns(),
638 dns_search: vec![],
639 push_routes: vec![],
640 redirect_gateway: default_redirect_gateway(),
641 mtu: default_mtu(),
642 },
643 security: SecuritySettings {
644 cipher: default_cipher(),
645 tls_min_version: default_tls_version(),
646 tls_auth: true,
647 tls_crypt: false,
648 cert_lifetime_days: default_cert_lifetime(),
649 client_cert_lifetime_days: default_client_cert_lifetime(),
650 reneg_sec: default_reneg_sec(),
651 pfs: true,
652 },
653 oauth: None,
654 logging: LoggingSettings::default(),
655 admin: AdminSettings::default(),
656 audit: AuditSettings::default(),
657 }
658 }
659
660 pub fn load(path: &Path) -> Result<Self> {
662 let content = std::fs::read_to_string(path)?;
663 let config: Self = toml::from_str(&content)?;
664 config.validate()?;
665 Ok(config)
666 }
667
668 pub fn save(&self, path: &Path) -> Result<()> {
670 let content = toml::to_string_pretty(self)?;
671 std::fs::write(path, content)?;
672 Ok(())
673 }
674
675 fn validate_hostname_or_ip(host: &str) -> Result<()> {
677 if host.parse::<IpAddr>().is_ok() {
679 return Ok(());
680 }
681
682 if host.is_empty() {
685 return Err(ConfigError::ValidationError("Hostname cannot be empty".into()));
686 }
687
688 if host.len() > 253 {
689 return Err(ConfigError::ValidationError(
690 "Hostname exceeds maximum length (253 characters)".into(),
691 ));
692 }
693
694 if !host.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '.') {
697 return Err(ConfigError::ValidationError(
698 "Hostname contains invalid characters".into(),
699 ));
700 }
701
702 if host.starts_with('.') || host.ends_with('.') {
703 return Err(ConfigError::ValidationError(
704 "Hostname cannot start or end with a dot".into(),
705 ));
706 }
707
708 if host.starts_with('-') || host.ends_with('-') {
709 return Err(ConfigError::ValidationError(
710 "Hostname cannot start or end with a hyphen".into(),
711 ));
712 }
713
714 if host.contains("..") {
716 return Err(ConfigError::ValidationError(
717 "Hostname cannot contain consecutive dots".into(),
718 ));
719 }
720
721 for label in host.split('.') {
723 if label.is_empty() {
724 return Err(ConfigError::ValidationError(
725 "Hostname labels cannot be empty".into(),
726 ));
727 }
728 if label.len() > 63 {
729 return Err(ConfigError::ValidationError(format!(
730 "Hostname label '{}' exceeds maximum length (63 characters)",
731 label
732 )));
733 }
734 }
735
736 Ok(())
737 }
738
739 pub fn validate(&self) -> Result<()> {
741 if self.server.public_host.is_empty() {
742 return Err(ConfigError::MissingField("server.public_host".into()));
743 }
744
745 Self::validate_hostname_or_ip(&self.server.public_host)?;
747
748 self.network.subnet.parse::<Ipv4Net>()
750 .map_err(|e| ConfigError::ValidationError(format!("invalid subnet: {}", e)))?;
751
752 if let Some(ref subnet_v6) = self.network.subnet_v6 {
754 subnet_v6.parse::<ipnet::Ipv6Net>()
755 .map_err(|e| ConfigError::ValidationError(format!("invalid IPv6 subnet: {}", e)))?;
756 }
757
758 for (idx, dns) in self.network.dns.iter().enumerate() {
760 dns.parse::<IpAddr>()
761 .map_err(|e| ConfigError::ValidationError(format!(
762 "invalid DNS server #{} '{}': {}",
763 idx + 1, dns, e
764 )))?;
765 }
766
767 for (idx, domain) in self.network.dns_search.iter().enumerate() {
769 if domain.is_empty() {
770 return Err(ConfigError::ValidationError(format!(
771 "DNS search domain #{} cannot be empty",
772 idx + 1
773 )));
774 }
775 if domain.len() > 253 {
777 return Err(ConfigError::ValidationError(format!(
778 "DNS search domain #{} exceeds maximum length (253 characters)",
779 idx + 1
780 )));
781 }
782 }
783
784 for (idx, route) in self.network.push_routes.iter().enumerate() {
786 let is_valid_ipv4 = route.parse::<Ipv4Net>().is_ok();
787 let is_valid_ipv6 = route.parse::<Ipv6Net>().is_ok();
788 if !is_valid_ipv4 && !is_valid_ipv6 {
789 return Err(ConfigError::ValidationError(format!(
790 "invalid push route #{} '{}': must be valid IPv4 or IPv6 network",
791 idx + 1, route
792 )));
793 }
794 }
795
796 if let Some(oauth) = &self.oauth {
798 if oauth.enabled {
799 if oauth.client_id.is_empty() {
800 return Err(ConfigError::MissingField("oauth.client_id".into()));
801 }
802 if oauth.client_secret.expose_secret().is_empty() {
803 return Err(ConfigError::MissingField("oauth.client_secret".into()));
804 }
805 }
806 }
807
808 if self.server.listen_addr.port() == 0 {
810 return Err(ConfigError::ValidationError(
811 "Server listen port cannot be 0".into(),
812 ));
813 }
814 if let Some(tcp_addr) = &self.server.tcp_listen_addr {
815 if tcp_addr.port() == 0 {
816 return Err(ConfigError::ValidationError(
817 "TCP listen port cannot be 0".into(),
818 ));
819 }
820 }
821 if self.admin.listen_addr.port() == 0 {
822 return Err(ConfigError::ValidationError(
823 "Admin API listen port cannot be 0".into(),
824 ));
825 }
826
827 if self.server.max_clients == 0 {
829 return Err(ConfigError::ValidationError(
830 "max_clients must be greater than 0".into(),
831 ));
832 }
833 if self.server.max_clients > 100000 {
834 return Err(ConfigError::ValidationError(
835 "max_clients exceeds maximum allowed value (100000)".into(),
836 ));
837 }
838
839 if self.security.cert_lifetime_days == 0 {
841 return Err(ConfigError::ValidationError(
842 "cert_lifetime_days must be greater than 0".into(),
843 ));
844 }
845 if self.security.client_cert_lifetime_days == 0 {
846 return Err(ConfigError::ValidationError(
847 "client_cert_lifetime_days must be greater than 0".into(),
848 ));
849 }
850 if self.security.client_cert_lifetime_days > self.security.cert_lifetime_days {
851 return Err(ConfigError::ValidationError(
852 "client_cert_lifetime_days cannot exceed cert_lifetime_days".into(),
853 ));
854 }
855
856 if self.security.reneg_sec == 0 {
858 return Err(ConfigError::ValidationError(
859 "reneg_sec must be greater than 0".into(),
860 ));
861 }
862
863 if self.network.mtu < 68 || self.network.mtu > 1500 {
865 return Err(ConfigError::ValidationError(
866 "MTU must be between 68 and 1500".into(),
867 ));
868 }
869
870 Ok(())
871 }
872
873 pub fn data_dir(&self) -> &Path {
875 &self.server.data_dir
876 }
877
878 pub fn ca_cert_path(&self) -> PathBuf {
880 self.server.data_dir.join("ca.crt")
881 }
882
883 pub fn ca_key_path(&self) -> PathBuf {
885 self.server.data_dir.join("ca.key")
886 }
887
888 pub fn server_cert_path(&self) -> PathBuf {
890 self.server.data_dir.join("server.crt")
891 }
892
893 pub fn server_key_path(&self) -> PathBuf {
895 self.server.data_dir.join("server.key")
896 }
897
898 pub fn ta_key_path(&self) -> PathBuf {
900 self.server.data_dir.join("ta.key")
901 }
902
903 pub fn dh_path(&self) -> PathBuf {
905 self.server.data_dir.join("dh.pem")
906 }
907}
908
909fn deserialize_secret_string<'de, D>(deserializer: D) -> std::result::Result<SecretString, D::Error>
911where
912 D: serde::Deserializer<'de>,
913{
914 let s = String::deserialize(deserializer)?;
915 Ok(SecretString::new(s))
916}
917
918fn deserialize_optional_secret_string<'de, D>(
920 deserializer: D,
921) -> std::result::Result<Option<SecretString>, D::Error>
922where
923 D: serde::Deserializer<'de>,
924{
925 let opt = Option::<String>::deserialize(deserializer)?;
926 Ok(opt.map(SecretString::new))
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932
933 #[test]
934 fn test_default_config() {
935 let config = ServerConfig::default_config("vpn.example.com");
936
937 assert_eq!(config.server.public_host, "vpn.example.com");
938 assert_eq!(config.network.subnet, "10.8.0.0/24");
939 assert!(config.security.tls_auth);
940 }
941
942 #[test]
943 fn test_config_validation() {
944 let mut config = ServerConfig::default_config("vpn.example.com");
945 assert!(config.validate().is_ok());
946
947 config.server.public_host = String::new();
948 assert!(config.validate().is_err());
949 }
950}