retrom_service/
lib.rs

1use diesel_async::pooled_connection::AsyncDieselConnectionManager;
2use either::Either;
3use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, CONTENT_TYPE};
4use hyper::{service::make_service_fn, Server};
5use retrom_db::run_migrations;
6use retry::retry;
7use std::{
8    convert::Infallible,
9    net::SocketAddr,
10    pin::Pin,
11    process::exit,
12    sync::Arc,
13    task::{Context, Poll},
14};
15use tokio::task::JoinHandle;
16use tower::Service;
17use tracing::Instrument;
18
19#[cfg(feature = "embedded_db")]
20use retrom_db::embedded::DB_NAME;
21
22pub mod config;
23mod grpc;
24mod providers;
25mod rest;
26
27pub const DEFAULT_PORT: i32 = 5101;
28pub const DEFAULT_DB_URL: &str = "postgres://postgres:postgres@localhost/retrom";
29const CARGO_VERSION: &str = env!("CARGO_PKG_VERSION");
30
31#[tracing::instrument]
32pub async fn get_server(db_params: Option<&str>) -> (JoinHandle<()>, SocketAddr) {
33    let config_manager = match crate::config::ServerConfigManager::new() {
34        Ok(config) => Arc::new(config),
35        Err(err) => {
36            tracing::error!("Could not load configuration: {:#?}", err);
37            exit(1)
38        }
39    };
40
41    let (mut port, mut db_url) = (DEFAULT_PORT, DEFAULT_DB_URL.to_string());
42
43    let conn_config = config_manager.get_config().await.connection;
44    if let Some(config_port) = conn_config.as_ref().and_then(|conn| conn.port) {
45        port = config_port;
46        tracing::info!("Using port from configuration: {}", port);
47    }
48
49    if let Some(config_db_url) = conn_config.as_ref().and_then(|conn| conn.db_url.clone()) {
50        db_url = config_db_url;
51        tracing::info!("Using database url from configuration: {}", db_url);
52    }
53
54    let mut addr: SocketAddr = format!("0.0.0.0:{port}").parse().unwrap();
55
56    #[cfg(feature = "embedded_db")]
57    let mut psql = None;
58
59    #[cfg(feature = "embedded_db")]
60    {
61        use core::panic;
62        let config_db_url = conn_config.and_then(|conn| conn.db_url);
63
64        if config_db_url.is_none() {
65            use retrom_db::embedded::PgCtlFailsafeOperations;
66
67            let mut db_url_with_params = db_url.clone();
68            if let Some(db_params) = db_params {
69                db_url_with_params.push_str(db_params);
70            }
71
72            psql.replace(
73                match retrom_db::embedded::start_embedded_db(&db_url_with_params).await {
74                    Ok(psql) => psql,
75                    Err(why) => {
76                        tracing::error!("Could not start embedded db: {:#?}", why);
77                        panic!("Could not start embedded db");
78                    }
79                },
80            );
81
82            // Port may be random, so get new db_url from running instance
83            if let Some(psql) = &psql {
84                db_url = psql.settings().url(DB_NAME);
85            }
86        } else {
87            tracing::debug!("Opting out of embedded db");
88        }
89    }
90
91    let pool_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(&db_url);
92
93    // pool used for REST service endpoints
94    let pool = deadpool::managed::Pool::builder(pool_config)
95        .max_size(
96            std::thread::available_parallelism()
97                .unwrap_or(std::num::NonZero::new(2_usize).unwrap())
98                .into(),
99        )
100        .build()
101        .expect("Could not create pool");
102
103    let db_url_clone = db_url.clone();
104    tokio::task::spawn_blocking(move || {
105        let mut conn = retry(retry::delay::Exponential::from_millis(100), || {
106            match retrom_db::get_db_connection_sync(&db_url_clone) {
107                Ok(conn) => retry::OperationResult::Ok(conn),
108                Err(diesel::ConnectionError::BadConnection(err)) => {
109                    tracing::info!("Error connecting to database, is the server running and accessible? Retrying...");
110                    retry::OperationResult::Retry(err)
111                },
112                _ => retry::OperationResult::Err("Could not connect to database".to_string())
113            }
114        }).expect("Could not connect to database");
115
116        let migrations = run_migrations(&mut conn).expect("Could not run migrations");
117
118        migrations
119            .into_iter()
120            .for_each(|migration| tracing::info!("Ran migration: {}", migration));
121    })
122    .instrument(tracing::info_span!("run_migrations"))
123    .await
124    .expect("Could not run migrations");
125
126    let pool_state = Arc::new(pool);
127
128    let rest_service = warp::service(rest::rest_service(pool_state.clone()));
129    let grpc_service = grpc::grpc_service(&db_url, config_manager);
130
131    tracing::info!(
132        "Starting Retrom {} service at: {}",
133        CARGO_VERSION,
134        addr.to_string()
135    );
136
137    check_version_announcements().await;
138
139    let mut binding = Server::try_bind(&addr);
140
141    while binding.is_err() {
142        let port = addr.port();
143
144        tracing::warn!("Could not bind to port {}, trying port {}", port, port + 1);
145        let new_port = port + 1;
146        addr.set_port(new_port);
147        binding = Server::try_bind(&addr);
148    }
149
150    let server = binding.unwrap().serve(make_service_fn(move |_| {
151        let mut rest_service = rest_service.clone();
152        let mut grpc_service = grpc_service.clone();
153        std::future::ready(Ok::<_, Infallible>(tower::service_fn(
154            move |req: hyper::Request<hyper::Body>| match is_grpc_request(&req) {
155                false => Either::Left({
156                    let res = rest_service.call(req);
157                    Box::pin(async move {
158                        let res = res.await.map(|res| res.map(EitherBody::Left))?;
159                        Ok::<_, Error>(res)
160                    })
161                }),
162                true => Either::Right({
163                    let res = grpc_service.call(req);
164                    Box::pin(async move {
165                        let res = res.await.map(|res| res.map(EitherBody::Right))?;
166                        Ok::<_, Error>(res)
167                    })
168                }),
169            },
170        )))
171    }));
172
173    let port = server.local_addr();
174
175    let handle: JoinHandle<()> = tokio::spawn(
176        async move {
177            if let Err(why) = server.await {
178                tracing::error!("Server error: {}", why);
179            }
180
181            #[cfg(feature = "embedded_db")]
182            if let Some(psql_running) = psql {
183                use retrom_db::embedded::PgCtlFailsafeOperations;
184
185                if let Err(why) = psql_running.failsafe_stop().await {
186                    tracing::error!("Could not stop embedded db: {}", why);
187                }
188
189                tracing::info!("Embedded db stopped");
190            }
191
192            tracing::info!("Server stopped");
193        }
194        .instrument(tracing::info_span!("server_handle")),
195    );
196
197    (handle, port)
198}
199
200fn is_grpc_request(req: &hyper::Request<hyper::Body>) -> bool {
201    let is_grpc = req
202        .headers()
203        .get(CONTENT_TYPE)
204        .map(|content_type| content_type.as_bytes())
205        .filter(|content_type| content_type.starts_with(b"application/grpc"))
206        .is_some();
207
208    let is_grpc_preflight = req.method() == hyper::Method::OPTIONS
209        && req
210            .headers()
211            .get(ACCESS_CONTROL_REQUEST_HEADERS)
212            .map(|headers| {
213                headers
214                    .to_str()
215                    .ok()
216                    .map(|headers| headers.contains("content-type") && headers.contains("grpc"))
217            })
218            .is_some();
219
220    is_grpc || is_grpc_preflight
221}
222
223async fn check_version_announcements() {
224    let url = "https://raw.githubusercontent.com/JMBeresford/retrom/refs/heads/main/version-announcements.json";
225
226    let res = match reqwest::get(url).await {
227        Ok(res) => res,
228        Err(err) => {
229            tracing::error!("Could not fetch version announcements: {}", err);
230            return;
231        }
232    };
233
234    if !res.status().is_success() {
235        tracing::error!("Could not fetch version announcements: {}", res.status());
236        return;
237    }
238
239    let json = match res
240        .json::<retrom_codegen::retrom::VersionAnnouncementsPayload>()
241        .await
242    {
243        Ok(json) => json,
244        Err(err) => {
245            tracing::error!("Could not parse version announcements: {}", err);
246            return;
247        }
248    };
249
250    json.announcements.iter().for_each(|announcement| {
251        announcement.versions.iter().for_each(|version| {
252            if version == CARGO_VERSION {
253                match announcement.level.as_str() {
254                    "info" => tracing::info!("Version announcement: {}", announcement.message),
255                    "warn" | "warning" => {
256                        tracing::warn!("Version announcement: {}", announcement.message)
257                    }
258                    "error" => tracing::error!("Version announcement: {}", announcement.message),
259                    _ => tracing::debug!("Skipping version announcement: {}", announcement.message),
260                }
261            }
262        });
263    });
264}
265
266type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
267
268enum EitherBody<A, B> {
269    Left(A),
270    Right(B),
271}
272
273impl<A, B> http_body::Body for EitherBody<A, B>
274where
275    A: http_body::Body + Send + Unpin,
276    B: http_body::Body<Data = A::Data> + Send + Unpin,
277    A::Error: Into<Error>,
278    B::Error: Into<Error>,
279{
280    type Data = A::Data;
281    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
282
283    fn is_end_stream(&self) -> bool {
284        match self {
285            EitherBody::Left(b) => b.is_end_stream(),
286            EitherBody::Right(b) => b.is_end_stream(),
287        }
288    }
289
290    fn poll_data(
291        self: Pin<&mut Self>,
292        cx: &mut Context<'_>,
293    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
294        match self.get_mut() {
295            EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err),
296            EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err),
297        }
298    }
299
300    fn poll_trailers(
301        self: Pin<&mut Self>,
302        cx: &mut Context<'_>,
303    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
304        match self.get_mut() {
305            EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
306            EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
307        }
308    }
309}
310
311fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result<T, Error>> {
312    err.map(|e| e.map_err(Into::into))
313}