use crate::batch;
use crate::client::{BatchOperation, CosmosDBClient};
use crate::containers;
use crate::errors;
use crate::leases::{InMemoryLeaseProvider, LeaseProvider};
use crate::models::*;
use crate::outbox;
use crate::outbox::OutboxFaultInjector;
use crate::query;
use duroxide::providers::*;
use duroxide::{Event, EventKind, TagFilter};
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
const MAX_LOCK_RETRIES: usize = 20;
#[derive(Debug, Clone)]
pub struct CosmosDBProviderConfig {
pub endpoint: String,
pub key: String,
pub database: String,
pub container: String,
pub orch_concurrency: u32,
pub worker_concurrency: u32,
pub reconciler_interval: Duration,
pub reconciler_age_threshold: Duration,
}
impl Default for CosmosDBProviderConfig {
fn default() -> Self {
Self {
endpoint: String::new(),
key: String::new(),
database: "duroxide".to_string(),
container: "duroxide".to_string(),
orch_concurrency: 1,
worker_concurrency: 1,
reconciler_interval: Duration::from_secs(2),
reconciler_age_threshold: Duration::from_secs(2),
}
}
}
#[derive(Clone)]
pub struct CosmosDBProvider {
inner: Arc<CosmosDBProviderInner>,
}
struct CosmosDBProviderInner {
client: CosmosDBClient,
orch_leases: Box<dyn LeaseProvider>,
worker_leases: Box<dyn LeaseProvider>,
cancel: CancellationToken,
_reconciler_handle: Option<tokio::task::JoinHandle<()>>,
outbox_fault_injector: Option<OutboxFaultInjector>,
}
impl CosmosDBProvider {
pub async fn new(endpoint: &str, key: &str, database: &str) -> Result<Self, ProviderError> {
let config = CosmosDBProviderConfig {
endpoint: endpoint.to_string(),
key: key.to_string(),
database: database.to_string(),
..Default::default()
};
Self::new_with_config(config).await
}
pub async fn new_with_config(config: CosmosDBProviderConfig) -> Result<Self, ProviderError> {
let client = CosmosDBClient::new(
&config.endpoint,
&config.key,
&config.database,
&config.container,
)
.map_err(|e| ProviderError::permanent("new", e))?;
containers::ensure_infrastructure(&client).await?;
let cancel = CancellationToken::new();
let reconciler_handle = outbox::start_reconciler(
client.clone(),
config.reconciler_interval,
config.reconciler_age_threshold,
cancel.clone(),
);
let orch_leases = Box::new(InMemoryLeaseProvider::new(config.orch_concurrency));
let worker_leases = Box::new(InMemoryLeaseProvider::new(config.worker_concurrency));
Ok(Self {
inner: Arc::new(CosmosDBProviderInner {
client,
orch_leases,
worker_leases,
cancel,
_reconciler_handle: Some(reconciler_handle),
outbox_fault_injector: None,
}),
})
}
pub async fn new_with_container(
endpoint: &str,
key: &str,
database: &str,
container: &str,
) -> Result<Self, ProviderError> {
let config = CosmosDBProviderConfig {
endpoint: endpoint.to_string(),
key: key.to_string(),
database: database.to_string(),
container: container.to_string(),
..Default::default()
};
Self::new_with_config(config).await
}
fn client(&self) -> &CosmosDBClient {
&self.inner.client
}
pub fn set_outbox_fault_injector(&mut self, injector: OutboxFaultInjector) {
let inner = Arc::get_mut(&mut self.inner)
.expect("Cannot set fault injector after provider has been cloned");
inner.outbox_fault_injector = Some(injector);
}
pub async fn cleanup(&self) -> Result<(), ProviderError> {
self.inner.cancel.cancel();
self.client().delete_container().await
}
async fn read_instance(
&self,
instance_id: &str,
) -> Result<Option<InstanceDocument>, ProviderError> {
let doc_id = InstanceDocument::doc_id(instance_id);
let resp = self.client().read_document(&doc_id, instance_id).await?;
if errors::is_not_found(resp.status) {
return Ok(None);
}
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"read_instance",
resp.status,
&resp.body,
));
}
let mut inst: InstanceDocument = serde_json::from_str(&resp.body).map_err(|e| {
ProviderError::permanent("read_instance", format!("Deserialize error: {e}"))
})?;
inst.etag = resp.etag;
Ok(Some(inst))
}
async fn try_lock_instance(
&self,
instance_id: &str,
lock_timeout: Duration,
now: u64,
work_item_json: Option<&str>,
) -> Result<Option<(InstanceDocument, String)>, ProviderError> {
let inst = match self.read_instance(instance_id).await? {
Some(i) => i,
None => {
let (orch_name, orch_version) = if let Some(json) = work_item_json {
match serde_json::from_str::<WorkItem>(json) {
Ok(WorkItem::StartOrchestration {
orchestration,
version,
..
}) => (orchestration, version.unwrap_or_default()),
_ => (String::new(), String::new()),
}
} else {
(String::new(), String::new())
};
let lock_token = uuid::Uuid::new_v4().to_string();
let locked_until = now + lock_timeout.as_millis() as u64;
let mut new_inst =
InstanceDocument::new(instance_id, &orch_name, &orch_version, 1, None, now);
new_inst.lock_token = Some(lock_token.clone());
new_inst.locked_until = Some(locked_until);
let doc_json = serde_json::to_value(&new_inst).map_err(|e| {
ProviderError::permanent("try_lock_instance", format!("Serialize error: {e}"))
})?;
let resp = self
.client()
.create_document(instance_id, &doc_json)
.await?;
if resp.is_success() {
new_inst.etag = resp.etag;
return Ok(Some((new_inst, lock_token)));
} else if errors::is_conflict(resp.status) {
match self.read_instance(instance_id).await? {
Some(i) => i,
None => return Ok(None),
}
} else {
return Err(errors::map_cosmosdb_error(
"try_lock_instance",
resp.status,
&resp.body,
));
}
}
};
if let Some(locked_until) = inst.locked_until {
if locked_until > now {
return Ok(None); }
}
let etag = inst.etag.clone().unwrap_or_default();
let lock_token = uuid::Uuid::new_v4().to_string();
let locked_until = now + lock_timeout.as_millis() as u64;
let mut updated = inst.clone();
updated.lock_token = Some(lock_token.clone());
updated.locked_until = Some(locked_until);
updated.updated_at = now;
let doc_json = serde_json::to_value(&updated).map_err(|e| {
ProviderError::permanent("try_lock_instance", format!("Serialize error: {e}"))
})?;
let resp = self
.client()
.replace_document(&updated.id, instance_id, &doc_json, Some(&etag))
.await?;
if resp.is_success() {
Ok(Some((updated, lock_token)))
} else if errors::is_precondition_failed(resp.status) || errors::is_conflict(resp.status) {
Ok(None) } else {
Err(errors::map_cosmosdb_error(
"try_lock_instance",
resp.status,
&resp.body,
))
}
}
async fn unlock_instance(&self, instance_id: &str) -> Result<(), ProviderError> {
let Some(inst) = self.read_instance(instance_id).await? else {
return Ok(());
};
let etag = inst.etag.clone().unwrap_or_default();
let mut updated = inst;
updated.lock_token = None;
updated.locked_until = None;
let doc_json = serde_json::to_value(&updated).map_err(|e| {
ProviderError::permanent("unlock_instance", format!("Serialize error: {e}"))
})?;
let resp = self
.client()
.replace_document(&updated.id, instance_id, &doc_json, Some(&etag))
.await?;
if resp.is_success() || errors::is_precondition_failed(resp.status) {
Ok(())
} else {
Err(errors::map_cosmosdb_error(
"unlock_instance",
resp.status,
&resp.body,
))
}
}
}
impl Drop for CosmosDBProviderInner {
fn drop(&mut self) {
self.cancel.cancel();
}
}
#[async_trait::async_trait]
impl Provider for CosmosDBProvider {
fn name(&self) -> &str {
"duroxide-cdb"
}
fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
async fn fetch_orchestration_item(
&self,
lock_timeout: Duration,
_poll_timeout: Duration,
filter: Option<&DispatcherCapabilityFilter>,
) -> Result<Option<(OrchestrationItem, String, u32)>, ProviderError> {
let caller_id = task_id_u64();
let my_slots = self.inner.orch_leases.acquire_slots(caller_id).await;
let now = now_ms();
let (min_packed, max_packed) = if let Some(f) = filter {
if f.supported_duroxide_versions.is_empty() {
return Ok(None);
}
let min = f
.supported_duroxide_versions
.iter()
.map(|r| pack_semver(&r.min))
.min()
.unwrap();
let max = f
.supported_duroxide_versions
.iter()
.map(|r| pack_semver(&r.max))
.max()
.unwrap();
(Some(min), Some(max))
} else {
(None, None)
};
let mut excluded = Vec::new();
for _attempt in 0..MAX_LOCK_RETRIES {
let candidate = query::find_candidate_orch_item(
self.client(),
now,
&my_slots,
None, None,
&excluded,
)
.await?;
let Some(candidate) = candidate else {
return Ok(None);
};
let instance_id = &candidate.instance_id;
let lock_result = self
.try_lock_instance(instance_id, lock_timeout, now, Some(&candidate.work_item))
.await?;
let Some((locked_instance, lock_token)) = lock_result else {
excluded.push(instance_id.to_string());
continue;
};
if let (Some(min_v), Some(max_v)) = (min_packed, max_packed) {
if let Some(pinned) = locked_instance.pinned_duroxide_version_packed {
if pinned < min_v || pinned > max_v {
self.unlock_instance(instance_id).await?;
excluded.push(instance_id.to_string());
continue;
}
}
}
let messages = query::collect_orch_messages(self.client(), instance_id, now).await?;
if messages.is_empty() {
self.unlock_instance(instance_id).await?;
excluded.push(instance_id.to_string());
continue;
}
let max_attempt = messages.iter().map(|m| m.attempt_count).max().unwrap_or(0);
for msg in &messages {
let mut updated_msg = msg.clone();
updated_msg.lock_token = Some(lock_token.clone());
updated_msg.locked_until = Some(now + lock_timeout.as_millis() as u64);
updated_msg.attempt_count += 1;
let doc_json = serde_json::to_value(&updated_msg).map_err(|e| {
ProviderError::permanent(
"fetch_orchestration_item",
format!("Serialize error: {e}"),
)
})?;
let _ = self
.client()
.replace_document(&msg.id, instance_id, &doc_json, msg.etag.as_deref())
.await?;
}
let attempt_count = (max_attempt + 1) as u32;
let work_items: Vec<WorkItem> = messages
.iter()
.map(|m| {
serde_json::from_str(&m.work_item).map_err(|e| {
ProviderError::permanent(
"fetch_orchestration_item",
format!("Failed to deserialize work item: {e}"),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
let execution_id = locked_instance.current_execution_id;
let history_docs =
query::fetch_history(self.client(), instance_id, execution_id).await?;
let (history, history_error) = {
let mut events = Vec::new();
let mut error = None;
for doc in &history_docs {
match serde_json::from_str::<Event>(&doc.event_data) {
Ok(event) => events.push(event),
Err(e) => {
error = Some(format!(
"Failed to deserialize history event {}: {e}",
doc.event_id
));
events.clear();
break;
}
}
}
(events, error)
};
let kv_docs = query::query_by_type_in_partition(
self.client(),
instance_id,
DOC_TYPE_KV,
)
.await?;
let kv_snapshot: std::collections::HashMap<String, duroxide::providers::KvEntry> =
kv_docs
.iter()
.filter_map(|doc| {
let key = doc.get("key")?.as_str()?.to_string();
let value = doc.get("value")?.as_str()?.to_string();
let last_updated_at_ms = doc
.get("lastUpdatedAtMs")
.and_then(|v| v.as_u64())
.unwrap_or(0);
Some((
key,
duroxide::providers::KvEntry {
value,
last_updated_at_ms,
},
))
})
.collect();
if locked_instance.orchestration_name.is_empty()
&& history.is_empty()
&& work_items
.iter()
.all(|m| matches!(m, WorkItem::QueueMessage { .. }))
{
let message_count = work_items.len();
tracing::warn!(
target = "duroxide::providers::cosmosdb",
instance = %instance_id,
message_count,
"Dropping orphan queue messages — events enqueued before orchestration started are not supported"
);
self.ack_orchestration_item(
&lock_token,
execution_id,
vec![],
vec![],
vec![],
ExecutionMetadata::default(),
vec![],
)
.await?;
return Ok(None);
}
let item = OrchestrationItem {
instance: instance_id.to_string(),
orchestration_name: locked_instance.orchestration_name.clone(),
execution_id,
version: locked_instance.orchestration_version.clone(),
history,
messages: work_items,
history_error,
kv_snapshot,
};
return Ok(Some((item, lock_token, attempt_count)));
}
Ok(None)
}
async fn ack_orchestration_item(
&self,
lock_token: &str,
execution_id: u64,
history_delta: Vec<Event>,
worker_items: Vec<WorkItem>,
orchestrator_items: Vec<WorkItem>,
metadata: ExecutionMetadata,
cancelled_activities: Vec<ScheduledActivityIdentifier>,
) -> Result<(), ProviderError> {
let instance = query::find_instance_by_lock_token(self.client(), lock_token)
.await?
.ok_or_else(|| {
ProviderError::permanent(
"ack_orchestration_item",
format!("Invalid lock token or lock expired: {lock_token}"),
)
})?;
let instance_id = &instance.instance_id;
let now = now_ms();
if let Some(locked_until) = instance.locked_until {
if locked_until <= now {
return Err(ProviderError::permanent(
"ack_orchestration_item",
format!("Lock has expired for instance {instance_id}"),
));
}
}
let mut same_partition_worker = Vec::new();
let mut same_partition_orch = Vec::new();
let mut cross_partition_intents = Vec::new();
let cancelled_set: std::collections::HashSet<(String, u64, u64)> = cancelled_activities
.iter()
.map(|c| (c.instance.clone(), c.execution_id, c.activity_id))
.collect();
for (seq, item) in worker_items.iter().enumerate() {
let is_cancelled = match item {
WorkItem::ActivityExecute {
instance,
execution_id,
id,
..
} => cancelled_set.contains(&(instance.clone(), *execution_id, *id)),
_ => false,
};
if is_cancelled {
continue; }
let target_instance = work_item_instance(item);
let item_json = serde_json::to_string(item).map_err(|e| {
ProviderError::permanent("ack_orchestration_item", format!("Serialize error: {e}"))
})?;
if target_instance == instance_id {
let (exec_id, activity_id, session_id, tag) = match item {
WorkItem::ActivityExecute {
execution_id,
id,
session_id,
tag,
..
} => (Some(*execution_id), Some(*id), session_id.clone(), tag.clone()),
_ => (None, None, None, None),
};
let doc = QueueItemDocument::new_worker_queue(
instance_id,
item_json,
exec_id,
activity_id,
session_id,
tag,
now,
);
same_partition_worker.push(serde_json::to_value(&doc).unwrap());
} else {
let (exec_id, activity_id, session_id, tag) = match item {
WorkItem::ActivityExecute {
execution_id,
id,
session_id,
tag,
..
} => (Some(*execution_id), Some(*id), session_id.clone(), tag.clone()),
_ => (None, None, None, None),
};
let target_doc = QueueItemDocument::new_worker_queue(
target_instance,
item_json,
exec_id,
activity_id,
session_id,
tag,
now,
);
let target_json = serde_json::to_string(&target_doc).unwrap();
let idem_key = idempotency_key(instance_id, execution_id, seq as u64);
let intent = OutboxIntentDocument::new(
instance_id,
target_instance,
DOC_TYPE_WORKER_QUEUE,
target_json,
idem_key,
now,
);
cross_partition_intents.push(intent);
}
}
for (seq, item) in orchestrator_items.iter().enumerate() {
let target_instance = work_item_instance(item);
let item_json = serde_json::to_string(item).map_err(|e| {
ProviderError::permanent("ack_orchestration_item", format!("Serialize error: {e}"))
})?;
let delay = match item {
WorkItem::TimerFired { fire_at_ms, .. } => {
let fire_at = *fire_at_ms;
if fire_at > now {
fire_at
} else {
now
}
}
_ => now,
};
if target_instance == instance_id {
let doc = QueueItemDocument::new_orch_queue(instance_id, item_json, delay, now);
same_partition_orch.push(serde_json::to_value(&doc).unwrap());
} else {
let target_doc =
QueueItemDocument::new_orch_queue(target_instance, item_json, delay, now);
let target_json = serde_json::to_string(&target_doc).unwrap();
let idem_key =
idempotency_key(instance_id, execution_id, (worker_items.len() + seq) as u64);
let intent = OutboxIntentDocument::new(
instance_id,
target_instance,
DOC_TYPE_ORCH_QUEUE,
target_json,
idem_key,
now,
);
cross_partition_intents.push(intent);
}
}
let locked_messages =
query::find_items_by_lock_token(self.client(), lock_token, DOC_TYPE_ORCH_QUEUE).await?;
let messages_to_delete: Vec<String> =
locked_messages.iter().map(|m| m.id.clone()).collect();
let mut cancelled_doc_ids = Vec::new();
for cancelled in &cancelled_activities {
let sql = format!(
"SELECT c.id FROM c WHERE c.instanceId = @instanceId AND c.type = '{}' \
AND c.executionId = @execId AND c.activityId = @activityId",
DOC_TYPE_WORKER_QUEUE
);
let params = vec![
crate::client::QueryParameter::new(
"@instanceId",
serde_json::json!(&cancelled.instance),
),
crate::client::QueryParameter::new(
"@execId",
serde_json::json!(cancelled.execution_id),
),
crate::client::QueryParameter::new(
"@activityId",
serde_json::json!(cancelled.activity_id),
),
];
let results = self
.client()
.query(&sql, params, Some(&cancelled.instance))
.await?;
for doc in results {
if let Some(id) = doc.get("id").and_then(|v| v.as_str()) {
cancelled_doc_ids.push(id.to_string());
}
}
}
let history_entries: Vec<(u64, String)> = history_delta
.iter()
.map(|event| {
let event_id = event.event_id;
let event_json = serde_json::to_string(event).unwrap();
(event_id, event_json)
})
.collect();
let mut updated_instance = instance.clone();
updated_instance.current_execution_id = execution_id;
updated_instance.updated_at = now;
updated_instance.lock_token = None;
updated_instance.locked_until = None;
if let Some(status) = &metadata.status {
updated_instance.status = status.clone();
}
if let Some(output) = &metadata.output {
updated_instance.output = Some(output.clone());
}
if let Some(name) = &metadata.orchestration_name {
updated_instance.orchestration_name = name.clone();
}
if let Some(version) = &metadata.orchestration_version {
updated_instance.orchestration_version = version.clone();
}
if let Some(parent) = &metadata.parent_instance_id {
updated_instance.parent_instance_id = Some(parent.clone());
}
if let Some(pinned) = &metadata.pinned_duroxide_version {
updated_instance.pinned_duroxide_version_packed = Some(pack_semver(pinned));
}
let custom_status_from_delta = history_delta.iter().rev().find_map(|e| match &e.kind {
EventKind::CustomStatusUpdated { status } => Some(status.clone()),
_ => None,
});
match custom_status_from_delta {
Some(Some(custom_status)) => {
updated_instance.custom_status = Some(custom_status);
updated_instance.custom_status_version += 1;
}
Some(None) => {
updated_instance.custom_status = None;
updated_instance.custom_status_version += 1;
}
None => {
}
}
let instance_json = serde_json::to_value(&updated_instance).map_err(|e| {
ProviderError::permanent("ack_orchestration_item", format!("Serialize error: {e}"))
})?;
let outbox_json: Vec<serde_json::Value> = cross_partition_intents
.iter()
.map(|intent| serde_json::to_value(intent).unwrap())
.collect();
let mut kv_ops: Vec<BatchOperation> = Vec::new();
for event in &history_delta {
match &event.kind {
EventKind::KeyValueSet {
key,
value,
last_updated_at_ms,
} => {
let kv_doc = KeyValueDocument::new(
instance_id,
key,
value,
execution_id,
*last_updated_at_ms,
);
let json = serde_json::to_value(&kv_doc).unwrap();
kv_ops.push(BatchOperation::Upsert { body: json });
}
EventKind::KeyValueCleared { key } => {
let doc_id = KeyValueDocument::doc_id(instance_id, key);
let resp = self.client().read_document(&doc_id, instance_id).await?;
if resp.is_success() {
kv_ops.push(BatchOperation::Delete { id: doc_id });
}
}
EventKind::KeyValuesCleared => {
let existing_kv = query::query_by_type_in_partition(
self.client(),
instance_id,
DOC_TYPE_KV,
)
.await?;
for doc in &existing_kv {
if let Some(doc_id) = doc.get("id").and_then(|v| v.as_str()) {
kv_ops.push(BatchOperation::Delete {
id: doc_id.to_string(),
});
}
}
}
_ => {}
}
}
let ops = batch::build_ack_batch(
instance_id,
execution_id,
lock_token,
&messages_to_delete,
&history_entries,
same_partition_worker,
same_partition_orch,
outbox_json,
kv_ops,
&[], instance_json,
);
batch::execute_batch(self.client(), instance_id, ops).await?;
for doc_id in &cancelled_doc_ids {
let _ = self.client().delete_document(doc_id, instance_id).await;
}
outbox::deliver_intents_best_effort(
self.client(),
&cross_partition_intents,
self.inner.outbox_fault_injector.as_ref(),
)
.await;
Ok(())
}
async fn abandon_orchestration_item(
&self,
lock_token: &str,
delay: Option<Duration>,
ignore_attempt: bool,
) -> Result<(), ProviderError> {
let now = now_ms();
let instance = query::find_instance_by_lock_token(self.client(), lock_token).await?;
let inst = instance.ok_or_else(|| {
ProviderError::permanent("abandon_orchestration_item", "Invalid lock token")
})?;
let mut updated = inst.clone();
updated.lock_token = None;
updated.locked_until = None;
updated.updated_at = now;
let doc_json = serde_json::to_value(&updated).unwrap();
let _ = self
.client()
.replace_document(
&updated.id,
&inst.instance_id,
&doc_json,
inst.etag.as_deref(),
)
.await;
let messages =
query::find_items_by_lock_token(self.client(), lock_token, DOC_TYPE_ORCH_QUEUE).await?;
for msg in &messages {
let mut updated = msg.clone();
updated.lock_token = None;
updated.locked_until = None;
if let Some(d) = delay {
updated.visible_at = now + d.as_millis() as u64;
}
if ignore_attempt && updated.attempt_count > 0 {
updated.attempt_count -= 1;
}
let doc_json = serde_json::to_value(&updated).unwrap();
let _ = self
.client()
.replace_document(&msg.id, &msg.instance_id, &doc_json, msg.etag.as_deref())
.await;
}
Ok(())
}
async fn read(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
let inst = self.read_instance(instance).await?;
let execution_id = inst.map(|i| i.current_execution_id).unwrap_or(1);
self.read_with_execution(instance, execution_id).await
}
async fn read_with_execution(
&self,
instance: &str,
execution_id: u64,
) -> Result<Vec<Event>, ProviderError> {
let docs = query::fetch_history(self.client(), instance, execution_id).await?;
let mut events = Vec::new();
for doc in docs {
let event: Event = serde_json::from_str(&doc.event_data).map_err(|e| {
ProviderError::permanent(
"read_with_execution",
format!("Failed to deserialize event: {e}"),
)
})?;
events.push(event);
}
Ok(events)
}
async fn append_with_execution(
&self,
instance: &str,
execution_id: u64,
new_events: Vec<Event>,
) -> Result<(), ProviderError> {
let existing = query::fetch_history(self.client(), instance, execution_id).await?;
let next_id = existing
.iter()
.map(|h| h.event_id)
.max()
.map(|m| m + 1)
.unwrap_or(0);
for (i, event) in new_events.iter().enumerate() {
let event_id = next_id + i as u64;
let event_json = serde_json::to_string(event).map_err(|e| {
ProviderError::permanent("append_with_execution", format!("Serialize error: {e}"))
})?;
let doc = HistoryDocument::new(instance, execution_id, event_id, event_json);
let doc_json = serde_json::to_value(&doc).unwrap();
let resp = self.client().create_document(instance, &doc_json).await?;
if !resp.is_success() && !errors::is_conflict(resp.status) {
return Err(errors::map_cosmosdb_error(
"append_with_execution",
resp.status,
&resp.body,
));
}
}
Ok(())
}
async fn enqueue_for_worker(&self, item: WorkItem) -> Result<(), ProviderError> {
let instance_id = work_item_instance(&item).to_string();
let item_json = serde_json::to_string(&item).map_err(|e| {
ProviderError::permanent("enqueue_for_worker", format!("Serialize error: {e}"))
})?;
let (exec_id, activity_id, session_id, tag) = match &item {
WorkItem::ActivityExecute {
execution_id,
id,
session_id,
tag,
..
} => (Some(*execution_id), Some(*id), session_id.clone(), tag.clone()),
_ => (None, None, None, None),
};
let now = now_ms();
let doc = QueueItemDocument::new_worker_queue(
&instance_id,
item_json,
exec_id,
activity_id,
session_id,
tag,
now,
);
let doc_json = serde_json::to_value(&doc).unwrap();
let resp = self
.client()
.create_document(&instance_id, &doc_json)
.await?;
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"enqueue_for_worker",
resp.status,
&resp.body,
));
}
Ok(())
}
async fn fetch_work_item(
&self,
lock_timeout: Duration,
_poll_timeout: Duration,
session: Option<&SessionFetchConfig>,
tag_filter: &TagFilter,
) -> Result<Option<(WorkItem, String, u32)>, ProviderError> {
if matches!(tag_filter, TagFilter::None) {
return Ok(None);
}
let caller_id = task_id_u64();
let my_slots = self.inner.worker_leases.acquire_slots(caller_id).await;
let now = now_ms();
let mut excluded = Vec::new();
for _attempt in 0..10 {
let candidate = query::find_candidate_work_item(
self.client(),
now,
&my_slots,
session.map(|s| s.owner_id.as_str()),
&excluded,
tag_filter,
)
.await?;
let Some(candidate) = candidate else {
return Ok(None);
};
if let Some(ref sid) = candidate.session_id {
if let Some(config) = session {
let session_doc_id = SessionDocument::doc_id(&candidate.instance_id, sid);
if let Ok(resp) = self
.client()
.read_document(&session_doc_id, &candidate.instance_id)
.await
{
if resp.is_success() {
if let Ok(session_doc) =
serde_json::from_str::<SessionDocument>(&resp.body)
{
if session_doc.locked_until > now
&& session_doc.owner_id != config.owner_id
{
excluded.push(candidate.id.clone());
continue;
}
}
}
}
} else {
excluded.push(candidate.id.clone());
continue;
}
}
let etag = candidate.etag.clone().unwrap_or_default();
let lock_token = uuid::Uuid::new_v4().to_string();
let mut updated = candidate.clone();
updated.lock_token = Some(lock_token.clone());
updated.locked_until = Some(now + lock_timeout.as_millis() as u64);
updated.attempt_count += 1;
let doc_json = serde_json::to_value(&updated).map_err(|e| {
ProviderError::permanent("fetch_work_item", format!("Serialize error: {e}"))
})?;
let resp = self
.client()
.replace_document(
&candidate.id,
&candidate.instance_id,
&doc_json,
Some(&etag),
)
.await?;
if resp.is_success() {
if let (Some(ref sid), Some(config)) = (&candidate.session_id, session) {
let session_now = now_ms();
let session_locked_until = session_now + config.lock_timeout.as_millis() as u64;
let session_doc_id = SessionDocument::doc_id(&candidate.instance_id, sid);
let existing_session = self
.client()
.read_document(&session_doc_id, &candidate.instance_id)
.await;
let session_claimed = match existing_session {
Ok(resp) if resp.is_success() => {
match serde_json::from_str::<SessionDocument>(&resp.body) {
Ok(mut existing) => {
if existing.locked_until <= session_now
|| existing.owner_id == config.owner_id
{
existing.owner_id = config.owner_id.clone();
existing.locked_until = session_locked_until;
existing.last_activity = session_now;
let session_json = serde_json::to_value(&existing).unwrap();
let update_resp = self
.client()
.replace_document(
&session_doc_id,
&candidate.instance_id,
&session_json,
resp.etag.as_deref(),
)
.await;
update_resp.map(|r| r.is_success()).unwrap_or(false)
} else {
false
}
}
Err(_) => false,
}
}
_ => {
let new_session = SessionDocument {
id: session_doc_id.clone(),
instance_id: candidate.instance_id.clone(),
doc_type: DOC_TYPE_SESSION.to_string(),
session_id: sid.clone(),
owner_id: config.owner_id.clone(),
locked_until: session_locked_until,
last_activity: session_now,
created_at: session_now,
etag: None,
rid: None,
self_link: None,
ts: None,
attachments: None,
};
let session_json = serde_json::to_value(&new_session).unwrap();
let create_resp = self
.client()
.create_document(&candidate.instance_id, &session_json)
.await;
match create_resp {
Ok(r) => r.is_success(),
Err(_) => false,
}
}
};
if !session_claimed {
let mut rollback = updated.clone();
rollback.lock_token = None;
rollback.locked_until = None;
rollback.attempt_count -= 1;
let rollback_json = serde_json::to_value(&rollback).unwrap();
let _ = self
.client()
.replace_document(
&candidate.id,
&candidate.instance_id,
&rollback_json,
None, )
.await;
excluded.push(candidate.id.clone());
continue;
}
}
let work_item: WorkItem =
serde_json::from_str(&candidate.work_item).map_err(|e| {
ProviderError::permanent(
"fetch_work_item",
format!("Failed to deserialize work item: {e}"),
)
})?;
return Ok(Some((work_item, lock_token, updated.attempt_count as u32)));
} else if errors::is_precondition_failed(resp.status)
|| errors::is_conflict(resp.status)
{
excluded.push(candidate.id.clone());
continue;
} else {
return Err(errors::map_cosmosdb_error(
"fetch_work_item",
resp.status,
&resp.body,
));
}
}
Ok(None)
}
async fn ack_work_item(
&self,
token: &str,
completion: Option<WorkItem>,
) -> Result<(), ProviderError> {
let now = now_ms();
let items =
query::find_items_by_lock_token(self.client(), token, DOC_TYPE_WORKER_QUEUE).await?;
let item = items.first().ok_or_else(|| {
ProviderError::permanent(
"ack_work_item",
"Activity was cancelled or lock expired (worker queue row not found or lock invalid)",
)
})?;
if let Some(locked_until) = item.locked_until {
if locked_until <= now {
return Err(ProviderError::permanent(
"ack_work_item",
"Activity was cancelled or lock expired (worker queue row not found or lock invalid)",
));
}
}
let instance_id = &item.instance_id;
let session_id = item.session_id.clone();
if let Some(completion_item) = completion {
let target_instance = work_item_instance(&completion_item).to_string();
let item_json = serde_json::to_string(&completion_item).map_err(|e| {
ProviderError::permanent("ack_work_item", format!("Serialize error: {e}"))
})?;
if target_instance == *instance_id {
let orch_doc =
QueueItemDocument::new_orch_queue(&target_instance, item_json, now, now);
let orch_json = serde_json::to_value(&orch_doc).unwrap();
let ops = vec![
BatchOperation::Delete {
id: item.id.clone(),
},
BatchOperation::Create { body: orch_json },
];
batch::execute_batch(self.client(), instance_id, ops).await?;
} else {
let _ = self.client().delete_document(&item.id, instance_id).await?;
let orch_doc =
QueueItemDocument::new_orch_queue(&target_instance, item_json, now, now);
let orch_json = serde_json::to_value(&orch_doc).unwrap();
let resp = self
.client()
.create_document(&target_instance, &orch_json)
.await?;
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"ack_work_item",
resp.status,
&resp.body,
));
}
}
} else {
let _ = self.client().delete_document(&item.id, instance_id).await?;
}
if let Some(ref sid) = session_id {
let session_doc_id = SessionDocument::doc_id(instance_id, sid);
let piggyback_now = now_ms(); if let Ok(resp) = self
.client()
.read_document(&session_doc_id, instance_id)
.await
{
if resp.is_success() {
if let Ok(mut session_doc) = serde_json::from_str::<SessionDocument>(&resp.body)
{
if session_doc.locked_until > piggyback_now {
session_doc.last_activity = piggyback_now;
session_doc.etag = resp.etag;
let doc_json = serde_json::to_value(&session_doc).unwrap();
let _ = self
.client()
.replace_document(
&session_doc_id,
instance_id,
&doc_json,
session_doc.etag.as_deref(),
)
.await;
}
}
}
}
}
Ok(())
}
async fn renew_work_item_lock(
&self,
token: &str,
extend_for: Duration,
) -> Result<(), ProviderError> {
let now = now_ms();
let items =
query::find_items_by_lock_token(self.client(), token, DOC_TYPE_WORKER_QUEUE).await?;
let item = items.first().ok_or_else(|| {
ProviderError::permanent(
"renew_work_item_lock",
format!("No worker item found with lock token {token}"),
)
})?;
if let Some(locked_until) = item.locked_until {
if locked_until <= now {
return Err(ProviderError::permanent(
"renew_work_item_lock",
"Lock has expired".to_string(),
));
}
}
let mut updated = item.clone();
updated.locked_until = Some(now + extend_for.as_millis() as u64);
let doc_json = serde_json::to_value(&updated).unwrap();
let resp = self
.client()
.replace_document(&item.id, &item.instance_id, &doc_json, item.etag.as_deref())
.await?;
if resp.is_success() {
if let Some(ref sid) = item.session_id {
let piggyback_now = now_ms(); let session_doc_id = SessionDocument::doc_id(&item.instance_id, sid);
if let Ok(sess_resp) = self
.client()
.read_document(&session_doc_id, &item.instance_id)
.await
{
if sess_resp.is_success() {
if let Ok(mut session_doc) =
serde_json::from_str::<SessionDocument>(&sess_resp.body)
{
if session_doc.locked_until > piggyback_now {
session_doc.last_activity = piggyback_now;
let sess_json = serde_json::to_value(&session_doc).unwrap();
let _ = self
.client()
.replace_document(
&session_doc_id,
&item.instance_id,
&sess_json,
sess_resp.etag.as_deref(),
)
.await;
}
}
}
}
}
Ok(())
} else {
Err(errors::map_cosmosdb_error(
"renew_work_item_lock",
resp.status,
&resp.body,
))
}
}
async fn renew_session_lock(
&self,
owner_ids: &[&str],
extend_for: Duration,
idle_timeout: Duration,
) -> Result<usize, ProviderError> {
if owner_ids.is_empty() {
return Ok(0);
}
let now = now_ms();
let locked_until = now + extend_for.as_millis() as u64;
let idle_cutoff = now.saturating_sub(idle_timeout.as_millis() as u64);
let sql = format!("SELECT * FROM c WHERE c.type = '{}'", DOC_TYPE_SESSION);
let results = self.client().query(&sql, vec![], None).await?;
let mut count = 0usize;
for doc in results {
if let Ok(session) = serde_json::from_value::<SessionDocument>(doc) {
if owner_ids.contains(&session.owner_id.as_str())
&& session.locked_until > now
&& session.last_activity > idle_cutoff
{
let mut updated = session.clone();
updated.locked_until = locked_until;
let doc_json = serde_json::to_value(&updated).unwrap();
if let Ok(resp) = self
.client()
.replace_document(
&session.id,
&session.instance_id,
&doc_json,
session.etag.as_deref(),
)
.await
{
if resp.is_success() {
count += 1;
}
}
}
}
}
Ok(count)
}
async fn cleanup_orphaned_sessions(
&self,
_idle_timeout: Duration,
) -> Result<usize, ProviderError> {
let now = now_ms();
let sql = format!(
"SELECT * FROM c WHERE c.type = '{}' AND c.lockedUntil < @now",
DOC_TYPE_SESSION
);
let params = vec![crate::client::QueryParameter::new(
"@now",
serde_json::json!(now),
)];
let results = self.client().query(&sql, params, None).await?;
let mut count = 0usize;
for doc in results {
if let Ok(session) = serde_json::from_value::<SessionDocument>(doc) {
let check_sql = format!(
"SELECT VALUE COUNT(1) FROM c WHERE c.type = '{}' AND c.sessionId = @sid AND c.instanceId = @iid",
DOC_TYPE_WORKER_QUEUE
);
let check_params = vec![
crate::client::QueryParameter::new(
"@sid",
serde_json::json!(&session.session_id),
),
crate::client::QueryParameter::new(
"@iid",
serde_json::json!(&session.instance_id),
),
];
let worker_count = self
.client()
.query(&check_sql, check_params, Some(&session.instance_id))
.await?;
let has_items = worker_count
.into_iter()
.next()
.and_then(|v| v.as_u64())
.unwrap_or(0)
> 0;
if !has_items {
let _ = self
.client()
.delete_document(&session.id, &session.instance_id)
.await;
count += 1;
}
}
}
Ok(count)
}
async fn abandon_work_item(
&self,
token: &str,
delay: Option<Duration>,
ignore_attempt: bool,
) -> Result<(), ProviderError> {
let now = now_ms();
let items =
query::find_items_by_lock_token(self.client(), token, DOC_TYPE_WORKER_QUEUE).await?;
for item in &items {
let mut updated = item.clone();
updated.lock_token = None;
updated.locked_until = None;
if let Some(d) = delay {
updated.visible_at = now + d.as_millis() as u64;
}
if ignore_attempt && updated.attempt_count > 0 {
updated.attempt_count -= 1;
}
let doc_json = serde_json::to_value(&updated).unwrap();
let _ = self
.client()
.replace_document(&item.id, &item.instance_id, &doc_json, item.etag.as_deref())
.await;
}
Ok(())
}
async fn renew_orchestration_item_lock(
&self,
token: &str,
extend_for: Duration,
) -> Result<(), ProviderError> {
let now = now_ms();
let instance = query::find_instance_by_lock_token(self.client(), token).await?;
let inst = instance.ok_or_else(|| {
ProviderError::permanent(
"renew_orchestration_item_lock",
format!("No instance found with lock token {token}"),
)
})?;
if let Some(locked_until) = inst.locked_until {
if locked_until <= now {
return Err(ProviderError::permanent(
"renew_orchestration_item_lock",
"Lock has expired".to_string(),
));
}
}
let mut updated_inst = inst.clone();
updated_inst.locked_until = Some(now + extend_for.as_millis() as u64);
let doc_json = serde_json::to_value(&updated_inst).unwrap();
let resp = self
.client()
.replace_document(&inst.id, &inst.instance_id, &doc_json, inst.etag.as_deref())
.await?;
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"renew_orchestration_item_lock",
resp.status,
&resp.body,
));
}
let messages =
query::find_items_by_lock_token(self.client(), token, DOC_TYPE_ORCH_QUEUE).await?;
for msg in &messages {
let mut updated = msg.clone();
updated.locked_until = Some(now + extend_for.as_millis() as u64);
let doc_json = serde_json::to_value(&updated).unwrap();
let _ = self
.client()
.replace_document(&msg.id, &msg.instance_id, &doc_json, msg.etag.as_deref())
.await;
}
Ok(())
}
async fn enqueue_for_orchestrator(
&self,
item: WorkItem,
delay: Option<Duration>,
) -> Result<(), ProviderError> {
let instance_id = work_item_instance(&item).to_string();
let item_json = serde_json::to_string(&item).map_err(|e| {
ProviderError::permanent("enqueue_for_orchestrator", format!("Serialize error: {e}"))
})?;
let now = now_ms();
let visible_at = delay.map(|d| now + d.as_millis() as u64).unwrap_or(now);
let doc = QueueItemDocument::new_orch_queue(&instance_id, item_json, visible_at, now);
let doc_json = serde_json::to_value(&doc).unwrap();
let resp = self
.client()
.create_document(&instance_id, &doc_json)
.await?;
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"enqueue_for_orchestrator",
resp.status,
&resp.body,
));
}
Ok(())
}
fn as_management_capability(&self) -> Option<&dyn ProviderAdmin> {
Some(self)
}
async fn get_custom_status(
&self,
instance: &str,
last_seen_version: u64,
) -> Result<Option<(Option<String>, u64)>, ProviderError> {
let inst = self.read_instance(instance).await?;
match inst {
Some(i) => {
if i.custom_status_version > last_seen_version {
Ok(Some((i.custom_status, i.custom_status_version)))
} else {
Ok(None)
}
}
None => Ok(None),
}
}
async fn get_kv_value(
&self,
instance_id: &str,
key: &str,
) -> Result<Option<String>, ProviderError> {
let doc_id = KeyValueDocument::doc_id(instance_id, key);
let resp = self.client().read_document(&doc_id, instance_id).await?;
if errors::is_not_found(resp.status) {
return Ok(None);
}
if !resp.is_success() {
return Err(errors::map_cosmosdb_error(
"get_kv_value",
resp.status,
&resp.body,
));
}
let doc: KeyValueDocument = serde_json::from_str(&resp.body).map_err(|e| {
ProviderError::permanent("get_kv_value", format!("Deserialize error: {e}"))
})?;
Ok(Some(doc.value))
}
async fn get_kv_all_values(
&self,
instance_id: &str,
) -> Result<std::collections::HashMap<String, String>, ProviderError> {
let docs = query::query_by_type_in_partition(self.client(), instance_id, DOC_TYPE_KV).await?;
let map: std::collections::HashMap<String, String> = docs
.iter()
.filter_map(|doc| {
let key = doc.get("key")?.as_str()?.to_string();
let value = doc.get("value")?.as_str()?.to_string();
Some((key, value))
})
.collect();
Ok(map)
}
}
#[async_trait::async_trait]
impl ProviderAdmin for CosmosDBProvider {
async fn list_instances(&self) -> Result<Vec<String>, ProviderError> {
let instances = query::query_instances(self.client(), None).await?;
Ok(instances.iter().map(|i| i.instance_id.clone()).collect())
}
async fn list_instances_by_status(&self, status: &str) -> Result<Vec<String>, ProviderError> {
let instances = query::query_instances(self.client(), Some(status)).await?;
Ok(instances.iter().map(|i| i.instance_id.clone()).collect())
}
async fn list_executions(&self, instance: &str) -> Result<Vec<u64>, ProviderError> {
let sql = format!(
"SELECT DISTINCT VALUE c.executionId FROM c \
WHERE c.instanceId = @instanceId AND c.type = '{}'",
DOC_TYPE_HISTORY
);
let params = vec![crate::client::QueryParameter::new(
"@instanceId",
serde_json::json!(instance),
)];
let results = self.client().query(&sql, params, Some(instance)).await?;
let mut exec_ids: Vec<u64> = results.into_iter().filter_map(|v| v.as_u64()).collect();
exec_ids.sort();
if exec_ids.is_empty() {
let inst = self.read_instance(instance).await?;
if let Some(i) = inst {
exec_ids.push(i.current_execution_id);
}
}
Ok(exec_ids)
}
async fn read_history_with_execution_id(
&self,
instance: &str,
execution_id: u64,
) -> Result<Vec<Event>, ProviderError> {
self.read_with_execution(instance, execution_id).await
}
async fn read_history(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
self.read(instance).await
}
async fn latest_execution_id(&self, instance: &str) -> Result<u64, ProviderError> {
let inst = self.read_instance(instance).await?.ok_or_else(|| {
ProviderError::permanent(
"latest_execution_id",
format!("Instance {instance} not found"),
)
})?;
Ok(inst.current_execution_id)
}
async fn get_instance_info(&self, instance: &str) -> Result<InstanceInfo, ProviderError> {
let inst = self.read_instance(instance).await?.ok_or_else(|| {
ProviderError::permanent(
"get_instance_info",
format!("Instance {instance} not found"),
)
})?;
Ok(InstanceInfo {
instance_id: inst.instance_id,
orchestration_name: inst.orchestration_name,
orchestration_version: inst.orchestration_version,
current_execution_id: inst.current_execution_id,
status: inst.status,
output: inst.output,
created_at: inst.created_at,
updated_at: inst.updated_at,
parent_instance_id: inst.parent_instance_id,
})
}
async fn get_execution_info(
&self,
instance: &str,
execution_id: u64,
) -> Result<ExecutionInfo, ProviderError> {
let events = self.read_with_execution(instance, execution_id).await?;
let inst = self.read_instance(instance).await?;
let event_count = events.len();
let started_at = inst.as_ref().map(|i| i.created_at).unwrap_or(0);
let status = if let Some(i) = &inst {
if i.current_execution_id == execution_id {
i.status.clone()
} else {
"ContinuedAsNew".to_string()
}
} else {
"Unknown".to_string()
};
let completed_at = if status == "Running" {
None
} else {
inst.as_ref().map(|i| i.updated_at)
};
Ok(ExecutionInfo {
execution_id,
status,
output: inst.and_then(|i| i.output),
started_at,
completed_at,
event_count,
})
}
async fn get_system_metrics(&self) -> Result<SystemMetrics, ProviderError> {
let instances = query::query_instances(self.client(), None).await?;
let total = instances.len() as u64;
let running = instances.iter().filter(|i| i.status == "Running").count() as u64;
let completed = instances.iter().filter(|i| i.status == "Completed").count() as u64;
let failed = instances.iter().filter(|i| i.status == "Failed").count() as u64;
let total_events =
query::count_by_type(self.client(), DOC_TYPE_HISTORY, None).await? as u64;
Ok(SystemMetrics {
total_instances: total,
total_executions: total, running_instances: running,
completed_instances: completed,
failed_instances: failed,
total_events,
})
}
async fn get_queue_depths(&self) -> Result<QueueDepths, ProviderError> {
let now = now_ms();
let now_filter = format!("c.visibleAt <= {now} AND (NOT IS_DEFINED(c.lockedUntil) OR c.lockedUntil = null OR c.lockedUntil <= {now})");
let orch =
query::count_by_type(self.client(), DOC_TYPE_ORCH_QUEUE, Some(&now_filter)).await?;
let worker =
query::count_by_type(self.client(), DOC_TYPE_WORKER_QUEUE, Some(&now_filter)).await?;
let timer_filter = format!("c.visibleAt > {now}");
let timer =
query::count_by_type(self.client(), DOC_TYPE_ORCH_QUEUE, Some(&timer_filter)).await?;
Ok(QueueDepths {
orchestrator_queue: orch,
worker_queue: worker,
timer_queue: timer,
})
}
async fn list_children(&self, instance_id: &str) -> Result<Vec<String>, ProviderError> {
let sql = format!(
"SELECT c.instanceId FROM c WHERE c.type = '{}' AND c.parentInstanceId = @parentId",
DOC_TYPE_INSTANCE
);
let params = vec![crate::client::QueryParameter::new(
"@parentId",
serde_json::json!(instance_id),
)];
let results = self.client().query(&sql, params, None).await?;
Ok(results
.into_iter()
.filter_map(|v| {
v.get("instanceId")
.and_then(|id| id.as_str())
.map(|s| s.to_string())
})
.collect())
}
async fn get_parent_id(&self, instance_id: &str) -> Result<Option<String>, ProviderError> {
let inst = self.read_instance(instance_id).await?.ok_or_else(|| {
ProviderError::permanent("get_parent_id", format!("Instance {instance_id} not found"))
})?;
Ok(inst.parent_instance_id)
}
async fn delete_instances_atomic(
&self,
ids: &[String],
force: bool,
) -> Result<DeleteInstanceResult, ProviderError> {
if ids.is_empty() {
return Ok(DeleteInstanceResult::default());
}
if !force {
for id in ids {
if let Some(inst) = self.read_instance(id).await? {
if inst.status == "Running" {
return Err(ProviderError::permanent(
"delete_instances_atomic",
format!("Instance {id} is still running. Use force=true to delete anyway, or cancel first."),
));
}
}
}
}
let id_set: std::collections::HashSet<&String> = ids.iter().collect();
for id in ids {
let children = self.list_children(id).await.unwrap_or_default();
for child in &children {
if !id_set.contains(child) {
return Err(ProviderError::permanent(
"delete_instances_atomic",
format!(
"Cannot delete: instance {id} has child {child} that was created after tree traversal. \
Re-fetch the tree and retry."
),
));
}
}
}
let mut result = DeleteInstanceResult::default();
for id in ids {
let history_filter = format!("c.instanceId = '{id}'");
let events_count =
query::count_by_type(self.client(), DOC_TYPE_HISTORY, Some(&history_filter))
.await
.unwrap_or(0) as u64;
let orch_q_count =
query::count_by_type(self.client(), DOC_TYPE_ORCH_QUEUE, Some(&history_filter))
.await
.unwrap_or(0) as u64;
let worker_q_count =
query::count_by_type(self.client(), DOC_TYPE_WORKER_QUEUE, Some(&history_filter))
.await
.unwrap_or(0) as u64;
let exec_count = if self.read_instance(id).await?.is_some() {
1u64
} else {
0u64
};
let docs = query::query_all_in_partition(self.client(), id).await?;
for doc in &docs {
if let Some(doc_id) = doc.get("id").and_then(|v| v.as_str()) {
let _ = self.client().delete_document(doc_id, id).await;
}
}
result.instances_deleted += 1;
result.executions_deleted += exec_count;
result.events_deleted += events_count;
result.queue_messages_deleted += orch_q_count + worker_q_count;
}
Ok(result)
}
async fn delete_instance_bulk(
&self,
filter: InstanceFilter,
) -> Result<DeleteInstanceResult, ProviderError> {
let all_instances = query::query_instances(self.client(), None).await?;
let mut candidates: Vec<InstanceDocument> = all_instances
.into_iter()
.filter(|inst| {
inst.parent_instance_id.is_none()
&& (inst.status == "Completed" || inst.status == "Failed" || inst.status == "ContinuedAsNew")
})
.collect();
if let Some(ref ids) = filter.instance_ids {
if ids.is_empty() {
return Ok(DeleteInstanceResult::default());
}
candidates.retain(|inst| ids.contains(&inst.instance_id));
}
if let Some(before) = filter.completed_before {
candidates.retain(|inst| inst.updated_at < before);
}
let limit = filter.limit.unwrap_or(1000) as usize;
candidates.truncate(limit);
if candidates.is_empty() {
return Ok(DeleteInstanceResult::default());
}
let mut result = DeleteInstanceResult::default();
for inst in &candidates {
let tree = self.get_instance_tree(&inst.instance_id).await?;
let delete_result = self.delete_instances_atomic(&tree.all_ids, true).await?;
result.instances_deleted += delete_result.instances_deleted;
result.executions_deleted += delete_result.executions_deleted;
result.events_deleted += delete_result.events_deleted;
result.queue_messages_deleted += delete_result.queue_messages_deleted;
}
Ok(result)
}
async fn prune_executions(
&self,
instance_id: &str,
options: PruneOptions,
) -> Result<PruneResult, ProviderError> {
let inst = self.read_instance(instance_id).await?.ok_or_else(|| {
ProviderError::permanent(
"prune_executions",
format!("Instance {instance_id} not found"),
)
})?;
let current_exec = inst.current_execution_id;
let mut all_execs = self.list_executions(instance_id).await?;
all_execs.sort();
let mut protected: std::collections::HashSet<u64> = std::collections::HashSet::new();
protected.insert(current_exec);
if let Some(keep_last) = options.keep_last {
let keep = keep_last as usize;
let skip = all_execs.len().saturating_sub(keep);
for &exec_id in &all_execs[skip..] {
protected.insert(exec_id);
}
}
let to_prune: Vec<u64> = all_execs
.into_iter()
.filter(|e| !protected.contains(e))
.collect();
let mut events_deleted = 0u64;
let mut execs_deleted = 0u64;
for exec_id in &to_prune {
let docs = query::fetch_history(self.client(), instance_id, *exec_id).await?;
for doc in &docs {
let _ = self.client().delete_document(&doc.id, instance_id).await;
events_deleted += 1;
}
execs_deleted += 1;
}
Ok(PruneResult {
instances_processed: 1,
executions_deleted: execs_deleted,
events_deleted,
})
}
async fn prune_executions_bulk(
&self,
filter: InstanceFilter,
options: PruneOptions,
) -> Result<PruneResult, ProviderError> {
let instances = if let Some(ids) = &filter.instance_ids {
ids.clone()
} else {
self.list_instances().await?
};
let mut total = PruneResult::default();
for id in &instances {
match self.prune_executions(id, options.clone()).await {
Ok(r) => {
total.instances_processed += r.instances_processed;
total.executions_deleted += r.executions_deleted;
total.events_deleted += r.events_deleted;
}
Err(_) => continue,
}
}
Ok(total)
}
}