Skip to main content

alien_bindings/grpc/
wait_until_service.rs

1#![cfg(feature = "grpc")]
2
3use alien_error::AlienError;
4use async_trait::async_trait;
5use std::{collections::HashMap, sync::Arc};
6use tokio::sync::{oneshot, Mutex};
7use tonic::{Request, Response, Status};
8use tracing::{debug, info, warn};
9
10// Module for the generated gRPC code.
11pub mod alien_bindings {
12    pub mod wait_until {
13        tonic::include_proto!("alien_bindings.wait_until");
14        pub const FILE_DESCRIPTOR_SET: &[u8] =
15            tonic::include_file_descriptor_set!("alien_bindings.wait_until_descriptor");
16    }
17}
18
19use alien_bindings::wait_until::{
20    wait_until_service_server::{WaitUntilService, WaitUntilServiceServer},
21    GetTaskCountRequest, GetTaskCountResponse, NotifyDrainCompleteRequest,
22    NotifyDrainCompleteResponse, NotifyTaskRegisteredRequest, NotifyTaskRegisteredResponse,
23    WaitForDrainSignalRequest, WaitForDrainSignalResponse,
24};
25
26/// Represents an application that has registered tasks and may be waiting for drain signals.
27#[derive(Debug)]
28struct ApplicationState {
29    /// Number of currently registered tasks.
30    task_count: u32,
31    /// Channel to send drain signal to the application when it's time to drain.
32    drain_signal_sender: Option<oneshot::Sender<WaitForDrainSignalResponse>>,
33}
34
35#[derive(Clone)]
36pub struct WaitUntilGrpcServer {
37    /// Track application states by application_id.
38    applications: Arc<Mutex<HashMap<String, ApplicationState>>>,
39}
40
41impl WaitUntilGrpcServer {
42    pub fn new() -> Self {
43        Self {
44            applications: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    pub fn into_service(self) -> WaitUntilServiceServer<Self> {
49        WaitUntilServiceServer::new(self)
50    }
51
52    /// Trigger drain for all registered applications.
53    /// This is called by the runtime when it's time to drain (e.g., on SIGTERM or Lambda INVOKE end).
54    pub async fn trigger_drain_all(
55        &self,
56        reason: &str,
57        timeout_secs: u64,
58    ) -> Result<(), AlienError> {
59        let mut applications = self.applications.lock().await;
60
61        info!("Triggering drain for {} applications", applications.len());
62
63        for (app_id, app_state) in applications.iter_mut() {
64            if let Some(sender) = app_state.drain_signal_sender.take() {
65                let response = WaitForDrainSignalResponse {
66                    should_drain: true,
67                    drain_timeout: Some(prost_types::Duration {
68                        seconds: timeout_secs as i64,
69                        nanos: 0,
70                    }),
71                    drain_reason: reason.to_string(),
72                };
73
74                if let Err(_) = sender.send(response) {
75                    warn!("Failed to send drain signal to application {}", app_id);
76                }
77            }
78        }
79
80        Ok(())
81    }
82
83    /// Get total number of tasks across all applications.
84    pub async fn get_total_task_count(&self) -> u32 {
85        let applications = self.applications.lock().await;
86        applications.values().map(|app| app.task_count).sum()
87    }
88}
89
90#[async_trait]
91impl WaitUntilService for WaitUntilGrpcServer {
92    async fn notify_task_registered(
93        &self,
94        request: Request<NotifyTaskRegisteredRequest>,
95    ) -> Result<Response<NotifyTaskRegisteredResponse>, Status> {
96        let req = request.into_inner();
97        let app_id = req.application_id;
98
99        debug!(
100            app_id = %app_id,
101            task_description = %req.task_description.as_deref().unwrap_or_default(),
102            "Task registered"
103        );
104
105        let mut applications = self.applications.lock().await;
106        let app_state = applications
107            .entry(app_id.clone())
108            .or_insert_with(|| ApplicationState {
109                task_count: 0,
110                drain_signal_sender: None,
111            });
112
113        app_state.task_count += 1;
114
115        debug!(app_id = %app_id, task_count = app_state.task_count, "Updated task count");
116
117        Ok(Response::new(NotifyTaskRegisteredResponse {
118            success: true,
119        }))
120    }
121
122    async fn wait_for_drain_signal(
123        &self,
124        request: Request<WaitForDrainSignalRequest>,
125    ) -> Result<Response<WaitForDrainSignalResponse>, Status> {
126        let req = request.into_inner();
127        let app_id = req.application_id;
128
129        debug!(app_id = %app_id, "Application waiting for drain signal");
130
131        let (sender, receiver) = oneshot::channel();
132
133        // Store the sender in the application state
134        {
135            let mut applications = self.applications.lock().await;
136            let app_state =
137                applications
138                    .entry(app_id.clone())
139                    .or_insert_with(|| ApplicationState {
140                        task_count: 0,
141                        drain_signal_sender: None,
142                    });
143            app_state.drain_signal_sender = Some(sender);
144        }
145
146        // Wait for the drain signal
147        match receiver.await {
148            Ok(response) => {
149                debug!(app_id = %app_id, reason = %response.drain_reason, "Sending drain signal to application");
150                Ok(Response::new(response))
151            }
152            Err(_) => {
153                // Channel was dropped, likely due to shutdown
154                warn!(app_id = %app_id, "Drain signal channel dropped");
155                Ok(Response::new(WaitForDrainSignalResponse {
156                    should_drain: true,
157                    drain_timeout: Some(prost_types::Duration {
158                        seconds: 10,
159                        nanos: 0,
160                    }),
161                    drain_reason: "runtime_shutdown".to_string(),
162                }))
163            }
164        }
165    }
166
167    async fn notify_drain_complete(
168        &self,
169        request: Request<NotifyDrainCompleteRequest>,
170    ) -> Result<Response<NotifyDrainCompleteResponse>, Status> {
171        let req = request.into_inner();
172        let app_id = req.application_id;
173
174        info!(
175            app_id = %app_id,
176            tasks_drained = req.tasks_drained,
177            success = req.success,
178            error = %req.error_message.as_deref().unwrap_or_default(),
179            "Application completed draining"
180        );
181
182        // Update the application state to reflect that tasks have been drained
183        {
184            let mut applications = self.applications.lock().await;
185            if let Some(app_state) = applications.get_mut(&app_id) {
186                app_state.task_count = app_state.task_count.saturating_sub(req.tasks_drained);
187            }
188        }
189
190        Ok(Response::new(NotifyDrainCompleteResponse {
191            acknowledged: true,
192        }))
193    }
194
195    async fn get_task_count(
196        &self,
197        request: Request<GetTaskCountRequest>,
198    ) -> Result<Response<GetTaskCountResponse>, Status> {
199        let req = request.into_inner();
200        let app_id = req.application_id;
201
202        let applications = self.applications.lock().await;
203        let task_count = applications
204            .get(&app_id)
205            .map(|app| app.task_count)
206            .unwrap_or(0);
207
208        Ok(Response::new(GetTaskCountResponse { task_count }))
209    }
210}