alien_bindings/grpc/
wait_until_service.rs1#![cfg(feature = "grpc")]
2
3use alien_error::AlienError;
4use async_trait::async_trait;
5use std::{collections::HashMap, sync::Arc};
6use tokio::sync::{oneshot, Mutex};
7use tonic::{Request, Response, Status};
8use tracing::{debug, info, warn};
9
10pub mod alien_bindings {
12 pub mod wait_until {
13 tonic::include_proto!("alien_bindings.wait_until");
14 pub const FILE_DESCRIPTOR_SET: &[u8] =
15 tonic::include_file_descriptor_set!("alien_bindings.wait_until_descriptor");
16 }
17}
18
19use alien_bindings::wait_until::{
20 wait_until_service_server::{WaitUntilService, WaitUntilServiceServer},
21 GetTaskCountRequest, GetTaskCountResponse, NotifyDrainCompleteRequest,
22 NotifyDrainCompleteResponse, NotifyTaskRegisteredRequest, NotifyTaskRegisteredResponse,
23 WaitForDrainSignalRequest, WaitForDrainSignalResponse,
24};
25
26#[derive(Debug)]
28struct ApplicationState {
29 task_count: u32,
31 drain_signal_sender: Option<oneshot::Sender<WaitForDrainSignalResponse>>,
33}
34
35#[derive(Clone)]
36pub struct WaitUntilGrpcServer {
37 applications: Arc<Mutex<HashMap<String, ApplicationState>>>,
39}
40
41impl WaitUntilGrpcServer {
42 pub fn new() -> Self {
43 Self {
44 applications: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47
48 pub fn into_service(self) -> WaitUntilServiceServer<Self> {
49 WaitUntilServiceServer::new(self)
50 }
51
52 pub async fn trigger_drain_all(
55 &self,
56 reason: &str,
57 timeout_secs: u64,
58 ) -> Result<(), AlienError> {
59 let mut applications = self.applications.lock().await;
60
61 info!("Triggering drain for {} applications", applications.len());
62
63 for (app_id, app_state) in applications.iter_mut() {
64 if let Some(sender) = app_state.drain_signal_sender.take() {
65 let response = WaitForDrainSignalResponse {
66 should_drain: true,
67 drain_timeout: Some(prost_types::Duration {
68 seconds: timeout_secs as i64,
69 nanos: 0,
70 }),
71 drain_reason: reason.to_string(),
72 };
73
74 if let Err(_) = sender.send(response) {
75 warn!("Failed to send drain signal to application {}", app_id);
76 }
77 }
78 }
79
80 Ok(())
81 }
82
83 pub async fn get_total_task_count(&self) -> u32 {
85 let applications = self.applications.lock().await;
86 applications.values().map(|app| app.task_count).sum()
87 }
88}
89
90#[async_trait]
91impl WaitUntilService for WaitUntilGrpcServer {
92 async fn notify_task_registered(
93 &self,
94 request: Request<NotifyTaskRegisteredRequest>,
95 ) -> Result<Response<NotifyTaskRegisteredResponse>, Status> {
96 let req = request.into_inner();
97 let app_id = req.application_id;
98
99 debug!(
100 app_id = %app_id,
101 task_description = %req.task_description.as_deref().unwrap_or_default(),
102 "Task registered"
103 );
104
105 let mut applications = self.applications.lock().await;
106 let app_state = applications
107 .entry(app_id.clone())
108 .or_insert_with(|| ApplicationState {
109 task_count: 0,
110 drain_signal_sender: None,
111 });
112
113 app_state.task_count += 1;
114
115 debug!(app_id = %app_id, task_count = app_state.task_count, "Updated task count");
116
117 Ok(Response::new(NotifyTaskRegisteredResponse {
118 success: true,
119 }))
120 }
121
122 async fn wait_for_drain_signal(
123 &self,
124 request: Request<WaitForDrainSignalRequest>,
125 ) -> Result<Response<WaitForDrainSignalResponse>, Status> {
126 let req = request.into_inner();
127 let app_id = req.application_id;
128
129 debug!(app_id = %app_id, "Application waiting for drain signal");
130
131 let (sender, receiver) = oneshot::channel();
132
133 {
135 let mut applications = self.applications.lock().await;
136 let app_state =
137 applications
138 .entry(app_id.clone())
139 .or_insert_with(|| ApplicationState {
140 task_count: 0,
141 drain_signal_sender: None,
142 });
143 app_state.drain_signal_sender = Some(sender);
144 }
145
146 match receiver.await {
148 Ok(response) => {
149 debug!(app_id = %app_id, reason = %response.drain_reason, "Sending drain signal to application");
150 Ok(Response::new(response))
151 }
152 Err(_) => {
153 warn!(app_id = %app_id, "Drain signal channel dropped");
155 Ok(Response::new(WaitForDrainSignalResponse {
156 should_drain: true,
157 drain_timeout: Some(prost_types::Duration {
158 seconds: 10,
159 nanos: 0,
160 }),
161 drain_reason: "runtime_shutdown".to_string(),
162 }))
163 }
164 }
165 }
166
167 async fn notify_drain_complete(
168 &self,
169 request: Request<NotifyDrainCompleteRequest>,
170 ) -> Result<Response<NotifyDrainCompleteResponse>, Status> {
171 let req = request.into_inner();
172 let app_id = req.application_id;
173
174 info!(
175 app_id = %app_id,
176 tasks_drained = req.tasks_drained,
177 success = req.success,
178 error = %req.error_message.as_deref().unwrap_or_default(),
179 "Application completed draining"
180 );
181
182 {
184 let mut applications = self.applications.lock().await;
185 if let Some(app_state) = applications.get_mut(&app_id) {
186 app_state.task_count = app_state.task_count.saturating_sub(req.tasks_drained);
187 }
188 }
189
190 Ok(Response::new(NotifyDrainCompleteResponse {
191 acknowledged: true,
192 }))
193 }
194
195 async fn get_task_count(
196 &self,
197 request: Request<GetTaskCountRequest>,
198 ) -> Result<Response<GetTaskCountResponse>, Status> {
199 let req = request.into_inner();
200 let app_id = req.application_id;
201
202 let applications = self.applications.lock().await;
203 let task_count = applications
204 .get(&app_id)
205 .map(|app| app.task_count)
206 .unwrap_or(0);
207
208 Ok(Response::new(GetTaskCountResponse { task_count }))
209 }
210}