1use crate::{
2 error::{ErrorData, Result},
3 traits::Binding,
4};
5use alien_error::{AlienError, Context, IntoAlienError};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::{
9 collections::HashMap,
10 future::Future,
11 sync::{
12 atomic::{AtomicU32, Ordering},
13 Arc,
14 },
15 time::Duration,
16};
17use tokio::{
18 sync::{oneshot, Mutex},
19 task::JoinHandle,
20 time::timeout,
21};
22#[cfg(feature = "grpc")]
23use tonic::transport::Channel;
24use tracing::{debug, error, info, warn};
25use uuid::Uuid;
26
27#[cfg(feature = "openapi")]
28use utoipa::ToSchema;
29
30#[cfg(feature = "grpc")]
31use crate::grpc::wait_until_service::alien_bindings::wait_until::{
32 wait_until_service_client::WaitUntilServiceClient, GetTaskCountRequest,
33 NotifyDrainCompleteRequest, NotifyTaskRegisteredRequest, WaitForDrainSignalRequest,
34};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39#[cfg_attr(feature = "openapi", derive(ToSchema))]
40pub struct DrainResponse {
41 pub tasks_drained: u32,
43 pub success: bool,
45 pub error_message: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52#[cfg_attr(feature = "openapi", derive(ToSchema))]
53pub struct DrainConfig {
54 pub timeout: Duration,
56 pub reason: String,
58}
59
60#[async_trait]
63pub trait WaitUntil: Binding {
64 async fn wait_for_drain_signal(&self, timeout: Option<Duration>) -> Result<DrainConfig>;
67
68 async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse>;
71
72 async fn get_task_count(&self) -> Result<u32>;
74
75 async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()>;
77}
78
79#[derive(Debug)]
82pub struct WaitUntilContext {
83 application_id: String,
85 tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
87 task_counter: AtomicU32,
89 #[cfg(feature = "grpc")]
91 grpc_client: Option<WaitUntilServiceClient<Channel>>,
92 draining: Arc<Mutex<bool>>,
94}
95
96impl WaitUntilContext {
97 pub fn new(application_id: Option<String>) -> Self {
99 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
100
101 Self {
102 application_id: app_id,
103 tasks: Arc::new(Mutex::new(HashMap::new())),
104 task_counter: AtomicU32::new(0),
105 #[cfg(feature = "grpc")]
106 grpc_client: None,
107 draining: Arc::new(Mutex::new(false)),
108 }
109 }
110
111 pub async fn from_env(application_id: Option<String>) -> Result<Self> {
114 let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
115 Self::from_env_with_vars(application_id, &env_vars).await
116 }
117
118 pub async fn from_env_with_vars(
120 application_id: Option<String>,
121 env_vars: &std::collections::HashMap<String, String>,
122 ) -> Result<Self> {
123 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
124
125 #[cfg(feature = "grpc")]
126 {
127 let bindings_mode = crate::get_bindings_mode_from_env(env_vars)?;
128
129 match bindings_mode {
130 crate::BindingsMode::Direct => {
131 return Ok(Self::new(Some(app_id)));
133 }
134 crate::BindingsMode::Grpc => {
135 let grpc_address =
137 env_vars.get("ALIEN_BINDINGS_GRPC_ADDRESS").ok_or_else(|| {
138 AlienError::new(ErrorData::EnvironmentVariableMissing {
139 variable_name: "ALIEN_BINDINGS_GRPC_ADDRESS".to_string(),
140 })
141 })?;
142
143 let channel = Self::create_grpc_channel(grpc_address.clone()).await?;
145 let grpc_client = WaitUntilServiceClient::new(channel);
146
147 return Ok(Self {
148 application_id: app_id,
149 tasks: Arc::new(Mutex::new(HashMap::new())),
150 task_counter: AtomicU32::new(0),
151 grpc_client: Some(grpc_client),
152 draining: Arc::new(Mutex::new(false)),
153 });
154 }
155 }
156 }
157
158 #[cfg(not(feature = "grpc"))]
159 {
160 Ok(Self::new(Some(app_id)))
161 }
162 }
163
164 #[cfg(feature = "grpc")]
167 async fn create_grpc_channel(grpc_address: String) -> Result<Channel> {
168 use std::time::Duration;
169
170 let endpoint_uri = if grpc_address.contains("://") {
172 grpc_address.clone()
173 } else {
174 format!("http://{}", grpc_address)
175 };
176
177 let endpoint = Channel::from_shared(endpoint_uri.clone())
178 .into_alien_error()
179 .context(ErrorData::GrpcConnectionFailed {
180 endpoint: endpoint_uri.clone(),
181 reason: "Invalid gRPC endpoint URI format".to_string(),
182 })?
183 .timeout(Duration::from_secs(300)) .connect_timeout(Duration::from_secs(5)) .http2_keep_alive_interval(Duration::from_secs(30)) .keep_alive_timeout(Duration::from_secs(10))
187 .keep_alive_while_idle(true); let channel = endpoint.connect().await.into_alien_error().context(
190 ErrorData::GrpcConnectionFailed {
191 endpoint: grpc_address.clone(),
192 reason: "Failed to establish gRPC connection".to_string(),
193 },
194 )?;
195
196 Ok(channel)
197 }
198
199 #[cfg(feature = "grpc")]
201 pub fn new_with_grpc_client(
202 application_id: Option<String>,
203 grpc_client: WaitUntilServiceClient<Channel>,
204 ) -> Self {
205 let app_id = application_id.unwrap_or_else(|| Uuid::new_v4().to_string());
206
207 Self {
208 application_id: app_id,
209 tasks: Arc::new(Mutex::new(HashMap::new())),
210 task_counter: AtomicU32::new(0),
211 grpc_client: Some(grpc_client),
212 draining: Arc::new(Mutex::new(false)),
213 }
214 }
215
216 #[cfg(feature = "grpc")]
218 pub fn set_grpc_client(&mut self, client: WaitUntilServiceClient<Channel>) {
219 self.grpc_client = Some(client);
220 }
221
222 pub fn application_id(&self) -> &str {
224 &self.application_id
225 }
226
227 pub async fn start_drain_listener(&self) -> Result<()> {
230 #[cfg(feature = "grpc")]
231 {
232 if let Some(mut client) = self.grpc_client.clone() {
233 let app_id = self.application_id.clone();
234 let context = self.clone_for_background();
235
236 tokio::spawn(async move {
237 loop {
238 debug!(app_id = %app_id, "Waiting for drain signal from runtime");
239
240 let request = WaitForDrainSignalRequest {
241 application_id: app_id.clone(),
242 timeout: Some(prost_types::Duration {
243 seconds: 300, nanos: 0,
245 }),
246 };
247
248 match client.wait_for_drain_signal(request).await {
249 Ok(response) => {
250 let resp = response.into_inner();
251 if resp.should_drain {
252 info!(
253 app_id = %app_id,
254 reason = %resp.drain_reason,
255 "Received drain signal from runtime"
256 );
257
258 let drain_timeout = resp
259 .drain_timeout
260 .map(|d| Duration::from_secs(d.seconds as u64))
261 .unwrap_or(Duration::from_secs(10));
262
263 let config = DrainConfig {
264 timeout: drain_timeout,
265 reason: resp.drain_reason,
266 };
267
268 match context.drain_all(config).await {
270 Ok(drain_response) => {
271 let complete_request = NotifyDrainCompleteRequest {
273 application_id: app_id.clone(),
274 tasks_drained: drain_response.tasks_drained,
275 success: drain_response.success,
276 error_message: drain_response.error_message,
277 };
278
279 if let Err(e) =
280 client.notify_drain_complete(complete_request).await
281 {
282 error!(app_id = %app_id, error = %e, "Failed to notify runtime of drain completion");
283 } else {
284 info!(app_id = %app_id, "Successfully notified runtime of drain completion");
285 }
286 }
287 Err(e) => {
288 error!(app_id = %app_id, error = %e, "Failed to drain tasks");
289 let complete_request = NotifyDrainCompleteRequest {
291 application_id: app_id.clone(),
292 tasks_drained: 0,
293 success: false,
294 error_message: Some(e.to_string()),
295 };
296 let _ = client
297 .notify_drain_complete(complete_request)
298 .await;
299 }
300 }
301 }
302 }
303 Err(e) => {
304 warn!(app_id = %app_id, error = %e, "Failed to wait for drain signal, retrying in 5 seconds");
305 tokio::time::sleep(Duration::from_secs(5)).await;
306 }
307 }
308 }
309 });
310 }
311 }
312
313 Ok(())
314 }
315
316 fn clone_for_background(&self) -> Self {
318 Self {
319 application_id: self.application_id.clone(),
320 tasks: Arc::clone(&self.tasks),
321 task_counter: AtomicU32::new(self.task_counter.load(Ordering::Relaxed)),
322 #[cfg(feature = "grpc")]
323 grpc_client: self.grpc_client.clone(),
324 draining: Arc::clone(&self.draining),
325 }
326 }
327
328 async fn notify_task_registered(&self, task_description: String) -> Result<()> {
330 #[cfg(feature = "grpc")]
331 {
332 if let Some(mut client) = self.grpc_client.clone() {
333 let request = NotifyTaskRegisteredRequest {
334 application_id: self.application_id.clone(),
335 task_description: Some(task_description),
336 };
337
338 client
339 .notify_task_registered(request)
340 .await
341 .into_alien_error()
342 .context(ErrorData::HttpRequestFailed {
343 url: "grpc://wait_until_service".to_string(),
344 method: "notify_task_registered".to_string(),
345 })?;
346 }
347 }
348
349 Ok(())
350 }
351}
352
353impl WaitUntilContext {
354 pub fn wait_until<F, Fut>(&self, task_fn: F) -> Result<()>
357 where
358 F: FnOnce() -> Fut + Send + 'static,
359 Fut: Future<Output = ()> + Send + 'static,
360 {
361 let task_id = self.task_counter.fetch_add(1, Ordering::Relaxed);
362 let task_key = format!("task_{}", task_id);
363 let task_description = format!("wait_until_task_{}", task_id);
364
365 let draining = self.draining.clone();
367 let tasks = self.tasks.clone();
368 let app_id = self.application_id.clone();
369 let task_key_clone = task_key.clone();
370
371 let handle = tokio::spawn(async move {
373 if *draining.lock().await {
375 warn!(app_id = %app_id, task_id = %task_key_clone, "Rejecting new task - currently draining");
376 return;
377 }
378
379 debug!(app_id = %app_id, task_id = %task_key_clone, "Starting wait_until task");
380
381 let future = task_fn();
382 future.await;
383
384 debug!(app_id = %app_id, task_id = %task_key_clone, "Completed wait_until task");
385
386 tasks.lock().await.remove(&task_key_clone);
388 });
389
390 {
392 let mut tasks_guard = futures::executor::block_on(self.tasks.lock());
393 tasks_guard.insert(task_key.clone(), handle);
394 }
395
396 let context_clone = self.clone_for_background();
398 tokio::spawn(async move {
399 if let Err(e) = context_clone.notify_task_registered(task_description).await {
400 warn!(app_id = %context_clone.application_id, task_id = %task_key, error = %e, "Failed to notify runtime of task registration");
401 }
402 });
403
404 Ok(())
405 }
406}
407
408impl Binding for WaitUntilContext {}
409
410#[async_trait]
411impl WaitUntil for WaitUntilContext {
412 async fn wait_for_drain_signal(
413 &self,
414 timeout_duration: Option<Duration>,
415 ) -> Result<DrainConfig> {
416 #[cfg(feature = "grpc")]
417 {
418 if let Some(mut client) = self.grpc_client.clone() {
419 let timeout_proto = timeout_duration.map(|d| prost_types::Duration {
420 seconds: d.as_secs() as i64,
421 nanos: d.subsec_nanos() as i32,
422 });
423
424 let request = WaitForDrainSignalRequest {
425 application_id: self.application_id.clone(),
426 timeout: timeout_proto,
427 };
428
429 let response = client
430 .wait_for_drain_signal(request)
431 .await
432 .into_alien_error()
433 .context(ErrorData::HttpRequestFailed {
434 url: "grpc://wait_until_service".to_string(),
435 method: "wait_for_drain_signal".to_string(),
436 })?;
437
438 let resp = response.into_inner();
439 if resp.should_drain {
440 let drain_timeout = resp
441 .drain_timeout
442 .map(|d| Duration::from_secs(d.seconds as u64))
443 .unwrap_or(Duration::from_secs(10));
444
445 return Ok(DrainConfig {
446 timeout: drain_timeout,
447 reason: resp.drain_reason,
448 });
449 }
450 }
451 }
452
453 Err(AlienError::new(ErrorData::Other {
455 message: "No drain signal received or gRPC client not available".to_string(),
456 }))
457 }
458
459 async fn drain_all(&self, config: DrainConfig) -> Result<DrainResponse> {
460 info!(
461 app_id = %self.application_id,
462 reason = %config.reason,
463 timeout_secs = config.timeout.as_secs(),
464 "Starting to drain all wait_until tasks"
465 );
466
467 {
469 let mut draining_guard = self.draining.lock().await;
470 *draining_guard = true;
471 }
472
473 let tasks_to_drain = {
474 let mut tasks_guard = self.tasks.lock().await;
475 std::mem::take(&mut *tasks_guard) };
477
478 let task_count = tasks_to_drain.len() as u32;
479 info!(app_id = %self.application_id, task_count = task_count, "Draining tasks");
480
481 let mut success = true;
482 let mut error_messages = Vec::new();
483
484 let drain_result = timeout(config.timeout, async {
486 for (task_id, handle) in tasks_to_drain {
487 match handle.await {
488 Ok(_) => {
489 debug!(app_id = %self.application_id, task_id = %task_id, "Task completed successfully");
490 }
491 Err(e) => {
492 warn!(app_id = %self.application_id, task_id = %task_id, error = %e, "Task failed");
493 success = false;
494 error_messages.push(format!("Task {} failed: {}", task_id, e));
495 }
496 }
497 }
498 })
499 .await;
500
501 match drain_result {
502 Ok(_) => {
503 info!(app_id = %self.application_id, "All tasks drained successfully");
504 }
505 Err(_) => {
506 warn!(app_id = %self.application_id, "Drain timeout exceeded");
507 success = false;
508 error_messages.push("Drain timeout exceeded".to_string());
509 }
510 }
511
512 {
514 let mut draining_guard = self.draining.lock().await;
515 *draining_guard = false;
516 }
517
518 let error_message = if error_messages.is_empty() {
519 None
520 } else {
521 Some(error_messages.join("; "))
522 };
523
524 Ok(DrainResponse {
525 tasks_drained: task_count,
526 success,
527 error_message,
528 })
529 }
530
531 async fn get_task_count(&self) -> Result<u32> {
532 let tasks_guard = self.tasks.lock().await;
533 Ok(tasks_guard.len() as u32)
534 }
535
536 async fn notify_drain_complete(&self, response: DrainResponse) -> Result<()> {
537 #[cfg(feature = "grpc")]
538 {
539 if let Some(mut client) = self.grpc_client.clone() {
540 let request = NotifyDrainCompleteRequest {
541 application_id: self.application_id.clone(),
542 tasks_drained: response.tasks_drained,
543 success: response.success,
544 error_message: response.error_message,
545 };
546
547 client
548 .notify_drain_complete(request)
549 .await
550 .into_alien_error()
551 .context(ErrorData::HttpRequestFailed {
552 url: "grpc://wait_until_service".to_string(),
553 method: "notify_drain_complete".to_string(),
554 })?;
555 }
556 }
557
558 Ok(())
559 }
560}