1use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
2
3use crate::api::{DurableTaskError, OrchestrationState, PurgeInstanceFilter, Result};
4use crate::internal;
5use crate::proto;
6use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
7
8use super::options::ClientOptions;
9
10pub struct TaskHubGrpcClient {
12 inner: TaskHubSidecarServiceClient<Channel>,
13 options: ClientOptions,
14}
15
16async fn build_channel(host_address: &str, options: &ClientOptions) -> Result<Channel> {
21 const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
22
23 let mut builder = Channel::from_shared(host_address.to_string())
24 .map_err(|e| DurableTaskError::InvalidAddress(e.to_string()))?
25 .user_agent(USER_AGENT)
26 .map_err(|e| DurableTaskError::InvalidAddress(e.to_string()))?;
27
28 if let Some(tls) = &options.tls {
29 if tls.skip_verify {
30 return Err(DurableTaskError::Other(
31 "skip_verify is not supported; connect without TLS for development".into(),
32 ));
33 }
34
35 let mut tls_config = ClientTlsConfig::new();
36
37 if let Some(ca_pem) = &tls.ca_cert_pem {
38 tls_config = tls_config.ca_certificate(Certificate::from_pem(ca_pem));
39 }
40
41 match (&tls.client_cert_pem, &tls.client_key_pem) {
42 (Some(cert), Some(key)) => {
43 tls_config = tls_config.identity(Identity::from_pem(cert, key));
44 }
45 (None, None) => {}
46 _ => {
47 return Err(DurableTaskError::Other(
48 "client_cert_pem and client_key_pem must both be set for mutual TLS".into(),
49 ));
50 }
51 }
52
53 if let Some(domain) = &tls.domain_name {
54 tls_config = tls_config.domain_name(domain.clone());
55 }
56
57 builder = builder
58 .tls_config(tls_config)
59 .map_err(|e| DurableTaskError::ConnectionFailed(e.to_string()))?;
60 }
61
62 if let Some(timeout) = options.connect_timeout {
63 builder = builder.connect_timeout(timeout);
64 }
65
66 if let Some(interval) = options.keepalive_interval {
67 builder = builder.tcp_keepalive(Some(interval));
68 }
69
70 builder
71 .connect()
72 .await
73 .map_err(|e| DurableTaskError::ConnectionFailed(e.to_string()))
74}
75
76fn make_stub(channel: Channel, options: &ClientOptions) -> TaskHubSidecarServiceClient<Channel> {
78 let mut stub = TaskHubSidecarServiceClient::new(channel);
79 if let Some(size) = options.max_grpc_message_size {
80 stub = stub.max_decoding_message_size(size);
81 }
82 stub
83}
84
85impl TaskHubGrpcClient {
88 pub async fn new(host_address: &str) -> Result<Self> {
97 Self::with_options(host_address, ClientOptions::default()).await
98 }
99
100 pub async fn with_options(host_address: &str, options: ClientOptions) -> Result<Self> {
110 tracing::info!(address = %host_address, "Connecting to sidecar");
111 let channel = build_channel(host_address, &options).await?;
112 tracing::info!(address = %host_address, "Client connected");
113 let inner = make_stub(channel, &options);
114 Ok(Self { inner, options })
115 }
116
117 pub fn from_channel(channel: Channel) -> Self {
119 let options = ClientOptions::default();
120 let inner = make_stub(channel, &options);
121 Self { inner, options }
122 }
123
124 pub fn from_channel_with_options(channel: Channel, options: ClientOptions) -> Self {
126 let inner = make_stub(channel, &options);
127 Self { inner, options }
128 }
129
130 pub fn close(self) {
135 drop(self);
136 }
137
138 pub async fn schedule_new_orchestration(
146 &mut self,
147 orchestrator_name: &str,
148 input: Option<String>,
149 instance_id: Option<String>,
150 start_at: Option<chrono::DateTime<chrono::Utc>>,
151 ) -> Result<String> {
152 internal::validate_identifier(
153 orchestrator_name,
154 "orchestrator name",
155 self.options.max_identifier_length,
156 )?;
157 let instance_id = instance_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
158 internal::validate_identifier(
159 &instance_id,
160 "instance ID",
161 self.options.max_identifier_length,
162 )?;
163
164 tracing::info!(
165 instance_id = %instance_id,
166 orchestrator = %orchestrator_name,
167 "Scheduling new orchestration"
168 );
169
170 #[cfg(feature = "opentelemetry")]
171 let (parent_trace_context, _otel_ctx) = {
172 let parent_ctx = opentelemetry::Context::current();
173 let ctx = internal::otel::start_create_orchestration_span(
174 &parent_ctx,
175 orchestrator_name,
176 &instance_id,
177 );
178 let sc = opentelemetry::trace::TraceContextExt::span(&ctx)
179 .span_context()
180 .clone();
181 let tc = internal::otel::trace_context_from_span_context(&sc);
182 (tc, ctx)
183 };
184 #[cfg(not(feature = "opentelemetry"))]
185 let parent_trace_context: Option<proto::TraceContext> = None;
186
187 let request = proto::CreateInstanceRequest {
188 instance_id: instance_id.clone(),
189 name: orchestrator_name.to_string(),
190 input,
191 scheduled_start_timestamp: start_at.map(internal::to_timestamp),
192 version: None,
193 execution_id: None,
194 tags: std::collections::HashMap::new(),
195 parent_trace_context,
196 };
197
198 let response = self.inner.start_instance(request).await?;
199 let result_id = response.into_inner().instance_id;
200
201 #[cfg(feature = "opentelemetry")]
202 internal::otel::end_span(&_otel_ctx);
203
204 tracing::debug!(instance_id = %result_id, "Orchestration scheduled");
205 Ok(result_id)
206 }
207
208 pub async fn get_orchestration_state(
215 &mut self,
216 instance_id: &str,
217 fetch_payloads: bool,
218 ) -> Result<Option<OrchestrationState>> {
219 internal::validate_identifier(
220 instance_id,
221 "instance ID",
222 self.options.max_identifier_length,
223 )?;
224 let request = proto::GetInstanceRequest {
225 instance_id: instance_id.to_string(),
226 get_inputs_and_outputs: fetch_payloads,
227 };
228 let response = self.inner.get_instance(request).await?;
229 Ok(OrchestrationState::try_from(&response.into_inner()).ok())
230 }
231
232 pub async fn wait_for_orchestration_start(
239 &mut self,
240 instance_id: &str,
241 fetch_payloads: bool,
242 timeout: Option<std::time::Duration>,
243 ) -> Result<Option<OrchestrationState>> {
244 internal::validate_identifier(
245 instance_id,
246 "instance ID",
247 self.options.max_identifier_length,
248 )?;
249 tracing::debug!(instance_id = %instance_id, "Waiting for orchestration to start");
250
251 let request = proto::GetInstanceRequest {
252 instance_id: instance_id.to_string(),
253 get_inputs_and_outputs: fetch_payloads,
254 };
255
256 let fut = self.inner.wait_for_instance_start(request);
257
258 let response = if let Some(timeout_dur) = timeout {
259 tokio::time::timeout(timeout_dur, fut)
260 .await
261 .map_err(|_| DurableTaskError::Timeout)??
262 } else {
263 fut.await?
264 };
265
266 let state = OrchestrationState::try_from(&response.into_inner()).ok();
267 tracing::debug!(
268 instance_id = %instance_id,
269 status = ?state.as_ref().map(|s| &s.runtime_status),
270 "Orchestration started"
271 );
272 Ok(state)
273 }
274
275 pub async fn wait_for_orchestration_completion(
282 &mut self,
283 instance_id: &str,
284 fetch_payloads: bool,
285 timeout: Option<std::time::Duration>,
286 ) -> Result<Option<OrchestrationState>> {
287 internal::validate_identifier(
288 instance_id,
289 "instance ID",
290 self.options.max_identifier_length,
291 )?;
292 tracing::debug!(instance_id = %instance_id, "Waiting for orchestration completion");
293
294 let request = proto::GetInstanceRequest {
295 instance_id: instance_id.to_string(),
296 get_inputs_and_outputs: fetch_payloads,
297 };
298
299 let fut = self.inner.wait_for_instance_completion(request);
300
301 let response = if let Some(timeout_dur) = timeout {
302 tokio::time::timeout(timeout_dur, fut)
303 .await
304 .map_err(|_| DurableTaskError::Timeout)??
305 } else {
306 fut.await?
307 };
308
309 let state = OrchestrationState::try_from(&response.into_inner()).ok();
310 tracing::debug!(
311 instance_id = %instance_id,
312 status = ?state.as_ref().map(|s| &s.runtime_status),
313 "Orchestration completed"
314 );
315 Ok(state)
316 }
317
318 pub async fn raise_orchestration_event(
324 &mut self,
325 instance_id: &str,
326 event_name: &str,
327 data: Option<String>,
328 ) -> Result<()> {
329 internal::validate_identifier(
330 instance_id,
331 "instance ID",
332 self.options.max_identifier_length,
333 )?;
334 internal::validate_identifier(
335 event_name,
336 "event name",
337 self.options.max_identifier_length,
338 )?;
339 tracing::info!(
340 instance_id = %instance_id,
341 event_name = %event_name,
342 "Raising orchestration event"
343 );
344 let request = proto::RaiseEventRequest {
345 instance_id: instance_id.to_string(),
346 name: event_name.to_string(),
347 input: data,
348 };
349 self.inner.raise_event(request).await?;
350 Ok(())
351 }
352
353 pub async fn terminate_orchestration(
359 &mut self,
360 instance_id: &str,
361 output: Option<String>,
362 recursive: bool,
363 ) -> Result<()> {
364 internal::validate_identifier(
365 instance_id,
366 "instance ID",
367 self.options.max_identifier_length,
368 )?;
369 tracing::info!(
370 instance_id = %instance_id,
371 recursive = recursive,
372 "Terminating orchestration"
373 );
374 let request = proto::TerminateRequest {
375 instance_id: instance_id.to_string(),
376 output,
377 recursive,
378 };
379 self.inner.terminate_instance(request).await?;
380 Ok(())
381 }
382
383 pub async fn suspend_orchestration(
389 &mut self,
390 instance_id: &str,
391 reason: Option<String>,
392 ) -> Result<()> {
393 internal::validate_identifier(
394 instance_id,
395 "instance ID",
396 self.options.max_identifier_length,
397 )?;
398 tracing::info!(instance_id = %instance_id, "Suspending orchestration");
399 let request = proto::SuspendRequest {
400 instance_id: instance_id.to_string(),
401 reason,
402 };
403 self.inner.suspend_instance(request).await?;
404 Ok(())
405 }
406
407 pub async fn resume_orchestration(
413 &mut self,
414 instance_id: &str,
415 reason: Option<String>,
416 ) -> Result<()> {
417 internal::validate_identifier(
418 instance_id,
419 "instance ID",
420 self.options.max_identifier_length,
421 )?;
422 tracing::info!(instance_id = %instance_id, "Resuming orchestration");
423 let request = proto::ResumeRequest {
424 instance_id: instance_id.to_string(),
425 reason,
426 };
427 self.inner.resume_instance(request).await?;
428 Ok(())
429 }
430
431 pub async fn purge_orchestration(&mut self, instance_id: &str, recursive: bool) -> Result<i32> {
439 internal::validate_identifier(
440 instance_id,
441 "instance ID",
442 self.options.max_identifier_length,
443 )?;
444 tracing::info!(instance_id = %instance_id, "Purging orchestration");
445 let request = proto::PurgeInstancesRequest {
446 request: Some(proto::purge_instances_request::Request::InstanceId(
447 instance_id.to_string(),
448 )),
449 recursive,
450 force: None,
451 };
452 let response = self.inner.purge_instances(request).await?;
453 let count = response.into_inner().deleted_instance_count;
454 tracing::debug!(instance_id = %instance_id, deleted = count, "Purge complete");
455 Ok(count)
456 }
457
458 pub async fn purge_orchestrations_by_filter(
480 &mut self,
481 filter: PurgeInstanceFilter,
482 recursive: bool,
483 ) -> Result<i32> {
484 tracing::info!(?filter, "Purging orchestrations by filter");
485 let request = proto::PurgeInstancesRequest {
486 request: Some(
487 proto::purge_instances_request::Request::PurgeInstanceFilter(filter.into_proto()),
488 ),
489 recursive,
490 force: None,
491 };
492 let response = self.inner.purge_instances(request).await?;
493 let count = response.into_inner().deleted_instance_count;
494 tracing::debug!(deleted = count, "Purge by filter complete");
495 Ok(count)
496 }
497}