1use diesel_async::pooled_connection::AsyncDieselConnectionManager;
2use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, CONTENT_TYPE};
3use hyper::{body::Incoming, Request};
4use hyper_util::rt::{TokioExecutor, TokioIo};
5use opentelemetry_otlp::OTEL_EXPORTER_OTLP_ENDPOINT;
6use retrom_db::run_migrations;
7use retrom_grpc_service::grpc_service;
8use retrom_rest_service::rest_service;
9use retrom_service_common::{config::ServerConfigManager, emulator_js};
10use retry::retry;
11use std::{net::SocketAddr, process::exit, sync::Arc};
12use tokio::{net::TcpListener, task::JoinHandle};
13use tower::ServiceExt;
14use tracing::{info_span, Instrument};
15
16#[cfg(feature = "embedded_db")]
17use retrom_db::embedded::DB_NAME;
18
19pub const DEFAULT_PORT: i32 = 5101;
20pub const DEFAULT_DB_URL: &str = "postgres://postgres:postgres@localhost/retrom";
21const CARGO_VERSION: &str = env!("CARGO_PKG_VERSION");
22
23#[tracing::instrument(name = "root_span")]
24pub async fn get_server(
25 db_params: Option<&str>,
26) -> (JoinHandle<Result<(), std::io::Error>>, SocketAddr) {
27 let _ = emulator_js::EmulatorJs::new().await;
28 let config_manager = match ServerConfigManager::new() {
29 Ok(config) => Arc::new(config),
30 Err(err) => {
31 tracing::error!("Could not load configuration: {:#?}", err);
32 exit(1)
33 }
34 };
35
36 if config_manager
37 .get_config()
38 .await
39 .telemetry
40 .is_some_and(|t| t.enabled)
41 {
42 tracing::info!(
43 "OpenTelemetry Tracing enabled: {:#?}",
44 std::env::var(OTEL_EXPORTER_OTLP_ENDPOINT).unwrap_or("endpoint unset".into())
45 );
46 } else {
47 tracing::warn!("OpenTelemetry Tracing is disabled, no telemetry data will be collected.");
48 }
49
50 let (mut port, mut db_url) = (DEFAULT_PORT, DEFAULT_DB_URL.to_string());
51
52 let conn_config = config_manager.get_config().await.connection;
53 if let Some(config_port) = conn_config.as_ref().and_then(|conn| conn.port) {
54 port = config_port;
55 tracing::info!("Using port from configuration: {}", port);
56 }
57
58 if let Some(config_db_url) = conn_config.as_ref().and_then(|conn| conn.db_url.clone()) {
59 db_url = config_db_url;
60 tracing::info!("Using database url from configuration: {}", db_url);
61 }
62
63 let mut addr: SocketAddr = format!("0.0.0.0:{port}").parse().unwrap();
64
65 #[cfg(feature = "embedded_db")]
66 let mut psql = None;
67
68 #[cfg(feature = "embedded_db")]
69 {
70 use core::panic;
71 let config_db_url = conn_config.and_then(|conn| conn.db_url);
72
73 if config_db_url.is_none() {
74 let mut db_url_with_params = db_url.clone();
75 if let Some(db_params) = db_params {
76 db_url_with_params.push_str(db_params);
77 }
78
79 psql.replace(
80 match retrom_db::embedded::start_embedded_db(&db_url_with_params).await {
81 Ok(psql) => psql,
82 Err(why) => {
83 tracing::error!("Could not start embedded db: {:#?}", why);
84 panic!("Could not start embedded db");
85 }
86 },
87 );
88
89 if let Some(psql) = &psql {
91 db_url = psql.settings().url(DB_NAME);
92 }
93 } else {
94 tracing::debug!("Opting out of embedded db");
95 }
96 }
97
98 let pool_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(&db_url);
99
100 let pool = deadpool::managed::Pool::builder(pool_config)
102 .max_size(
103 std::thread::available_parallelism()
104 .unwrap_or(std::num::NonZero::new(2_usize).unwrap())
105 .into(),
106 )
107 .build()
108 .expect("Could not create pool");
109
110 let db_url_clone = db_url.clone();
111 tokio::task::spawn_blocking(move || {
112 let mut conn = retry(retry::delay::Exponential::from_millis(100), || {
113 match retrom_db::get_db_connection_sync(&db_url_clone) {
114 Ok(conn) => retry::OperationResult::Ok(conn),
115 Err(diesel::ConnectionError::BadConnection(err)) => {
116 tracing::info!("Error connecting to database, is the server running and accessible? Retrying...");
117 retry::OperationResult::Retry(err)
118 },
119 _ => retry::OperationResult::Err("Could not connect to database".to_string())
120 }
121 }).expect("Could not connect to database");
122
123 let migrations = run_migrations(&mut conn).expect("Could not run migrations");
124
125 migrations
126 .into_iter()
127 .for_each(|migration| tracing::info!("Ran migration: {}", migration));
128 })
129 .await
130 .expect("Could not run migrations");
131
132 let pool_state = Arc::new(pool);
133
134 let rest_service = rest_service(pool_state.clone());
135 let grpc_service = grpc_service(&db_url, config_manager);
136
137 tracing::info!(
138 "Starting Retrom {} service at: {}",
139 CARGO_VERSION,
140 addr.to_string()
141 );
142
143 check_version_announcements().await;
144
145 let mut listener = TcpListener::bind(&addr).await;
146 while listener.is_err() {
147 let port = addr.port();
148
149 tracing::warn!("Could not bind to port {}, trying port {}", port, port + 1);
150 let new_port = port + 1;
151 addr.set_port(new_port);
152 listener = TcpListener::bind(&addr).await;
153 }
154
155 let listener = listener.expect("Could not bind to address");
156 let port = listener.local_addr().expect("Could not get local address");
157
158 let handle: JoinHandle<_> = tokio::spawn(
159 async move {
160 let server = async {
161 loop {
162 let (socket, addr) = listener
163 .accept()
164 .await
165 .expect("Could not accept connection");
166
167 let grpc_service = grpc_service.clone();
168 let rest_service = rest_service.clone();
169
170 tokio::spawn(
171 async move {
172 let socket = TokioIo::new(socket);
173
174 let hyper_service =
175 hyper::service::service_fn(move |req: Request<Incoming>| {
176 let is_grpc = req
177 .headers()
178 .get(CONTENT_TYPE)
179 .map(|content_type| content_type.as_bytes())
180 .filter(|content_type| {
181 content_type.starts_with(b"application/grpc")
182 })
183 .is_some();
184
185 let is_grpc_preflight = req.method() == hyper::Method::OPTIONS
186 && req
187 .headers()
188 .get(ACCESS_CONTROL_REQUEST_HEADERS)
189 .map(|headers| {
190 headers.to_str().ok().map(|headers| {
191 headers.contains("content-type")
192 && headers.contains("grpc")
193 })
194 })
195 .is_some();
196
197 if is_grpc || is_grpc_preflight {
198 tracing::debug!(
199 "Routing request to gRPC service: {} {}",
200 req.method(),
201 req.uri().path()
202 );
203 grpc_service.clone().oneshot(req)
204 } else {
205 tracing::debug!(
206 "Routing request to REST service: {} {}",
207 req.method(),
208 req.uri().path()
209 );
210
211 rest_service.clone().oneshot(req)
212 }
213 });
214
215 if let Err(err) =
216 hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
217 .serve_connection(socket, hyper_service)
218 .await
219 {
220 tracing::error!("Error serving connection for {}: {}", addr, err);
221 }
222 }
223 .instrument(info_span!("connection", %addr)),
224 );
225 }
226 }
227 .instrument(info_span!("server_loop"));
228
229 tokio::select! {
230 _ = server => {
231 tracing::info!("Server exited");
232 }
233 _ = shutdown_signal() => {
234 tracing::info!("Shutdown signal received");
235 }
236 }
237
238 #[cfg(feature = "embedded_db")]
239 if let Some(psql_running) = psql {
240 if let Err(why) = psql_running.stop().await {
241 tracing::error!("Could not stop embedded db: {}", why);
242 }
243
244 tracing::info!("Embedded db stopped");
245 }
246
247 tracing::info!("Server stopped");
248
249 Ok::<(), std::io::Error>(())
250 }
251 .instrument(tracing::info_span!("server_task")),
252 );
253
254 (handle, port)
255}
256
257async fn check_version_announcements() {
258 let url = "https://raw.githubusercontent.com/JMBeresford/retrom/refs/heads/main/version-announcements.json";
259
260 let res = match reqwest::get(url).await {
261 Ok(res) => res,
262 Err(err) => {
263 tracing::error!("Could not fetch version announcements: {}", err);
264 return;
265 }
266 };
267
268 if !res.status().is_success() {
269 tracing::error!("Could not fetch version announcements: {}", res.status());
270 return;
271 }
272
273 let json = match res
274 .json::<retrom_codegen::retrom::VersionAnnouncementsPayload>()
275 .await
276 {
277 Ok(json) => json,
278 Err(err) => {
279 tracing::error!("Could not parse version announcements: {}", err);
280 return;
281 }
282 };
283
284 json.announcements.iter().for_each(|announcement| {
285 announcement.versions.iter().for_each(|version| {
286 if version == CARGO_VERSION {
287 match announcement.level.as_str() {
288 "info" => tracing::info!("Version announcement: {}", announcement.message),
289 "warn" | "warning" => {
290 tracing::warn!("Version announcement: {}", announcement.message)
291 }
292 "error" => tracing::error!("Version announcement: {}", announcement.message),
293 _ => tracing::debug!("Skipping version announcement: {}", announcement.message),
294 }
295 }
296 });
297 });
298}
299
300async fn shutdown_signal() {
301 #[cfg(windows)]
302 {
303 let _ = tokio::signal::ctrl_c().await;
304 tracing::info!("Received Ctrl+C, shutting down...");
305 }
306
307 #[cfg(not(windows))]
308 {
309 use futures::stream::StreamExt;
310 use signal_hook::consts::signal::*;
311 use signal_hook_tokio::Signals;
312
313 let mut signals =
314 Signals::new([SIGTERM, SIGINT, SIGQUIT]).expect("Could not create signal handler");
315
316 let handle = signals.handle();
317 let handle_signals = async move {
318 while let Some(signal) = signals.next().await {
319 match signal {
320 SIGTERM | SIGINT | SIGQUIT => {
321 break;
322 }
323 _ => {}
324 }
325 }
326 };
327
328 tokio::select! {
329 _ = handle_signals => {
330 tracing::info!("Received termination signal, shutting down...");
331 }
332 _ = tokio::signal::ctrl_c() => {
333 tracing::info!("Received Ctrl+C, shutting down...");
334 }
335 }
336
337 handle.close();
338 }
339}