1use std::convert::Infallible;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::{net::SocketAddr, str::FromStr, time::Duration};
9
10use duration_str::deserialize_duration;
11use futures::{FutureExt, TryStreamExt};
12use serde::Deserialize;
13use tonic::transport::server::TcpIncoming;
14
15use super::errors::ConfigError;
16use crate::auth::ServerAuthenticator;
17use crate::auth::basic::Config as BasicAuthenticationConfig;
18use crate::auth::bearer::Config as BearerAuthenticationConfig;
19use crate::component::configuration::{Configuration, ConfigurationError};
20use crate::tls::{common::RustlsConfigLoader, server::TlsServerConfig as TLSSetting};
21
22#[derive(Debug, Deserialize, PartialEq, Clone)]
23pub struct KeepaliveServerParameters {
24 #[serde(
26 default = "default_max_connection_idle",
27 deserialize_with = "deserialize_duration"
28 )]
29 max_connection_idle: Duration,
30
31 #[serde(
33 default = "default_max_connection_age",
34 deserialize_with = "deserialize_duration"
35 )]
36 max_connection_age: Duration,
37
38 #[serde(
40 default = "default_max_connection_age_grace",
41 deserialize_with = "deserialize_duration"
42 )]
43 max_connection_age_grace: Duration,
44
45 #[serde(default = "default_time", deserialize_with = "deserialize_duration")]
47 time: Duration,
48
49 #[serde(default = "default_timeout", deserialize_with = "deserialize_duration")]
51 timeout: Duration,
52}
53
54#[derive(Debug, Deserialize, Clone, PartialEq)]
56#[serde(rename_all = "snake_case")]
57pub enum AuthenticationConfig {
58 Basic(BasicAuthenticationConfig),
60 Bearer(BearerAuthenticationConfig),
62 None,
64}
65
66impl Default for AuthenticationConfig {
67 fn default() -> Self {
68 Self::None
69 }
70}
71
72#[derive(Debug, Deserialize, PartialEq, Clone)]
73pub struct ServerConfig {
74 pub endpoint: String,
76
77 #[serde(default, rename = "tls")]
79 pub tls_setting: TLSSetting,
80
81 #[serde(default = "default_http2_only")]
83 pub http2_only: bool,
84
85 pub max_frame_size: Option<u32>,
87
88 pub max_concurrent_streams: Option<u32>,
90
91 pub max_header_list_size: Option<u32>,
93
94 pub read_buffer_size: Option<usize>,
97
98 pub write_buffer_size: Option<usize>,
101
102 #[serde(default)]
104 pub keepalive: KeepaliveServerParameters,
105
106 #[serde(default)]
108 #[serde(with = "serde_yaml::with::singleton_map")]
109 pub auth: AuthenticationConfig,
110}
111
112impl Default for KeepaliveServerParameters {
114 fn default() -> Self {
115 Self {
116 max_connection_idle: default_max_connection_idle(),
117 max_connection_age: default_max_connection_age(),
118 max_connection_age_grace: default_max_connection_age_grace(),
119 time: default_time(),
120 timeout: default_timeout(),
121 }
122 }
123}
124
125fn default_max_connection_idle() -> Duration {
126 Duration::from_secs(3600)
127}
128
129fn default_max_connection_age() -> Duration {
130 Duration::from_secs(2 * 3600)
131}
132
133fn default_max_connection_age_grace() -> Duration {
134 Duration::from_secs(5 * 60)
135}
136
137fn default_time() -> Duration {
138 Duration::from_secs(2 * 60)
139}
140
141fn default_timeout() -> Duration {
142 Duration::from_secs(20)
143}
144
145impl Default for ServerConfig {
147 fn default() -> Self {
148 Self {
149 endpoint: String::new(),
150 tls_setting: TLSSetting::default(),
151 http2_only: default_http2_only(),
152 max_frame_size: Some(4),
153 max_concurrent_streams: Some(100),
154 max_header_list_size: None,
155 read_buffer_size: Some(1024 * 1024),
156 write_buffer_size: Some(1024 * 1024),
157 keepalive: KeepaliveServerParameters::default(),
158 auth: AuthenticationConfig::default(),
159 }
160 }
161}
162
163fn default_http2_only() -> bool {
164 true
165}
166
167impl std::fmt::Display for ServerConfig {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 write!(
172 f,
173 "ServerConfig {{ endpoint: {}, tls_setting: {}, http2_only: {}, max_frame_size: {:?}, max_concurrent_streams: {:?}, max_header_list_size: {:?}, read_buffer_size: {:?}, write_buffer_size: {:?}, keepalive: {:?}, auth: {:?} }}",
174 self.endpoint,
175 self.tls_setting,
176 self.http2_only,
177 self.max_frame_size,
178 self.max_concurrent_streams,
179 self.max_header_list_size,
180 self.read_buffer_size,
181 self.write_buffer_size,
182 self.keepalive,
183 self.auth
184 )
185 }
186}
187
188impl Configuration for ServerConfig {
189 fn validate(&self) -> Result<(), ConfigurationError> {
190 self.tls_setting.validate()
192 }
193}
194
195type ServerFuture = Pin<Box<dyn Future<Output = Result<(), tonic::transport::Error>> + Send>>;
197
198impl ServerConfig {
203 pub fn with_endpoint(endpoint: &str) -> Self {
204 Self {
205 endpoint: endpoint.to_string(),
206 ..Default::default()
207 }
208 }
209
210 pub fn with_tls_settings(self, tls_setting: TLSSetting) -> Self {
211 Self {
212 tls_setting,
213 ..self
214 }
215 }
216
217 pub fn with_http2_only(self, http2_only: bool) -> Self {
218 Self { http2_only, ..self }
219 }
220
221 pub fn with_max_frame_size(self, max_frame_size: Option<u32>) -> Self {
222 Self {
223 max_frame_size,
224 ..self
225 }
226 }
227
228 pub fn with_max_concurrent_streams(self, max_concurrent_streams: Option<u32>) -> Self {
229 Self {
230 max_concurrent_streams,
231 ..self
232 }
233 }
234
235 pub fn with_max_header_list_size(self, max_header_list_size: Option<u32>) -> Self {
236 Self {
237 max_header_list_size,
238 ..self
239 }
240 }
241
242 pub fn with_read_buffer_size(self, read_buffer_size: Option<usize>) -> Self {
243 Self {
244 read_buffer_size,
245 ..self
246 }
247 }
248
249 pub fn with_write_buffer_size(self, write_buffer_size: Option<usize>) -> Self {
250 Self {
251 write_buffer_size,
252 ..self
253 }
254 }
255
256 pub fn with_keepalive(self, keepalive: KeepaliveServerParameters) -> Self {
257 Self { keepalive, ..self }
258 }
259
260 pub fn with_auth(self, auth: AuthenticationConfig) -> Self {
261 Self { auth, ..self }
262 }
263
264 pub fn to_server_future<S>(&self, svc: &[S]) -> Result<ServerFuture, ConfigError>
265 where
266 S: tower_service::Service<
267 http::Request<tonic::body::Body>,
268 Response = http::Response<tonic::body::Body>,
269 Error = Infallible,
270 >
271 + tonic::server::NamedService
272 + Clone
273 + Send
274 + 'static
275 + Sync,
276 S::Future: Send + 'static,
277 {
278 if svc.is_empty() {
280 return Err(ConfigError::MissingServices);
281 }
282
283 if self.endpoint.is_empty() {
285 return Err(ConfigError::MissingEndpoint);
286 }
287
288 let addr = SocketAddr::from_str(self.endpoint.as_str())
290 .map_err(|e| ConfigError::EndpointParseError(e.to_string()))?;
291
292 let incoming =
294 TcpIncoming::bind(addr).map_err(|e| ConfigError::TcpIncomingError(e.to_string()))?;
295
296 let builder: tonic::transport::Server =
298 tonic::transport::Server::builder().accept_http1(false);
299
300 let builder = match self.max_concurrent_streams {
302 Some(max_concurrent_streams) => {
303 builder.concurrency_limit_per_connection(max_concurrent_streams as usize)
304 }
305 None => builder,
306 };
307
308 let builder = match self.max_frame_size {
310 Some(max_frame_size) => builder.max_frame_size(max_frame_size * 1024 * 1024),
311 None => builder,
312 };
313
314 let builder = match self.max_header_list_size {
316 Some(max_header_list_size) => builder.http2_max_header_list_size(max_header_list_size),
317 None => builder,
318 };
319
320 let builder = builder.http2_keepalive_interval(Some(self.keepalive.time));
322 let builder = builder.http2_keepalive_timeout(Some(self.keepalive.timeout));
323
324 let mut builder = builder.max_connection_age(self.keepalive.max_connection_age);
326
327 let tls_config = TLSSetting::load_rustls_config(&self.tls_setting)
329 .map_err(|e| ConfigError::TLSSettingError(e.to_string()))?;
330
331 match &self.auth {
332 AuthenticationConfig::Basic(basic) => {
333 let auth_layer = basic
334 .get_server_layer()
335 .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
336
337 let mut builder = builder.layer(auth_layer);
338
339 let mut router = builder.add_service(svc[0].clone());
340 for s in svc.iter().skip(1) {
341 router = builder.add_service(s.clone());
342 }
343
344 if let Some(tls_config) = tls_config {
345 let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
346 .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
347
348 return Ok(router.serve_with_incoming(incoming).boxed());
350 };
351
352 Ok(router.serve_with_incoming(incoming).boxed())
353 }
354 AuthenticationConfig::Bearer(bearer) => {
355 let auth_layer = bearer
356 .get_server_layer()
357 .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
358
359 let mut builder = builder.layer(auth_layer);
360
361 let mut router = builder.add_service(svc[0].clone());
362 for s in svc.iter().skip(1) {
363 router = builder.add_service(s.clone());
364 }
365
366 if let Some(tls_config) = tls_config {
367 let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
368 .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
369
370 return Ok(router.serve_with_incoming(incoming).boxed());
372 };
373
374 Ok(router.serve_with_incoming(incoming).boxed())
375 }
376 AuthenticationConfig::None => {
377 let mut router = builder.add_service(svc[0].clone());
378 for s in svc.iter().skip(1) {
379 router = builder.add_service(s.clone());
380 }
381
382 if let Some(tls_config) = tls_config {
383 let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
384 .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
385
386 return Ok(router.serve_with_incoming(incoming).boxed());
388 };
389
390 Ok(router.serve_with_incoming(incoming).boxed())
391 }
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::testutils::{Empty, helloworld::greeter_server::GreeterServer};
400
401 static TEST_DATA_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata/grpc");
402
403 #[test]
404 fn test_default_keepalive_server_parameters() {
405 let keepalive = KeepaliveServerParameters::default();
406 assert_eq!(keepalive.max_connection_idle, default_max_connection_idle());
407 assert_eq!(keepalive.max_connection_age, default_max_connection_age());
408 assert_eq!(
409 keepalive.max_connection_age_grace,
410 default_max_connection_age_grace()
411 );
412 assert_eq!(keepalive.time, default_time());
413 assert_eq!(keepalive.timeout, default_timeout());
414 }
415
416 #[test]
417 fn test_default_server_config() {
418 let server_config = ServerConfig::default();
419 assert_eq!(server_config.endpoint, String::new());
420 assert_eq!(server_config.tls_setting, TLSSetting::default());
421 assert_eq!(server_config.http2_only, default_http2_only());
422 assert_eq!(server_config.max_frame_size, Some(4));
423 assert_eq!(server_config.max_concurrent_streams, Some(100));
424 assert_eq!(server_config.max_header_list_size, None);
425 assert_eq!(server_config.read_buffer_size, Some(1024 * 1024));
426 assert_eq!(server_config.write_buffer_size, Some(1024 * 1024));
427 assert_eq!(
428 server_config.keepalive,
429 KeepaliveServerParameters::default()
430 );
431 assert_eq!(server_config.auth, AuthenticationConfig::None);
432 }
433
434 #[tokio::test]
435 async fn test_to_incoming_server_config() {
436 let mut server_config = ServerConfig::default();
437 let empty_service = Arc::new(Empty::new());
438
439 let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
441 assert!(ret.is_err_and(|e| { e.to_string().contains("missing grpc endpoint") }));
443
444 server_config.endpoint = "0.0.0.0:123456".to_string();
446 let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
447 assert!(ret.is_err_and(|e| { e.to_string().contains("error parsing grpc endpoint") }));
448
449 server_config.endpoint = "0.0.0.0:12345".to_string();
451 let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
452 assert!(ret.is_err_and(|e| { e.to_string().contains("tls setting error") }));
453
454 server_config.tls_setting.insecure = true;
456 let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
457 assert!(ret.is_ok());
458
459 drop(ret.unwrap());
461
462 server_config.tls_setting.insecure = false;
464 server_config.tls_setting.config.cert_file = Some(format!("{}/server.crt", TEST_DATA_PATH));
465 server_config.tls_setting.config.key_file = Some(format!("{}/server.key", TEST_DATA_PATH));
466 let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
467 assert!(ret.is_ok());
468 }
469}