1use duration_str::deserialize_duration;
5use serde::{Deserialize, Serialize};
6use std::time;
7use tokio::runtime::{Builder, Runtime};
8use tracing::{info, warn};
9
10use agp_config::component::configuration::ConfigurationError;
11
12#[derive(Clone, Debug, Deserialize, Serialize)]
13pub struct RuntimeConfiguration {
14 #[serde(default = "default_n_cores")]
16 n_cores: usize,
17
18 #[serde(default = "default_thread_name")]
20 thread_name: String,
21
22 #[serde(
24 default = "default_drain_timeout",
25 deserialize_with = "deserialize_duration"
26 )]
27 drain_timeout: time::Duration,
28}
29
30impl Default for RuntimeConfiguration {
31 fn default() -> Self {
32 RuntimeConfiguration {
33 n_cores: default_n_cores(),
34 thread_name: default_thread_name(),
35 drain_timeout: default_drain_timeout(),
36 }
37 }
38}
39
40fn default_n_cores() -> usize {
41 0
43}
44
45fn default_thread_name() -> String {
46 "gateway".to_string()
47}
48
49fn default_drain_timeout() -> time::Duration {
50 time::Duration::from_secs(10)
51}
52
53impl RuntimeConfiguration {
54 pub fn new() -> Self {
55 RuntimeConfiguration::default()
56 }
57
58 pub fn with_cores(n_cores: usize) -> Self {
59 RuntimeConfiguration {
60 n_cores,
61 ..RuntimeConfiguration::default()
62 }
63 }
64
65 pub fn with_thread_name(thread_name: &str) -> Self {
66 RuntimeConfiguration {
67 thread_name: thread_name.to_string(),
68 ..RuntimeConfiguration::default()
69 }
70 }
71
72 pub fn with_drain_timeout(drain_timeout: time::Duration) -> Self {
73 RuntimeConfiguration {
74 drain_timeout,
75 ..RuntimeConfiguration::default()
76 }
77 }
78
79 pub fn n_cores(&self) -> usize {
80 self.n_cores
81 }
82
83 pub fn thread_name(&self) -> &str {
84 &self.thread_name
85 }
86
87 pub fn drain_timeout(&self) -> time::Duration {
88 self.drain_timeout
89 }
90}
91
92pub struct GatewayRuntime {
93 pub config: RuntimeConfiguration,
95
96 pub runtime: Runtime,
98}
99
100pub fn build(config: &RuntimeConfiguration) -> Result<GatewayRuntime, ConfigurationError> {
101 let n_cpu = num_cpus::get();
102 debug_assert!(n_cpu > 0, "failed to get number of CPUs");
103
104 let cores = if config.n_cores > n_cpu {
105 warn!(
106 "Requested number of cores ({}) is greater than available cores ({}). Using all available cores",
107 config.n_cores, n_cpu
108 );
109 n_cpu
110 } else if config.n_cores == 0 {
111 info!(
112 %n_cpu,
113 "Using all available cores",
114 );
115 n_cpu
116 } else {
117 config.n_cores
118 };
119
120 let runtime = match cores {
121 1 => {
122 info!("Using single-threaded runtime");
123 Builder::new_current_thread()
124 .enable_all()
125 .thread_name(config.thread_name.as_str())
126 .build()
127 .expect("failed to build single-thread runtime!")
128 }
129 _ => {
130 info!(%cores, "Using multi-threaded runtime");
131 Builder::new_multi_thread()
132 .enable_all()
133 .thread_name(config.thread_name.as_str())
134 .worker_threads(cores)
135 .max_blocking_threads(cores)
136 .build()
137 .expect("failed to build threaded runtime!")
138 }
139 };
140
141 Ok(GatewayRuntime {
142 config: config.clone(),
143 runtime,
144 })
145}
146
147#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_runtime_configuration() {
154 let config = RuntimeConfiguration::default();
155 assert_eq!(config.n_cores, 0);
156 assert_eq!(config.thread_name, "gateway");
157 assert_eq!(config.drain_timeout, time::Duration::from_secs(10));
158
159 let config = RuntimeConfiguration {
160 n_cores: 1,
161 thread_name: "test".to_string(),
162 drain_timeout: time::Duration::from_secs(5),
163 };
164 assert_eq!(config.n_cores, 1);
165 assert_eq!(config.thread_name, "test");
166 assert_eq!(config.drain_timeout, time::Duration::from_secs(5));
167 }
168
169 #[test]
170 fn test_runtime_builder() {
171 let config = RuntimeConfiguration::default();
172 let runtime = build(&config).unwrap();
173 assert_eq!(runtime.config.n_cores, 0);
174 }
175
176 #[test]
177 fn test_runtime_builder_with_cores() {
178 let config = RuntimeConfiguration {
179 n_cores: 3,
180 thread_name: "test".to_string(),
181 drain_timeout: time::Duration::from_secs(10),
182 };
183 let runtime = build(&config).unwrap();
184 assert_eq!(runtime.config.n_cores, 3);
185 assert_eq!(config.drain_timeout, time::Duration::from_secs(10));
186 }
187
188 #[test]
189 fn test_runtime_builder_with_invalid_cores() {
190 let config = RuntimeConfiguration {
191 n_cores: 100,
192 thread_name: "test".to_string(),
193 drain_timeout: time::Duration::from_secs(10),
194 };
195 let runtime = build(&config).unwrap();
196 assert_eq!(runtime.config.n_cores, 100);
197 assert_eq!(config.drain_timeout, time::Duration::from_secs(10));
198 }
199}