1use diesel_async::pooled_connection::AsyncDieselConnectionManager;
2use either::Either;
3use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, CONTENT_TYPE};
4use hyper::{service::make_service_fn, Server};
5use retrom_db::run_migrations;
6use retry::retry;
7use std::{
8 convert::Infallible,
9 net::SocketAddr,
10 pin::Pin,
11 process::exit,
12 sync::Arc,
13 task::{Context, Poll},
14};
15use tokio::task::JoinHandle;
16use tower::Service;
17use tracing::Instrument;
18
19#[cfg(feature = "embedded_db")]
20use retrom_db::embedded::DB_NAME;
21
22pub mod config;
23mod grpc;
24mod providers;
25mod rest;
26
27pub const DEFAULT_PORT: i32 = 5101;
28pub const DEFAULT_DB_URL: &str = "postgres://postgres:postgres@localhost/retrom";
29const CARGO_VERSION: &str = env!("CARGO_PKG_VERSION");
30
31#[tracing::instrument]
32pub async fn get_server(db_params: Option<&str>) -> (JoinHandle<()>, SocketAddr) {
33 let config_manager = match crate::config::ServerConfigManager::new() {
34 Ok(config) => Arc::new(config),
35 Err(err) => {
36 tracing::error!("Could not load configuration: {:#?}", err);
37 exit(1)
38 }
39 };
40
41 let (mut port, mut db_url) = (DEFAULT_PORT, DEFAULT_DB_URL.to_string());
42
43 let conn_config = config_manager.get_config().await.connection;
44 if let Some(config_port) = conn_config.as_ref().and_then(|conn| conn.port) {
45 port = config_port;
46 tracing::info!("Using port from configuration: {}", port);
47 }
48
49 if let Some(config_db_url) = conn_config.as_ref().and_then(|conn| conn.db_url.clone()) {
50 db_url = config_db_url;
51 tracing::info!("Using database url from configuration: {}", db_url);
52 }
53
54 let mut addr: SocketAddr = format!("0.0.0.0:{port}").parse().unwrap();
55
56 #[cfg(feature = "embedded_db")]
57 let mut psql = None;
58
59 #[cfg(feature = "embedded_db")]
60 {
61 use core::panic;
62 let config_db_url = conn_config.and_then(|conn| conn.db_url);
63
64 if config_db_url.is_none() {
65 use retrom_db::embedded::PgCtlFailsafeOperations;
66
67 let mut db_url_with_params = db_url.clone();
68 if let Some(db_params) = db_params {
69 db_url_with_params.push_str(db_params);
70 }
71
72 psql.replace(
73 match retrom_db::embedded::start_embedded_db(&db_url_with_params).await {
74 Ok(psql) => psql,
75 Err(why) => {
76 tracing::error!("Could not start embedded db: {:#?}", why);
77 panic!("Could not start embedded db");
78 }
79 },
80 );
81
82 if let Some(psql) = &psql {
84 db_url = psql.settings().url(DB_NAME);
85 }
86 } else {
87 tracing::debug!("Opting out of embedded db");
88 }
89 }
90
91 let pool_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(&db_url);
92
93 let pool = deadpool::managed::Pool::builder(pool_config)
95 .max_size(
96 std::thread::available_parallelism()
97 .unwrap_or(std::num::NonZero::new(2_usize).unwrap())
98 .into(),
99 )
100 .build()
101 .expect("Could not create pool");
102
103 let db_url_clone = db_url.clone();
104 tokio::task::spawn_blocking(move || {
105 let mut conn = retry(retry::delay::Exponential::from_millis(100), || {
106 match retrom_db::get_db_connection_sync(&db_url_clone) {
107 Ok(conn) => retry::OperationResult::Ok(conn),
108 Err(diesel::ConnectionError::BadConnection(err)) => {
109 tracing::info!("Error connecting to database, is the server running and accessible? Retrying...");
110 retry::OperationResult::Retry(err)
111 },
112 _ => retry::OperationResult::Err("Could not connect to database".to_string())
113 }
114 }).expect("Could not connect to database");
115
116 let migrations = run_migrations(&mut conn).expect("Could not run migrations");
117
118 migrations
119 .into_iter()
120 .for_each(|migration| tracing::info!("Ran migration: {}", migration));
121 })
122 .instrument(tracing::info_span!("run_migrations"))
123 .await
124 .expect("Could not run migrations");
125
126 let pool_state = Arc::new(pool);
127
128 let rest_service = warp::service(rest::rest_service(pool_state.clone()));
129 let grpc_service = grpc::grpc_service(&db_url, config_manager);
130
131 tracing::info!(
132 "Starting Retrom {} service at: {}",
133 CARGO_VERSION,
134 addr.to_string()
135 );
136
137 check_version_announcements().await;
138
139 let mut binding = Server::try_bind(&addr);
140
141 while binding.is_err() {
142 let port = addr.port();
143
144 tracing::warn!("Could not bind to port {}, trying port {}", port, port + 1);
145 let new_port = port + 1;
146 addr.set_port(new_port);
147 binding = Server::try_bind(&addr);
148 }
149
150 let server = binding.unwrap().serve(make_service_fn(move |_| {
151 let mut rest_service = rest_service.clone();
152 let mut grpc_service = grpc_service.clone();
153 std::future::ready(Ok::<_, Infallible>(tower::service_fn(
154 move |req: hyper::Request<hyper::Body>| match is_grpc_request(&req) {
155 false => Either::Left({
156 let res = rest_service.call(req);
157 Box::pin(async move {
158 let res = res.await.map(|res| res.map(EitherBody::Left))?;
159 Ok::<_, Error>(res)
160 })
161 }),
162 true => Either::Right({
163 let res = grpc_service.call(req);
164 Box::pin(async move {
165 let res = res.await.map(|res| res.map(EitherBody::Right))?;
166 Ok::<_, Error>(res)
167 })
168 }),
169 },
170 )))
171 }));
172
173 let port = server.local_addr();
174
175 let handle: JoinHandle<()> = tokio::spawn(
176 async move {
177 if let Err(why) = server.await {
178 tracing::error!("Server error: {}", why);
179 }
180
181 #[cfg(feature = "embedded_db")]
182 if let Some(psql_running) = psql {
183 use retrom_db::embedded::PgCtlFailsafeOperations;
184
185 if let Err(why) = psql_running.failsafe_stop().await {
186 tracing::error!("Could not stop embedded db: {}", why);
187 }
188
189 tracing::info!("Embedded db stopped");
190 }
191
192 tracing::info!("Server stopped");
193 }
194 .instrument(tracing::info_span!("server_handle")),
195 );
196
197 (handle, port)
198}
199
200fn is_grpc_request(req: &hyper::Request<hyper::Body>) -> bool {
201 let is_grpc = req
202 .headers()
203 .get(CONTENT_TYPE)
204 .map(|content_type| content_type.as_bytes())
205 .filter(|content_type| content_type.starts_with(b"application/grpc"))
206 .is_some();
207
208 let is_grpc_preflight = req.method() == hyper::Method::OPTIONS
209 && req
210 .headers()
211 .get(ACCESS_CONTROL_REQUEST_HEADERS)
212 .map(|headers| {
213 headers
214 .to_str()
215 .ok()
216 .map(|headers| headers.contains("content-type") && headers.contains("grpc"))
217 })
218 .is_some();
219
220 is_grpc || is_grpc_preflight
221}
222
223async fn check_version_announcements() {
224 let url = "https://raw.githubusercontent.com/JMBeresford/retrom/refs/heads/main/version-announcements.json";
225
226 let res = match reqwest::get(url).await {
227 Ok(res) => res,
228 Err(err) => {
229 tracing::error!("Could not fetch version announcements: {}", err);
230 return;
231 }
232 };
233
234 if !res.status().is_success() {
235 tracing::error!("Could not fetch version announcements: {}", res.status());
236 return;
237 }
238
239 let json = match res
240 .json::<retrom_codegen::retrom::VersionAnnouncementsPayload>()
241 .await
242 {
243 Ok(json) => json,
244 Err(err) => {
245 tracing::error!("Could not parse version announcements: {}", err);
246 return;
247 }
248 };
249
250 json.announcements.iter().for_each(|announcement| {
251 announcement.versions.iter().for_each(|version| {
252 if version == CARGO_VERSION {
253 match announcement.level.as_str() {
254 "info" => tracing::info!("Version announcement: {}", announcement.message),
255 "warn" | "warning" => {
256 tracing::warn!("Version announcement: {}", announcement.message)
257 }
258 "error" => tracing::error!("Version announcement: {}", announcement.message),
259 _ => tracing::debug!("Skipping version announcement: {}", announcement.message),
260 }
261 }
262 });
263 });
264}
265
266type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
267
268enum EitherBody<A, B> {
269 Left(A),
270 Right(B),
271}
272
273impl<A, B> http_body::Body for EitherBody<A, B>
274where
275 A: http_body::Body + Send + Unpin,
276 B: http_body::Body<Data = A::Data> + Send + Unpin,
277 A::Error: Into<Error>,
278 B::Error: Into<Error>,
279{
280 type Data = A::Data;
281 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
282
283 fn is_end_stream(&self) -> bool {
284 match self {
285 EitherBody::Left(b) => b.is_end_stream(),
286 EitherBody::Right(b) => b.is_end_stream(),
287 }
288 }
289
290 fn poll_data(
291 self: Pin<&mut Self>,
292 cx: &mut Context<'_>,
293 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
294 match self.get_mut() {
295 EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err),
296 EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err),
297 }
298 }
299
300 fn poll_trailers(
301 self: Pin<&mut Self>,
302 cx: &mut Context<'_>,
303 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
304 match self.get_mut() {
305 EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
306 EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
307 }
308 }
309}
310
311fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result<T, Error>> {
312 err.map(|e| e.map_err(Into::into))
313}