1use http::{HeaderValue, Method};
2use std::path::{Path, PathBuf};
3
4use crate::core::error::ServerConfigError;
5
6pub const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:3000";
7
8#[derive(Clone, Debug, Eq, PartialEq)]
9enum CorsMode {
10 Permissive,
11 Restricted { allowed_origins: Vec<String> },
12}
13
14#[derive(Clone, Debug, Eq, PartialEq)]
15pub struct CorsConfig {
16 mode: CorsMode,
17 allowed_methods: Vec<String>,
18 max_age: Option<usize>,
19}
20
21impl Default for CorsConfig {
22 fn default() -> Self {
23 Self::permissive()
24 }
25}
26
27impl CorsConfig {
28 pub fn permissive() -> Self {
29 Self {
30 mode: CorsMode::Permissive,
31 allowed_methods: Vec::new(),
32 max_age: None,
33 }
34 }
35
36 pub fn restricted<I, S>(allowed_origins: I) -> Self
37 where
38 I: IntoIterator<Item = S>,
39 S: Into<String>,
40 {
41 Self {
42 mode: CorsMode::Restricted {
43 allowed_origins: allowed_origins.into_iter().map(Into::into).collect(),
44 },
45 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
46 max_age: Some(3600),
47 }
48 }
49
50 pub fn with_allowed_methods<I, S>(mut self, allowed_methods: I) -> Self
51 where
52 I: IntoIterator<Item = S>,
53 S: Into<String>,
54 {
55 self.allowed_methods = allowed_methods.into_iter().map(Into::into).collect();
56 self
57 }
58
59 pub fn with_max_age(mut self, max_age: usize) -> Self {
60 self.max_age = Some(max_age);
61 self
62 }
63
64 pub fn is_permissive(&self) -> bool {
65 matches!(self.mode, CorsMode::Permissive)
66 }
67
68 pub fn allowed_origins(&self) -> &[String] {
69 match &self.mode {
70 CorsMode::Permissive => &[],
71 CorsMode::Restricted { allowed_origins } => allowed_origins.as_slice(),
72 }
73 }
74
75 pub fn allowed_methods(&self) -> &[String] {
76 self.allowed_methods.as_slice()
77 }
78
79 pub fn max_age(&self) -> Option<usize> {
80 self.max_age
81 }
82
83 pub(crate) fn validate(&self) -> Result<(), ServerConfigError> {
84 match &self.mode {
85 CorsMode::Permissive => Ok(()),
86 CorsMode::Restricted { allowed_origins } => {
87 if allowed_origins.is_empty() {
88 return Err(ServerConfigError::MissingCorsOrigins);
89 }
90
91 for origin in allowed_origins {
92 HeaderValue::from_str(origin)
93 .map_err(|_| ServerConfigError::InvalidCorsOrigin(origin.clone()))?;
94 }
95
96 if self.allowed_methods.is_empty() {
97 return Err(ServerConfigError::MissingCorsMethods);
98 }
99
100 for method in &self.allowed_methods {
101 Method::from_bytes(method.as_bytes())
102 .map_err(|_| ServerConfigError::InvalidCorsMethod(method.clone()))?;
103 }
104
105 Ok(())
106 }
107 }
108 }
109}
110
111#[derive(Clone, Debug, Default, Eq, PartialEq)]
112pub struct TlsConfig {
113 cert_path: Option<PathBuf>,
114 cert_key_path: Option<PathBuf>,
115 client_ca: Option<PathBuf>,
116}
117
118impl TlsConfig {
119 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn with_cert_path<P>(mut self, cert_path: P) -> Self
124 where
125 P: Into<PathBuf>,
126 {
127 self.cert_path = Some(cert_path.into());
128 self
129 }
130
131 pub fn with_cert_key_path<P>(mut self, cert_key_path: P) -> Self
132 where
133 P: Into<PathBuf>,
134 {
135 self.cert_key_path = Some(cert_key_path.into());
136 self
137 }
138
139 pub fn with_client_ca<P>(mut self, client_ca: P) -> Self
140 where
141 P: Into<PathBuf>,
142 {
143 self.client_ca = Some(client_ca.into());
144 self
145 }
146
147 pub fn with_client_ca_path<P>(self, client_ca_path: P) -> Self
148 where
149 P: Into<PathBuf>,
150 {
151 self.with_client_ca(client_ca_path)
152 }
153
154 pub fn cert_path(&self) -> Option<&Path> {
155 self.cert_path.as_deref()
156 }
157
158 pub fn cert_key_path(&self) -> Option<&Path> {
159 self.cert_key_path.as_deref()
160 }
161
162 pub fn client_ca(&self) -> Option<&Path> {
163 self.client_ca.as_deref()
164 }
165
166 pub fn client_ca_path(&self) -> Option<&Path> {
167 self.client_ca()
168 }
169
170 pub(crate) fn validate(self) -> Result<ValidatedTlsConfig, ServerConfigError> {
171 let cert_path = self
172 .cert_path
173 .ok_or(ServerConfigError::MissingTlsCertPath)?;
174 let cert_key_path = self
175 .cert_key_path
176 .ok_or(ServerConfigError::MissingTlsKeyPath)?;
177
178 Ok(ValidatedTlsConfig {
179 cert_path,
180 cert_key_path,
181 client_ca: self.client_ca,
182 })
183 }
184}
185
186#[derive(Clone, Debug, Eq, PartialEq)]
187pub(crate) struct ValidatedTlsConfig {
188 pub(crate) cert_path: PathBuf,
189 pub(crate) cert_key_path: PathBuf,
190 pub(crate) client_ca: Option<PathBuf>,
191}
192
193#[derive(Clone, Debug)]
194pub struct ServerConfig<U = ()> {
195 listen_addr: String,
196 app_data: Option<U>,
197 cors: CorsConfig,
198 tls: Option<TlsConfig>,
199}
200
201impl Default for ServerConfig<()> {
202 fn default() -> Self {
203 Self {
204 listen_addr: DEFAULT_LISTEN_ADDR.to_string(),
205 app_data: None,
206 cors: CorsConfig::default(),
207 tls: None,
208 }
209 }
210}
211
212impl ServerConfig<()> {
213 pub fn new() -> Self {
214 Self::default()
215 }
216}
217
218impl<U> ServerConfig<U> {
219 pub fn with_listen_addr(mut self, listen_addr: impl Into<String>) -> Self {
220 self.listen_addr = normalize_listen_addr(listen_addr);
221 self
222 }
223
224 pub fn with_app_data<T>(self, app_data: T) -> ServerConfig<T> {
225 ServerConfig {
226 listen_addr: self.listen_addr,
227 app_data: Some(app_data),
228 cors: self.cors,
229 tls: self.tls,
230 }
231 }
232
233 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
234 self.cors = cors;
235 self
236 }
237
238 pub fn with_tls(mut self, tls: TlsConfig) -> Self {
239 self.tls = Some(tls);
240 self
241 }
242
243 pub fn build(self) -> Result<ValidatedServerConfig<U>, ServerConfigError> {
244 if self.listen_addr.trim().is_empty() {
245 return Err(ServerConfigError::MissingListenAddr);
246 }
247
248 self.cors.validate()?;
249 let tls = self.tls.map(TlsConfig::validate).transpose()?;
250
251 Ok(ValidatedServerConfig {
252 listen_addr: self.listen_addr,
253 app_data: self.app_data,
254 cors: self.cors,
255 tls,
256 })
257 }
258}
259
260pub fn normalize_listen_addr(listen_addr: impl Into<String>) -> String {
261 let listen_addr = listen_addr.into();
262 let trimmed = listen_addr.trim();
263 if trimmed.starts_with(':') {
264 format!("0.0.0.0{trimmed}")
265 } else {
266 trimmed.to_string()
267 }
268}
269
270#[derive(Clone, Debug)]
271pub struct ValidatedServerConfig<U = ()> {
272 pub(crate) listen_addr: String,
273 pub(crate) app_data: Option<U>,
274 pub(crate) cors: CorsConfig,
275 pub(crate) tls: Option<ValidatedTlsConfig>,
276}
277
278impl<U> ValidatedServerConfig<U> {
279 pub fn listen_addr(&self) -> &str {
280 self.listen_addr.as_str()
281 }
282
283 pub fn app_data(&self) -> Option<&U> {
284 self.app_data.as_ref()
285 }
286
287 pub fn cors(&self) -> &CorsConfig {
288 &self.cors
289 }
290
291 pub fn tls_enabled(&self) -> bool {
292 self.tls.is_some()
293 }
294
295 pub fn tls_paths(&self) -> Option<(&Path, &Path)> {
296 self.tls
297 .as_ref()
298 .map(|tls| (tls.cert_path.as_path(), tls.cert_key_path.as_path()))
299 }
300}