alien_bindings/grpc/
wait_until_service.rs1#![cfg(feature = "grpc")]
2
3use crate::BindingsProviderApi;
4use alien_error::AlienError;
5use async_trait::async_trait;
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::{oneshot, Mutex};
8use tonic::{Request, Response, Status};
9use tracing::{debug, info, warn};
10
11pub mod alien_bindings {
13 pub mod wait_until {
14 tonic::include_proto!("alien_bindings.wait_until");
15 pub const FILE_DESCRIPTOR_SET: &[u8] =
16 tonic::include_file_descriptor_set!("alien_bindings.wait_until_descriptor");
17 }
18}
19
20use alien_bindings::wait_until::{
21 wait_until_service_server::{WaitUntilService, WaitUntilServiceServer},
22 GetTaskCountRequest, GetTaskCountResponse, NotifyDrainCompleteRequest,
23 NotifyDrainCompleteResponse, NotifyTaskRegisteredRequest, NotifyTaskRegisteredResponse,
24 WaitForDrainSignalRequest, WaitForDrainSignalResponse,
25};
26
27#[derive(Debug)]
29struct ApplicationState {
30 task_count: u32,
32 drain_signal_sender: Option<oneshot::Sender<WaitForDrainSignalResponse>>,
34}
35
36#[derive(Clone)]
37pub struct WaitUntilGrpcServer {
38 provider: Arc<dyn BindingsProviderApi>,
39 applications: Arc<Mutex<HashMap<String, ApplicationState>>>,
41}
42
43impl WaitUntilGrpcServer {
44 pub fn new(provider: Arc<dyn BindingsProviderApi>) -> Self {
45 Self {
46 provider,
47 applications: Arc::new(Mutex::new(HashMap::new())),
48 }
49 }
50
51 pub fn into_service(self) -> WaitUntilServiceServer<Self> {
52 WaitUntilServiceServer::new(self)
53 }
54
55 pub async fn trigger_drain_all(
58 &self,
59 reason: &str,
60 timeout_secs: u64,
61 ) -> Result<(), AlienError> {
62 let mut applications = self.applications.lock().await;
63
64 info!("Triggering drain for {} applications", applications.len());
65
66 for (app_id, app_state) in applications.iter_mut() {
67 if let Some(sender) = app_state.drain_signal_sender.take() {
68 let response = WaitForDrainSignalResponse {
69 should_drain: true,
70 drain_timeout: Some(prost_types::Duration {
71 seconds: timeout_secs as i64,
72 nanos: 0,
73 }),
74 drain_reason: reason.to_string(),
75 };
76
77 if let Err(_) = sender.send(response) {
78 warn!("Failed to send drain signal to application {}", app_id);
79 }
80 }
81 }
82
83 Ok(())
84 }
85
86 pub async fn get_total_task_count(&self) -> u32 {
88 let applications = self.applications.lock().await;
89 applications.values().map(|app| app.task_count).sum()
90 }
91}
92
93#[async_trait]
94impl WaitUntilService for WaitUntilGrpcServer {
95 async fn notify_task_registered(
96 &self,
97 request: Request<NotifyTaskRegisteredRequest>,
98 ) -> Result<Response<NotifyTaskRegisteredResponse>, Status> {
99 let req = request.into_inner();
100 let app_id = req.application_id;
101
102 debug!(
103 app_id = %app_id,
104 task_description = %req.task_description.as_deref().unwrap_or_default(),
105 "Task registered"
106 );
107
108 let mut applications = self.applications.lock().await;
109 let app_state = applications
110 .entry(app_id.clone())
111 .or_insert_with(|| ApplicationState {
112 task_count: 0,
113 drain_signal_sender: None,
114 });
115
116 app_state.task_count += 1;
117
118 debug!(app_id = %app_id, task_count = app_state.task_count, "Updated task count");
119
120 Ok(Response::new(NotifyTaskRegisteredResponse {
121 success: true,
122 }))
123 }
124
125 async fn wait_for_drain_signal(
126 &self,
127 request: Request<WaitForDrainSignalRequest>,
128 ) -> Result<Response<WaitForDrainSignalResponse>, Status> {
129 let req = request.into_inner();
130 let app_id = req.application_id;
131
132 debug!(app_id = %app_id, "Application waiting for drain signal");
133
134 let (sender, receiver) = oneshot::channel();
135
136 {
138 let mut applications = self.applications.lock().await;
139 let app_state =
140 applications
141 .entry(app_id.clone())
142 .or_insert_with(|| ApplicationState {
143 task_count: 0,
144 drain_signal_sender: None,
145 });
146 app_state.drain_signal_sender = Some(sender);
147 }
148
149 match receiver.await {
151 Ok(response) => {
152 debug!(app_id = %app_id, reason = %response.drain_reason, "Sending drain signal to application");
153 Ok(Response::new(response))
154 }
155 Err(_) => {
156 warn!(app_id = %app_id, "Drain signal channel dropped");
158 Ok(Response::new(WaitForDrainSignalResponse {
159 should_drain: true,
160 drain_timeout: Some(prost_types::Duration {
161 seconds: 10,
162 nanos: 0,
163 }),
164 drain_reason: "runtime_shutdown".to_string(),
165 }))
166 }
167 }
168 }
169
170 async fn notify_drain_complete(
171 &self,
172 request: Request<NotifyDrainCompleteRequest>,
173 ) -> Result<Response<NotifyDrainCompleteResponse>, Status> {
174 let req = request.into_inner();
175 let app_id = req.application_id;
176
177 info!(
178 app_id = %app_id,
179 tasks_drained = req.tasks_drained,
180 success = req.success,
181 error = %req.error_message.as_deref().unwrap_or_default(),
182 "Application completed draining"
183 );
184
185 {
187 let mut applications = self.applications.lock().await;
188 if let Some(app_state) = applications.get_mut(&app_id) {
189 app_state.task_count = app_state.task_count.saturating_sub(req.tasks_drained);
190 }
191 }
192
193 Ok(Response::new(NotifyDrainCompleteResponse {
194 acknowledged: true,
195 }))
196 }
197
198 async fn get_task_count(
199 &self,
200 request: Request<GetTaskCountRequest>,
201 ) -> Result<Response<GetTaskCountResponse>, Status> {
202 let req = request.into_inner();
203 let app_id = req.application_id;
204
205 let applications = self.applications.lock().await;
206 let task_count = applications
207 .get(&app_id)
208 .map(|app| app.task_count)
209 .unwrap_or(0);
210
211 Ok(Response::new(GetTaskCountResponse { task_count }))
212 }
213}