use std::sync::Arc;
use bytes::Bytes;
use sayiir_core::codec::sealed;
use sayiir_core::codec::{Codec, EnvelopeCodec};
use sayiir_core::snapshot::{SignalKind, SignalRequest};
use sayiir_core::task::TaskIdentifier;
use sayiir_core::workflow::{ConflictPolicy, Workflow, WorkflowStatus};
use sayiir_persistence::{SignalStore, SnapshotStore, TaskResultStore};
use crate::error::RuntimeError;
use crate::{PrepareRunOutcome, check_existing_instance, prepare_run};
pub struct WorkflowClient<B> {
backend: Arc<B>,
conflict_policy: ConflictPolicy,
}
impl<B> WorkflowClient<B> {
pub fn new(backend: B) -> Self {
Self {
backend: Arc::new(backend),
conflict_policy: ConflictPolicy::default(),
}
}
pub fn from_shared(backend: Arc<B>) -> Self {
Self {
backend,
conflict_policy: ConflictPolicy::default(),
}
}
#[must_use]
pub fn with_conflict_policy(mut self, policy: ConflictPolicy) -> Self {
self.conflict_policy = policy;
self
}
#[must_use]
pub fn backend(&self) -> &Arc<B> {
&self.backend
}
}
impl<B> WorkflowClient<B>
where
B: SnapshotStore + SignalStore,
{
pub async fn submit<C, Input, M>(
&self,
workflow: &Workflow<C, Input, M>,
instance_id: impl Into<String>,
input: Input,
) -> Result<(WorkflowStatus, Option<Bytes>), RuntimeError>
where
Input: Send + 'static,
M: Send + Sync + 'static,
C: Codec + EnvelopeCodec + sealed::EncodeValue<Input> + 'static,
{
let instance_id = instance_id.into();
let definition_hash = *workflow.definition_hash();
let conflict_policy = self.conflict_policy;
if let Some(early) = check_existing_instance(
&instance_id,
&definition_hash,
self.backend.as_ref(),
conflict_policy,
)
.await?
{
return Ok(early);
}
let input_bytes = workflow.context().codec.encode(&input)?;
let first_task = workflow.continuation().first_task_hint();
match prepare_run(
&instance_id,
definition_hash,
input_bytes,
first_task,
self.backend.as_ref(),
conflict_policy,
)
.await?
{
PrepareRunOutcome::Fresh(_) => Ok((WorkflowStatus::InProgress, None)),
PrepareRunOutcome::ExistingStatus(status, output) => Ok((status, output)),
}
}
pub async fn cancel(
&self,
instance_id: &str,
reason: Option<String>,
cancelled_by: Option<String>,
) -> Result<(), RuntimeError> {
self.backend
.store_signal(
instance_id,
SignalKind::Cancel,
SignalRequest::new(reason, cancelled_by),
)
.await?;
Ok(())
}
pub async fn pause(
&self,
instance_id: &str,
reason: Option<String>,
paused_by: Option<String>,
) -> Result<(), RuntimeError> {
self.backend
.store_signal(
instance_id,
SignalKind::Pause,
SignalRequest::new(reason, paused_by),
)
.await?;
Ok(())
}
pub async fn unpause(&self, instance_id: &str) -> Result<(), RuntimeError> {
self.backend.unpause(instance_id).await?;
Ok(())
}
pub async fn send_event(
&self,
instance_id: &str,
signal_name: &str,
payload: Bytes,
) -> Result<(), RuntimeError> {
self.backend
.send_event(instance_id, signal_name, payload)
.await?;
Ok(())
}
pub async fn status(&self, instance_id: &str) -> Result<WorkflowStatus, RuntimeError> {
let snapshot = self.backend.load_snapshot(instance_id).await?;
Ok(snapshot.state.as_status())
}
}
impl<B> WorkflowClient<B>
where
B: SnapshotStore + SignalStore + TaskResultStore,
{
pub async fn get_task_result(
&self,
instance_id: &str,
task_id: &str,
) -> Result<Option<Bytes>, RuntimeError> {
Ok(self
.backend
.load_task_result(instance_id, &sayiir_core::TaskId::from(task_id))
.await?)
}
pub async fn get_task_result_of<T: TaskIdentifier>(
&self,
instance_id: &str,
) -> Result<Option<Bytes>, RuntimeError> {
self.get_task_result(instance_id, T::task_id()).await
}
}