mod poll_buffer;
pub use poll_buffer::PollWorkflowTaskBuffer;
use crate::protos::temporal::api::workflowservice::v1::{
RecordActivityTaskHeartbeatRequest, RecordActivityTaskHeartbeatResponse,
};
use crate::{
machines::ProtoCommand,
protos::temporal::api::{
common::v1::{Payloads, WorkflowExecution, WorkflowType},
enums::v1::{TaskQueueKind, WorkflowTaskFailedCause},
failure::v1::Failure,
taskqueue::v1::TaskQueue,
workflowservice::v1::{
workflow_service_client::WorkflowServiceClient, PollActivityTaskQueueRequest,
PollActivityTaskQueueResponse, PollWorkflowTaskQueueRequest,
PollWorkflowTaskQueueResponse, RespondActivityTaskCanceledRequest,
RespondActivityTaskCanceledResponse, RespondActivityTaskCompletedRequest,
RespondActivityTaskCompletedResponse, RespondActivityTaskFailedRequest,
RespondActivityTaskFailedResponse, RespondWorkflowTaskCompletedRequest,
RespondWorkflowTaskCompletedResponse, RespondWorkflowTaskFailedRequest,
RespondWorkflowTaskFailedResponse, SignalWorkflowExecutionRequest,
SignalWorkflowExecutionResponse, StartWorkflowExecutionRequest,
StartWorkflowExecutionResponse,
},
},
CoreInitError,
};
use std::time::Duration;
use tonic::{transport::Channel, Request, Status};
use url::Url;
use uuid::Uuid;
pub type Result<T, E = Status> = std::result::Result<T, E>;
#[derive(Clone, Debug)]
pub struct ServerGatewayOptions {
pub target_url: Url,
pub namespace: String,
pub task_queue: String,
pub identity: String,
pub worker_binary_id: String,
pub long_poll_timeout: Duration,
}
impl ServerGatewayOptions {
pub async fn connect(&self) -> Result<ServerGateway, CoreInitError> {
let channel = Channel::from_shared(self.target_url.to_string())?
.connect()
.await?;
let interceptor = intercept(&self);
let service = WorkflowServiceClient::with_interceptor(channel, interceptor);
Ok(ServerGateway {
service,
opts: self.clone(),
})
}
}
fn intercept(opts: &ServerGatewayOptions) -> impl Fn(Request<()>) -> Result<Request<()>, Status> {
let timeout_str = format!("{}m", opts.long_poll_timeout.as_millis());
move |mut req: Request<()>| {
let metadata = req.metadata_mut();
metadata.insert(
"grpc-timeout",
timeout_str
.parse()
.expect("Timeout string construction cannot fail"),
);
metadata.insert(
"client-name",
"core-sdk".parse().expect("Static value is parsable"),
);
Ok(req)
}
}
pub struct ServerGateway {
pub service: WorkflowServiceClient<tonic::transport::Channel>,
pub opts: ServerGatewayOptions,
}
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait ServerGatewayApis {
async fn start_workflow(
&self,
namespace: String,
task_queue: String,
workflow_id: String,
workflow_type: String,
) -> Result<StartWorkflowExecutionResponse>;
async fn poll_workflow_task(&self) -> Result<PollWorkflowTaskQueueResponse>;
async fn poll_activity_task(&self) -> Result<PollActivityTaskQueueResponse>;
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> Result<RespondWorkflowTaskCompletedResponse>;
async fn complete_activity_task(
&self,
task_token: Vec<u8>,
result: Option<Payloads>,
) -> Result<RespondActivityTaskCompletedResponse>;
async fn record_activity_heartbeat(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> Result<RecordActivityTaskHeartbeatResponse>;
async fn cancel_activity_task(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> Result<RespondActivityTaskCanceledResponse>;
async fn fail_activity_task(
&self,
task_token: Vec<u8>,
failure: Option<Failure>,
) -> Result<RespondActivityTaskFailedResponse>;
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> Result<RespondWorkflowTaskFailedResponse>;
async fn signal_workflow_execution(
&self,
workflow_id: String,
run_id: String,
signal_name: String,
payloads: Option<Payloads>,
) -> Result<SignalWorkflowExecutionResponse>;
}
pub enum PollTaskRequest {
Workflow(String),
Activity(String),
}
#[async_trait::async_trait]
impl ServerGatewayApis for ServerGateway {
async fn start_workflow(
&self,
namespace: String,
task_queue: String,
workflow_id: String,
workflow_type: String,
) -> Result<StartWorkflowExecutionResponse> {
let request_id = Uuid::new_v4().to_string();
Ok(self
.service
.clone()
.start_workflow_execution(StartWorkflowExecutionRequest {
namespace,
workflow_id,
workflow_type: Some(WorkflowType {
name: workflow_type,
}),
task_queue: Some(TaskQueue {
name: task_queue,
kind: 0,
}),
request_id,
..Default::default()
})
.await?
.into_inner())
}
async fn poll_workflow_task(&self) -> Result<PollWorkflowTaskQueueResponse> {
let request = PollWorkflowTaskQueueRequest {
namespace: self.opts.namespace.clone(),
task_queue: Some(TaskQueue {
name: self.opts.task_queue.clone(),
kind: TaskQueueKind::Unspecified as i32,
}),
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
};
Ok(self
.service
.clone()
.poll_workflow_task_queue(request)
.await?
.into_inner())
}
async fn poll_activity_task(&self) -> Result<PollActivityTaskQueueResponse> {
let request = PollActivityTaskQueueRequest {
namespace: self.opts.namespace.clone(),
task_queue: Some(TaskQueue {
name: self.opts.task_queue.clone(),
kind: TaskQueueKind::Normal as i32,
}),
identity: self.opts.identity.clone(),
task_queue_metadata: None,
};
Ok(self
.service
.clone()
.poll_activity_task_queue(request)
.await?
.into_inner())
}
async fn complete_workflow_task(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> Result<RespondWorkflowTaskCompletedResponse> {
let request = RespondWorkflowTaskCompletedRequest {
task_token,
commands,
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
namespace: self.opts.namespace.clone(),
..Default::default()
};
Ok(self
.service
.clone()
.respond_workflow_task_completed(request)
.await?
.into_inner())
}
async fn complete_activity_task(
&self,
task_token: Vec<u8>,
result: Option<Payloads>,
) -> Result<RespondActivityTaskCompletedResponse> {
Ok(self
.service
.clone()
.respond_activity_task_completed(RespondActivityTaskCompletedRequest {
task_token,
result,
identity: self.opts.identity.clone(),
namespace: self.opts.namespace.clone(),
})
.await?
.into_inner())
}
async fn record_activity_heartbeat(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> Result<RecordActivityTaskHeartbeatResponse> {
Ok(self
.service
.clone()
.record_activity_task_heartbeat(RecordActivityTaskHeartbeatRequest {
task_token,
details,
identity: self.opts.identity.clone(),
namespace: self.opts.namespace.clone(),
})
.await?
.into_inner())
}
async fn cancel_activity_task(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> Result<RespondActivityTaskCanceledResponse> {
Ok(self
.service
.clone()
.respond_activity_task_canceled(RespondActivityTaskCanceledRequest {
task_token,
details,
identity: self.opts.identity.clone(),
namespace: self.opts.namespace.clone(),
})
.await?
.into_inner())
}
async fn fail_activity_task(
&self,
task_token: Vec<u8>,
failure: Option<Failure>,
) -> Result<RespondActivityTaskFailedResponse> {
Ok(self
.service
.clone()
.respond_activity_task_failed(RespondActivityTaskFailedRequest {
task_token,
failure,
identity: self.opts.identity.clone(),
namespace: self.opts.namespace.clone(),
})
.await?
.into_inner())
}
async fn fail_workflow_task(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> Result<RespondWorkflowTaskFailedResponse> {
let request = RespondWorkflowTaskFailedRequest {
task_token,
cause: cause as i32,
failure,
identity: self.opts.identity.clone(),
binary_checksum: self.opts.worker_binary_id.clone(),
namespace: self.opts.namespace.clone(),
};
Ok(self
.service
.clone()
.respond_workflow_task_failed(request)
.await?
.into_inner())
}
async fn signal_workflow_execution(
&self,
workflow_id: String,
run_id: String,
signal_name: String,
payloads: Option<Payloads>,
) -> Result<SignalWorkflowExecutionResponse> {
Ok(self
.service
.clone()
.signal_workflow_execution(SignalWorkflowExecutionRequest {
namespace: self.opts.namespace.clone(),
workflow_execution: Some(WorkflowExecution {
workflow_id,
run_id,
}),
signal_name,
input: payloads,
identity: self.opts.identity.clone(),
..Default::default()
})
.await?
.into_inner())
}
}
#[cfg(test)]
mod manual_mock {
use super::*;
use std::future::Future;
mockall::mock! {
pub ManualGateway {}
impl ServerGatewayApis for ManualGateway {
fn start_workflow<'a, 'b>(
&self,
namespace: String,
task_queue: String,
workflow_id: String,
workflow_type: String,
) -> impl Future<Output = Result<StartWorkflowExecutionResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn poll_workflow_task<'a, 'b>(&'a self)
-> impl Future<Output = Result<PollWorkflowTaskQueueResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn poll_activity_task<'a, 'b>(&self)
-> impl Future<Output = Result<PollActivityTaskQueueResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn complete_workflow_task<'a, 'b>(
&self,
task_token: Vec<u8>,
commands: Vec<ProtoCommand>,
) -> impl Future<Output = Result<RespondWorkflowTaskCompletedResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn complete_activity_task<'a, 'b>(
&self,
task_token: Vec<u8>,
result: Option<Payloads>,
) -> impl Future<Output = Result<RespondActivityTaskCompletedResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn cancel_activity_task<'a, 'b>(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> impl Future<Output = Result<RespondActivityTaskCanceledResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn fail_activity_task<'a, 'b>(
&self,
task_token: Vec<u8>,
failure: Option<Failure>,
) -> impl Future<Output = Result<RespondActivityTaskFailedResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn fail_workflow_task<'a, 'b>(
&self,
task_token: Vec<u8>,
cause: WorkflowTaskFailedCause,
failure: Option<Failure>,
) -> impl Future<Output = Result<RespondWorkflowTaskFailedResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn signal_workflow_execution<'a, 'b>(
&self,
workflow_id: String,
run_id: String,
signal_name: String,
payloads: Option<Payloads>,
) -> impl Future<Output = Result<SignalWorkflowExecutionResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
fn record_activity_heartbeat<'a, 'b>(
&self,
task_token: Vec<u8>,
details: Option<Payloads>,
) -> impl Future<Output = Result<RecordActivityTaskHeartbeatResponse>> + Send + 'b
where 'a: 'b, Self: 'b;
}
}
}