use std::sync::Arc;
use std::time::Duration;
use reqwest::Method;
use serde::Deserialize;
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
use crate::client::Client;
use crate::error::{Error, Result};
use crate::types::{Action, ActionStatus};
#[derive(Clone)]
pub struct Actions {
client: Client,
}
impl Actions {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
pub async fn list(&self) -> Result<Vec<Action>> {
#[derive(Deserialize)]
struct Wrapper {
actions: Vec<Action>,
}
let env_id = &self.client.inner.env_id;
let resp: Wrapper = self
.client
.inner
.http
.request(Method::GET, &format!("/c/{env_id}/actions"), None::<&()>)
.await?;
Ok(resp.actions)
}
pub async fn ack(&self, action_id: &str) -> Result<Action> {
let env_id = &self.client.inner.env_id;
self.client
.inner
.http
.request(
Method::POST,
&format!("/c/{env_id}/actions/{action_id}/ack"),
None::<&()>,
)
.await
}
pub async fn update(
&self,
action_id: &str,
message: impl Into<String>,
data: Option<serde_json::Value>,
) -> Result<Action> {
let env_id = &self.client.inner.env_id;
let mut body = serde_json::json!({ "message": message.into() });
if let Some(d) = data {
body["data"] = d;
}
self.client
.inner
.http
.request(
Method::POST,
&format!("/c/{env_id}/actions/{action_id}/update"),
Some(&body),
)
.await
}
pub async fn complete(
&self,
action_id: &str,
result: Option<serde_json::Value>,
) -> Result<Action> {
self.resolve(action_id, "complete", result).await
}
pub async fn fail(&self, action_id: &str, result: Option<serde_json::Value>) -> Result<Action> {
self.resolve(action_id, "fail", result).await
}
async fn resolve(
&self,
action_id: &str,
suffix: &str,
result: Option<serde_json::Value>,
) -> Result<Action> {
let env_id = &self.client.inner.env_id;
let body = match result {
Some(r) => serde_json::json!({ "result": r }),
None => serde_json::json!({}),
};
self.client
.inner
.http
.request(
Method::POST,
&format!("/c/{env_id}/actions/{action_id}/{suffix}"),
Some(&body),
)
.await
}
pub async fn consume<F, Fut>(&self, handler: F, options: ConsumeOptions) -> Result<()>
where
F: Fn(Action, ActionContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<
Output = std::result::Result<
Option<serde_json::Value>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'static,
{
let handler = Arc::new(handler);
let cancel = options.cancel_token.unwrap_or_default();
let semaphore = Arc::new(Semaphore::new(options.concurrency.max(1)));
let mut empty_polls: u32 = 0;
let mut tasks = Vec::new();
loop {
if cancel.is_cancelled() {
break;
}
let actions = match self.list().await {
Ok(a) => a,
Err(err) => {
if let Some(on_err) = &options.on_error {
on_err(err, None);
}
if sleep_or_cancel(
backoff_delay(
empty_polls,
options.poll_interval,
options.max_poll_interval,
),
&cancel,
)
.await
{
break;
}
continue;
}
};
let pending: Vec<Action> = actions
.into_iter()
.filter(|a| a.status == ActionStatus::Pending)
.collect();
if pending.is_empty() {
empty_polls = empty_polls.saturating_add(1);
if sleep_or_cancel(
backoff_delay(
empty_polls,
options.poll_interval,
options.max_poll_interval,
),
&cancel,
)
.await
{
break;
}
continue;
}
empty_polls = 0;
for action in pending {
if cancel.is_cancelled() {
break;
}
let permit = semaphore.clone().acquire_owned().await.expect("semaphore");
let actions = self.clone();
let handler = handler.clone();
let cancel = cancel.clone();
let on_error = options.on_error.clone();
tasks.push(tokio::spawn(async move {
process_action(actions, action, handler, cancel, on_error).await;
drop(permit);
}));
}
}
for task in tasks {
let _ = task.await;
}
Ok(())
}
}
pub struct ActionContext {
actions: Actions,
action_id: String,
pub cancel_token: CancellationToken,
}
impl ActionContext {
pub async fn update(
&self,
message: impl Into<String>,
data: Option<serde_json::Value>,
) -> Result<Action> {
self.actions.update(&self.action_id, message, data).await
}
}
pub type ErrorCallback = Arc<dyn Fn(Error, Option<Action>) + Send + Sync>;
#[derive(Clone)]
pub struct ConsumeOptions {
pub poll_interval: Duration,
pub max_poll_interval: Duration,
pub concurrency: usize,
pub cancel_token: Option<CancellationToken>,
pub on_error: Option<ErrorCallback>,
}
impl Default for ConsumeOptions {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(15),
max_poll_interval: Duration::from_secs(60),
concurrency: 1,
cancel_token: None,
on_error: None,
}
}
}
async fn process_action<F, Fut>(
actions: Actions,
action: Action,
handler: Arc<F>,
cancel: CancellationToken,
on_error: Option<ErrorCallback>,
) where
F: Fn(Action, ActionContext) -> Fut + Send + Sync,
Fut: std::future::Future<
Output = std::result::Result<
Option<serde_json::Value>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send,
{
if let Err(err) = actions.ack(&action.id).await {
if err.is_conflict() {
return;
}
if let Some(cb) = &on_error {
cb(err, Some(action));
}
return;
}
let ctx = ActionContext {
actions: actions.clone(),
action_id: action.id.clone(),
cancel_token: cancel.clone(),
};
let result = handler(action.clone(), ctx).await;
let _ = cancel;
match result {
Ok(value) => {
if let Err(err) = actions.complete(&action.id, value).await {
if let Some(cb) = &on_error {
cb(err, Some(action));
}
}
}
Err(handler_err) => {
let message = handler_err.to_string();
if let Some(cb) = &on_error {
cb(
Error::Api {
status: 0,
message: message.clone(),
body: None,
},
Some(action.clone()),
);
}
if let Err(fail_err) = actions
.fail(&action.id, Some(serde_json::json!({ "error": message })))
.await
{
if let Some(cb) = &on_error {
cb(fail_err, Some(action));
}
}
}
}
}
async fn sleep_or_cancel(d: Duration, cancel: &CancellationToken) -> bool {
tokio::select! {
_ = tokio::time::sleep(d) => false,
_ = cancel.cancelled() => true,
}
}
pub(crate) fn backoff_delay(empty_polls: u32, base: Duration, max: Duration) -> Duration {
if empty_polls <= 3 {
return base;
}
let factor = 1u64 << (empty_polls - 3).min(20); let candidate = base.saturating_mul(factor.try_into().unwrap_or(u32::MAX));
candidate.min(max)
}