1use std::env;
2#[cfg(unix)]
3use std::future::Future;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8#[cfg(unix)]
9use tokio::net::UnixListener;
10#[cfg(unix)]
11use tokio::signal;
12#[cfg(unix)]
13use tokio::time::sleep;
14#[cfg(unix)]
15use tokio_stream::wrappers::UnixListenerStream;
16#[cfg(unix)]
17use tonic::transport::Server;
18
19use crate::catalog::write_catalog;
20use crate::env::{
21 ENV_PROVIDER_NAME, ENV_PROVIDER_PARENT_PID, ENV_PROVIDER_SOCKET, ENV_WRITE_CATALOG,
22};
23use crate::error::{Error, Result};
24#[cfg(unix)]
25use crate::generated::v1::agent_server::AgentServer as AgentRpcServer;
26#[cfg(unix)]
27use crate::generated::v1::app_provider_server::AppProviderServer;
28#[cfg(unix)]
29use crate::generated::v1::authentication_server::AuthenticationServer as AuthenticationRpcServer;
30#[cfg(unix)]
31use crate::generated::v1::cache_server::CacheServer;
32#[cfg(unix)]
33use crate::generated::v1::provider_lifecycle_server::ProviderLifecycleServer;
34#[cfg(unix)]
35use crate::generated::v1::runtime_server::RuntimeServer as RuntimeProviderRpcServiceServer;
36#[cfg(unix)]
37use crate::generated::v1::s3_server::S3Server;
38#[cfg(unix)]
39use crate::generated::v1::secrets_server::SecretsServer as SecretsRpcServer;
40#[cfg(unix)]
41use crate::generated::v1::workflow_server::WorkflowServer as WorkflowRpcServer;
42use crate::provider_server::ProviderServer;
43use crate::{
44 AgentProvider, AuthenticationProvider, CacheProvider, Provider, Router, RuntimeProvider,
45 S3Provider, SecretsProvider, WorkflowProvider,
46};
47#[cfg(unix)]
48use crate::{
49 agent_provider::AgentServer, auth_server::AuthenticationServer, cache_server::CacheRpcServer,
50 runtime_provider_impl::RuntimeServer as RuntimeProviderRpcServer,
51 runtime_server::RuntimeServer, s3_provider::S3RpcServer, secrets_server::SecretsServer,
52 workflow_provider::WorkflowServer,
53};
54
55fn build_runtime_and_block_on<F, Fut>(f: F) -> Result<()>
56where
57 F: FnOnce() -> Fut,
58 Fut: std::future::Future<Output = Result<()>>,
59{
60 let runtime = tokio::runtime::Builder::new_multi_thread()
61 .enable_all()
62 .build()
63 .map_err(|error| Error::internal(error.to_string()))?;
64 runtime.block_on(f())
65}
66
67pub fn run_provider<P: Provider>(provider: Arc<P>, router: Router<P>) -> Result<()> {
69 build_runtime_and_block_on(|| serve_provider(provider, router))
70}
71
72pub fn run_authentication_provider<P: AuthenticationProvider>(provider: Arc<P>) -> Result<()> {
74 build_runtime_and_block_on(|| serve_authentication_provider(provider))
75}
76
77pub fn run_cache_provider<P: CacheProvider>(provider: Arc<P>) -> Result<()> {
79 build_runtime_and_block_on(|| serve_cache_provider(provider))
80}
81
82pub fn run_secrets_provider<P: SecretsProvider>(provider: Arc<P>) -> Result<()> {
84 build_runtime_and_block_on(|| serve_secrets_provider(provider))
85}
86
87pub fn run_s3_provider<P: S3Provider>(provider: Arc<P>) -> Result<()> {
89 build_runtime_and_block_on(|| serve_s3_provider(provider))
90}
91
92pub fn run_runtime_provider<P: RuntimeProvider>(provider: Arc<P>) -> Result<()> {
94 build_runtime_and_block_on(|| serve_runtime_provider(provider))
95}
96
97pub fn run_workflow_provider<P: WorkflowProvider>(provider: Arc<P>) -> Result<()> {
99 build_runtime_and_block_on(|| serve_workflow_provider(provider))
100}
101
102pub fn run_agent_provider<P: AgentProvider>(provider: Arc<P>) -> Result<()> {
104 build_runtime_and_block_on(|| serve_agent_provider(provider))
105}
106
107pub fn write_catalog_path<P>(router: &Router<P>, path: impl AsRef<Path>) -> Result<()> {
109 write_catalog(router.catalog(), path)
110}
111
112pub fn maybe_write_catalog<P>(router: &Router<P>) -> Result<bool> {
115 let Some(path) = env::var_os(ENV_WRITE_CATALOG) else {
116 return Ok(false);
117 };
118
119 let catalog = if let Ok(name) = env::var(ENV_PROVIDER_NAME) {
120 router.catalog().clone().with_name(name)
121 } else {
122 router.catalog().clone()
123 };
124
125 write_catalog(&catalog, PathBuf::from(path))?;
126 Ok(true)
127}
128
129#[cfg(unix)]
130pub async fn serve_provider<P>(provider: Arc<P>, router: Router<P>) -> Result<()>
132where
133 P: Provider,
134{
135 if maybe_write_catalog(&router)? {
136 return Ok(());
137 }
138 let server = ProviderServer::new(Arc::clone(&provider), router);
139 serve_unix_provider(
140 provider,
141 move |incoming, provider| {
142 Server::builder()
143 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_provider(
144 Arc::clone(&provider),
145 )))
146 .add_service(AppProviderServer::new(server))
147 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
148 },
149 |provider| async move { provider.close().await },
150 )
151 .await
152}
153
154#[cfg(unix)]
155pub async fn serve_authentication_provider<P>(provider: Arc<P>) -> Result<()>
157where
158 P: AuthenticationProvider,
159{
160 serve_unix_provider(
161 provider,
162 move |incoming, provider| {
163 let auth_server = AuthenticationServer::new(Arc::clone(&provider));
164 Server::builder()
165 .add_service(ProviderLifecycleServer::new(
166 RuntimeServer::for_authentication(Arc::clone(&provider)),
167 ))
168 .add_service(AuthenticationRpcServer::new(auth_server))
169 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
170 },
171 |provider| async move { provider.close().await },
172 )
173 .await
174}
175
176#[cfg(not(unix))]
177pub async fn serve_authentication_provider<P>(_provider: Arc<P>) -> Result<()>
178where
179 P: AuthenticationProvider,
180{
181 Err(Error::internal(
182 "unix sockets are unsupported on this platform",
183 ))
184}
185
186#[cfg(unix)]
187pub async fn serve_cache_provider<P>(provider: Arc<P>) -> Result<()>
189where
190 P: CacheProvider,
191{
192 serve_unix_provider(
193 provider,
194 move |incoming, provider| {
195 Server::builder()
196 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_cache(
197 Arc::clone(&provider),
198 )))
199 .add_service(CacheServer::new(CacheRpcServer::new(Arc::clone(&provider))))
200 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
201 },
202 |provider| async move { provider.close().await },
203 )
204 .await
205}
206
207#[cfg(unix)]
208pub async fn serve_secrets_provider<P>(provider: Arc<P>) -> Result<()>
210where
211 P: SecretsProvider,
212{
213 serve_unix_provider(
214 provider,
215 move |incoming, provider| {
216 Server::builder()
217 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_secrets(
218 Arc::clone(&provider),
219 )))
220 .add_service(SecretsRpcServer::new(SecretsServer::new(Arc::clone(
221 &provider,
222 ))))
223 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
224 },
225 |provider| async move { provider.close().await },
226 )
227 .await
228}
229
230#[cfg(unix)]
231pub async fn serve_s3_provider<P>(provider: Arc<P>) -> Result<()>
233where
234 P: S3Provider,
235{
236 serve_unix_provider(
237 provider,
238 move |incoming, provider| {
239 Server::builder()
240 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_s3(
241 Arc::clone(&provider),
242 )))
243 .add_service(S3Server::new(S3RpcServer::new(Arc::clone(&provider))))
244 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
245 },
246 |provider| async move { provider.close().await },
247 )
248 .await
249}
250
251#[cfg(unix)]
252pub async fn serve_runtime_provider<P>(provider: Arc<P>) -> Result<()>
254where
255 P: RuntimeProvider,
256{
257 serve_unix_provider(
258 provider,
259 move |incoming, provider| {
260 Server::builder()
261 .add_service(ProviderLifecycleServer::new(
262 RuntimeServer::for_runtime_provider(Arc::clone(&provider)),
263 ))
264 .add_service(RuntimeProviderRpcServiceServer::new(
265 RuntimeProviderRpcServer::new(Arc::clone(&provider)),
266 ))
267 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
268 },
269 |provider| async move { provider.close().await },
270 )
271 .await
272}
273
274#[cfg(unix)]
275pub async fn serve_workflow_provider<P>(provider: Arc<P>) -> Result<()>
277where
278 P: WorkflowProvider,
279{
280 serve_unix_provider(
281 provider,
282 move |incoming, provider| {
283 Server::builder()
284 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_workflow(
285 Arc::clone(&provider),
286 )))
287 .add_service(WorkflowRpcServer::new(WorkflowServer::new(Arc::clone(
288 &provider,
289 ))))
290 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
291 },
292 |provider| async move { provider.close().await },
293 )
294 .await
295}
296
297#[cfg(unix)]
298pub async fn serve_agent_provider<P>(provider: Arc<P>) -> Result<()>
300where
301 P: AgentProvider,
302{
303 serve_unix_provider(
304 provider,
305 move |incoming, provider| {
306 Server::builder()
307 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_agent(
308 Arc::clone(&provider),
309 )))
310 .add_service(AgentRpcServer::new(AgentServer::new(Arc::clone(&provider))))
311 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
312 },
313 |provider| async move { provider.close().await },
314 )
315 .await
316}
317
318#[cfg(not(unix))]
319pub async fn serve_provider<P>(_provider: Arc<P>, router: Router<P>) -> Result<()>
320where
321 P: Provider,
322{
323 if maybe_write_catalog(&router)? {
324 return Ok(());
325 }
326 Err(Error::internal(
327 "unix sockets are unsupported on this platform",
328 ))
329}
330
331#[cfg(not(unix))]
332pub async fn serve_cache_provider<P>(_provider: Arc<P>) -> Result<()>
333where
334 P: CacheProvider,
335{
336 Err(Error::internal(
337 "unix sockets are unsupported on this platform",
338 ))
339}
340
341#[cfg(not(unix))]
342pub async fn serve_secrets_provider<P>(_provider: Arc<P>) -> Result<()>
343where
344 P: SecretsProvider,
345{
346 Err(Error::internal(
347 "unix sockets are unsupported on this platform",
348 ))
349}
350
351#[cfg(not(unix))]
352pub async fn serve_s3_provider<P>(_provider: Arc<P>) -> Result<()>
353where
354 P: S3Provider,
355{
356 Err(Error::internal(
357 "unix sockets are unsupported on this platform",
358 ))
359}
360
361#[cfg(not(unix))]
362pub async fn serve_runtime_provider<P>(_provider: Arc<P>) -> Result<()>
363where
364 P: RuntimeProvider,
365{
366 Err(Error::internal(
367 "unix sockets are unsupported on this platform",
368 ))
369}
370
371#[cfg(not(unix))]
372pub async fn serve_workflow_provider<P>(_provider: Arc<P>) -> Result<()>
373where
374 P: WorkflowProvider,
375{
376 Err(Error::internal(
377 "unix sockets are unsupported on this platform",
378 ))
379}
380
381#[cfg(not(unix))]
382pub async fn serve_agent_provider<P>(_provider: Arc<P>) -> Result<()>
383where
384 P: AgentProvider,
385{
386 Err(Error::internal(
387 "unix sockets are unsupported on this platform",
388 ))
389}
390
391#[cfg(unix)]
392async fn shutdown_signal(parent_pid: Option<u32>) {
393 let ctrl_c = async {
394 let _ = signal::ctrl_c().await;
395 };
396
397 tokio::pin!(ctrl_c);
398
399 if let Some(parent_pid) = parent_pid {
400 tokio::select! {
401 _ = &mut ctrl_c => {}
402 _ = watch_parent(parent_pid) => {}
403 }
404 return;
405 }
406
407 ctrl_c.await;
408}
409
410#[cfg(unix)]
411async fn serve_unix_provider<P, F, S, C, CF>(provider: Arc<P>, serve: F, close: C) -> Result<()>
412where
413 P: Send + Sync,
414 F: FnOnce(UnixListenerStream, Arc<P>) -> S,
415 S: Future<Output = std::result::Result<(), tonic::transport::Error>>,
416 C: FnOnce(Arc<P>) -> CF,
417 CF: Future<Output = Result<()>>,
418{
419 let socket = env::var_os(ENV_PROVIDER_SOCKET)
420 .ok_or_else(|| Error::internal(format!("{ENV_PROVIDER_SOCKET} is required")))?;
421 let socket = PathBuf::from(socket);
422 if socket.exists() {
423 std::fs::remove_file(&socket)?;
424 }
425 if let Some(parent) = socket.parent()
426 && !parent.as_os_str().is_empty()
427 {
428 std::fs::create_dir_all(parent)?;
429 }
430
431 let listener = UnixListener::bind(&socket)?;
432 let incoming = UnixListenerStream::new(listener);
433 let serve_result = serve(incoming, Arc::clone(&provider))
434 .await
435 .map_err(Error::from);
436
437 let close_result = close(provider).await;
438 let _ = remove_socket(&socket);
439
440 serve_result?;
441 close_result
442}
443
444#[cfg(unix)]
445fn parent_pid() -> Option<u32> {
446 env::var(ENV_PROVIDER_PARENT_PID)
447 .ok()
448 .and_then(|value| value.parse::<u32>().ok())
449 .filter(|pid| *pid > 0)
450}
451
452#[cfg(unix)]
453async fn watch_parent(parent_pid: u32) {
454 loop {
455 if current_parent_pid() != parent_pid {
456 break;
457 }
458 sleep(Duration::from_millis(500)).await;
459 }
460}
461
462#[cfg(unix)]
463fn current_parent_pid() -> u32 {
464 unsafe { libc::getppid() as u32 }
465}
466
467#[cfg(unix)]
468fn remove_socket(path: &Path) -> std::io::Result<()> {
469 match std::fs::remove_file(path) {
470 Ok(()) => Ok(()),
471 Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
472 Err(error) => Err(error),
473 }
474}