Skip to main content

agp_config/grpc/
server.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::convert::Infallible;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::{net::SocketAddr, str::FromStr, time::Duration};
9
10use duration_str::deserialize_duration;
11use futures::{FutureExt, TryStreamExt};
12use serde::Deserialize;
13use tonic::transport::server::TcpIncoming;
14
15use super::errors::ConfigError;
16use crate::auth::ServerAuthenticator;
17use crate::auth::basic::Config as BasicAuthenticationConfig;
18use crate::auth::bearer::Config as BearerAuthenticationConfig;
19use crate::component::configuration::{Configuration, ConfigurationError};
20use crate::tls::{common::RustlsConfigLoader, server::TlsServerConfig as TLSSetting};
21
22#[derive(Debug, Deserialize, PartialEq, Clone)]
23pub struct KeepaliveServerParameters {
24    /// max_connection_idle sets the time after which an idle connection is closed.
25    #[serde(
26        default = "default_max_connection_idle",
27        deserialize_with = "deserialize_duration"
28    )]
29    max_connection_idle: Duration,
30
31    /// max_connection_age sets the maximum amount of time a connection may exist before it will be closed.
32    #[serde(
33        default = "default_max_connection_age",
34        deserialize_with = "deserialize_duration"
35    )]
36    max_connection_age: Duration,
37
38    /// max_connection_age_grace is an additional time given after MaxConnectionAge before closing the connection.
39    #[serde(
40        default = "default_max_connection_age_grace",
41        deserialize_with = "deserialize_duration"
42    )]
43    max_connection_age_grace: Duration,
44
45    /// Time sets the frequency of the keepalive ping.
46    #[serde(default = "default_time", deserialize_with = "deserialize_duration")]
47    time: Duration,
48
49    /// Timeout sets the amount of time the server waits for a keepalive ping ack.
50    #[serde(default = "default_timeout", deserialize_with = "deserialize_duration")]
51    timeout: Duration,
52}
53
54/// Enum holding one configuration for the client.
55#[derive(Debug, Deserialize, Clone, PartialEq)]
56#[serde(rename_all = "snake_case")]
57pub enum AuthenticationConfig {
58    /// Basic authentication configuration.
59    Basic(BasicAuthenticationConfig),
60    /// Bearer authentication configuration.
61    Bearer(BearerAuthenticationConfig),
62    /// None
63    None,
64}
65
66impl Default for AuthenticationConfig {
67    fn default() -> Self {
68        Self::None
69    }
70}
71
72#[derive(Debug, Deserialize, PartialEq, Clone)]
73pub struct ServerConfig {
74    /// Endpoint is the address to listen on.
75    pub endpoint: String,
76
77    /// Configures the protocol to use TLS.
78    #[serde(default, rename = "tls")]
79    pub tls_setting: TLSSetting,
80
81    /// Use HTTP 2 only.
82    #[serde(default = "default_http2_only")]
83    pub http2_only: bool,
84
85    /// Maximum size (in MiB) of messages accepted by the server.
86    pub max_frame_size: Option<u32>,
87
88    /// MaxConcurrentStreams sets the limit on the number of concurrent streams to each ServerTransport.
89    pub max_concurrent_streams: Option<u32>,
90
91    /// Max header list size
92    pub max_header_list_size: Option<u32>,
93
94    /// ReadBufferSize for gRPC server.
95    // TODO(msardara): not implemented yet
96    pub read_buffer_size: Option<usize>,
97
98    /// WriteBufferSize for gRPC server.
99    // TODO(msardara): not implemented yet
100    pub write_buffer_size: Option<usize>,
101
102    /// Keepalive anchor for all the settings related to keepalive.
103    #[serde(default)]
104    pub keepalive: KeepaliveServerParameters,
105
106    /// Auth for this receiver.
107    #[serde(default)]
108    #[serde(with = "serde_yaml::with::singleton_map")]
109    pub auth: AuthenticationConfig,
110}
111
112/// Default values for KeepaliveServerParameters
113impl Default for KeepaliveServerParameters {
114    fn default() -> Self {
115        Self {
116            max_connection_idle: default_max_connection_idle(),
117            max_connection_age: default_max_connection_age(),
118            max_connection_age_grace: default_max_connection_age_grace(),
119            time: default_time(),
120            timeout: default_timeout(),
121        }
122    }
123}
124
125fn default_max_connection_idle() -> Duration {
126    Duration::from_secs(3600)
127}
128
129fn default_max_connection_age() -> Duration {
130    Duration::from_secs(2 * 3600)
131}
132
133fn default_max_connection_age_grace() -> Duration {
134    Duration::from_secs(5 * 60)
135}
136
137fn default_time() -> Duration {
138    Duration::from_secs(2 * 60)
139}
140
141fn default_timeout() -> Duration {
142    Duration::from_secs(20)
143}
144
145/// Default values for ServerConfig
146impl Default for ServerConfig {
147    fn default() -> Self {
148        Self {
149            endpoint: String::new(),
150            tls_setting: TLSSetting::default(),
151            http2_only: default_http2_only(),
152            max_frame_size: Some(4),
153            max_concurrent_streams: Some(100),
154            max_header_list_size: None,
155            read_buffer_size: Some(1024 * 1024),
156            write_buffer_size: Some(1024 * 1024),
157            keepalive: KeepaliveServerParameters::default(),
158            auth: AuthenticationConfig::default(),
159        }
160    }
161}
162
163fn default_http2_only() -> bool {
164    true
165}
166
167/// Display implementation for ServerConfig
168/// This is used to print the ServerConfig in a human-readable format.
169impl std::fmt::Display for ServerConfig {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(
172            f,
173            "ServerConfig {{ endpoint: {}, tls_setting: {}, http2_only: {}, max_frame_size: {:?}, max_concurrent_streams: {:?}, max_header_list_size: {:?}, read_buffer_size: {:?}, write_buffer_size: {:?}, keepalive: {:?}, auth: {:?} }}",
174            self.endpoint,
175            self.tls_setting,
176            self.http2_only,
177            self.max_frame_size,
178            self.max_concurrent_streams,
179            self.max_header_list_size,
180            self.read_buffer_size,
181            self.write_buffer_size,
182            self.keepalive,
183            self.auth
184        )
185    }
186}
187
188impl Configuration for ServerConfig {
189    fn validate(&self) -> Result<(), ConfigurationError> {
190        // Validate the client configuration
191        self.tls_setting.validate()
192    }
193}
194
195/// ServerFuture is a type alias for a boxed future that returns a Result<(), tonic::transport::Error>.
196type ServerFuture = Pin<Box<dyn Future<Output = Result<(), tonic::transport::Error>> + Send>>;
197
198/// Convert ServerConfig to IncomingServerConfig
199/// This function takes a ServerConfig and a service and returns a ServerFuture.
200/// The ServerFuture is a boxed future that returns a Result<(), tonic::transport::Error>.
201/// The ServerFuture is created by creating a new TcpIncoming and then creating a new Server.
202impl ServerConfig {
203    pub fn with_endpoint(endpoint: &str) -> Self {
204        Self {
205            endpoint: endpoint.to_string(),
206            ..Default::default()
207        }
208    }
209
210    pub fn with_tls_settings(self, tls_setting: TLSSetting) -> Self {
211        Self {
212            tls_setting,
213            ..self
214        }
215    }
216
217    pub fn with_http2_only(self, http2_only: bool) -> Self {
218        Self { http2_only, ..self }
219    }
220
221    pub fn with_max_frame_size(self, max_frame_size: Option<u32>) -> Self {
222        Self {
223            max_frame_size,
224            ..self
225        }
226    }
227
228    pub fn with_max_concurrent_streams(self, max_concurrent_streams: Option<u32>) -> Self {
229        Self {
230            max_concurrent_streams,
231            ..self
232        }
233    }
234
235    pub fn with_max_header_list_size(self, max_header_list_size: Option<u32>) -> Self {
236        Self {
237            max_header_list_size,
238            ..self
239        }
240    }
241
242    pub fn with_read_buffer_size(self, read_buffer_size: Option<usize>) -> Self {
243        Self {
244            read_buffer_size,
245            ..self
246        }
247    }
248
249    pub fn with_write_buffer_size(self, write_buffer_size: Option<usize>) -> Self {
250        Self {
251            write_buffer_size,
252            ..self
253        }
254    }
255
256    pub fn with_keepalive(self, keepalive: KeepaliveServerParameters) -> Self {
257        Self { keepalive, ..self }
258    }
259
260    pub fn with_auth(self, auth: AuthenticationConfig) -> Self {
261        Self { auth, ..self }
262    }
263
264    pub fn to_server_future<S>(&self, svc: &[S]) -> Result<ServerFuture, ConfigError>
265    where
266        S: tower_service::Service<
267                http::Request<tonic::body::Body>,
268                Response = http::Response<tonic::body::Body>,
269                Error = Infallible,
270            >
271            + tonic::server::NamedService
272            + Clone
273            + Send
274            + 'static
275            + Sync,
276        S::Future: Send + 'static,
277    {
278        // Make sure at least one service is provided
279        if svc.is_empty() {
280            return Err(ConfigError::MissingServices);
281        }
282
283        // Check if the endpoint is missing
284        if self.endpoint.is_empty() {
285            return Err(ConfigError::MissingEndpoint);
286        }
287
288        // make sure endpoint is valid
289        let addr = SocketAddr::from_str(self.endpoint.as_str())
290            .map_err(|e| ConfigError::EndpointParseError(e.to_string()))?;
291
292        // create a new TcpIncoming
293        let incoming =
294            TcpIncoming::bind(addr).map_err(|e| ConfigError::TcpIncomingError(e.to_string()))?;
295
296        // Create initial server config
297        let builder: tonic::transport::Server =
298            tonic::transport::Server::builder().accept_http1(false);
299
300        // Set max number of concurrent streams per connection
301        let builder = match self.max_concurrent_streams {
302            Some(max_concurrent_streams) => {
303                builder.concurrency_limit_per_connection(max_concurrent_streams as usize)
304            }
305            None => builder,
306        };
307
308        // Set max size of messages accepted by the server
309        let builder = match self.max_frame_size {
310            Some(max_frame_size) => builder.max_frame_size(max_frame_size * 1024 * 1024),
311            None => builder,
312        };
313
314        // Set max header list size
315        let builder = match self.max_header_list_size {
316            Some(max_header_list_size) => builder.http2_max_header_list_size(max_header_list_size),
317            None => builder,
318        };
319
320        // Set keepalive parameters
321        let builder = builder.http2_keepalive_interval(Some(self.keepalive.time));
322        let builder = builder.http2_keepalive_timeout(Some(self.keepalive.timeout));
323
324        // Set max connection age
325        let mut builder = builder.max_connection_age(self.keepalive.max_connection_age);
326
327        // TLS configuration
328        let tls_config = TLSSetting::load_rustls_config(&self.tls_setting)
329            .map_err(|e| ConfigError::TLSSettingError(e.to_string()))?;
330
331        match &self.auth {
332            AuthenticationConfig::Basic(basic) => {
333                let auth_layer = basic
334                    .get_server_layer()
335                    .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
336
337                let mut builder = builder.layer(auth_layer);
338
339                let mut router = builder.add_service(svc[0].clone());
340                for s in svc.iter().skip(1) {
341                    router = builder.add_service(s.clone());
342                }
343
344                if let Some(tls_config) = tls_config {
345                    let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
346                        .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
347
348                    // Return the server future with the TLS configuration
349                    return Ok(router.serve_with_incoming(incoming).boxed());
350                };
351
352                Ok(router.serve_with_incoming(incoming).boxed())
353            }
354            AuthenticationConfig::Bearer(bearer) => {
355                let auth_layer = bearer
356                    .get_server_layer()
357                    .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
358
359                let mut builder = builder.layer(auth_layer);
360
361                let mut router = builder.add_service(svc[0].clone());
362                for s in svc.iter().skip(1) {
363                    router = builder.add_service(s.clone());
364                }
365
366                if let Some(tls_config) = tls_config {
367                    let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
368                        .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
369
370                    // Return the server future with the TLS configuration
371                    return Ok(router.serve_with_incoming(incoming).boxed());
372                };
373
374                Ok(router.serve_with_incoming(incoming).boxed())
375            }
376            AuthenticationConfig::None => {
377                let mut router = builder.add_service(svc[0].clone());
378                for s in svc.iter().skip(1) {
379                    router = builder.add_service(s.clone());
380                }
381
382                if let Some(tls_config) = tls_config {
383                    let incoming = tonic_tls::rustls::incoming(incoming, Arc::new(tls_config))
384                        .map_err(|e| ConfigError::TcpIncomingError(e.to_string()));
385
386                    // Return the server future with the TLS configuration
387                    return Ok(router.serve_with_incoming(incoming).boxed());
388                };
389
390                Ok(router.serve_with_incoming(incoming).boxed())
391            }
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::testutils::{Empty, helloworld::greeter_server::GreeterServer};
400
401    static TEST_DATA_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata/grpc");
402
403    #[test]
404    fn test_default_keepalive_server_parameters() {
405        let keepalive = KeepaliveServerParameters::default();
406        assert_eq!(keepalive.max_connection_idle, default_max_connection_idle());
407        assert_eq!(keepalive.max_connection_age, default_max_connection_age());
408        assert_eq!(
409            keepalive.max_connection_age_grace,
410            default_max_connection_age_grace()
411        );
412        assert_eq!(keepalive.time, default_time());
413        assert_eq!(keepalive.timeout, default_timeout());
414    }
415
416    #[test]
417    fn test_default_server_config() {
418        let server_config = ServerConfig::default();
419        assert_eq!(server_config.endpoint, String::new());
420        assert_eq!(server_config.tls_setting, TLSSetting::default());
421        assert_eq!(server_config.http2_only, default_http2_only());
422        assert_eq!(server_config.max_frame_size, Some(4));
423        assert_eq!(server_config.max_concurrent_streams, Some(100));
424        assert_eq!(server_config.max_header_list_size, None);
425        assert_eq!(server_config.read_buffer_size, Some(1024 * 1024));
426        assert_eq!(server_config.write_buffer_size, Some(1024 * 1024));
427        assert_eq!(
428            server_config.keepalive,
429            KeepaliveServerParameters::default()
430        );
431        assert_eq!(server_config.auth, AuthenticationConfig::None);
432    }
433
434    #[tokio::test]
435    async fn test_to_incoming_server_config() {
436        let mut server_config = ServerConfig::default();
437        let empty_service = Arc::new(Empty::new());
438
439        // no endpoint - should return an error
440        let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
441        // Make sure the error is a ConfigError::MissingEndpoint
442        assert!(ret.is_err_and(|e| { e.to_string().contains("missing grpc endpoint") }));
443
444        // set the endpoint in the config. Now it shouhld fail because of the invalid endpoint
445        server_config.endpoint = "0.0.0.0:123456".to_string();
446        let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
447        assert!(ret.is_err_and(|e| { e.to_string().contains("error parsing grpc endpoint") }));
448
449        // set a valid endpoint in the config. Now it should fail because of the missing cert/key files for tls
450        server_config.endpoint = "0.0.0.0:12345".to_string();
451        let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
452        assert!(ret.is_err_and(|e| { e.to_string().contains("tls setting error") }));
453
454        // set the tls setting to insecure. Now it should return a server future
455        server_config.tls_setting.insecure = true;
456        let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
457        assert!(ret.is_ok());
458
459        // drop it, as we have a server listening on the port now
460        drop(ret.unwrap());
461
462        // Set insecure to false and set the path to the cert and key files
463        server_config.tls_setting.insecure = false;
464        server_config.tls_setting.config.cert_file = Some(format!("{}/server.crt", TEST_DATA_PATH));
465        server_config.tls_setting.config.key_file = Some(format!("{}/server.key", TEST_DATA_PATH));
466        let ret = server_config.to_server_future(&[GreeterServer::from_arc(empty_service.clone())]);
467        assert!(ret.is_ok());
468    }
469}