1use std::env;
2#[cfg(unix)]
3use std::future::Future;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8#[cfg(unix)]
9use tokio::net::UnixListener;
10#[cfg(unix)]
11use tokio::signal;
12#[cfg(unix)]
13use tokio::time::sleep;
14#[cfg(unix)]
15use tokio_stream::wrappers::UnixListenerStream;
16#[cfg(unix)]
17use tonic::transport::Server;
18
19use crate::catalog::write_catalog;
20use crate::env::{
21 ENV_PROVIDER_NAME, ENV_PROVIDER_PARENT_PID, ENV_PROVIDER_SOCKET, ENV_WRITE_CATALOG,
22};
23use crate::error::{Error, Result};
24#[cfg(unix)]
25use crate::generated::v1::agent_provider_server::AgentProviderServer as AgentRpcServer;
26#[cfg(unix)]
27use crate::generated::v1::authentication_provider_server::AuthenticationProviderServer;
28#[cfg(unix)]
29use crate::generated::v1::cache_server::CacheServer;
30#[cfg(unix)]
31use crate::generated::v1::integration_provider_server::IntegrationProviderServer;
32#[cfg(unix)]
33use crate::generated::v1::plugin_runtime_provider_server::PluginRuntimeProviderServer;
34#[cfg(unix)]
35use crate::generated::v1::provider_lifecycle_server::ProviderLifecycleServer;
36#[cfg(unix)]
37use crate::generated::v1::s3_server::S3Server;
38#[cfg(unix)]
39use crate::generated::v1::secrets_provider_server::SecretsProviderServer;
40#[cfg(unix)]
41use crate::generated::v1::workflow_provider_server::WorkflowProviderServer as WorkflowRpcServer;
42use crate::provider_server::ProviderServer;
43use crate::{
44 AgentProvider, AuthenticationProvider, CacheProvider, PluginRuntimeProvider, Provider, Router,
45 S3Provider, SecretsProvider, WorkflowProvider,
46};
47#[cfg(unix)]
48use crate::{
49 agent::AgentServer, auth_server::AuthenticationServer, cache_server::CacheRpcServer,
50 plugin_runtime::PluginRuntimeServer, runtime_server::RuntimeServer,
51 secrets_server::SecretsServer, workflow::WorkflowServer,
52};
53
54fn build_runtime_and_block_on<F, Fut>(f: F) -> Result<()>
55where
56 F: FnOnce() -> Fut,
57 Fut: std::future::Future<Output = Result<()>>,
58{
59 let runtime = tokio::runtime::Builder::new_multi_thread()
60 .enable_all()
61 .build()
62 .map_err(|error| Error::internal(error.to_string()))?;
63 runtime.block_on(f())
64}
65
66pub fn run_provider<P: Provider>(provider: Arc<P>, router: Router<P>) -> Result<()> {
68 build_runtime_and_block_on(|| serve_provider(provider, router))
69}
70
71pub fn run_authentication_provider<P: AuthenticationProvider>(provider: Arc<P>) -> Result<()> {
73 build_runtime_and_block_on(|| serve_authentication_provider(provider))
74}
75
76pub fn run_cache_provider<P: CacheProvider>(provider: Arc<P>) -> Result<()> {
78 build_runtime_and_block_on(|| serve_cache_provider(provider))
79}
80
81pub fn run_secrets_provider<P: SecretsProvider>(provider: Arc<P>) -> Result<()> {
83 build_runtime_and_block_on(|| serve_secrets_provider(provider))
84}
85
86pub fn run_s3_provider<P: S3Provider>(provider: Arc<P>) -> Result<()> {
88 build_runtime_and_block_on(|| serve_s3_provider(provider))
89}
90
91pub fn run_plugin_runtime_provider<P: PluginRuntimeProvider>(provider: Arc<P>) -> Result<()> {
93 build_runtime_and_block_on(|| serve_plugin_runtime_provider(provider))
94}
95
96pub fn run_workflow_provider<P: WorkflowProvider>(provider: Arc<P>) -> Result<()> {
98 build_runtime_and_block_on(|| serve_workflow_provider(provider))
99}
100
101pub fn run_agent_provider<P: AgentProvider>(provider: Arc<P>) -> Result<()> {
103 build_runtime_and_block_on(|| serve_agent_provider(provider))
104}
105
106pub fn write_catalog_path<P>(router: &Router<P>, path: impl AsRef<Path>) -> Result<()> {
108 write_catalog(router.catalog(), path)
109}
110
111pub fn maybe_write_catalog<P>(router: &Router<P>) -> Result<bool> {
114 let Some(path) = env::var_os(ENV_WRITE_CATALOG) else {
115 return Ok(false);
116 };
117
118 let catalog = if let Ok(name) = env::var(ENV_PROVIDER_NAME) {
119 router.catalog().clone().with_name(name)
120 } else {
121 router.catalog().clone()
122 };
123
124 write_catalog(&catalog, PathBuf::from(path))?;
125 Ok(true)
126}
127
128#[cfg(unix)]
129pub async fn serve_provider<P>(provider: Arc<P>, router: Router<P>) -> Result<()>
131where
132 P: Provider,
133{
134 if maybe_write_catalog(&router)? {
135 return Ok(());
136 }
137 let server = ProviderServer::new(Arc::clone(&provider), router);
138 serve_unix_provider(
139 provider,
140 move |incoming, provider| {
141 Server::builder()
142 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_provider(
143 Arc::clone(&provider),
144 )))
145 .add_service(IntegrationProviderServer::new(server))
146 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
147 },
148 |provider| async move { provider.close().await },
149 )
150 .await
151}
152
153#[cfg(unix)]
154pub async fn serve_authentication_provider<P>(provider: Arc<P>) -> Result<()>
156where
157 P: AuthenticationProvider,
158{
159 serve_unix_provider(
160 provider,
161 move |incoming, provider| {
162 let auth_server = AuthenticationServer::new(Arc::clone(&provider));
163 Server::builder()
164 .add_service(ProviderLifecycleServer::new(
165 RuntimeServer::for_authentication(Arc::clone(&provider)),
166 ))
167 .add_service(AuthenticationProviderServer::new(auth_server))
168 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
169 },
170 |provider| async move { provider.close().await },
171 )
172 .await
173}
174
175#[cfg(not(unix))]
176pub async fn serve_authentication_provider<P>(_provider: Arc<P>) -> Result<()>
177where
178 P: AuthenticationProvider,
179{
180 Err(Error::internal(
181 "unix sockets are unsupported on this platform",
182 ))
183}
184
185#[cfg(unix)]
186pub async fn serve_cache_provider<P>(provider: Arc<P>) -> Result<()>
188where
189 P: CacheProvider,
190{
191 serve_unix_provider(
192 provider,
193 move |incoming, provider| {
194 Server::builder()
195 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_cache(
196 Arc::clone(&provider),
197 )))
198 .add_service(CacheServer::new(CacheRpcServer::new(Arc::clone(&provider))))
199 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
200 },
201 |provider| async move { provider.close().await },
202 )
203 .await
204}
205
206#[cfg(unix)]
207pub async fn serve_secrets_provider<P>(provider: Arc<P>) -> Result<()>
209where
210 P: SecretsProvider,
211{
212 serve_unix_provider(
213 provider,
214 move |incoming, provider| {
215 Server::builder()
216 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_secrets(
217 Arc::clone(&provider),
218 )))
219 .add_service(SecretsProviderServer::new(SecretsServer::new(Arc::clone(
220 &provider,
221 ))))
222 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
223 },
224 |provider| async move { provider.close().await },
225 )
226 .await
227}
228
229#[cfg(unix)]
230pub async fn serve_s3_provider<P>(provider: Arc<P>) -> Result<()>
232where
233 P: S3Provider,
234{
235 serve_unix_provider(
236 provider,
237 move |incoming, provider| {
238 Server::builder()
239 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_s3(
240 Arc::clone(&provider),
241 )))
242 .add_service(S3Server::new(Arc::clone(&provider)))
243 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
244 },
245 |provider| async move { provider.close().await },
246 )
247 .await
248}
249
250#[cfg(unix)]
251pub async fn serve_plugin_runtime_provider<P>(provider: Arc<P>) -> Result<()>
253where
254 P: PluginRuntimeProvider,
255{
256 serve_unix_provider(
257 provider,
258 move |incoming, provider| {
259 Server::builder()
260 .add_service(ProviderLifecycleServer::new(
261 RuntimeServer::for_plugin_runtime(Arc::clone(&provider)),
262 ))
263 .add_service(PluginRuntimeProviderServer::new(PluginRuntimeServer::new(
264 Arc::clone(&provider),
265 )))
266 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
267 },
268 |provider| async move { provider.close().await },
269 )
270 .await
271}
272
273#[cfg(unix)]
274pub async fn serve_workflow_provider<P>(provider: Arc<P>) -> Result<()>
276where
277 P: WorkflowProvider,
278{
279 serve_unix_provider(
280 provider,
281 move |incoming, provider| {
282 Server::builder()
283 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_workflow(
284 Arc::clone(&provider),
285 )))
286 .add_service(WorkflowRpcServer::new(WorkflowServer::new(Arc::clone(
287 &provider,
288 ))))
289 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
290 },
291 |provider| async move { provider.close().await },
292 )
293 .await
294}
295
296#[cfg(unix)]
297pub async fn serve_agent_provider<P>(provider: Arc<P>) -> Result<()>
299where
300 P: AgentProvider,
301{
302 serve_unix_provider(
303 provider,
304 move |incoming, provider| {
305 Server::builder()
306 .add_service(ProviderLifecycleServer::new(RuntimeServer::for_agent(
307 Arc::clone(&provider),
308 )))
309 .add_service(AgentRpcServer::new(AgentServer::new(Arc::clone(&provider))))
310 .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
311 },
312 |provider| async move { provider.close().await },
313 )
314 .await
315}
316
317#[cfg(not(unix))]
318pub async fn serve_provider<P>(_provider: Arc<P>, router: Router<P>) -> Result<()>
319where
320 P: Provider,
321{
322 if maybe_write_catalog(&router)? {
323 return Ok(());
324 }
325 Err(Error::internal(
326 "unix sockets are unsupported on this platform",
327 ))
328}
329
330#[cfg(not(unix))]
331pub async fn serve_cache_provider<P>(_provider: Arc<P>) -> Result<()>
332where
333 P: CacheProvider,
334{
335 Err(Error::internal(
336 "unix sockets are unsupported on this platform",
337 ))
338}
339
340#[cfg(not(unix))]
341pub async fn serve_secrets_provider<P>(_provider: Arc<P>) -> Result<()>
342where
343 P: SecretsProvider,
344{
345 Err(Error::internal(
346 "unix sockets are unsupported on this platform",
347 ))
348}
349
350#[cfg(not(unix))]
351pub async fn serve_s3_provider<P>(_provider: Arc<P>) -> Result<()>
352where
353 P: S3Provider,
354{
355 Err(Error::internal(
356 "unix sockets are unsupported on this platform",
357 ))
358}
359
360#[cfg(not(unix))]
361pub async fn serve_plugin_runtime_provider<P>(_provider: Arc<P>) -> Result<()>
362where
363 P: PluginRuntimeProvider,
364{
365 Err(Error::internal(
366 "unix sockets are unsupported on this platform",
367 ))
368}
369
370#[cfg(not(unix))]
371pub async fn serve_workflow_provider<P>(_provider: Arc<P>) -> Result<()>
372where
373 P: WorkflowProvider,
374{
375 Err(Error::internal(
376 "unix sockets are unsupported on this platform",
377 ))
378}
379
380#[cfg(not(unix))]
381pub async fn serve_agent_provider<P>(_provider: Arc<P>) -> Result<()>
382where
383 P: AgentProvider,
384{
385 Err(Error::internal(
386 "unix sockets are unsupported on this platform",
387 ))
388}
389
390#[cfg(unix)]
391async fn shutdown_signal(parent_pid: Option<u32>) {
392 let ctrl_c = async {
393 let _ = signal::ctrl_c().await;
394 };
395
396 tokio::pin!(ctrl_c);
397
398 if let Some(parent_pid) = parent_pid {
399 tokio::select! {
400 _ = &mut ctrl_c => {}
401 _ = watch_parent(parent_pid) => {}
402 }
403 return;
404 }
405
406 ctrl_c.await;
407}
408
409#[cfg(unix)]
410async fn serve_unix_provider<P, F, S, C, CF>(provider: Arc<P>, serve: F, close: C) -> Result<()>
411where
412 P: Send + Sync,
413 F: FnOnce(UnixListenerStream, Arc<P>) -> S,
414 S: Future<Output = std::result::Result<(), tonic::transport::Error>>,
415 C: FnOnce(Arc<P>) -> CF,
416 CF: Future<Output = Result<()>>,
417{
418 let socket = env::var_os(ENV_PROVIDER_SOCKET)
419 .ok_or_else(|| Error::internal(format!("{ENV_PROVIDER_SOCKET} is required")))?;
420 let socket = PathBuf::from(socket);
421 if socket.exists() {
422 std::fs::remove_file(&socket)?;
423 }
424 if let Some(parent) = socket.parent()
425 && !parent.as_os_str().is_empty()
426 {
427 std::fs::create_dir_all(parent)?;
428 }
429
430 let listener = UnixListener::bind(&socket)?;
431 let incoming = UnixListenerStream::new(listener);
432 let serve_result = serve(incoming, Arc::clone(&provider))
433 .await
434 .map_err(Error::from);
435
436 let close_result = close(provider).await;
437 let _ = remove_socket(&socket);
438
439 serve_result?;
440 close_result
441}
442
443#[cfg(unix)]
444fn parent_pid() -> Option<u32> {
445 env::var(ENV_PROVIDER_PARENT_PID)
446 .ok()
447 .and_then(|value| value.parse::<u32>().ok())
448 .filter(|pid| *pid > 0)
449}
450
451#[cfg(unix)]
452async fn watch_parent(parent_pid: u32) {
453 loop {
454 if current_parent_pid() != parent_pid {
455 break;
456 }
457 sleep(Duration::from_millis(500)).await;
458 }
459}
460
461#[cfg(unix)]
462fn current_parent_pid() -> u32 {
463 unsafe { libc::getppid() as u32 }
464}
465
466#[cfg(unix)]
467fn remove_socket(path: &Path) -> std::io::Result<()> {
468 match std::fs::remove_file(path) {
469 Ok(()) => Ok(()),
470 Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
471 Err(error) => Err(error),
472 }
473}