Skip to main content

alien_bindings/
wait_until.rs

1use crate::{
2    error::{ErrorData, Result},
3    traits::Binding,
4};
5use alien_error::{AlienError, Context, IntoAlienError};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::{
9    collections::HashMap,
10    future::Future,
11    sync::{
12        atomic::{AtomicU32, Ordering},
13        Arc,
14    },
15    time::Duration,
16};
17use tokio::{
18    sync::{oneshot, Mutex},
19    task::JoinHandle,
20    time::timeout,
21};
22#[cfg(feature = "grpc")]
23use tonic::transport::Channel;
24use tracing::{debug, error, info, warn};
25use uuid::Uuid;
26
27#[cfg(feature = "openapi")]
28use utoipa::ToSchema;
29
30#[cfg(feature = "grpc")]
31use crate::grpc::wait_until_service::alien_bindings::wait_until::{
32    wait_until_service_client::WaitUntilServiceClient, GetTaskCountRequest,
33    NotifyDrainCompleteRequest, NotifyTaskRegisteredRequest, WaitForDrainSignalRequest,
34};
35
36/// Response from drain operations.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39#[cfg_attr(feature = "openapi", derive(ToSchema))]
40pub struct DrainResponse {
41    /// Number of tasks that were drained.
42    pub tasks_drained: u32,
43    /// Whether all tasks completed successfully.
44    pub success: bool,
45    /// Optional error message if draining failed.
46    pub error_message: Option<String>,
47}
48
49/// Configuration for wait_until drain behavior.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52#[cfg_attr(feature = "openapi", derive(ToSchema))]
53pub struct DrainConfig {
54    /// Maximum time to wait for all tasks to complete.
55    pub timeout: Duration,
56    /// Reason for the drain request.
57    pub reason: String,
58}
59
60/// A trait for wait_until bindings that provide task coordination capabilities.
61/// Note: This trait is not object-safe due to generic methods, so we use concrete types in providers.
62#[async_trait]
63pub trait WaitUntil: Binding {
64    /// Waits for a drain signal from the runtime.
65    /// This is a blocking call that returns when the runtime decides it's time to drain.
66    async fn wait_for_drain_signal(&self, timeout: Option<Duration>) -> Result<DrainConfig>;
67
68    /// Drains all currently registered tasks.
69    /// This waits for all tasks to complete or timeout.
70    async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse>;
71
72    /// Gets the current number of registered tasks.
73    async fn get_task_count(&self) -> Result<u32>;
74
75    /// Notifies the runtime that draining is complete.
76    async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()>;
77}
78
79/// A context for managing wait_until tasks within an application.
80/// This handles local task execution and coordinates with the runtime via gRPC.
81#[derive(Debug)]
82pub struct WaitUntilContext {
83    /// Unique identifier for this application instance.
84    application_id: String,
85    /// Currently running tasks.
86    tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
87    /// Task counter for generating unique task IDs.
88    task_counter: AtomicU32,
89    /// gRPC client for communicating with the runtime.
90    #[cfg(feature = "grpc")]
91    grpc_client: Option<WaitUntilServiceClient<Channel>>,
92    /// Whether we're currently draining tasks.
93    draining: Arc<Mutex<bool>>,
94}
95
96impl WaitUntilContext {
97    /// Creates a new WaitUntilContext.
98    pub fn new(application_id: Option<String>) -> Self {
99        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
100
101        Self {
102            application_id: app_id,
103            tasks: Arc::new(Mutex::new(HashMap::new())),
104            task_counter: AtomicU32::new(0),
105            #[cfg(feature = "grpc")]
106            grpc_client: None,
107            draining: Arc::new(Mutex::new(false)),
108        }
109    }
110
111    /// Creates a new WaitUntilContext and connects to gRPC endpoint from environment variables.
112    /// This is the recommended way to create a WaitUntilContext in production.
113    pub async fn from_env(application_id: Option<String>) -> Result<Self> {
114        let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
115        Self::from_env_with_vars(application_id, &env_vars).await
116    }
117
118    /// Creates a new WaitUntilContext and connects to gRPC endpoint from provided environment variables.
119    pub async fn from_env_with_vars(
120        application_id: Option<String>,
121        env_vars: &std::collections::HashMap<String, String>,
122    ) -> Result<Self> {
123        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
124
125        #[cfg(feature = "grpc")]
126        {
127            let bindings_mode = crate::get_bindings_mode_from_env(env_vars)?;
128
129            match bindings_mode {
130                crate::BindingsMode::Direct => {
131                    // No gRPC needed - run in-process
132                    return Ok(Self::new(Some(app_id)));
133                }
134                crate::BindingsMode::Grpc => {
135                    // Require gRPC connection
136                    let grpc_address =
137                        env_vars.get("ALIEN_BINDINGS_GRPC_ADDRESS").ok_or_else(|| {
138                            AlienError::new(ErrorData::EnvironmentVariableMissing {
139                                variable_name: "ALIEN_BINDINGS_GRPC_ADDRESS".to_string(),
140                            })
141                        })?;
142
143                    // Create gRPC client
144                    let channel = Self::create_grpc_channel(grpc_address.clone()).await?;
145                    let grpc_client = WaitUntilServiceClient::new(channel);
146
147                    return Ok(Self {
148                        application_id: app_id,
149                        tasks: Arc::new(Mutex::new(HashMap::new())),
150                        task_counter: AtomicU32::new(0),
151                        grpc_client: Some(grpc_client),
152                        draining: Arc::new(Mutex::new(false)),
153                    });
154                }
155            }
156        }
157
158        #[cfg(not(feature = "grpc"))]
159        {
160            Ok(Self::new(Some(app_id)))
161        }
162    }
163
164    /// Creates a gRPC channel from an address string.
165    /// This creates a dedicated channel for wait_until with proper timeout and keep-alive configuration.
166    #[cfg(feature = "grpc")]
167    async fn create_grpc_channel(grpc_address: String) -> Result<Channel> {
168        use std::time::Duration;
169
170        // Ensure the address has a scheme, default to http if not present
171        let endpoint_uri = if grpc_address.contains("://") {
172            grpc_address.clone()
173        } else {
174            format!("http://{}", grpc_address)
175        };
176
177        let endpoint = Channel::from_shared(endpoint_uri.clone())
178            .into_alien_error()
179            .context(ErrorData::GrpcConnectionFailed {
180                endpoint: endpoint_uri.clone(),
181                reason: "Invalid gRPC endpoint URI format".to_string(),
182            })?
183            .timeout(Duration::from_secs(300)) // 5 min timeout for long-lived drain signal RPC
184            .connect_timeout(Duration::from_secs(5)) // Connection establishment timeout
185            .http2_keep_alive_interval(Duration::from_secs(30)) // Keep connection alive
186            .keep_alive_timeout(Duration::from_secs(10))
187            .keep_alive_while_idle(true); // Keep alive even when idle (important for drain listener)
188
189        let channel = endpoint.connect().await.into_alien_error().context(
190            ErrorData::GrpcConnectionFailed {
191                endpoint: grpc_address.clone(),
192                reason: "Failed to establish gRPC connection".to_string(),
193            },
194        )?;
195
196        Ok(channel)
197    }
198
199    /// Creates a new WaitUntilContext with a gRPC client.
200    #[cfg(feature = "grpc")]
201    pub fn new_with_grpc_client(
202        application_id: Option<String>,
203        grpc_client: WaitUntilServiceClient<Channel>,
204    ) -> Self {
205        let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
206
207        Self {
208            application_id: app_id,
209            tasks: Arc::new(Mutex::new(HashMap::new())),
210            task_counter: AtomicU32::new(0),
211            grpc_client: Some(grpc_client),
212            draining: Arc::new(Mutex::new(false)),
213        }
214    }
215
216    /// Sets the gRPC client for communicating with the runtime.
217    #[cfg(feature = "grpc")]
218    pub fn set_grpc_client(&mut self, client: WaitUntilServiceClient<Channel>) {
219        self.grpc_client = Some(client);
220    }
221
222    /// Gets the application ID.
223    pub fn application_id(&self) -> &str {
224        &self.application_id
225    }
226
227    /// Starts a background task that waits for drain signals from the runtime.
228    /// This should be called once when the application starts.
229    pub async fn start_drain_listener(&self) -> Result<()> {
230        #[cfg(feature = "grpc")]
231        {
232            if let Some(mut client) = self.grpc_client.clone() {
233                let app_id = self.application_id.clone();
234                let context = self.clone_for_background();
235
236                tokio::spawn(async move {
237                    loop {
238                        debug!(app_id = %app_id, "Waiting for drain signal from runtime");
239
240                        let request = WaitForDrainSignalRequest {
241                            application_id: app_id.clone(),
242                            timeout: Some(prost_types::Duration {
243                                seconds: 300, // 5 minute timeout
244                                nanos: 0,
245                            }),
246                        };
247
248                        match client.wait_for_drain_signal(request).await {
249                            Ok(response) => {
250                                let resp = response.into_inner();
251                                if resp.should_drain {
252                                    info!(
253                                        app_id = %app_id,
254                                        reason = %resp.drain_reason,
255                                        "Received drain signal from runtime"
256                                    );
257
258                                    let drain_timeout = resp
259                                        .drain_timeout
260                                        .map(|d| Duration::from_secs(d.seconds as u64))
261                                        .unwrap_or(Duration::from_secs(10));
262
263                                    let config = DrainConfig {
264                                        timeout: drain_timeout,
265                                        reason: resp.drain_reason,
266                                    };
267
268                                    // Drain all tasks
269                                    match context.drain_all(config).await {
270                                        Ok(drain_response) => {
271                                            // Notify runtime that draining is complete
272                                            let complete_request = NotifyDrainCompleteRequest {
273                                                application_id: app_id.clone(),
274                                                tasks_drained: drain_response.tasks_drained,
275                                                success: drain_response.success,
276                                                error_message: drain_response.error_message,
277                                            };
278
279                                            if let Err(e) =
280                                                client.notify_drain_complete(complete_request).await
281                                            {
282                                                error!(app_id = %app_id, error = %e, "Failed to notify runtime of drain completion");
283                                            } else {
284                                                info!(app_id = %app_id, "Successfully notified runtime of drain completion");
285                                            }
286                                        }
287                                        Err(e) => {
288                                            error!(app_id = %app_id, error = %e, "Failed to drain tasks");
289                                            // Still notify runtime of the failure
290                                            let complete_request = NotifyDrainCompleteRequest {
291                                                application_id: app_id.clone(),
292                                                tasks_drained: 0,
293                                                success: false,
294                                                error_message: Some(e.to_string()),
295                                            };
296                                            let _ = client
297                                                .notify_drain_complete(complete_request)
298                                                .await;
299                                        }
300                                    }
301                                }
302                            }
303                            Err(e) => {
304                                warn!(app_id = %app_id, error = %e, "Failed to wait for drain signal, retrying in 5 seconds");
305                                tokio::time::sleep(Duration::from_secs(5)).await;
306                            }
307                        }
308                    }
309                });
310            }
311        }
312
313        Ok(())
314    }
315
316    /// Creates a clone suitable for background tasks.
317    fn clone_for_background(&self) -> Self {
318        Self {
319            application_id: self.application_id.clone(),
320            tasks: Arc::clone(&self.tasks),
321            task_counter: AtomicU32::new(self.task_counter.load(Ordering::Relaxed)),
322            #[cfg(feature = "grpc")]
323            grpc_client: self.grpc_client.clone(),
324            draining: Arc::clone(&self.draining),
325        }
326    }
327
328    /// Notifies the runtime that a task has been registered (if gRPC client is available).
329    async fn notify_task_registered(&self, task_description: String) -> Result<()> {
330        #[cfg(feature = "grpc")]
331        {
332            if let Some(mut client) = self.grpc_client.clone() {
333                let request = NotifyTaskRegisteredRequest {
334                    application_id: self.application_id.clone(),
335                    task_description: Some(task_description),
336                };
337
338                client
339                    .notify_task_registered(request)
340                    .await
341                    .into_alien_error()
342                    .context(ErrorData::HttpRequestFailed {
343                        url: "grpc://wait_until_service".to_string(),
344                        method: "notify_task_registered".to_string(),
345                    })?;
346            }
347        }
348
349        Ok(())
350    }
351}
352
353impl WaitUntilContext {
354    /// Registers a new wait_until task that will be executed immediately.
355    /// The task runs in the application process but is tracked by the runtime.
356    pub fn wait_until<F, Fut>(&self, task_fn: F) -> Result<()>
357    where
358        F: FnOnce() -> Fut + Send + 'static,
359        Fut: Future<Output = ()> + Send + 'static,
360    {
361        let task_id = self.task_counter.fetch_add(1, Ordering::Relaxed);
362        let task_key = format!("task_{}", task_id);
363        let task_description = format!("wait_until_task_{}", task_id);
364
365        // Check if we're currently draining - if so, reject new tasks
366        let draining = self.draining.clone();
367        let tasks = self.tasks.clone();
368        let app_id = self.application_id.clone();
369        let task_key_clone = task_key.clone();
370
371        // Start the task immediately
372        let handle = tokio::spawn(async move {
373            // Double-check if we're draining
374            if *draining.lock().await {
375                warn!(app_id = %app_id, task_id = %task_key_clone, "Rejecting new task - currently draining");
376                return;
377            }
378
379            debug!(app_id = %app_id, task_id = %task_key_clone, "Starting wait_until task");
380
381            let future = task_fn();
382            future.await;
383
384            debug!(app_id = %app_id, task_id = %task_key_clone, "Completed wait_until task");
385
386            // Remove ourselves from the tasks map when done
387            tasks.lock().await.remove(&task_key_clone);
388        });
389
390        // Store the task handle
391        {
392            let mut tasks_guard = futures::executor::block_on(self.tasks.lock());
393            tasks_guard.insert(task_key.clone(), handle);
394        }
395
396        // Notify the runtime in a background task (non-blocking)
397        let context_clone = self.clone_for_background();
398        tokio::spawn(async move {
399            if let Err(e) = context_clone.notify_task_registered(task_description).await {
400                warn!(app_id = %context_clone.application_id, task_id = %task_key, error = %e, "Failed to notify runtime of task registration");
401            }
402        });
403
404        Ok(())
405    }
406}
407
408impl Binding for WaitUntilContext {}
409
410#[async_trait]
411impl WaitUntil for WaitUntilContext {
412    async fn wait_for_drain_signal(
413        &self,
414        timeout_duration: Option<Duration>,
415    ) -> Result<DrainConfig> {
416        #[cfg(feature = "grpc")]
417        {
418            if let Some(mut client) = self.grpc_client.clone() {
419                let timeout_proto = timeout_duration.map(|d| prost_types::Duration {
420                    seconds: d.as_secs() as i64,
421                    nanos: d.subsec_nanos() as i32,
422                });
423
424                let request = WaitForDrainSignalRequest {
425                    application_id: self.application_id.clone(),
426                    timeout: timeout_proto,
427                };
428
429                let response = client
430                    .wait_for_drain_signal(request)
431                    .await
432                    .into_alien_error()
433                    .context(ErrorData::HttpRequestFailed {
434                        url: "grpc://wait_until_service".to_string(),
435                        method: "wait_for_drain_signal".to_string(),
436                    })?;
437
438                let resp = response.into_inner();
439                if resp.should_drain {
440                    let drain_timeout = resp
441                        .drain_timeout
442                        .map(|d| Duration::from_secs(d.seconds as u64))
443                        .unwrap_or(Duration::from_secs(10));
444
445                    return Ok(DrainConfig {
446                        timeout: drain_timeout,
447                        reason: resp.drain_reason,
448                    });
449                }
450            }
451        }
452
453        // If no gRPC client or no drain signal, return a default config
454        Err(AlienError::new(ErrorData::Other {
455            message: "No drain signal received or gRPC client not available".to_string(),
456        }))
457    }
458
459    async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse> {
460        info!(
461            app_id = %self.application_id,
462            reason = %config.reason,
463            timeout_secs = config.timeout.as_secs(),
464            "Starting to drain all wait_until tasks"
465        );
466
467        // Mark that we're draining to prevent new tasks
468        {
469            let mut draining_guard = self.draining.lock().await;
470            *draining_guard = true;
471        }
472
473        let tasks_to_drain = {
474            let mut tasks_guard = self.tasks.lock().await;
475            std::mem::take(&mut *tasks_guard) // Take all tasks out of the map
476        };
477
478        let task_count = tasks_to_drain.len() as u32;
479        info!(app_id = %self.application_id, task_count = task_count, "Draining tasks");
480
481        let mut success = true;
482        let mut error_messages = Vec::new();
483
484        // Wait for all tasks to complete or timeout
485        let drain_result = timeout(config.timeout, async {
486            for (task_id, handle) in tasks_to_drain {
487                match handle.await {
488                    Ok(_) => {
489                        debug!(app_id = %self.application_id, task_id = %task_id, "Task completed successfully");
490                    }
491                    Err(e) => {
492                        warn!(app_id = %self.application_id, task_id = %task_id, error = %e, "Task failed");
493                        success = false;
494                        error_messages.push(format!("Task {} failed: {}", task_id, e));
495                    }
496                }
497            }
498        })
499        .await;
500
501        match drain_result {
502            Ok(_) => {
503                info!(app_id = %self.application_id, "All tasks drained successfully");
504            }
505            Err(_) => {
506                warn!(app_id = %self.application_id, "Drain timeout exceeded");
507                success = false;
508                error_messages.push("Drain timeout exceeded".to_string());
509            }
510        }
511
512        // Reset draining flag
513        {
514            let mut draining_guard = self.draining.lock().await;
515            *draining_guard = false;
516        }
517
518        let error_message = if error_messages.is_empty() {
519            None
520        } else {
521            Some(error_messages.join("; "))
522        };
523
524        Ok(DrainResponse {
525            tasks_drained: task_count,
526            success,
527            error_message,
528        })
529    }
530
531    async fn get_task_count(&self) -> Result<u32> {
532        let tasks_guard = self.tasks.lock().await;
533        Ok(tasks_guard.len() as u32)
534    }
535
536    async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()> {
537        #[cfg(feature = "grpc")]
538        {
539            if let Some(mut client) = self.grpc_client.clone() {
540                let request = NotifyDrainCompleteRequest {
541                    application_id: self.application_id.clone(),
542                    tasks_drained: response.tasks_drained,
543                    success: response.success,
544                    error_message: response.error_message,
545                };
546
547                client
548                    .notify_drain_complete(request)
549                    .await
550                    .into_alien_error()
551                    .context(ErrorData::HttpRequestFailed {
552                        url: "grpc://wait_until_service".to_string(),
553                        method: "notify_drain_complete".to_string(),
554                    })?;
555            }
556        }
557
558        Ok(())
559    }
560}