Skip to main content

cloudiful_server/core/
config.rs

1use http::{HeaderValue, Method};
2use std::path::{Path, PathBuf};
3
4use crate::core::error::ServerConfigError;
5
6pub const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:3000";
7
8#[derive(Clone, Debug, Eq, PartialEq)]
9enum CorsMode {
10    Permissive,
11    Restricted { allowed_origins: Vec<String> },
12}
13
14#[derive(Clone, Debug, Eq, PartialEq)]
15pub struct CorsConfig {
16    mode: CorsMode,
17    allowed_methods: Vec<String>,
18    max_age: Option<usize>,
19}
20
21impl Default for CorsConfig {
22    fn default() -> Self {
23        Self::permissive()
24    }
25}
26
27impl CorsConfig {
28    pub fn permissive() -> Self {
29        Self {
30            mode: CorsMode::Permissive,
31            allowed_methods: Vec::new(),
32            max_age: None,
33        }
34    }
35
36    pub fn restricted<I, S>(allowed_origins: I) -> Self
37    where
38        I: IntoIterator<Item = S>,
39        S: Into<String>,
40    {
41        Self {
42            mode: CorsMode::Restricted {
43                allowed_origins: allowed_origins.into_iter().map(Into::into).collect(),
44            },
45            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
46            max_age: Some(3600),
47        }
48    }
49
50    pub fn with_allowed_methods<I, S>(mut self, allowed_methods: I) -> Self
51    where
52        I: IntoIterator<Item = S>,
53        S: Into<String>,
54    {
55        self.allowed_methods = allowed_methods.into_iter().map(Into::into).collect();
56        self
57    }
58
59    pub fn with_max_age(mut self, max_age: usize) -> Self {
60        self.max_age = Some(max_age);
61        self
62    }
63
64    pub fn is_permissive(&self) -> bool {
65        matches!(self.mode, CorsMode::Permissive)
66    }
67
68    pub fn allowed_origins(&self) -> &[String] {
69        match &self.mode {
70            CorsMode::Permissive => &[],
71            CorsMode::Restricted { allowed_origins } => allowed_origins.as_slice(),
72        }
73    }
74
75    pub fn allowed_methods(&self) -> &[String] {
76        self.allowed_methods.as_slice()
77    }
78
79    pub fn max_age(&self) -> Option<usize> {
80        self.max_age
81    }
82
83    pub(crate) fn validate(&self) -> Result<(), ServerConfigError> {
84        match &self.mode {
85            CorsMode::Permissive => Ok(()),
86            CorsMode::Restricted { allowed_origins } => {
87                if allowed_origins.is_empty() {
88                    return Err(ServerConfigError::MissingCorsOrigins);
89                }
90
91                for origin in allowed_origins {
92                    HeaderValue::from_str(origin)
93                        .map_err(|_| ServerConfigError::InvalidCorsOrigin(origin.clone()))?;
94                }
95
96                if self.allowed_methods.is_empty() {
97                    return Err(ServerConfigError::MissingCorsMethods);
98                }
99
100                for method in &self.allowed_methods {
101                    Method::from_bytes(method.as_bytes())
102                        .map_err(|_| ServerConfigError::InvalidCorsMethod(method.clone()))?;
103                }
104
105                Ok(())
106            }
107        }
108    }
109}
110
111#[derive(Clone, Debug, Default, Eq, PartialEq)]
112pub struct TlsConfig {
113    cert_path: Option<PathBuf>,
114    cert_key_path: Option<PathBuf>,
115    client_ca: Option<PathBuf>,
116}
117
118impl TlsConfig {
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    pub fn with_cert_path<P>(mut self, cert_path: P) -> Self
124    where
125        P: Into<PathBuf>,
126    {
127        self.cert_path = Some(cert_path.into());
128        self
129    }
130
131    pub fn with_cert_key_path<P>(mut self, cert_key_path: P) -> Self
132    where
133        P: Into<PathBuf>,
134    {
135        self.cert_key_path = Some(cert_key_path.into());
136        self
137    }
138
139    pub fn with_client_ca<P>(mut self, client_ca: P) -> Self
140    where
141        P: Into<PathBuf>,
142    {
143        self.client_ca = Some(client_ca.into());
144        self
145    }
146
147    pub fn with_client_ca_path<P>(self, client_ca_path: P) -> Self
148    where
149        P: Into<PathBuf>,
150    {
151        self.with_client_ca(client_ca_path)
152    }
153
154    pub fn cert_path(&self) -> Option<&Path> {
155        self.cert_path.as_deref()
156    }
157
158    pub fn cert_key_path(&self) -> Option<&Path> {
159        self.cert_key_path.as_deref()
160    }
161
162    pub fn client_ca(&self) -> Option<&Path> {
163        self.client_ca.as_deref()
164    }
165
166    pub fn client_ca_path(&self) -> Option<&Path> {
167        self.client_ca()
168    }
169
170    pub(crate) fn validate(self) -> Result<ValidatedTlsConfig, ServerConfigError> {
171        let cert_path = self
172            .cert_path
173            .ok_or(ServerConfigError::MissingTlsCertPath)?;
174        let cert_key_path = self
175            .cert_key_path
176            .ok_or(ServerConfigError::MissingTlsKeyPath)?;
177
178        Ok(ValidatedTlsConfig {
179            cert_path,
180            cert_key_path,
181            client_ca: self.client_ca,
182        })
183    }
184}
185
186#[derive(Clone, Debug, Eq, PartialEq)]
187pub(crate) struct ValidatedTlsConfig {
188    pub(crate) cert_path: PathBuf,
189    pub(crate) cert_key_path: PathBuf,
190    pub(crate) client_ca: Option<PathBuf>,
191}
192
193#[derive(Clone, Debug)]
194pub struct ServerConfig<U = ()> {
195    listen_addr: String,
196    app_data: Option<U>,
197    cors: CorsConfig,
198    tls: Option<TlsConfig>,
199}
200
201impl Default for ServerConfig<()> {
202    fn default() -> Self {
203        Self {
204            listen_addr: DEFAULT_LISTEN_ADDR.to_string(),
205            app_data: None,
206            cors: CorsConfig::default(),
207            tls: None,
208        }
209    }
210}
211
212impl ServerConfig<()> {
213    pub fn new() -> Self {
214        Self::default()
215    }
216}
217
218impl<U> ServerConfig<U> {
219    pub fn with_listen_addr(mut self, listen_addr: impl Into<String>) -> Self {
220        self.listen_addr = normalize_listen_addr(listen_addr);
221        self
222    }
223
224    pub fn with_app_data<T>(self, app_data: T) -> ServerConfig<T> {
225        ServerConfig {
226            listen_addr: self.listen_addr,
227            app_data: Some(app_data),
228            cors: self.cors,
229            tls: self.tls,
230        }
231    }
232
233    pub fn with_cors(mut self, cors: CorsConfig) -> Self {
234        self.cors = cors;
235        self
236    }
237
238    pub fn with_tls(mut self, tls: TlsConfig) -> Self {
239        self.tls = Some(tls);
240        self
241    }
242
243    pub fn build(self) -> Result<ValidatedServerConfig<U>, ServerConfigError> {
244        if self.listen_addr.trim().is_empty() {
245            return Err(ServerConfigError::MissingListenAddr);
246        }
247
248        self.cors.validate()?;
249        let tls = self.tls.map(TlsConfig::validate).transpose()?;
250
251        Ok(ValidatedServerConfig {
252            listen_addr: self.listen_addr,
253            app_data: self.app_data,
254            cors: self.cors,
255            tls,
256        })
257    }
258}
259
260pub fn normalize_listen_addr(listen_addr: impl Into<String>) -> String {
261    let listen_addr = listen_addr.into();
262    let trimmed = listen_addr.trim();
263    if trimmed.starts_with(':') {
264        format!("0.0.0.0{trimmed}")
265    } else {
266        trimmed.to_string()
267    }
268}
269
270#[derive(Clone, Debug)]
271pub struct ValidatedServerConfig<U = ()> {
272    pub(crate) listen_addr: String,
273    pub(crate) app_data: Option<U>,
274    pub(crate) cors: CorsConfig,
275    pub(crate) tls: Option<ValidatedTlsConfig>,
276}
277
278impl<U> ValidatedServerConfig<U> {
279    pub fn listen_addr(&self) -> &str {
280        self.listen_addr.as_str()
281    }
282
283    pub fn app_data(&self) -> Option<&U> {
284        self.app_data.as_ref()
285    }
286
287    pub fn cors(&self) -> &CorsConfig {
288        &self.cors
289    }
290
291    pub fn tls_enabled(&self) -> bool {
292        self.tls.is_some()
293    }
294
295    pub fn tls_paths(&self) -> Option<(&Path, &Path)> {
296        self.tls
297            .as_ref()
298            .map(|tls| (tls.cert_path.as_path(), tls.cert_key_path.as_path()))
299    }
300}