Skip to main content

gestalt/
runtime.rs

1use std::env;
2#[cfg(unix)]
3use std::future::Future;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8#[cfg(unix)]
9use tokio::net::UnixListener;
10#[cfg(unix)]
11use tokio::signal;
12#[cfg(unix)]
13use tokio::time::sleep;
14#[cfg(unix)]
15use tokio_stream::wrappers::UnixListenerStream;
16#[cfg(unix)]
17use tonic::transport::Server;
18
19use crate::catalog::write_catalog;
20use crate::env::{
21    ENV_PROVIDER_NAME, ENV_PROVIDER_PARENT_PID, ENV_PROVIDER_SOCKET, ENV_WRITE_CATALOG,
22};
23use crate::error::{Error, Result};
24#[cfg(unix)]
25use crate::generated::v1::auth_provider_server::AuthProviderServer;
26#[cfg(unix)]
27use crate::generated::v1::cache_server::CacheServer;
28#[cfg(unix)]
29use crate::generated::v1::integration_provider_server::IntegrationProviderServer;
30#[cfg(unix)]
31use crate::generated::v1::provider_lifecycle_server::ProviderLifecycleServer;
32#[cfg(unix)]
33use crate::generated::v1::s3_server::S3Server;
34#[cfg(unix)]
35use crate::generated::v1::secrets_provider_server::SecretsProviderServer;
36use crate::provider_server::ProviderServer;
37use crate::{AuthProvider, CacheProvider, S3Provider, SecretsProvider};
38use crate::{Provider, Router};
39#[cfg(unix)]
40use crate::{
41    auth_server::AuthServer, cache_server::CacheRpcServer, runtime_server::RuntimeServer,
42    secrets_server::SecretsServer,
43};
44
45fn build_runtime_and_block_on<F, Fut>(f: F) -> Result<()>
46where
47    F: FnOnce() -> Fut,
48    Fut: std::future::Future<Output = Result<()>>,
49{
50    let runtime = tokio::runtime::Builder::new_multi_thread()
51        .enable_all()
52        .build()
53        .map_err(|error| Error::internal(error.to_string()))?;
54    runtime.block_on(f())
55}
56
57pub fn run_provider<P: Provider>(provider: Arc<P>, router: Router<P>) -> Result<()> {
58    build_runtime_and_block_on(|| serve_provider(provider, router))
59}
60
61pub fn run_auth_provider<P: AuthProvider>(provider: Arc<P>) -> Result<()> {
62    build_runtime_and_block_on(|| serve_auth_provider(provider))
63}
64
65pub fn run_cache_provider<P: CacheProvider>(provider: Arc<P>) -> Result<()> {
66    build_runtime_and_block_on(|| serve_cache_provider(provider))
67}
68
69pub fn run_secrets_provider<P: SecretsProvider>(provider: Arc<P>) -> Result<()> {
70    build_runtime_and_block_on(|| serve_secrets_provider(provider))
71}
72
73pub fn run_s3_provider<P: S3Provider>(provider: Arc<P>) -> Result<()> {
74    build_runtime_and_block_on(|| serve_s3_provider(provider))
75}
76
77pub fn write_catalog_path<P>(router: &Router<P>, path: impl AsRef<Path>) -> Result<()> {
78    write_catalog(router.catalog(), path)
79}
80
81pub fn maybe_write_catalog<P>(router: &Router<P>) -> Result<bool> {
82    let Some(path) = env::var_os(ENV_WRITE_CATALOG) else {
83        return Ok(false);
84    };
85
86    let catalog = if let Ok(name) = env::var(ENV_PROVIDER_NAME) {
87        router.catalog().clone().with_name(name)
88    } else {
89        router.catalog().clone()
90    };
91
92    write_catalog(&catalog, PathBuf::from(path))?;
93    Ok(true)
94}
95
96#[cfg(unix)]
97pub async fn serve_provider<P>(provider: Arc<P>, router: Router<P>) -> Result<()>
98where
99    P: Provider,
100{
101    if maybe_write_catalog(&router)? {
102        return Ok(());
103    }
104    let server = ProviderServer::new(Arc::clone(&provider), router);
105    serve_unix_provider(
106        provider,
107        move |incoming, provider| {
108            Server::builder()
109                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_provider(
110                    Arc::clone(&provider),
111                )))
112                .add_service(IntegrationProviderServer::new(server))
113                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
114        },
115        |provider| async move { provider.close().await },
116    )
117    .await
118}
119
120#[cfg(unix)]
121pub async fn serve_auth_provider<P>(provider: Arc<P>) -> Result<()>
122where
123    P: AuthProvider,
124{
125    serve_unix_provider(
126        provider,
127        move |incoming, provider| {
128            Server::builder()
129                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_auth(
130                    Arc::clone(&provider),
131                )))
132                .add_service(AuthProviderServer::new(AuthServer::new(Arc::clone(
133                    &provider,
134                ))))
135                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
136        },
137        |provider| async move { provider.close().await },
138    )
139    .await
140}
141
142#[cfg(unix)]
143pub async fn serve_cache_provider<P>(provider: Arc<P>) -> Result<()>
144where
145    P: CacheProvider,
146{
147    serve_unix_provider(
148        provider,
149        move |incoming, provider| {
150            Server::builder()
151                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_cache(
152                    Arc::clone(&provider),
153                )))
154                .add_service(CacheServer::new(CacheRpcServer::new(Arc::clone(&provider))))
155                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
156        },
157        |provider| async move { provider.close().await },
158    )
159    .await
160}
161
162#[cfg(unix)]
163pub async fn serve_secrets_provider<P>(provider: Arc<P>) -> Result<()>
164where
165    P: SecretsProvider,
166{
167    serve_unix_provider(
168        provider,
169        move |incoming, provider| {
170            Server::builder()
171                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_secrets(
172                    Arc::clone(&provider),
173                )))
174                .add_service(SecretsProviderServer::new(SecretsServer::new(Arc::clone(
175                    &provider,
176                ))))
177                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
178        },
179        |provider| async move { provider.close().await },
180    )
181    .await
182}
183
184#[cfg(unix)]
185pub async fn serve_s3_provider<P>(provider: Arc<P>) -> Result<()>
186where
187    P: S3Provider,
188{
189    serve_unix_provider(
190        provider,
191        move |incoming, provider| {
192            Server::builder()
193                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_s3(
194                    Arc::clone(&provider),
195                )))
196                .add_service(S3Server::new(Arc::clone(&provider)))
197                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
198        },
199        |provider| async move { provider.close().await },
200    )
201    .await
202}
203
204#[cfg(not(unix))]
205pub async fn serve_provider<P>(_provider: Arc<P>, router: Router<P>) -> Result<()>
206where
207    P: Provider,
208{
209    if maybe_write_catalog(&router)? {
210        return Ok(());
211    }
212    Err(Error::internal(
213        "unix sockets are unsupported on this platform",
214    ))
215}
216
217#[cfg(not(unix))]
218pub async fn serve_auth_provider<P>(_provider: Arc<P>) -> Result<()>
219where
220    P: AuthProvider,
221{
222    Err(Error::internal(
223        "unix sockets are unsupported on this platform",
224    ))
225}
226
227#[cfg(not(unix))]
228pub async fn serve_cache_provider<P>(_provider: Arc<P>) -> Result<()>
229where
230    P: CacheProvider,
231{
232    Err(Error::internal(
233        "unix sockets are unsupported on this platform",
234    ))
235}
236
237#[cfg(not(unix))]
238pub async fn serve_secrets_provider<P>(_provider: Arc<P>) -> Result<()>
239where
240    P: SecretsProvider,
241{
242    Err(Error::internal(
243        "unix sockets are unsupported on this platform",
244    ))
245}
246
247#[cfg(not(unix))]
248pub async fn serve_s3_provider<P>(_provider: Arc<P>) -> Result<()>
249where
250    P: S3Provider,
251{
252    Err(Error::internal(
253        "unix sockets are unsupported on this platform",
254    ))
255}
256
257#[cfg(unix)]
258async fn shutdown_signal(parent_pid: Option<u32>) {
259    let ctrl_c = async {
260        let _ = signal::ctrl_c().await;
261    };
262
263    tokio::pin!(ctrl_c);
264
265    if let Some(parent_pid) = parent_pid {
266        tokio::select! {
267            _ = &mut ctrl_c => {}
268            _ = watch_parent(parent_pid) => {}
269        }
270        return;
271    }
272
273    ctrl_c.await;
274}
275
276#[cfg(unix)]
277async fn serve_unix_provider<P, F, S, C, CF>(provider: Arc<P>, serve: F, close: C) -> Result<()>
278where
279    P: Send + Sync,
280    F: FnOnce(UnixListenerStream, Arc<P>) -> S,
281    S: Future<Output = std::result::Result<(), tonic::transport::Error>>,
282    C: FnOnce(Arc<P>) -> CF,
283    CF: Future<Output = Result<()>>,
284{
285    let socket = env::var_os(ENV_PROVIDER_SOCKET)
286        .ok_or_else(|| Error::internal(format!("{ENV_PROVIDER_SOCKET} is required")))?;
287    let socket = PathBuf::from(socket);
288    if socket.exists() {
289        std::fs::remove_file(&socket)?;
290    }
291    if let Some(parent) = socket.parent()
292        && !parent.as_os_str().is_empty()
293    {
294        std::fs::create_dir_all(parent)?;
295    }
296
297    let listener = UnixListener::bind(&socket)?;
298    let incoming = UnixListenerStream::new(listener);
299    let serve_result = serve(incoming, Arc::clone(&provider))
300        .await
301        .map_err(Error::from);
302
303    let close_result = close(provider).await;
304    let _ = remove_socket(&socket);
305
306    serve_result?;
307    close_result
308}
309
310#[cfg(unix)]
311fn parent_pid() -> Option<u32> {
312    env::var(ENV_PROVIDER_PARENT_PID)
313        .ok()
314        .and_then(|value| value.parse::<u32>().ok())
315        .filter(|pid| *pid > 0)
316}
317
318#[cfg(unix)]
319async fn watch_parent(parent_pid: u32) {
320    loop {
321        if current_parent_pid() != parent_pid {
322            break;
323        }
324        sleep(Duration::from_millis(500)).await;
325    }
326}
327
328#[cfg(unix)]
329fn current_parent_pid() -> u32 {
330    unsafe { libc::getppid() as u32 }
331}
332
333#[cfg(unix)]
334fn remove_socket(path: &Path) -> std::io::Result<()> {
335    match std::fs::remove_file(path) {
336        Ok(()) => Ok(()),
337        Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
338        Err(error) => Err(error),
339    }
340}