1use crate::{
2 error::{ErrorData, Result},
3 traits::Binding,
4};
5use alien_error::{AlienError, Context, IntoAlienError};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::{
9 collections::HashMap,
10 future::Future,
11 sync::{
12 atomic::{AtomicU32, Ordering},
13 Arc,
14 },
15 time::Duration,
16};
17use tokio::{sync::Mutex, task::JoinHandle, time::timeout};
18#[cfg(feature = "grpc")]
19use tonic::transport::Channel;
20use tracing::{debug, error, info, warn};
21use uuid::Uuid;
22
23#[cfg(feature = "openapi")]
24use utoipa::ToSchema;
25
26#[cfg(feature = "grpc")]
27use crate::grpc::wait_until_service::alien_bindings::wait_until::{
28 wait_until_service_client::WaitUntilServiceClient, NotifyDrainCompleteRequest,
29 NotifyTaskRegisteredRequest, WaitForDrainSignalRequest,
30};
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35#[cfg_attr(feature = "openapi", derive(ToSchema))]
36pub struct DrainResponse {
37 pub tasks_drained: u32,
39 pub success: bool,
41 pub error_message: Option<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48#[cfg_attr(feature = "openapi", derive(ToSchema))]
49pub struct DrainConfig {
50 pub timeout: Duration,
52 pub reason: String,
54}
55
56#[async_trait]
59pub trait WaitUntil: Binding {
60 async fn wait_for_drain_signal(&self, timeout: Option<Duration>) -> Result<DrainConfig>;
63
64 async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse>;
67
68 async fn get_task_count(&self) -> Result<u32>;
70
71 async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()>;
73}
74
75#[derive(Debug)]
78pub struct WaitUntilContext {
79 application_id: String,
81 tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
83 task_counter: AtomicU32,
85 #[cfg(feature = "grpc")]
87 grpc_client: Option<WaitUntilServiceClient<Channel>>,
88 draining: Arc<Mutex<bool>>,
90}
91
92impl WaitUntilContext {
93 pub fn new(application_id: Option<String>) -> Self {
95 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
96
97 Self {
98 application_id: app_id,
99 tasks: Arc::new(Mutex::new(HashMap::new())),
100 task_counter: AtomicU32::new(0),
101 #[cfg(feature = "grpc")]
102 grpc_client: None,
103 draining: Arc::new(Mutex::new(false)),
104 }
105 }
106
107 pub async fn from_env(application_id: Option<String>) -> Result<Self> {
110 let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
111 Self::from_env_with_vars(application_id, &env_vars).await
112 }
113
114 pub async fn from_env_with_vars(
116 application_id: Option<String>,
117 env_vars: &std::collections::HashMap<String, String>,
118 ) -> Result<Self> {
119 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
120
121 #[cfg(feature = "grpc")]
122 {
123 let bindings_mode = crate::get_bindings_mode_from_env(env_vars)?;
124
125 match bindings_mode {
126 crate::BindingsMode::Direct => {
127 return Ok(Self::new(Some(app_id)));
129 }
130 crate::BindingsMode::Grpc => {
131 let grpc_address =
133 env_vars.get("ALIEN_BINDINGS_GRPC_ADDRESS").ok_or_else(|| {
134 AlienError::new(ErrorData::EnvironmentVariableMissing {
135 variable_name: "ALIEN_BINDINGS_GRPC_ADDRESS".to_string(),
136 })
137 })?;
138
139 let channel = Self::create_grpc_channel(grpc_address.clone()).await?;
141 let grpc_client = WaitUntilServiceClient::new(channel);
142
143 return Ok(Self {
144 application_id: app_id,
145 tasks: Arc::new(Mutex::new(HashMap::new())),
146 task_counter: AtomicU32::new(0),
147 grpc_client: Some(grpc_client),
148 draining: Arc::new(Mutex::new(false)),
149 });
150 }
151 }
152 }
153
154 #[cfg(not(feature = "grpc"))]
155 {
156 Ok(Self::new(Some(app_id)))
157 }
158 }
159
160 #[cfg(feature = "grpc")]
163 async fn create_grpc_channel(grpc_address: String) -> Result<Channel> {
164 use std::time::Duration;
165
166 let endpoint_uri = if grpc_address.contains("://") {
168 grpc_address.clone()
169 } else {
170 format!("http://{}", grpc_address)
171 };
172
173 let endpoint = Channel::from_shared(endpoint_uri.clone())
174 .into_alien_error()
175 .context(ErrorData::GrpcConnectionFailed {
176 endpoint: endpoint_uri.clone(),
177 reason: "Invalid gRPC endpoint URI format".to_string(),
178 })?
179 .timeout(Duration::from_secs(300)) .connect_timeout(Duration::from_secs(5)) .http2_keep_alive_interval(Duration::from_secs(30)) .keep_alive_timeout(Duration::from_secs(10))
183 .keep_alive_while_idle(true); let channel = endpoint.connect().await.into_alien_error().context(
186 ErrorData::GrpcConnectionFailed {
187 endpoint: grpc_address.clone(),
188 reason: "Failed to establish gRPC connection".to_string(),
189 },
190 )?;
191
192 Ok(channel)
193 }
194
195 #[cfg(feature = "grpc")]
197 pub fn new_with_grpc_client(
198 application_id: Option<String>,
199 grpc_client: WaitUntilServiceClient<Channel>,
200 ) -> Self {
201 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
202
203 Self {
204 application_id: app_id,
205 tasks: Arc::new(Mutex::new(HashMap::new())),
206 task_counter: AtomicU32::new(0),
207 grpc_client: Some(grpc_client),
208 draining: Arc::new(Mutex::new(false)),
209 }
210 }
211
212 #[cfg(feature = "grpc")]
214 pub fn set_grpc_client(&mut self, client: WaitUntilServiceClient<Channel>) {
215 self.grpc_client = Some(client);
216 }
217
218 pub fn application_id(&self) -> &str {
220 &self.application_id
221 }
222
223 pub async fn start_drain_listener(&self) -> Result<()> {
226 #[cfg(feature = "grpc")]
227 {
228 if let Some(mut client) = self.grpc_client.clone() {
229 let app_id = self.application_id.clone();
230 let context = self.clone_for_background();
231
232 tokio::spawn(async move {
233 loop {
234 debug!(app_id = %app_id, "Waiting for drain signal from runtime");
235
236 let request = WaitForDrainSignalRequest {
237 application_id: app_id.clone(),
238 timeout: Some(prost_types::Duration {
239 seconds: 300, nanos: 0,
241 }),
242 };
243
244 match client.wait_for_drain_signal(request).await {
245 Ok(response) => {
246 let resp = response.into_inner();
247 if resp.should_drain {
248 info!(
249 app_id = %app_id,
250 reason = %resp.drain_reason,
251 "Received drain signal from runtime"
252 );
253
254 let drain_timeout = resp
255 .drain_timeout
256 .map(|d| Duration::from_secs(d.seconds as u64))
257 .unwrap_or(Duration::from_secs(10));
258
259 let config = DrainConfig {
260 timeout: drain_timeout,
261 reason: resp.drain_reason,
262 };
263
264 match context.drain_all(config).await {
266 Ok(drain_response) => {
267 let complete_request = NotifyDrainCompleteRequest {
269 application_id: app_id.clone(),
270 tasks_drained: drain_response.tasks_drained,
271 success: drain_response.success,
272 error_message: drain_response.error_message,
273 };
274
275 if let Err(e) =
276 client.notify_drain_complete(complete_request).await
277 {
278 error!(app_id = %app_id, error = %e, "Failed to notify runtime of drain completion");
279 } else {
280 info!(app_id = %app_id, "Successfully notified runtime of drain completion");
281 }
282 }
283 Err(e) => {
284 error!(app_id = %app_id, error = %e, "Failed to drain tasks");
285 let complete_request = NotifyDrainCompleteRequest {
287 application_id: app_id.clone(),
288 tasks_drained: 0,
289 success: false,
290 error_message: Some(e.to_string()),
291 };
292 let _ = client
293 .notify_drain_complete(complete_request)
294 .await;
295 }
296 }
297 }
298 }
299 Err(e) => {
300 warn!(app_id = %app_id, error = %e, "Failed to wait for drain signal, retrying in 5 seconds");
301 tokio::time::sleep(Duration::from_secs(5)).await;
302 }
303 }
304 }
305 });
306 }
307 }
308
309 Ok(())
310 }
311
312 fn clone_for_background(&self) -> Self {
314 Self {
315 application_id: self.application_id.clone(),
316 tasks: Arc::clone(&self.tasks),
317 task_counter: AtomicU32::new(self.task_counter.load(Ordering::Relaxed)),
318 #[cfg(feature = "grpc")]
319 grpc_client: self.grpc_client.clone(),
320 draining: Arc::clone(&self.draining),
321 }
322 }
323
324 async fn notify_task_registered(&self, task_description: String) -> Result<()> {
326 #[cfg(feature = "grpc")]
327 {
328 if let Some(mut client) = self.grpc_client.clone() {
329 let request = NotifyTaskRegisteredRequest {
330 application_id: self.application_id.clone(),
331 task_description: Some(task_description),
332 };
333
334 client
335 .notify_task_registered(request)
336 .await
337 .into_alien_error()
338 .context(ErrorData::HttpRequestFailed {
339 url: "grpc://wait_until_service".to_string(),
340 method: "notify_task_registered".to_string(),
341 })?;
342 }
343 }
344
345 Ok(())
346 }
347}
348
349impl WaitUntilContext {
350 pub fn wait_until<F, Fut>(&self, task_fn: F) -> Result<()>
353 where
354 F: FnOnce() -> Fut + Send + 'static,
355 Fut: Future<Output = ()> + Send + 'static,
356 {
357 let task_id = self.task_counter.fetch_add(1, Ordering::Relaxed);
358 let task_key = format!("task_{}", task_id);
359 let task_description = format!("wait_until_task_{}", task_id);
360
361 let draining = self.draining.clone();
363 let tasks = self.tasks.clone();
364 let app_id = self.application_id.clone();
365 let task_key_clone = task_key.clone();
366
367 let handle = tokio::spawn(async move {
369 if *draining.lock().await {
371 warn!(app_id = %app_id, task_id = %task_key_clone, "Rejecting new task - currently draining");
372 return;
373 }
374
375 debug!(app_id = %app_id, task_id = %task_key_clone, "Starting wait_until task");
376
377 let future = task_fn();
378 future.await;
379
380 debug!(app_id = %app_id, task_id = %task_key_clone, "Completed wait_until task");
381
382 tasks.lock().await.remove(&task_key_clone);
384 });
385
386 {
388 let mut tasks_guard = futures::executor::block_on(self.tasks.lock());
389 tasks_guard.insert(task_key.clone(), handle);
390 }
391
392 let context_clone = self.clone_for_background();
394 tokio::spawn(async move {
395 if let Err(e) = context_clone.notify_task_registered(task_description).await {
396 warn!(app_id = %context_clone.application_id, task_id = %task_key, error = %e, "Failed to notify runtime of task registration");
397 }
398 });
399
400 Ok(())
401 }
402}
403
404impl Binding for WaitUntilContext {}
405
406#[async_trait]
407impl WaitUntil for WaitUntilContext {
408 async fn wait_for_drain_signal(
409 &self,
410 timeout_duration: Option<Duration>,
411 ) -> Result<DrainConfig> {
412 #[cfg(feature = "grpc")]
413 {
414 if let Some(mut client) = self.grpc_client.clone() {
415 let timeout_proto = timeout_duration.map(|d| prost_types::Duration {
416 seconds: d.as_secs() as i64,
417 nanos: d.subsec_nanos() as i32,
418 });
419
420 let request = WaitForDrainSignalRequest {
421 application_id: self.application_id.clone(),
422 timeout: timeout_proto,
423 };
424
425 let response = client
426 .wait_for_drain_signal(request)
427 .await
428 .into_alien_error()
429 .context(ErrorData::HttpRequestFailed {
430 url: "grpc://wait_until_service".to_string(),
431 method: "wait_for_drain_signal".to_string(),
432 })?;
433
434 let resp = response.into_inner();
435 if resp.should_drain {
436 let drain_timeout = resp
437 .drain_timeout
438 .map(|d| Duration::from_secs(d.seconds as u64))
439 .unwrap_or(Duration::from_secs(10));
440
441 return Ok(DrainConfig {
442 timeout: drain_timeout,
443 reason: resp.drain_reason,
444 });
445 }
446 }
447 }
448
449 Err(AlienError::new(ErrorData::Other {
451 message: "No drain signal received or gRPC client not available".to_string(),
452 }))
453 }
454
455 async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse> {
456 info!(
457 app_id = %self.application_id,
458 reason = %config.reason,
459 timeout_secs = config.timeout.as_secs(),
460 "Starting to drain all wait_until tasks"
461 );
462
463 {
465 let mut draining_guard = self.draining.lock().await;
466 *draining_guard = true;
467 }
468
469 let tasks_to_drain = {
470 let mut tasks_guard = self.tasks.lock().await;
471 std::mem::take(&mut *tasks_guard) };
473
474 let task_count = tasks_to_drain.len() as u32;
475 info!(app_id = %self.application_id, task_count = task_count, "Draining tasks");
476
477 let mut success = true;
478 let mut error_messages = Vec::new();
479
480 let drain_result = timeout(config.timeout, async {
482 for (task_id, handle) in tasks_to_drain {
483 match handle.await {
484 Ok(_) => {
485 debug!(app_id = %self.application_id, task_id = %task_id, "Task completed successfully");
486 }
487 Err(e) => {
488 warn!(app_id = %self.application_id, task_id = %task_id, error = %e, "Task failed");
489 success = false;
490 error_messages.push(format!("Task {} failed: {}", task_id, e));
491 }
492 }
493 }
494 })
495 .await;
496
497 match drain_result {
498 Ok(_) => {
499 info!(app_id = %self.application_id, "All tasks drained successfully");
500 }
501 Err(_) => {
502 warn!(app_id = %self.application_id, "Drain timeout exceeded");
503 success = false;
504 error_messages.push("Drain timeout exceeded".to_string());
505 }
506 }
507
508 {
510 let mut draining_guard = self.draining.lock().await;
511 *draining_guard = false;
512 }
513
514 let error_message = if error_messages.is_empty() {
515 None
516 } else {
517 Some(error_messages.join("; "))
518 };
519
520 Ok(DrainResponse {
521 tasks_drained: task_count,
522 success,
523 error_message,
524 })
525 }
526
527 async fn get_task_count(&self) -> Result<u32> {
528 let tasks_guard = self.tasks.lock().await;
529 Ok(tasks_guard.len() as u32)
530 }
531
532 async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()> {
533 #[cfg(feature = "grpc")]
534 {
535 if let Some(mut client) = self.grpc_client.clone() {
536 let request = NotifyDrainCompleteRequest {
537 application_id: self.application_id.clone(),
538 tasks_drained: response.tasks_drained,
539 success: response.success,
540 error_message: response.error_message,
541 };
542
543 client
544 .notify_drain_complete(request)
545 .await
546 .into_alien_error()
547 .context(ErrorData::HttpRequestFailed {
548 url: "grpc://wait_until_service".to_string(),
549 method: "notify_drain_complete".to_string(),
550 })?;
551 }
552 }
553
554 Ok(())
555 }
556}