use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use serde::de::DeserializeOwned;
use rustvello_proto::identifiers::InvocationId;
use rustvello_proto::status::InvocationStatus;
use crate::error::{RustvelloError, RustvelloResult};
use crate::orchestrator::Orchestrator;
use crate::state_backend::StateBackend;
pub struct InvocationHandle<R: DeserializeOwned = String> {
invocation_id: InvocationId,
orchestrator: Arc<dyn Orchestrator>,
state_backend: Arc<dyn StateBackend>,
_result_type: PhantomData<R>,
}
impl<R: DeserializeOwned> InvocationHandle<R> {
pub fn new(
invocation_id: InvocationId,
orchestrator: Arc<dyn Orchestrator>,
state_backend: Arc<dyn StateBackend>,
) -> Self {
Self {
invocation_id,
orchestrator,
state_backend,
_result_type: PhantomData,
}
}
pub fn invocation_id(&self) -> &InvocationId {
&self.invocation_id
}
pub async fn status(&self) -> RustvelloResult<InvocationStatus> {
let record = self
.orchestrator
.get_invocation_status(&self.invocation_id)
.await?;
Ok(record.status)
}
pub async fn is_done(&self) -> RustvelloResult<bool> {
Ok(self.status().await?.is_terminal())
}
pub async fn result(&self) -> RustvelloResult<R> {
let status = self.status().await?;
match status {
InvocationStatus::Success => {
let raw = self
.state_backend
.get_result(&self.invocation_id)
.await?
.ok_or_else(|| RustvelloError::Internal {
message: format!(
"invocation {} has SUCCESS status but no stored result",
self.invocation_id
),
})?;
serde_json::from_str(&raw).map_err(|e| RustvelloError::Serialization {
message: e.to_string(),
})
}
InvocationStatus::Failed => {
let err = self.state_backend.get_error(&self.invocation_id).await?;
Err(RustvelloError::runner_err(err.map_or_else(
|| "unknown error".to_string(),
|e| e.to_string(),
)))
}
other => Err(RustvelloError::Internal {
message: format!(
"invocation {} is not finished (status: {})",
self.invocation_id, other
),
}),
}
}
pub async fn wait(&self, poll_interval: Duration) -> RustvelloResult<R> {
loop {
if self.is_done().await? {
return self.result().await;
}
tokio::time::sleep(poll_interval).await;
}
}
pub async fn wait_timeout(
&self,
timeout: Duration,
poll_interval: Duration,
) -> RustvelloResult<R> {
tokio::time::timeout(timeout, self.wait(poll_interval))
.await
.map_err(|_| {
RustvelloError::runner_err(format!(
"timeout waiting for invocation {}",
self.invocation_id
))
})?
}
pub fn into_untyped(self) -> InvocationHandle<String> {
InvocationHandle {
invocation_id: self.invocation_id,
orchestrator: self.orchestrator,
state_backend: self.state_backend,
_result_type: PhantomData,
}
}
}
impl<R: DeserializeOwned> std::fmt::Debug for InvocationHandle<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InvocationHandle")
.field("invocation_id", &self.invocation_id)
.finish()
}
}
pub struct SyncInvocation<R> {
invocation_id: InvocationId,
status: InvocationStatus,
result: Result<R, RustvelloError>,
}
impl<R> SyncInvocation<R> {
pub fn success(invocation_id: InvocationId, result: R) -> Self {
Self {
invocation_id,
status: InvocationStatus::Success,
result: Ok(result),
}
}
pub fn failed(invocation_id: InvocationId, error: RustvelloError) -> Self {
Self {
invocation_id,
status: InvocationStatus::Failed,
result: Err(error),
}
}
pub fn invocation_id(&self) -> &InvocationId {
&self.invocation_id
}
pub fn status(&self) -> InvocationStatus {
self.status
}
pub fn is_done(&self) -> bool {
true
}
}
impl<R> std::fmt::Debug for SyncInvocation<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncInvocation")
.field("invocation_id", &self.invocation_id)
.field("status", &self.status)
.finish()
}
}
#[non_exhaustive]
pub enum Invocation<R: DeserializeOwned> {
Sync(SyncInvocation<R>),
Distributed(InvocationHandle<R>),
}
impl<R: DeserializeOwned> Invocation<R> {
pub fn invocation_id(&self) -> &InvocationId {
match self {
Self::Sync(s) => s.invocation_id(),
Self::Distributed(d) => d.invocation_id(),
}
}
pub async fn status(&self) -> RustvelloResult<InvocationStatus> {
match self {
Self::Sync(s) => Ok(s.status()),
Self::Distributed(d) => d.status().await,
}
}
pub async fn is_done(&self) -> RustvelloResult<bool> {
match self {
Self::Sync(s) => Ok(s.is_done()),
Self::Distributed(d) => d.is_done().await,
}
}
pub async fn result(self) -> RustvelloResult<R> {
match self {
Self::Sync(s) => s.result,
Self::Distributed(d) => d.result().await,
}
}
pub async fn wait(self, poll_interval: Duration) -> RustvelloResult<R> {
match self {
Self::Sync(s) => s.result,
Self::Distributed(d) => d.wait(poll_interval).await,
}
}
pub async fn wait_timeout(
self,
timeout: Duration,
poll_interval: Duration,
) -> RustvelloResult<R> {
match self {
Self::Sync(s) => s.result,
Self::Distributed(d) => d.wait_timeout(timeout, poll_interval).await,
}
}
pub fn is_sync(&self) -> bool {
matches!(self, Self::Sync(_))
}
pub fn is_distributed(&self) -> bool {
matches!(self, Self::Distributed(_))
}
}
impl<R: DeserializeOwned> std::fmt::Debug for Invocation<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Sync(s) => f.debug_tuple("Invocation::Sync").field(s).finish(),
Self::Distributed(d) => f.debug_tuple("Invocation::Distributed").field(d).finish(),
}
}
}