Skip to main content

alien_bindings/grpc/
wait_until_service.rs

1#![cfg(feature = "grpc")]
2
3use crate::BindingsProviderApi;
4use alien_error::AlienError;
5use async_trait::async_trait;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::{oneshot, Mutex};
8use tonic::{Request, Response, Status};
9use tracing::{debug, info, warn};
10
11// Module for the generated gRPC code.
12pub mod alien_bindings {
13    pub mod wait_until {
14        tonic::include_proto!("alien_bindings.wait_until");
15        pub const FILE_DESCRIPTOR_SET: &[u8] =
16            tonic::include_file_descriptor_set!("alien_bindings.wait_until_descriptor");
17    }
18}
19
20use alien_bindings::wait_until::{
21    wait_until_service_server::{WaitUntilService, WaitUntilServiceServer},
22    GetTaskCountRequest, GetTaskCountResponse, NotifyDrainCompleteRequest,
23    NotifyDrainCompleteResponse, NotifyTaskRegisteredRequest, NotifyTaskRegisteredResponse,
24    WaitForDrainSignalRequest, WaitForDrainSignalResponse,
25};
26
27/// Represents an application that has registered tasks and may be waiting for drain signals.
28#[derive(Debug)]
29struct ApplicationState {
30    /// Number of currently registered tasks.
31    task_count: u32,
32    /// Channel to send drain signal to the application when it's time to drain.
33    drain_signal_sender: Option<oneshot::Sender<WaitForDrainSignalResponse>>,
34}
35
36#[derive(Clone)]
37pub struct WaitUntilGrpcServer {
38    provider: Arc<dyn BindingsProviderApi>,
39    /// Track application states by application_id.
40    applications: Arc<Mutex<HashMap<String, ApplicationState>>>,
41}
42
43impl WaitUntilGrpcServer {
44    pub fn new(provider: Arc<dyn BindingsProviderApi>) -> Self {
45        Self {
46            provider,
47            applications: Arc::new(Mutex::new(HashMap::new())),
48        }
49    }
50
51    pub fn into_service(self) -> WaitUntilServiceServer<Self> {
52        WaitUntilServiceServer::new(self)
53    }
54
55    /// Trigger drain for all registered applications.
56    /// This is called by the runtime when it's time to drain (e.g., on SIGTERM or Lambda INVOKE end).
57    pub async fn trigger_drain_all(
58        &self,
59        reason: &str,
60        timeout_secs: u64,
61    ) -> Result<(), AlienError> {
62        let mut applications = self.applications.lock().await;
63
64        info!("Triggering drain for {} applications", applications.len());
65
66        for (app_id, app_state) in applications.iter_mut() {
67            if let Some(sender) = app_state.drain_signal_sender.take() {
68                let response = WaitForDrainSignalResponse {
69                    should_drain: true,
70                    drain_timeout: Some(prost_types::Duration {
71                        seconds: timeout_secs as i64,
72                        nanos: 0,
73                    }),
74                    drain_reason: reason.to_string(),
75                };
76
77                if let Err(_) = sender.send(response) {
78                    warn!("Failed to send drain signal to application {}", app_id);
79                }
80            }
81        }
82
83        Ok(())
84    }
85
86    /// Get total number of tasks across all applications.
87    pub async fn get_total_task_count(&self) -> u32 {
88        let applications = self.applications.lock().await;
89        applications.values().map(|app| app.task_count).sum()
90    }
91}
92
93#[async_trait]
94impl WaitUntilService for WaitUntilGrpcServer {
95    async fn notify_task_registered(
96        &self,
97        request: Request<NotifyTaskRegisteredRequest>,
98    ) -> Result<Response<NotifyTaskRegisteredResponse>, Status> {
99        let req = request.into_inner();
100        let app_id = req.application_id;
101
102        debug!(
103            app_id = %app_id,
104            task_description = %req.task_description.as_deref().unwrap_or_default(),
105            "Task registered"
106        );
107
108        let mut applications = self.applications.lock().await;
109        let app_state = applications
110            .entry(app_id.clone())
111            .or_insert_with(|| ApplicationState {
112                task_count: 0,
113                drain_signal_sender: None,
114            });
115
116        app_state.task_count += 1;
117
118        debug!(app_id = %app_id, task_count = app_state.task_count, "Updated task count");
119
120        Ok(Response::new(NotifyTaskRegisteredResponse {
121            success: true,
122        }))
123    }
124
125    async fn wait_for_drain_signal(
126        &self,
127        request: Request<WaitForDrainSignalRequest>,
128    ) -> Result<Response<WaitForDrainSignalResponse>, Status> {
129        let req = request.into_inner();
130        let app_id = req.application_id;
131
132        debug!(app_id = %app_id, "Application waiting for drain signal");
133
134        let (sender, receiver) = oneshot::channel();
135
136        // Store the sender in the application state
137        {
138            let mut applications = self.applications.lock().await;
139            let app_state =
140                applications
141                    .entry(app_id.clone())
142                    .or_insert_with(|| ApplicationState {
143                        task_count: 0,
144                        drain_signal_sender: None,
145                    });
146            app_state.drain_signal_sender = Some(sender);
147        }
148
149        // Wait for the drain signal
150        match receiver.await {
151            Ok(response) => {
152                debug!(app_id = %app_id, reason = %response.drain_reason, "Sending drain signal to application");
153                Ok(Response::new(response))
154            }
155            Err(_) => {
156                // Channel was dropped, likely due to shutdown
157                warn!(app_id = %app_id, "Drain signal channel dropped");
158                Ok(Response::new(WaitForDrainSignalResponse {
159                    should_drain: true,
160                    drain_timeout: Some(prost_types::Duration {
161                        seconds: 10,
162                        nanos: 0,
163                    }),
164                    drain_reason: "runtime_shutdown".to_string(),
165                }))
166            }
167        }
168    }
169
170    async fn notify_drain_complete(
171        &self,
172        request: Request<NotifyDrainCompleteRequest>,
173    ) -> Result<Response<NotifyDrainCompleteResponse>, Status> {
174        let req = request.into_inner();
175        let app_id = req.application_id;
176
177        info!(
178            app_id = %app_id,
179            tasks_drained = req.tasks_drained,
180            success = req.success,
181            error = %req.error_message.as_deref().unwrap_or_default(),
182            "Application completed draining"
183        );
184
185        // Update the application state to reflect that tasks have been drained
186        {
187            let mut applications = self.applications.lock().await;
188            if let Some(app_state) = applications.get_mut(&app_id) {
189                app_state.task_count = app_state.task_count.saturating_sub(req.tasks_drained);
190            }
191        }
192
193        Ok(Response::new(NotifyDrainCompleteResponse {
194            acknowledged: true,
195        }))
196    }
197
198    async fn get_task_count(
199        &self,
200        request: Request<GetTaskCountRequest>,
201    ) -> Result<Response<GetTaskCountResponse>, Status> {
202        let req = request.into_inner();
203        let app_id = req.application_id;
204
205        let applications = self.applications.lock().await;
206        let task_count = applications
207            .get(&app_id)
208            .map(|app| app.task_count)
209            .unwrap_or(0);
210
211        Ok(Response::new(GetTaskCountResponse { task_count }))
212    }
213}