Skip to main content

alien_bindings/
wait_until.rs

1use crate::{
2    error::{ErrorData, Result},
3    traits::Binding,
4};
5use alien_error::{AlienError, Context, IntoAlienError};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::{
9    collections::HashMap,
10    future::Future,
11    sync::{
12        atomic::{AtomicU32, Ordering},
13        Arc,
14    },
15    time::Duration,
16};
17use tokio::{sync::Mutex, task::JoinHandle, time::timeout};
18#[cfg(feature = "grpc")]
19use tonic::transport::Channel;
20use tracing::{debug, error, info, warn};
21use uuid::Uuid;
22
23#[cfg(feature = "openapi")]
24use utoipa::ToSchema;
25
26#[cfg(feature = "grpc")]
27use crate::grpc::wait_until_service::alien_bindings::wait_until::{
28    wait_until_service_client::WaitUntilServiceClient, NotifyDrainCompleteRequest,
29    NotifyTaskRegisteredRequest, WaitForDrainSignalRequest,
30};
31
32/// Response from drain operations.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35#[cfg_attr(feature = "openapi", derive(ToSchema))]
36pub struct DrainResponse {
37    /// Number of tasks that were drained.
38    pub tasks_drained: u32,
39    /// Whether all tasks completed successfully.
40    pub success: bool,
41    /// Optional error message if draining failed.
42    pub error_message: Option<String>,
43}
44
45/// Configuration for wait_until drain behavior.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48#[cfg_attr(feature = "openapi", derive(ToSchema))]
49pub struct DrainConfig {
50    /// Maximum time to wait for all tasks to complete.
51    pub timeout: Duration,
52    /// Reason for the drain request.
53    pub reason: String,
54}
55
56/// A trait for wait_until bindings that provide task coordination capabilities.
57/// Note: This trait is not object-safe due to generic methods, so we use concrete types in providers.
58#[async_trait]
59pub trait WaitUntil: Binding {
60    /// Waits for a drain signal from the runtime.
61    /// This is a blocking call that returns when the runtime decides it's time to drain.
62    async fn wait_for_drain_signal(&self, timeout: Option<Duration>) -> Result<DrainConfig>;
63
64    /// Drains all currently registered tasks.
65    /// This waits for all tasks to complete or timeout.
66    async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse>;
67
68    /// Gets the current number of registered tasks.
69    async fn get_task_count(&self) -> Result<u32>;
70
71    /// Notifies the runtime that draining is complete.
72    async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()>;
73}
74
75/// A context for managing wait_until tasks within an application.
76/// This handles local task execution and coordinates with the runtime via gRPC.
77#[derive(Debug)]
78pub struct WaitUntilContext {
79    /// Unique identifier for this application instance.
80    application_id: String,
81    /// Currently running tasks.
82    tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
83    /// Task counter for generating unique task IDs.
84    task_counter: AtomicU32,
85    /// gRPC client for communicating with the runtime.
86    #[cfg(feature = "grpc")]
87    grpc_client: Option<WaitUntilServiceClient<Channel>>,
88    /// Whether we're currently draining tasks.
89    draining: Arc<Mutex<bool>>,
90}
91
92impl WaitUntilContext {
93    /// Creates a new WaitUntilContext.
94    pub fn new(application_id: Option<String>) -> Self {
95        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
96
97        Self {
98            application_id: app_id,
99            tasks: Arc::new(Mutex::new(HashMap::new())),
100            task_counter: AtomicU32::new(0),
101            #[cfg(feature = "grpc")]
102            grpc_client: None,
103            draining: Arc::new(Mutex::new(false)),
104        }
105    }
106
107    /// Creates a new WaitUntilContext and connects to gRPC endpoint from environment variables.
108    /// This is the recommended way to create a WaitUntilContext in production.
109    pub async fn from_env(application_id: Option<String>) -> Result<Self> {
110        let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
111        Self::from_env_with_vars(application_id, &env_vars).await
112    }
113
114    /// Creates a new WaitUntilContext and connects to gRPC endpoint from provided environment variables.
115    pub async fn from_env_with_vars(
116        application_id: Option<String>,
117        env_vars: &std::collections::HashMap<String, String>,
118    ) -> Result<Self> {
119        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
120
121        #[cfg(feature = "grpc")]
122        {
123            let bindings_mode = crate::get_bindings_mode_from_env(env_vars)?;
124
125            match bindings_mode {
126                crate::BindingsMode::Direct => {
127                    // No gRPC needed - run in-process
128                    return Ok(Self::new(Some(app_id)));
129                }
130                crate::BindingsMode::Grpc => {
131                    // Require gRPC connection
132                    let grpc_address =
133                        env_vars.get("ALIEN_BINDINGS_GRPC_ADDRESS").ok_or_else(|| {
134                            AlienError::new(ErrorData::EnvironmentVariableMissing {
135                                variable_name: "ALIEN_BINDINGS_GRPC_ADDRESS".to_string(),
136                            })
137                        })?;
138
139                    // Create gRPC client
140                    let channel = Self::create_grpc_channel(grpc_address.clone()).await?;
141                    let grpc_client = WaitUntilServiceClient::new(channel);
142
143                    return Ok(Self {
144                        application_id: app_id,
145                        tasks: Arc::new(Mutex::new(HashMap::new())),
146                        task_counter: AtomicU32::new(0),
147                        grpc_client: Some(grpc_client),
148                        draining: Arc::new(Mutex::new(false)),
149                    });
150                }
151            }
152        }
153
154        #[cfg(not(feature = "grpc"))]
155        {
156            Ok(Self::new(Some(app_id)))
157        }
158    }
159
160    /// Creates a gRPC channel from an address string.
161    /// This creates a dedicated channel for wait_until with proper timeout and keep-alive configuration.
162    #[cfg(feature = "grpc")]
163    async fn create_grpc_channel(grpc_address: String) -> Result<Channel> {
164        use std::time::Duration;
165
166        // Ensure the address has a scheme, default to http if not present
167        let endpoint_uri = if grpc_address.contains("://") {
168            grpc_address.clone()
169        } else {
170            format!("http://{}", grpc_address)
171        };
172
173        let endpoint = Channel::from_shared(endpoint_uri.clone())
174            .into_alien_error()
175            .context(ErrorData::GrpcConnectionFailed {
176                endpoint: endpoint_uri.clone(),
177                reason: "Invalid gRPC endpoint URI format".to_string(),
178            })?
179            .timeout(Duration::from_secs(300)) // 5 min timeout for long-lived drain signal RPC
180            .connect_timeout(Duration::from_secs(5)) // Connection establishment timeout
181            .http2_keep_alive_interval(Duration::from_secs(30)) // Keep connection alive
182            .keep_alive_timeout(Duration::from_secs(10))
183            .keep_alive_while_idle(true); // Keep alive even when idle (important for drain listener)
184
185        let channel = endpoint.connect().await.into_alien_error().context(
186            ErrorData::GrpcConnectionFailed {
187                endpoint: grpc_address.clone(),
188                reason: "Failed to establish gRPC connection".to_string(),
189            },
190        )?;
191
192        Ok(channel)
193    }
194
195    /// Creates a new WaitUntilContext with a gRPC client.
196    #[cfg(feature = "grpc")]
197    pub fn new_with_grpc_client(
198        application_id: Option<String>,
199        grpc_client: WaitUntilServiceClient<Channel>,
200    ) -> Self {
201        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
202
203        Self {
204            application_id: app_id,
205            tasks: Arc::new(Mutex::new(HashMap::new())),
206            task_counter: AtomicU32::new(0),
207            grpc_client: Some(grpc_client),
208            draining: Arc::new(Mutex::new(false)),
209        }
210    }
211
212    /// Sets the gRPC client for communicating with the runtime.
213    #[cfg(feature = "grpc")]
214    pub fn set_grpc_client(&mut self, client: WaitUntilServiceClient<Channel>) {
215        self.grpc_client = Some(client);
216    }
217
218    /// Gets the application ID.
219    pub fn application_id(&self) -> &str {
220        &self.application_id
221    }
222
223    /// Starts a background task that waits for drain signals from the runtime.
224    /// This should be called once when the application starts.
225    pub async fn start_drain_listener(&self) -> Result<()> {
226        #[cfg(feature = "grpc")]
227        {
228            if let Some(mut client) = self.grpc_client.clone() {
229                let app_id = self.application_id.clone();
230                let context = self.clone_for_background();
231
232                tokio::spawn(async move {
233                    loop {
234                        debug!(app_id = %app_id, "Waiting for drain signal from runtime");
235
236                        let request = WaitForDrainSignalRequest {
237                            application_id: app_id.clone(),
238                            timeout: Some(prost_types::Duration {
239                                seconds: 300, // 5 minute timeout
240                                nanos: 0,
241                            }),
242                        };
243
244                        match client.wait_for_drain_signal(request).await {
245                            Ok(response) => {
246                                let resp = response.into_inner();
247                                if resp.should_drain {
248                                    info!(
249                                        app_id = %app_id,
250                                        reason = %resp.drain_reason,
251                                        "Received drain signal from runtime"
252                                    );
253
254                                    let drain_timeout = resp
255                                        .drain_timeout
256                                        .map(|d| Duration::from_secs(d.seconds as u64))
257                                        .unwrap_or(Duration::from_secs(10));
258
259                                    let config = DrainConfig {
260                                        timeout: drain_timeout,
261                                        reason: resp.drain_reason,
262                                    };
263
264                                    // Drain all tasks
265                                    match context.drain_all(config).await {
266                                        Ok(drain_response) => {
267                                            // Notify runtime that draining is complete
268                                            let complete_request = NotifyDrainCompleteRequest {
269                                                application_id: app_id.clone(),
270                                                tasks_drained: drain_response.tasks_drained,
271                                                success: drain_response.success,
272                                                error_message: drain_response.error_message,
273                                            };
274
275                                            if let Err(e) =
276                                                client.notify_drain_complete(complete_request).await
277                                            {
278                                                error!(app_id = %app_id, error = %e, "Failed to notify runtime of drain completion");
279                                            } else {
280                                                info!(app_id = %app_id, "Successfully notified runtime of drain completion");
281                                            }
282                                        }
283                                        Err(e) => {
284                                            error!(app_id = %app_id, error = %e, "Failed to drain tasks");
285                                            // Still notify runtime of the failure
286                                            let complete_request = NotifyDrainCompleteRequest {
287                                                application_id: app_id.clone(),
288                                                tasks_drained: 0,
289                                                success: false,
290                                                error_message: Some(e.to_string()),
291                                            };
292                                            let _ = client
293                                                .notify_drain_complete(complete_request)
294                                                .await;
295                                        }
296                                    }
297                                }
298                            }
299                            Err(e) => {
300                                warn!(app_id = %app_id, error = %e, "Failed to wait for drain signal, retrying in 5 seconds");
301                                tokio::time::sleep(Duration::from_secs(5)).await;
302                            }
303                        }
304                    }
305                });
306            }
307        }
308
309        Ok(())
310    }
311
312    /// Creates a clone suitable for background tasks.
313    fn clone_for_background(&self) -> Self {
314        Self {
315            application_id: self.application_id.clone(),
316            tasks: Arc::clone(&self.tasks),
317            task_counter: AtomicU32::new(self.task_counter.load(Ordering::Relaxed)),
318            #[cfg(feature = "grpc")]
319            grpc_client: self.grpc_client.clone(),
320            draining: Arc::clone(&self.draining),
321        }
322    }
323
324    /// Notifies the runtime that a task has been registered (if gRPC client is available).
325    async fn notify_task_registered(&self, task_description: String) -> Result<()> {
326        #[cfg(feature = "grpc")]
327        {
328            if let Some(mut client) = self.grpc_client.clone() {
329                let request = NotifyTaskRegisteredRequest {
330                    application_id: self.application_id.clone(),
331                    task_description: Some(task_description),
332                };
333
334                client
335                    .notify_task_registered(request)
336                    .await
337                    .into_alien_error()
338                    .context(ErrorData::HttpRequestFailed {
339                        url: "grpc://wait_until_service".to_string(),
340                        method: "notify_task_registered".to_string(),
341                    })?;
342            }
343        }
344
345        Ok(())
346    }
347}
348
349impl WaitUntilContext {
350    /// Registers a new wait_until task that will be executed immediately.
351    /// The task runs in the application process but is tracked by the runtime.
352    pub fn wait_until<F, Fut>(&self, task_fn: F) -> Result<()>
353    where
354        F: FnOnce() -> Fut + Send + 'static,
355        Fut: Future<Output = ()> + Send + 'static,
356    {
357        let task_id = self.task_counter.fetch_add(1, Ordering::Relaxed);
358        let task_key = format!("task_{}", task_id);
359        let task_description = format!("wait_until_task_{}", task_id);
360
361        // Check if we're currently draining - if so, reject new tasks
362        let draining = self.draining.clone();
363        let tasks = self.tasks.clone();
364        let app_id = self.application_id.clone();
365        let task_key_clone = task_key.clone();
366
367        // Start the task immediately
368        let handle = tokio::spawn(async move {
369            // Double-check if we're draining
370            if *draining.lock().await {
371                warn!(app_id = %app_id, task_id = %task_key_clone, "Rejecting new task - currently draining");
372                return;
373            }
374
375            debug!(app_id = %app_id, task_id = %task_key_clone, "Starting wait_until task");
376
377            let future = task_fn();
378            future.await;
379
380            debug!(app_id = %app_id, task_id = %task_key_clone, "Completed wait_until task");
381
382            // Remove ourselves from the tasks map when done
383            tasks.lock().await.remove(&task_key_clone);
384        });
385
386        // Store the task handle
387        {
388            let mut tasks_guard = futures::executor::block_on(self.tasks.lock());
389            tasks_guard.insert(task_key.clone(), handle);
390        }
391
392        // Notify the runtime in a background task (non-blocking)
393        let context_clone = self.clone_for_background();
394        tokio::spawn(async move {
395            if let Err(e) = context_clone.notify_task_registered(task_description).await {
396                warn!(app_id = %context_clone.application_id, task_id = %task_key, error = %e, "Failed to notify runtime of task registration");
397            }
398        });
399
400        Ok(())
401    }
402}
403
404impl Binding for WaitUntilContext {}
405
406#[async_trait]
407impl WaitUntil for WaitUntilContext {
408    async fn wait_for_drain_signal(
409        &self,
410        timeout_duration: Option<Duration>,
411    ) -> Result<DrainConfig> {
412        #[cfg(feature = "grpc")]
413        {
414            if let Some(mut client) = self.grpc_client.clone() {
415                let timeout_proto = timeout_duration.map(|d| prost_types::Duration {
416                    seconds: d.as_secs() as i64,
417                    nanos: d.subsec_nanos() as i32,
418                });
419
420                let request = WaitForDrainSignalRequest {
421                    application_id: self.application_id.clone(),
422                    timeout: timeout_proto,
423                };
424
425                let response = client
426                    .wait_for_drain_signal(request)
427                    .await
428                    .into_alien_error()
429                    .context(ErrorData::HttpRequestFailed {
430                        url: "grpc://wait_until_service".to_string(),
431                        method: "wait_for_drain_signal".to_string(),
432                    })?;
433
434                let resp = response.into_inner();
435                if resp.should_drain {
436                    let drain_timeout = resp
437                        .drain_timeout
438                        .map(|d| Duration::from_secs(d.seconds as u64))
439                        .unwrap_or(Duration::from_secs(10));
440
441                    return Ok(DrainConfig {
442                        timeout: drain_timeout,
443                        reason: resp.drain_reason,
444                    });
445                }
446            }
447        }
448
449        // If no gRPC client or no drain signal, return a default config
450        Err(AlienError::new(ErrorData::Other {
451            message: "No drain signal received or gRPC client not available".to_string(),
452        }))
453    }
454
455    async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse> {
456        info!(
457            app_id = %self.application_id,
458            reason = %config.reason,
459            timeout_secs = config.timeout.as_secs(),
460            "Starting to drain all wait_until tasks"
461        );
462
463        // Mark that we're draining to prevent new tasks
464        {
465            let mut draining_guard = self.draining.lock().await;
466            *draining_guard = true;
467        }
468
469        let tasks_to_drain = {
470            let mut tasks_guard = self.tasks.lock().await;
471            std::mem::take(&mut *tasks_guard) // Take all tasks out of the map
472        };
473
474        let task_count = tasks_to_drain.len() as u32;
475        info!(app_id = %self.application_id, task_count = task_count, "Draining tasks");
476
477        let mut success = true;
478        let mut error_messages = Vec::new();
479
480        // Wait for all tasks to complete or timeout
481        let drain_result = timeout(config.timeout, async {
482            for (task_id, handle) in tasks_to_drain {
483                match handle.await {
484                    Ok(_) => {
485                        debug!(app_id = %self.application_id, task_id = %task_id, "Task completed successfully");
486                    }
487                    Err(e) => {
488                        warn!(app_id = %self.application_id, task_id = %task_id, error = %e, "Task failed");
489                        success = false;
490                        error_messages.push(format!("Task {} failed: {}", task_id, e));
491                    }
492                }
493            }
494        })
495        .await;
496
497        match drain_result {
498            Ok(_) => {
499                info!(app_id = %self.application_id, "All tasks drained successfully");
500            }
501            Err(_) => {
502                warn!(app_id = %self.application_id, "Drain timeout exceeded");
503                success = false;
504                error_messages.push("Drain timeout exceeded".to_string());
505            }
506        }
507
508        // Reset draining flag
509        {
510            let mut draining_guard = self.draining.lock().await;
511            *draining_guard = false;
512        }
513
514        let error_message = if error_messages.is_empty() {
515            None
516        } else {
517            Some(error_messages.join("; "))
518        };
519
520        Ok(DrainResponse {
521            tasks_drained: task_count,
522            success,
523            error_message,
524        })
525    }
526
527    async fn get_task_count(&self) -> Result<u32> {
528        let tasks_guard = self.tasks.lock().await;
529        Ok(tasks_guard.len() as u32)
530    }
531
532    async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()> {
533        #[cfg(feature = "grpc")]
534        {
535            if let Some(mut client) = self.grpc_client.clone() {
536                let request = NotifyDrainCompleteRequest {
537                    application_id: self.application_id.clone(),
538                    tasks_drained: response.tasks_drained,
539                    success: response.success,
540                    error_message: response.error_message,
541                };
542
543                client
544                    .notify_drain_complete(request)
545                    .await
546                    .into_alien_error()
547                    .context(ErrorData::HttpRequestFailed {
548                        url: "grpc://wait_until_service".to_string(),
549                        method: "notify_drain_complete".to_string(),
550                    })?;
551            }
552        }
553
554        Ok(())
555    }
556}