retrom_service/
lib.rs

1use diesel_async::pooled_connection::AsyncDieselConnectionManager;
2use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, CONTENT_TYPE};
3use hyper::{body::Incoming, Request};
4use hyper_util::rt::{TokioExecutor, TokioIo};
5use opentelemetry_otlp::OTEL_EXPORTER_OTLP_ENDPOINT;
6use retrom_db::run_migrations;
7use retrom_grpc_service::grpc_service;
8use retrom_rest_service::rest_service;
9use retrom_service_common::{config::ServerConfigManager, emulator_js};
10use retry::retry;
11use std::{net::SocketAddr, process::exit, sync::Arc};
12use tokio::{net::TcpListener, task::JoinHandle};
13use tower::ServiceExt;
14use tracing::{info_span, Instrument};
15
16#[cfg(feature = "embedded_db")]
17use retrom_db::embedded::DB_NAME;
18
19pub const DEFAULT_PORT: i32 = 5101;
20pub const DEFAULT_DB_URL: &str = "postgres://postgres:postgres@localhost/retrom";
21const CARGO_VERSION: &str = env!("CARGO_PKG_VERSION");
22
23#[tracing::instrument(name = "root_span")]
24pub async fn get_server(
25    db_params: Option<&str>,
26) -> (JoinHandle<Result<(), std::io::Error>>, SocketAddr) {
27    let _ = emulator_js::EmulatorJs::new().await;
28    let config_manager = match ServerConfigManager::new() {
29        Ok(config) => Arc::new(config),
30        Err(err) => {
31            tracing::error!("Could not load configuration: {:#?}", err);
32            exit(1)
33        }
34    };
35
36    if config_manager
37        .get_config()
38        .await
39        .telemetry
40        .is_some_and(|t| t.enabled)
41    {
42        tracing::info!(
43            "OpenTelemetry Tracing enabled: {:#?}",
44            std::env::var(OTEL_EXPORTER_OTLP_ENDPOINT).unwrap_or("endpoint unset".into())
45        );
46    } else {
47        tracing::warn!("OpenTelemetry Tracing is disabled, no telemetry data will be collected.");
48    }
49
50    let (mut port, mut db_url) = (DEFAULT_PORT, DEFAULT_DB_URL.to_string());
51
52    let conn_config = config_manager.get_config().await.connection;
53    if let Some(config_port) = conn_config.as_ref().and_then(|conn| conn.port) {
54        port = config_port;
55        tracing::info!("Using port from configuration: {}", port);
56    }
57
58    if let Some(config_db_url) = conn_config.as_ref().and_then(|conn| conn.db_url.clone()) {
59        db_url = config_db_url;
60        tracing::info!("Using database url from configuration: {}", db_url);
61    }
62
63    let mut addr: SocketAddr = format!("0.0.0.0:{port}").parse().unwrap();
64
65    #[cfg(feature = "embedded_db")]
66    let mut psql = None;
67
68    #[cfg(feature = "embedded_db")]
69    {
70        use core::panic;
71        let config_db_url = conn_config.and_then(|conn| conn.db_url);
72
73        if config_db_url.is_none() {
74            let mut db_url_with_params = db_url.clone();
75            if let Some(db_params) = db_params {
76                db_url_with_params.push_str(db_params);
77            }
78
79            psql.replace(
80                match retrom_db::embedded::start_embedded_db(&db_url_with_params).await {
81                    Ok(psql) => psql,
82                    Err(why) => {
83                        tracing::error!("Could not start embedded db: {:#?}", why);
84                        panic!("Could not start embedded db");
85                    }
86                },
87            );
88
89            // Port may be random, so get new db_url from running instance
90            if let Some(psql) = &psql {
91                db_url = psql.settings().url(DB_NAME);
92            }
93        } else {
94            tracing::debug!("Opting out of embedded db");
95        }
96    }
97
98    let pool_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(&db_url);
99
100    // pool used for REST service endpoints
101    let pool = deadpool::managed::Pool::builder(pool_config)
102        .max_size(
103            std::thread::available_parallelism()
104                .unwrap_or(std::num::NonZero::new(2_usize).unwrap())
105                .into(),
106        )
107        .build()
108        .expect("Could not create pool");
109
110    let db_url_clone = db_url.clone();
111    tokio::task::spawn_blocking(move || {
112        let mut conn = retry(retry::delay::Exponential::from_millis(100), || {
113            match retrom_db::get_db_connection_sync(&db_url_clone) {
114                Ok(conn) => retry::OperationResult::Ok(conn),
115                Err(diesel::ConnectionError::BadConnection(err)) => {
116                    tracing::info!("Error connecting to database, is the server running and accessible? Retrying...");
117                    retry::OperationResult::Retry(err)
118                },
119                _ => retry::OperationResult::Err("Could not connect to database".to_string())
120            }
121        }).expect("Could not connect to database");
122
123        let migrations = run_migrations(&mut conn).expect("Could not run migrations");
124
125        migrations
126            .into_iter()
127            .for_each(|migration| tracing::info!("Ran migration: {}", migration));
128    })
129    .await
130    .expect("Could not run migrations");
131
132    let pool_state = Arc::new(pool);
133
134    let rest_service = rest_service(pool_state.clone());
135    let grpc_service = grpc_service(&db_url, config_manager);
136
137    tracing::info!(
138        "Starting Retrom {} service at: {}",
139        CARGO_VERSION,
140        addr.to_string()
141    );
142
143    check_version_announcements().await;
144
145    let mut listener = TcpListener::bind(&addr).await;
146    while listener.is_err() {
147        let port = addr.port();
148
149        tracing::warn!("Could not bind to port {}, trying port {}", port, port + 1);
150        let new_port = port + 1;
151        addr.set_port(new_port);
152        listener = TcpListener::bind(&addr).await;
153    }
154
155    let listener = listener.expect("Could not bind to address");
156    let port = listener.local_addr().expect("Could not get local address");
157
158    let handle: JoinHandle<_> = tokio::spawn(
159        async move {
160            let server = async {
161                loop {
162                    let (socket, addr) = listener
163                        .accept()
164                        .await
165                        .expect("Could not accept connection");
166
167                    let grpc_service = grpc_service.clone();
168                    let rest_service = rest_service.clone();
169
170                    tokio::spawn(
171                        async move {
172                            let socket = TokioIo::new(socket);
173
174                            let hyper_service =
175                                hyper::service::service_fn(move |req: Request<Incoming>| {
176                                    let is_grpc = req
177                                        .headers()
178                                        .get(CONTENT_TYPE)
179                                        .map(|content_type| content_type.as_bytes())
180                                        .filter(|content_type| {
181                                            content_type.starts_with(b"application/grpc")
182                                        })
183                                        .is_some();
184
185                                    let is_grpc_preflight = req.method() == hyper::Method::OPTIONS
186                                        && req
187                                            .headers()
188                                            .get(ACCESS_CONTROL_REQUEST_HEADERS)
189                                            .map(|headers| {
190                                                headers.to_str().ok().map(|headers| {
191                                                    headers.contains("content-type")
192                                                        && headers.contains("grpc")
193                                                })
194                                            })
195                                            .is_some();
196
197                                    if is_grpc || is_grpc_preflight {
198                                        tracing::debug!(
199                                            "Routing request to gRPC service: {} {}",
200                                            req.method(),
201                                            req.uri().path()
202                                        );
203                                        grpc_service.clone().oneshot(req)
204                                    } else {
205                                        tracing::debug!(
206                                            "Routing request to REST service: {} {}",
207                                            req.method(),
208                                            req.uri().path()
209                                        );
210
211                                        rest_service.clone().oneshot(req)
212                                    }
213                                });
214
215                            if let Err(err) =
216                                hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
217                                    .serve_connection(socket, hyper_service)
218                                    .await
219                            {
220                                tracing::error!("Error serving connection for {}: {}", addr, err);
221                            }
222                        }
223                        .instrument(info_span!("connection", %addr)),
224                    );
225                }
226            }
227            .instrument(info_span!("server_loop"));
228
229            tokio::select! {
230                _ = server => {
231                    tracing::info!("Server exited");
232                }
233                _ = shutdown_signal() => {
234                    tracing::info!("Shutdown signal received");
235                }
236            }
237
238            #[cfg(feature = "embedded_db")]
239            if let Some(psql_running) = psql {
240                if let Err(why) = psql_running.stop().await {
241                    tracing::error!("Could not stop embedded db: {}", why);
242                }
243
244                tracing::info!("Embedded db stopped");
245            }
246
247            tracing::info!("Server stopped");
248
249            Ok::<(), std::io::Error>(())
250        }
251        .instrument(tracing::info_span!("server_task")),
252    );
253
254    (handle, port)
255}
256
257async fn check_version_announcements() {
258    let url = "https://raw.githubusercontent.com/JMBeresford/retrom/refs/heads/main/version-announcements.json";
259
260    let res = match reqwest::get(url).await {
261        Ok(res) => res,
262        Err(err) => {
263            tracing::error!("Could not fetch version announcements: {}", err);
264            return;
265        }
266    };
267
268    if !res.status().is_success() {
269        tracing::error!("Could not fetch version announcements: {}", res.status());
270        return;
271    }
272
273    let json = match res
274        .json::<retrom_codegen::retrom::VersionAnnouncementsPayload>()
275        .await
276    {
277        Ok(json) => json,
278        Err(err) => {
279            tracing::error!("Could not parse version announcements: {}", err);
280            return;
281        }
282    };
283
284    json.announcements.iter().for_each(|announcement| {
285        announcement.versions.iter().for_each(|version| {
286            if version == CARGO_VERSION {
287                match announcement.level.as_str() {
288                    "info" => tracing::info!("Version announcement: {}", announcement.message),
289                    "warn" | "warning" => {
290                        tracing::warn!("Version announcement: {}", announcement.message)
291                    }
292                    "error" => tracing::error!("Version announcement: {}", announcement.message),
293                    _ => tracing::debug!("Skipping version announcement: {}", announcement.message),
294                }
295            }
296        });
297    });
298}
299
300async fn shutdown_signal() {
301    #[cfg(windows)]
302    {
303        let _ = tokio::signal::ctrl_c().await;
304        tracing::info!("Received Ctrl+C, shutting down...");
305    }
306
307    #[cfg(not(windows))]
308    {
309        use futures::stream::StreamExt;
310        use signal_hook::consts::signal::*;
311        use signal_hook_tokio::Signals;
312
313        let mut signals =
314            Signals::new([SIGTERM, SIGINT, SIGQUIT]).expect("Could not create signal handler");
315
316        let handle = signals.handle();
317        let handle_signals = async move {
318            while let Some(signal) = signals.next().await {
319                match signal {
320                    SIGTERM | SIGINT | SIGQUIT => {
321                        break;
322                    }
323                    _ => {}
324                }
325            }
326        };
327
328        tokio::select! {
329             _ = handle_signals => {
330                tracing::info!("Received termination signal, shutting down...");
331            }
332            _ = tokio::signal::ctrl_c() => {
333                tracing::info!("Received Ctrl+C, shutting down...");
334            }
335        }
336
337        handle.close();
338    }
339}