#![warn(missing_docs)]
#[macro_use]
extern crate tracing;
mod async_activity_handle;
pub mod callback_based;
mod dns;
#[cfg(feature = "envconfig")]
pub mod envconfig;
pub mod errors;
pub mod grpc;
mod metrics;
mod options_structs;
#[doc(hidden)]
pub mod proxy;
mod replaceable;
pub mod request_extensions;
mod retry;
pub mod schedules;
pub mod worker;
mod workflow_handle;
pub use crate::{
proxy::HttpConnectProxyOptions,
retry::{CallType, RETRYABLE_ERROR_CODES},
};
pub use async_activity_handle::{
ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle,
};
pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME};
pub use options_structs::*;
pub use replaceable::SharedReplaceableClient;
pub use retry::RetryOptions;
pub use tonic;
pub use workflow_handle::{
UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle,
WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle,
WorkflowHistory, WorkflowUpdateHandle,
};
use crate::{
grpc::{
AttachMetricLabels, CloudService, HealthService, OperatorService, TestService,
WorkflowService,
},
metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
request_extensions::RequestExt,
worker::ClientWorkerSet,
};
use errors::*;
use futures_util::{stream, stream::Stream};
use http::Uri;
use parking_lot::RwLock;
use std::{
collections::{HashMap, VecDeque},
fmt::Debug,
pin::Pin,
str::FromStr,
sync::{Arc, OnceLock},
task::{Context, Poll},
time::{Duration, SystemTime},
};
use temporalio_common::{
HasWorkflowDefinition,
data_converters::{
DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
SerializationContextData,
},
protos::{
coresdk::IntoPayloadsExt,
grpc::health::v1::health_client::HealthClient,
proto_ts_to_system_time,
temporal::api::{
cloud::cloudservice::v1::cloud_service_client::CloudServiceClient,
common::v1::{Memo, Payload, SearchAttributes, WorkflowType},
enums::v1::{TaskQueueKind, WorkflowExecutionStatus},
errordetails::v1::WorkflowExecutionAlreadyStartedFailure,
operatorservice::v1::operator_service_client::OperatorServiceClient,
sdk::v1::UserMetadata,
taskqueue::v1::TaskQueue,
testservice::v1::test_service_client::TestServiceClient,
workflow::v1 as workflow,
workflowservice::v1::{
count_workflow_executions_response, workflow_service_client::WorkflowServiceClient,
*,
},
},
utilities::decode_status_detail,
},
};
use tonic::{
Code, IntoRequest,
body::Body,
client::GrpcService,
codegen::InterceptedService,
metadata::{
AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
MetadataValue,
},
service::Interceptor,
transport::{Certificate, Channel, Endpoint, Identity},
};
use tower::ServiceBuilder;
use uuid::Uuid;
static CLIENT_NAME_HEADER_KEY: &str = "client-name";
static CLIENT_VERSION_HEADER_KEY: &str = "client-version";
static TEMPORAL_NAMESPACE_HEADER_KEY: &str = "temporal-namespace";
#[doc(hidden)]
pub static MESSAGE_TOO_LARGE_KEY: &str = "message-too-large";
#[doc(hidden)]
pub static ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT: &str = "short-circuit";
const LONG_POLL_TIMEOUT: Duration = Duration::from_secs(70);
const OTHER_CALL_TIMEOUT: Duration = Duration::from_secs(30);
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Clone)]
pub struct Connection {
inner: Arc<ConnectionInner>,
}
#[derive(Clone)]
struct ConnectionInner {
service: TemporalServiceClient,
retry_options: RetryOptions,
identity: String,
headers: Arc<RwLock<ClientHeaders>>,
client_name: String,
client_version: String,
capabilities: Option<get_system_info_response::Capabilities>,
workers: Arc<ClientWorkerSet>,
_dns_task: Option<Arc<dns::DnsReresolutionHandle>>,
}
impl Connection {
pub async fn connect(options: ConnectionOptions) -> Result<Self, ClientConnectError> {
let dns_lb_opts = dns::validate_and_get_dns_lb(&options)?.cloned();
let (service, dns_task) = if let Some(service_override) = options.service_override {
(
GrpcMetricSvc {
inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
metrics: options.metrics_meter.clone().map(MetricsContext::new),
disable_errcode_label: options.disable_error_code_metric_tags,
},
None,
)
} else if let Some(dns_opts) = &dns_lb_opts {
let (channel, sender) = dns::create_balanced_channel(&options).await?;
let handle = dns::spawn_dns_reresolution(
sender,
options.target.clone(),
options.tls_options.clone(),
options.keep_alive.clone(),
options.override_origin.clone(),
dns_opts.resolution_interval,
);
(
ServiceBuilder::new()
.layer_fn(move |channel| GrpcMetricSvc {
inner: ChannelOrGrpcOverride::Channel(channel),
metrics: options.metrics_meter.clone().map(MetricsContext::new),
disable_errcode_label: options.disable_error_code_metric_tags,
})
.service(channel),
Some(handle),
)
} else {
let channel = Channel::from_shared(options.target.to_string())?;
let channel = add_tls_to_channel(options.tls_options.as_ref(), channel).await?;
let channel = if let Some(keep_alive) = options.keep_alive.as_ref() {
channel
.keep_alive_while_idle(true)
.http2_keep_alive_interval(keep_alive.interval)
.keep_alive_timeout(keep_alive.timeout)
} else {
channel
};
let channel = if let Some(origin) = options.override_origin.clone() {
channel.origin(origin)
} else {
channel
};
let channel = if let Some(proxy) = options.http_connect_proxy.as_ref() {
proxy.connect_endpoint(&channel).await?
} else {
channel.connect().await?
};
(
ServiceBuilder::new()
.layer_fn(move |channel| GrpcMetricSvc {
inner: ChannelOrGrpcOverride::Channel(channel),
metrics: options.metrics_meter.clone().map(MetricsContext::new),
disable_errcode_label: options.disable_error_code_metric_tags,
})
.service(channel),
None,
)
};
let headers = Arc::new(RwLock::new(ClientHeaders {
user_headers: parse_ascii_headers(options.headers.clone().unwrap_or_default())?,
user_binary_headers: parse_binary_headers(
options.binary_headers.clone().unwrap_or_default(),
)?,
api_key: options.api_key.clone(),
}));
let interceptor = ServiceCallInterceptor {
client_name: options.client_name.clone(),
client_version: options.client_version.clone(),
headers: headers.clone(),
};
let svc = InterceptedService::new(service, interceptor);
let mut svc_client = TemporalServiceClient::new(svc);
let capabilities = if !options.skip_get_system_info {
match svc_client
.get_system_info(GetSystemInfoRequest::default().into_request())
.await
{
Ok(sysinfo) => sysinfo.into_inner().capabilities,
Err(status) => match status.code() {
Code::Unimplemented => None,
_ => return Err(ClientConnectError::SystemInfoCallError(status)),
},
}
} else {
None
};
Ok(Self {
inner: Arc::new(ConnectionInner {
service: svc_client,
retry_options: options.retry_options,
identity: options.identity,
headers,
client_name: options.client_name,
client_version: options.client_version,
capabilities,
workers: Arc::new(ClientWorkerSet::new()),
_dns_task: dns_task,
}),
})
}
pub fn set_api_key(&self, api_key: Option<String>) {
self.inner.headers.write().api_key = api_key;
}
pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
self.inner.headers.write().user_headers = parse_ascii_headers(headers)?;
Ok(())
}
pub fn set_binary_headers(
&self,
binary_headers: HashMap<String, Vec<u8>>,
) -> Result<(), InvalidHeaderError> {
self.inner.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
Ok(())
}
pub fn client_name(&self) -> &str {
&self.inner.client_name
}
pub fn client_version(&self) -> &str {
&self.inner.client_version
}
pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> {
self.inner.capabilities.as_ref()
}
pub fn retry_options_mut(&mut self) -> &mut RetryOptions {
&mut Arc::make_mut(&mut self.inner).retry_options
}
pub fn identity(&self) -> &str {
&self.inner.identity
}
pub fn identity_mut(&mut self) -> &mut String {
&mut Arc::make_mut(&mut self.inner).identity
}
pub fn workers(&self) -> Arc<ClientWorkerSet> {
self.inner.workers.clone()
}
pub fn worker_grouping_key(&self) -> Uuid {
self.inner.workers.worker_grouping_key()
}
pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
self.inner.service.workflow_service()
}
pub fn operator_service(&self) -> Box<dyn OperatorService> {
self.inner.service.operator_service()
}
pub fn cloud_service(&self) -> Box<dyn CloudService> {
self.inner.service.cloud_service()
}
pub fn test_service(&self) -> Box<dyn TestService> {
self.inner.service.test_service()
}
pub fn health_service(&self) -> Box<dyn HealthService> {
self.inner.service.health_service()
}
}
#[derive(Debug)]
struct ClientHeaders {
user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
api_key: Option<String>,
}
impl ClientHeaders {
fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
for (key, val) in self.user_headers.iter() {
if !metadata.contains_key(key) {
metadata.insert(key, val.clone());
}
}
for (key, val) in self.user_binary_headers.iter() {
if !metadata.contains_key(key) {
metadata.insert_bin(key, val.clone());
}
}
if let Some(api_key) = &self.api_key {
if !metadata.contains_key("authorization")
&& let Ok(val) = format!("Bearer {api_key}").parse()
{
metadata.insert("authorization", val);
}
}
}
}
async fn add_tls_to_channel(
tls_options: Option<&TlsOptions>,
mut channel: Endpoint,
) -> Result<Endpoint, ClientConnectError> {
if let Some(tls_cfg) = tls_options {
let mut tls = tonic::transport::ClientTlsConfig::new();
if let Some(root_cert) = &tls_cfg.server_root_ca_cert {
let server_root_ca_cert = Certificate::from_pem(root_cert);
tls = tls.ca_certificate(server_root_ca_cert);
} else {
tls = tls.with_native_roots();
}
if let Some(domain) = &tls_cfg.domain {
tls = tls.domain_name(domain);
let uri: Uri = format!("https://{domain}").parse()?;
channel = channel.origin(uri);
}
if let Some(client_opts) = &tls_cfg.client_tls_options {
let client_identity =
Identity::from_pem(&client_opts.client_cert, &client_opts.client_private_key);
tls = tls.identity(client_identity);
}
return channel.tls_config(tls).map_err(Into::into);
}
Ok(channel)
}
fn parse_ascii_headers(
headers: HashMap<String, String>,
) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
let mut parsed_headers = HashMap::with_capacity(headers.len());
for (k, v) in headers.into_iter() {
let key = match AsciiMetadataKey::from_str(&k) {
Ok(key) => key,
Err(err) => {
return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
key: k,
source: err,
});
}
};
let value = match MetadataValue::from_str(&v) {
Ok(value) => value,
Err(err) => {
return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
key: k,
value: v,
source: err,
});
}
};
parsed_headers.insert(key, value);
}
Ok(parsed_headers)
}
fn parse_binary_headers(
headers: HashMap<String, Vec<u8>>,
) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
let mut parsed_headers = HashMap::with_capacity(headers.len());
for (k, v) in headers.into_iter() {
let key = match BinaryMetadataKey::from_str(&k) {
Ok(key) => key,
Err(err) => {
return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
key: k,
source: err,
});
}
};
let value = BinaryMetadataValue::from_bytes(&v);
parsed_headers.insert(key, value);
}
Ok(parsed_headers)
}
#[derive(Clone)]
pub struct ServiceCallInterceptor {
client_name: String,
client_version: String,
headers: Arc<RwLock<ClientHeaders>>,
}
impl Interceptor for ServiceCallInterceptor {
fn call(
&mut self,
mut request: tonic::Request<()>,
) -> Result<tonic::Request<()>, tonic::Status> {
let metadata = request.metadata_mut();
if !metadata.contains_key(CLIENT_NAME_HEADER_KEY) {
metadata.insert(
CLIENT_NAME_HEADER_KEY,
self.client_name
.parse()
.unwrap_or_else(|_| MetadataValue::from_static("")),
);
}
if !metadata.contains_key(CLIENT_VERSION_HEADER_KEY) {
metadata.insert(
CLIENT_VERSION_HEADER_KEY,
self.client_version
.parse()
.unwrap_or_else(|_| MetadataValue::from_static("")),
);
}
self.headers.read().apply_to_metadata(metadata);
request.set_default_timeout(OTHER_CALL_TIMEOUT);
Ok(request)
}
}
#[derive(Clone)]
pub struct TemporalServiceClient {
workflow_svc_client: Box<dyn WorkflowService>,
operator_svc_client: Box<dyn OperatorService>,
cloud_svc_client: Box<dyn CloudService>,
test_svc_client: Box<dyn TestService>,
health_svc_client: Box<dyn HealthService>,
}
fn get_decode_max_size() -> usize {
static _DECODE_MAX_SIZE: OnceLock<usize> = OnceLock::new();
*_DECODE_MAX_SIZE.get_or_init(|| {
std::env::var("TEMPORAL_MAX_INCOMING_GRPC_BYTES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(128 * 1024 * 1024)
})
}
impl TemporalServiceClient {
fn new<T>(svc: T) -> Self
where
T: GrpcService<Body> + Send + Sync + Clone + 'static,
T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
T::Error: Into<tonic::codegen::StdError>,
<T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
<T as GrpcService<Body>>::Future: Send,
{
let workflow_svc_client = Box::new(
WorkflowServiceClient::new(svc.clone())
.max_decoding_message_size(get_decode_max_size()),
);
let operator_svc_client = Box::new(
OperatorServiceClient::new(svc.clone())
.max_decoding_message_size(get_decode_max_size()),
);
let cloud_svc_client = Box::new(
CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
);
let test_svc_client = Box::new(
TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
);
let health_svc_client = Box::new(
HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()),
);
Self {
workflow_svc_client,
operator_svc_client,
cloud_svc_client,
test_svc_client,
health_svc_client,
}
}
pub fn from_services(
workflow: Box<dyn WorkflowService>,
operator: Box<dyn OperatorService>,
cloud: Box<dyn CloudService>,
test: Box<dyn TestService>,
health: Box<dyn HealthService>,
) -> Self {
Self {
workflow_svc_client: workflow,
operator_svc_client: operator,
cloud_svc_client: cloud,
test_svc_client: test,
health_svc_client: health,
}
}
pub fn workflow_service(&self) -> Box<dyn WorkflowService> {
self.workflow_svc_client.clone()
}
pub fn operator_service(&self) -> Box<dyn OperatorService> {
self.operator_svc_client.clone()
}
pub fn cloud_service(&self) -> Box<dyn CloudService> {
self.cloud_svc_client.clone()
}
pub fn test_service(&self) -> Box<dyn TestService> {
self.test_svc_client.clone()
}
pub fn health_service(&self) -> Box<dyn HealthService> {
self.health_svc_client.clone()
}
}
#[derive(Clone)]
pub struct Client {
connection: Connection,
options: Arc<ClientOptions>,
}
impl Client {
pub fn new(connection: Connection, options: ClientOptions) -> Result<Self, ClientNewError> {
Ok(Client {
connection,
options: Arc::new(options),
})
}
pub fn options(&self) -> &ClientOptions {
&self.options
}
pub fn options_mut(&mut self) -> &mut ClientOptions {
Arc::make_mut(&mut self.options)
}
pub fn connection(&self) -> &Connection {
&self.connection
}
pub fn connection_mut(&mut self) -> &mut Connection {
&mut self.connection
}
}
impl Client {
pub async fn start_workflow<W>(
&self,
workflow: W,
input: W::Input,
options: WorkflowStartOptions,
) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
where
W: HasWorkflowDefinition,
W::Input: Send,
{
WorkflowClientTrait::start_workflow(self, workflow, input, options).await
}
pub fn get_workflow_handle<W: HasWorkflowDefinition>(
&self,
workflow_id: impl Into<String>,
) -> WorkflowHandle<Self, W> {
WorkflowClientTrait::get_workflow_handle(self, workflow_id)
}
pub fn list_workflows(
&self,
query: impl Into<String>,
opts: WorkflowListOptions,
) -> ListWorkflowsStream {
WorkflowClientTrait::list_workflows(self, query, opts)
}
pub async fn count_workflows(
&self,
query: impl Into<String>,
opts: WorkflowCountOptions,
) -> Result<WorkflowExecutionCount, ClientError> {
WorkflowClientTrait::count_workflows(self, query, opts).await
}
pub fn get_async_activity_handle(
&self,
identifier: ActivityIdentifier,
) -> AsyncActivityHandle<Self> {
WorkflowClientTrait::get_async_activity_handle(self, identifier)
}
}
impl NamespacedClient for Client {
fn namespace(&self) -> String {
self.options.namespace.clone()
}
fn identity(&self) -> String {
self.connection.identity().to_owned()
}
fn data_converter(&self) -> &DataConverter {
&self.options.data_converter
}
}
#[derive(Clone)]
pub enum Namespace {
Name(String),
Id(String),
}
impl Namespace {
pub fn into_describe_namespace_request(self) -> DescribeNamespaceRequest {
let (namespace, id) = match self {
Namespace::Name(n) => (n, "".to_owned()),
Namespace::Id(n) => ("".to_owned(), n),
};
DescribeNamespaceRequest { namespace, id }
}
}
pub(crate) trait WorkflowClientTrait: NamespacedClient {
fn start_workflow<W>(
&self,
workflow: W,
input: W::Input,
options: WorkflowStartOptions,
) -> impl Future<Output = Result<WorkflowHandle<Self, W>, WorkflowStartError>>
where
Self: Sized,
W: HasWorkflowDefinition,
W::Input: Send;
fn get_workflow_handle<W: HasWorkflowDefinition>(
&self,
workflow_id: impl Into<String>,
) -> WorkflowHandle<Self, W>
where
Self: Sized;
fn list_workflows(
&self,
query: impl Into<String>,
opts: WorkflowListOptions,
) -> ListWorkflowsStream;
fn count_workflows(
&self,
query: impl Into<String>,
opts: WorkflowCountOptions,
) -> impl Future<Output = Result<WorkflowExecutionCount, ClientError>>;
fn get_async_activity_handle(
&self,
identifier: ActivityIdentifier,
) -> AsyncActivityHandle<Self>
where
Self: Sized;
}
pub trait NamespacedClient {
fn namespace(&self) -> String;
fn identity(&self) -> String;
fn data_converter(&self) -> &DataConverter {
static DEFAULT: OnceLock<DataConverter> = OnceLock::new();
DEFAULT.get_or_init(DataConverter::default)
}
}
#[derive(Debug, Clone)]
pub struct WorkflowExecution {
raw: workflow::WorkflowExecutionInfo,
}
impl WorkflowExecution {
pub fn new(raw: workflow::WorkflowExecutionInfo) -> Self {
Self { raw }
}
pub fn id(&self) -> &str {
self.raw
.execution
.as_ref()
.map(|e| e.workflow_id.as_str())
.unwrap_or("")
}
pub fn run_id(&self) -> &str {
self.raw
.execution
.as_ref()
.map(|e| e.run_id.as_str())
.unwrap_or("")
}
pub fn workflow_type(&self) -> &str {
self.raw
.r#type
.as_ref()
.map(|t| t.name.as_str())
.unwrap_or("")
}
pub fn status(&self) -> WorkflowExecutionStatus {
self.raw.status()
}
pub fn start_time(&self) -> Option<SystemTime> {
self.raw
.start_time
.as_ref()
.and_then(proto_ts_to_system_time)
}
pub fn execution_time(&self) -> Option<SystemTime> {
self.raw
.execution_time
.as_ref()
.and_then(proto_ts_to_system_time)
}
pub fn close_time(&self) -> Option<SystemTime> {
self.raw
.close_time
.as_ref()
.and_then(proto_ts_to_system_time)
}
pub fn task_queue(&self) -> &str {
&self.raw.task_queue
}
pub fn history_length(&self) -> i64 {
self.raw.history_length
}
pub fn memo(&self) -> Option<&Memo> {
self.raw.memo.as_ref()
}
pub fn parent_id(&self) -> Option<&str> {
self.raw
.parent_execution
.as_ref()
.map(|e| e.workflow_id.as_str())
}
pub fn parent_run_id(&self) -> Option<&str> {
self.raw
.parent_execution
.as_ref()
.map(|e| e.run_id.as_str())
}
pub fn search_attributes(&self) -> Option<&SearchAttributes> {
self.raw.search_attributes.as_ref()
}
pub fn raw(&self) -> &workflow::WorkflowExecutionInfo {
&self.raw
}
pub fn into_raw(self) -> workflow::WorkflowExecutionInfo {
self.raw
}
}
impl From<workflow::WorkflowExecutionInfo> for WorkflowExecution {
fn from(raw: workflow::WorkflowExecutionInfo) -> Self {
Self::new(raw)
}
}
pub struct ListWorkflowsStream {
inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
}
impl ListWorkflowsStream {
fn new(
inner: Pin<Box<dyn Stream<Item = Result<WorkflowExecution, ClientError>> + Send>>,
) -> Self {
Self { inner }
}
}
impl Stream for ListWorkflowsStream {
type Item = Result<WorkflowExecution, ClientError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[derive(Debug, Clone)]
pub struct WorkflowExecutionCount {
count: usize,
groups: Vec<WorkflowCountAggregationGroup>,
}
impl WorkflowExecutionCount {
pub(crate) fn from_response(resp: CountWorkflowExecutionsResponse) -> Self {
Self {
count: resp.count as usize,
groups: resp
.groups
.into_iter()
.map(WorkflowCountAggregationGroup::from_proto)
.collect(),
}
}
pub fn count(&self) -> usize {
self.count
}
pub fn groups(&self) -> &[WorkflowCountAggregationGroup] {
&self.groups
}
}
#[derive(Debug, Clone)]
pub struct WorkflowCountAggregationGroup {
group_values: Vec<Payload>,
count: usize,
}
impl WorkflowCountAggregationGroup {
fn from_proto(proto: count_workflow_executions_response::AggregationGroup) -> Self {
Self {
group_values: proto.group_values,
count: proto.count as usize,
}
}
pub fn group_values(&self) -> &[Payload] {
&self.group_values
}
pub fn count(&self) -> usize {
self.count
}
}
impl<T> WorkflowClientTrait for T
where
T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static,
{
async fn start_workflow<W>(
&self,
workflow: W,
input: W::Input,
options: WorkflowStartOptions,
) -> Result<WorkflowHandle<Self, W>, WorkflowStartError>
where
W: HasWorkflowDefinition,
W::Input: Send,
{
let payloads = self
.data_converter()
.to_payloads(&SerializationContextData::Workflow, &input)
.await?;
let namespace = self.namespace();
let workflow_id = options.workflow_id.clone();
let task_queue_name = options.task_queue.clone();
let user_metadata = if options.static_summary.is_some() || options.static_details.is_some()
{
let payload_converter = PayloadConverter::default();
let context = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &payload_converter,
};
Some(UserMetadata {
summary: options.static_summary.map(|s| {
payload_converter
.to_payload(&context, &s)
.expect("String-to-JSON payload serialization is infallible")
}),
details: options.static_details.map(|s| {
payload_converter
.to_payload(&context, &s)
.expect("String-to-JSON payload serialization is infallible")
}),
})
} else {
None
};
let run_id = if let Some(start_signal) = options.start_signal {
let res = WorkflowService::signal_with_start_workflow_execution(
&mut self.clone(),
SignalWithStartWorkflowExecutionRequest {
namespace: namespace.clone(),
workflow_id: workflow_id.clone(),
workflow_type: Some(WorkflowType {
name: workflow.name().to_string(),
}),
task_queue: Some(TaskQueue {
name: task_queue_name,
kind: TaskQueueKind::Normal as i32,
normal_name: "".to_string(),
}),
input: payloads.into_payloads(),
signal_name: start_signal.signal_name,
signal_input: start_signal.input,
identity: self.identity(),
request_id: Uuid::new_v4().to_string(),
workflow_id_reuse_policy: options.id_reuse_policy as i32,
workflow_id_conflict_policy: options.id_conflict_policy as i32,
workflow_execution_timeout: options
.execution_timeout
.and_then(|d| d.try_into().ok()),
workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
search_attributes: options.search_attributes.map(|d| d.into()),
cron_schedule: options.cron_schedule.unwrap_or_default(),
header: options.header.or(start_signal.header),
user_metadata,
..Default::default()
}
.into_request(),
)
.await?
.into_inner();
res.run_id
} else {
let res = self
.clone()
.start_workflow_execution(
StartWorkflowExecutionRequest {
namespace: namespace.clone(),
input: payloads.into_payloads(),
workflow_id: workflow_id.clone(),
workflow_type: Some(WorkflowType {
name: workflow.name().to_string(),
}),
task_queue: Some(TaskQueue {
name: task_queue_name,
kind: TaskQueueKind::Unspecified as i32,
normal_name: "".to_string(),
}),
request_id: Uuid::new_v4().to_string(),
workflow_id_reuse_policy: options.id_reuse_policy as i32,
workflow_id_conflict_policy: options.id_conflict_policy as i32,
workflow_execution_timeout: options
.execution_timeout
.and_then(|d| d.try_into().ok()),
workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()),
workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()),
search_attributes: options.search_attributes.map(|d| d.into()),
cron_schedule: options.cron_schedule.unwrap_or_default(),
request_eager_execution: options.enable_eager_workflow_start,
retry_policy: options.retry_policy,
links: options.links,
completion_callbacks: options.completion_callbacks,
priority: Some(options.priority.into()),
header: options.header,
user_metadata,
..Default::default()
}
.into_request(),
)
.await
.map_err(|status| {
if status.code() == Code::AlreadyExists {
let run_id =
decode_status_detail::<WorkflowExecutionAlreadyStartedFailure>(
status.details(),
)
.map(|f| f.run_id);
WorkflowStartError::AlreadyStarted {
run_id,
source: status,
}
} else {
WorkflowStartError::Rpc(status)
}
})?
.into_inner();
res.run_id
};
Ok(WorkflowHandle::new(
self.clone(),
WorkflowExecutionInfo {
namespace,
workflow_id,
run_id: Some(run_id.clone()),
first_execution_run_id: Some(run_id),
},
))
}
fn get_workflow_handle<W: HasWorkflowDefinition>(
&self,
workflow_id: impl Into<String>,
) -> WorkflowHandle<Self, W>
where
Self: Sized,
{
WorkflowHandle::new(
self.clone(),
WorkflowExecutionInfo {
namespace: self.namespace(),
workflow_id: workflow_id.into(),
run_id: None,
first_execution_run_id: None,
},
)
}
fn list_workflows(
&self,
query: impl Into<String>,
opts: WorkflowListOptions,
) -> ListWorkflowsStream {
let client = self.clone();
let namespace = self.namespace();
let query = query.into();
let limit = opts.limit;
let initial_state = (Vec::new(), VecDeque::new(), 0, false);
let stream = stream::unfold(
initial_state,
move |(next_page_token, mut buffer, mut yielded, exhausted)| {
let mut client = client.clone();
let namespace = namespace.clone();
let query = query.clone();
async move {
if let Some(l) = limit
&& yielded >= l
{
return None;
}
if let Some(exec) = buffer.pop_front() {
yielded += 1;
return Some((Ok(exec), (next_page_token, buffer, yielded, exhausted)));
}
if exhausted {
return None;
}
let response = WorkflowService::list_workflow_executions(
&mut client,
ListWorkflowExecutionsRequest {
namespace,
page_size: 0, next_page_token: next_page_token.clone(),
query,
}
.into_request(),
)
.await;
match response {
Ok(resp) => {
let resp = resp.into_inner();
let new_exhausted = resp.next_page_token.is_empty();
let new_token = resp.next_page_token;
buffer = resp
.executions
.into_iter()
.map(WorkflowExecution::from)
.collect();
if let Some(exec) = buffer.pop_front() {
yielded += 1;
Some((Ok(exec), (new_token, buffer, yielded, new_exhausted)))
} else {
None
}
}
Err(e) => Some((Err(e.into()), (next_page_token, buffer, yielded, true))),
}
}
},
);
ListWorkflowsStream::new(Box::pin(stream))
}
async fn count_workflows(
&self,
query: impl Into<String>,
_opts: WorkflowCountOptions,
) -> Result<WorkflowExecutionCount, ClientError> {
let resp = WorkflowService::count_workflow_executions(
&mut self.clone(),
CountWorkflowExecutionsRequest {
namespace: self.namespace(),
query: query.into(),
}
.into_request(),
)
.await?
.into_inner();
Ok(WorkflowExecutionCount::from_response(resp))
}
fn get_async_activity_handle(&self, identifier: ActivityIdentifier) -> AsyncActivityHandle<Self>
where
Self: Sized,
{
AsyncActivityHandle::new(self.clone(), identifier)
}
}
macro_rules! dbg_panic {
($($arg:tt)*) => {
use tracing::error;
error!($($arg)*);
debug_assert!(false, $($arg)*);
};
}
pub(crate) use dbg_panic;
#[cfg(test)]
mod tests {
use super::*;
use tonic::metadata::Ascii;
use url::Url;
#[test]
fn applies_headers() {
let headers = Arc::new(RwLock::new(ClientHeaders {
user_headers: HashMap::new(),
user_binary_headers: HashMap::new(),
api_key: Some("my-api-key".to_owned()),
}));
headers.clone().write().user_headers.insert(
"my-meta-key".parse().unwrap(),
"my-meta-val".parse().unwrap(),
);
headers.clone().write().user_binary_headers.insert(
"my-bin-meta-key-bin".parse().unwrap(),
vec![1, 2, 3].try_into().unwrap(),
);
let mut interceptor = ServiceCallInterceptor {
client_name: "cute-kitty".to_string(),
client_version: "0.1.0".to_string(),
headers: headers.clone(),
};
let req = interceptor.call(tonic::Request::new(())).unwrap();
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
assert_eq!(
req.metadata().get("authorization").unwrap(),
"Bearer my-api-key"
);
assert_eq!(
req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
vec![1, 2, 3].as_slice()
);
let mut req = tonic::Request::new(());
req.metadata_mut()
.insert("my-meta-key", "my-meta-val2".parse().unwrap());
req.metadata_mut()
.insert("authorization", "my-api-key2".parse().unwrap());
req.metadata_mut()
.insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
let req = interceptor.call(req).unwrap();
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
assert_eq!(
req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
vec![4, 5, 6].as_slice()
);
headers.clone().write().user_headers.insert(
"authorization".parse().unwrap(),
"my-api-key3".parse().unwrap(),
);
let req = interceptor.call(tonic::Request::new(())).unwrap();
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
headers.clone().write().user_headers.clear();
headers.clone().write().user_binary_headers.clear();
headers.clone().write().api_key.take();
let req = interceptor.call(tonic::Request::new(())).unwrap();
assert!(!req.metadata().contains_key("my-meta-key"));
assert!(!req.metadata().contains_key("authorization"));
assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));
let mut req = tonic::Request::new(());
req.metadata_mut()
.insert("grpc-timeout", "1S".parse().unwrap());
let req = interceptor.call(req).unwrap();
assert_eq!(
req.metadata().get("grpc-timeout").unwrap(),
"1S".parse::<MetadataValue<Ascii>>().unwrap()
);
}
#[test]
fn invalid_ascii_header_key() {
let invalid_headers = {
let mut h = HashMap::new();
h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
h
};
let result = parse_ascii_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
);
}
#[test]
fn invalid_ascii_header_value() {
let invalid_headers = {
let mut h = HashMap::new();
h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
h
};
let result = parse_ascii_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
);
}
#[test]
fn invalid_binary_header_key() {
let invalid_headers = {
let mut h = HashMap::new();
h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
h
};
let result = parse_binary_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
);
}
#[test]
fn keep_alive_defaults() {
let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
.identity("enchicat".to_string())
.client_name("cute-kitty".to_string())
.client_version("0.1.0".to_string())
.build();
assert_eq!(
opts.keep_alive.clone().unwrap().interval,
ClientKeepAliveOptions::default().interval
);
assert_eq!(
opts.keep_alive.clone().unwrap().timeout,
ClientKeepAliveOptions::default().timeout
);
let opts = ConnectionOptions::new(Url::parse("https://smolkitty").unwrap())
.identity("enchicat".to_string())
.client_name("cute-kitty".to_string())
.client_version("0.1.0".to_string())
.keep_alive(None)
.build();
dbg!(&opts.keep_alive);
assert!(opts.keep_alive.is_none());
}
mod list_workflows_tests {
use super::*;
use futures_util::{FutureExt, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use temporalio_common::protos::temporal::api::common::v1::WorkflowExecution as ProtoWorkflowExecution;
use tonic::{Request, Response};
#[derive(Clone)]
struct MockListWorkflowsClient {
call_count: Arc<AtomicUsize>,
page_size: usize,
total_workflows: usize,
}
impl NamespacedClient for MockListWorkflowsClient {
fn namespace(&self) -> String {
"test-namespace".to_string()
}
fn identity(&self) -> String {
"test-identity".to_string()
}
}
impl WorkflowService for MockListWorkflowsClient {
fn list_workflow_executions(
&mut self,
request: Request<ListWorkflowExecutionsRequest>,
) -> futures_util::future::BoxFuture<
'_,
Result<Response<ListWorkflowExecutionsResponse>, tonic::Status>,
> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let req = request.into_inner();
let offset: usize = if req.next_page_token.is_empty() {
0
} else {
String::from_utf8(req.next_page_token)
.unwrap()
.parse()
.unwrap()
};
let remaining = self.total_workflows.saturating_sub(offset);
let count = remaining.min(self.page_size);
let new_offset = offset + count;
let executions: Vec<_> = (offset..offset + count)
.map(|i| workflow::WorkflowExecutionInfo {
execution: Some(ProtoWorkflowExecution {
workflow_id: format!("wf-{i}"),
run_id: format!("run-{i}"),
}),
r#type: Some(WorkflowType {
name: "TestWorkflow".to_string(),
}),
task_queue: "test-queue".to_string(),
..Default::default()
})
.collect();
let next_page_token = if new_offset < self.total_workflows {
new_offset.to_string().into_bytes()
} else {
vec![]
};
async move {
Ok(Response::new(ListWorkflowExecutionsResponse {
executions,
next_page_token,
}))
}
.boxed()
}
}
#[tokio::test]
async fn list_workflows_paginates_through_all_results() {
let call_count = Arc::new(AtomicUsize::new(0));
let client = MockListWorkflowsClient {
call_count: call_count.clone(),
page_size: 3,
total_workflows: 10,
};
let stream = client.list_workflows("", WorkflowListOptions::default());
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 10);
for (i, result) in results.iter().enumerate() {
let wf = result.as_ref().unwrap();
assert_eq!(wf.id(), format!("wf-{i}"));
assert_eq!(wf.run_id(), format!("run-{i}"));
}
assert_eq!(call_count.load(Ordering::SeqCst), 4);
}
#[tokio::test]
async fn list_workflows_respects_limit() {
let call_count = Arc::new(AtomicUsize::new(0));
let client = MockListWorkflowsClient {
call_count: call_count.clone(),
page_size: 3,
total_workflows: 10,
};
let opts = WorkflowListOptions::builder().limit(5).build();
let stream = client.list_workflows("", opts);
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 5);
for (i, result) in results.iter().enumerate() {
let wf = result.as_ref().unwrap();
assert_eq!(wf.id(), format!("wf-{i}"));
}
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn list_workflows_limit_less_than_page_size() {
let call_count = Arc::new(AtomicUsize::new(0));
let client = MockListWorkflowsClient {
call_count: call_count.clone(),
page_size: 10,
total_workflows: 100,
};
let opts = WorkflowListOptions::builder().limit(3).build();
let stream = client.list_workflows("", opts);
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 3);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn list_workflows_empty_results() {
let call_count = Arc::new(AtomicUsize::new(0));
let client = MockListWorkflowsClient {
call_count: call_count.clone(),
page_size: 10,
total_workflows: 0,
};
let stream = client.list_workflows("", WorkflowListOptions::default());
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 0);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
}
}