Skip to main content

dapr_durabletask/client/
grpc_client.rs

1use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
2
3use crate::api::{DurableTaskError, OrchestrationState, PurgeInstanceFilter, Result};
4use crate::internal;
5use crate::proto;
6use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
7
8use super::options::ClientOptions;
9
10/// Client for managing orchestrations via a gRPC connection to a sidecar.
11pub struct TaskHubGrpcClient {
12    inner: TaskHubSidecarServiceClient<Channel>,
13    options: ClientOptions,
14}
15
16// ─── Channel construction ────────────────────────────────────────────────────
17
18/// Build a tonic [`Channel`] from the host address and client options,
19/// applying TLS, keepalive, connect timeout, and message-size limits.
20async fn build_channel(host_address: &str, options: &ClientOptions) -> Result<Channel> {
21    const USER_AGENT: &str = concat!("dapr-durabletask/rust/", env!("CARGO_PKG_VERSION"));
22
23    let mut builder = Channel::from_shared(host_address.to_string())
24        .map_err(|e| DurableTaskError::InvalidAddress(e.to_string()))?
25        .user_agent(USER_AGENT)
26        .map_err(|e| DurableTaskError::InvalidAddress(e.to_string()))?;
27
28    if let Some(tls) = &options.tls {
29        if tls.skip_verify {
30            return Err(DurableTaskError::Other(
31                "skip_verify is not supported; connect without TLS for development".into(),
32            ));
33        }
34
35        let mut tls_config = ClientTlsConfig::new();
36
37        if let Some(ca_pem) = &tls.ca_cert_pem {
38            tls_config = tls_config.ca_certificate(Certificate::from_pem(ca_pem));
39        }
40
41        match (&tls.client_cert_pem, &tls.client_key_pem) {
42            (Some(cert), Some(key)) => {
43                tls_config = tls_config.identity(Identity::from_pem(cert, key));
44            }
45            (None, None) => {}
46            _ => {
47                return Err(DurableTaskError::Other(
48                    "client_cert_pem and client_key_pem must both be set for mutual TLS".into(),
49                ));
50            }
51        }
52
53        if let Some(domain) = &tls.domain_name {
54            tls_config = tls_config.domain_name(domain.clone());
55        }
56
57        builder = builder
58            .tls_config(tls_config)
59            .map_err(|e| DurableTaskError::ConnectionFailed(e.to_string()))?;
60    }
61
62    if let Some(timeout) = options.connect_timeout {
63        builder = builder.connect_timeout(timeout);
64    }
65
66    if let Some(interval) = options.keepalive_interval {
67        builder = builder.tcp_keepalive(Some(interval));
68    }
69
70    builder
71        .connect()
72        .await
73        .map_err(|e| DurableTaskError::ConnectionFailed(e.to_string()))
74}
75
76/// Wrap a channel in the gRPC client stub, applying the max-message-size limit.
77fn make_stub(channel: Channel, options: &ClientOptions) -> TaskHubSidecarServiceClient<Channel> {
78    let mut stub = TaskHubSidecarServiceClient::new(channel);
79    if let Some(size) = options.max_grpc_message_size {
80        stub = stub.max_decoding_message_size(size);
81    }
82    stub
83}
84
85// ─── TaskHubGrpcClient ───────────────────────────────────────────────────────
86
87impl TaskHubGrpcClient {
88    /// Create a new client connected to the given host address.
89    ///
90    /// The default address is `http://localhost:4001`.
91    ///
92    /// # Errors
93    /// Returns [`DurableTaskError::InvalidAddress`] if `host_address` is not a
94    /// valid URI, or [`DurableTaskError::ConnectionFailed`] / [`DurableTaskError::GrpcError`]
95    /// if the underlying transport cannot be established.
96    pub async fn new(host_address: &str) -> Result<Self> {
97        Self::with_options(host_address, ClientOptions::default()).await
98    }
99
100    /// Create a new client connected to the given host address with custom options.
101    ///
102    /// # Errors
103    /// Returns [`DurableTaskError::InvalidAddress`] if `host_address` is not a
104    /// valid URI, [`DurableTaskError::Other`] if TLS options are inconsistent
105    /// (e.g. only one of `client_cert_pem` / `client_key_pem` is set, or
106    /// `skip_verify` is requested), or [`DurableTaskError::ConnectionFailed`] /
107    /// [`DurableTaskError::GrpcError`] if the underlying transport cannot be
108    /// established.
109    pub async fn with_options(host_address: &str, options: ClientOptions) -> Result<Self> {
110        tracing::info!(address = %host_address, "Connecting to sidecar");
111        let channel = build_channel(host_address, &options).await?;
112        tracing::info!(address = %host_address, "Client connected");
113        let inner = make_stub(channel, &options);
114        Ok(Self { inner, options })
115    }
116
117    /// Create a client from an existing tonic [`Channel`].
118    pub fn from_channel(channel: Channel) -> Self {
119        let options = ClientOptions::default();
120        let inner = make_stub(channel, &options);
121        Self { inner, options }
122    }
123
124    /// Create a client from an existing tonic [`Channel`] with custom options.
125    pub fn from_channel_with_options(channel: Channel, options: ClientOptions) -> Self {
126        let inner = make_stub(channel, &options);
127        Self { inner, options }
128    }
129
130    /// Close the client, releasing the underlying gRPC channel.
131    ///
132    /// The channel is also released when the client is dropped. This method
133    /// provides an explicit, named alternative.
134    pub fn close(self) {
135        drop(self);
136    }
137
138    /// Schedule a new orchestration instance and return its instance ID.
139    ///
140    /// # Errors
141    /// Returns [`DurableTaskError::Other`] if `orchestrator_name` or
142    /// `instance_id` is empty, exceeds the configured identifier length, or
143    /// contains control characters. Returns [`DurableTaskError::GrpcError`] if
144    /// the sidecar RPC fails.
145    pub async fn schedule_new_orchestration(
146        &mut self,
147        orchestrator_name: &str,
148        input: Option<String>,
149        instance_id: Option<String>,
150        start_at: Option<chrono::DateTime<chrono::Utc>>,
151    ) -> Result<String> {
152        internal::validate_identifier(
153            orchestrator_name,
154            "orchestrator name",
155            self.options.max_identifier_length,
156        )?;
157        let instance_id = instance_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
158        internal::validate_identifier(
159            &instance_id,
160            "instance ID",
161            self.options.max_identifier_length,
162        )?;
163
164        tracing::info!(
165            instance_id = %instance_id,
166            orchestrator = %orchestrator_name,
167            "Scheduling new orchestration"
168        );
169
170        #[cfg(feature = "opentelemetry")]
171        let (parent_trace_context, _otel_ctx) = {
172            let parent_ctx = opentelemetry::Context::current();
173            let ctx = internal::otel::start_create_orchestration_span(
174                &parent_ctx,
175                orchestrator_name,
176                &instance_id,
177            );
178            let sc = opentelemetry::trace::TraceContextExt::span(&ctx)
179                .span_context()
180                .clone();
181            let tc = internal::otel::trace_context_from_span_context(&sc);
182            (tc, ctx)
183        };
184        #[cfg(not(feature = "opentelemetry"))]
185        let parent_trace_context: Option<proto::TraceContext> = None;
186
187        let request = proto::CreateInstanceRequest {
188            instance_id: instance_id.clone(),
189            name: orchestrator_name.to_string(),
190            input,
191            scheduled_start_timestamp: start_at.map(internal::to_timestamp),
192            version: None,
193            execution_id: None,
194            tags: std::collections::HashMap::new(),
195            parent_trace_context,
196        };
197
198        let response = self.inner.start_instance(request).await?;
199        let result_id = response.into_inner().instance_id;
200
201        #[cfg(feature = "opentelemetry")]
202        internal::otel::end_span(&_otel_ctx);
203
204        tracing::debug!(instance_id = %result_id, "Orchestration scheduled");
205        Ok(result_id)
206    }
207
208    /// Get the current state of an orchestration.
209    ///
210    /// # Errors
211    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid, or
212    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails. The successful
213    /// result is `Ok(None)` if the instance does not exist.
214    pub async fn get_orchestration_state(
215        &mut self,
216        instance_id: &str,
217        fetch_payloads: bool,
218    ) -> Result<Option<OrchestrationState>> {
219        internal::validate_identifier(
220            instance_id,
221            "instance ID",
222            self.options.max_identifier_length,
223        )?;
224        let request = proto::GetInstanceRequest {
225            instance_id: instance_id.to_string(),
226            get_inputs_and_outputs: fetch_payloads,
227        };
228        let response = self.inner.get_instance(request).await?;
229        Ok(OrchestrationState::try_from(&response.into_inner()).ok())
230    }
231
232    /// Wait for an orchestration to start running.
233    ///
234    /// # Errors
235    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid,
236    /// [`DurableTaskError::Timeout`] if `timeout` elapses before the instance
237    /// starts, or [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
238    pub async fn wait_for_orchestration_start(
239        &mut self,
240        instance_id: &str,
241        fetch_payloads: bool,
242        timeout: Option<std::time::Duration>,
243    ) -> Result<Option<OrchestrationState>> {
244        internal::validate_identifier(
245            instance_id,
246            "instance ID",
247            self.options.max_identifier_length,
248        )?;
249        tracing::debug!(instance_id = %instance_id, "Waiting for orchestration to start");
250
251        let request = proto::GetInstanceRequest {
252            instance_id: instance_id.to_string(),
253            get_inputs_and_outputs: fetch_payloads,
254        };
255
256        let fut = self.inner.wait_for_instance_start(request);
257
258        let response = if let Some(timeout_dur) = timeout {
259            tokio::time::timeout(timeout_dur, fut)
260                .await
261                .map_err(|_| DurableTaskError::Timeout)??
262        } else {
263            fut.await?
264        };
265
266        let state = OrchestrationState::try_from(&response.into_inner()).ok();
267        tracing::debug!(
268            instance_id = %instance_id,
269            status = ?state.as_ref().map(|s| &s.runtime_status),
270            "Orchestration started"
271        );
272        Ok(state)
273    }
274
275    /// Wait for an orchestration to reach a terminal state.
276    ///
277    /// # Errors
278    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid,
279    /// [`DurableTaskError::Timeout`] if `timeout` elapses before completion, or
280    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
281    pub async fn wait_for_orchestration_completion(
282        &mut self,
283        instance_id: &str,
284        fetch_payloads: bool,
285        timeout: Option<std::time::Duration>,
286    ) -> Result<Option<OrchestrationState>> {
287        internal::validate_identifier(
288            instance_id,
289            "instance ID",
290            self.options.max_identifier_length,
291        )?;
292        tracing::debug!(instance_id = %instance_id, "Waiting for orchestration completion");
293
294        let request = proto::GetInstanceRequest {
295            instance_id: instance_id.to_string(),
296            get_inputs_and_outputs: fetch_payloads,
297        };
298
299        let fut = self.inner.wait_for_instance_completion(request);
300
301        let response = if let Some(timeout_dur) = timeout {
302            tokio::time::timeout(timeout_dur, fut)
303                .await
304                .map_err(|_| DurableTaskError::Timeout)??
305        } else {
306            fut.await?
307        };
308
309        let state = OrchestrationState::try_from(&response.into_inner()).ok();
310        tracing::debug!(
311            instance_id = %instance_id,
312            status = ?state.as_ref().map(|s| &s.runtime_status),
313            "Orchestration completed"
314        );
315        Ok(state)
316    }
317
318    /// Raise an event to an orchestration instance.
319    ///
320    /// # Errors
321    /// Returns [`DurableTaskError::Other`] if `instance_id` or `event_name` is
322    /// invalid, or [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
323    pub async fn raise_orchestration_event(
324        &mut self,
325        instance_id: &str,
326        event_name: &str,
327        data: Option<String>,
328    ) -> Result<()> {
329        internal::validate_identifier(
330            instance_id,
331            "instance ID",
332            self.options.max_identifier_length,
333        )?;
334        internal::validate_identifier(
335            event_name,
336            "event name",
337            self.options.max_identifier_length,
338        )?;
339        tracing::info!(
340            instance_id = %instance_id,
341            event_name = %event_name,
342            "Raising orchestration event"
343        );
344        let request = proto::RaiseEventRequest {
345            instance_id: instance_id.to_string(),
346            name: event_name.to_string(),
347            input: data,
348        };
349        self.inner.raise_event(request).await?;
350        Ok(())
351    }
352
353    /// Terminate a running orchestration.
354    ///
355    /// # Errors
356    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid, or
357    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
358    pub async fn terminate_orchestration(
359        &mut self,
360        instance_id: &str,
361        output: Option<String>,
362        recursive: bool,
363    ) -> Result<()> {
364        internal::validate_identifier(
365            instance_id,
366            "instance ID",
367            self.options.max_identifier_length,
368        )?;
369        tracing::info!(
370            instance_id = %instance_id,
371            recursive = recursive,
372            "Terminating orchestration"
373        );
374        let request = proto::TerminateRequest {
375            instance_id: instance_id.to_string(),
376            output,
377            recursive,
378        };
379        self.inner.terminate_instance(request).await?;
380        Ok(())
381    }
382
383    /// Suspend a running orchestration.
384    ///
385    /// # Errors
386    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid, or
387    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
388    pub async fn suspend_orchestration(
389        &mut self,
390        instance_id: &str,
391        reason: Option<String>,
392    ) -> Result<()> {
393        internal::validate_identifier(
394            instance_id,
395            "instance ID",
396            self.options.max_identifier_length,
397        )?;
398        tracing::info!(instance_id = %instance_id, "Suspending orchestration");
399        let request = proto::SuspendRequest {
400            instance_id: instance_id.to_string(),
401            reason,
402        };
403        self.inner.suspend_instance(request).await?;
404        Ok(())
405    }
406
407    /// Resume a suspended orchestration.
408    ///
409    /// # Errors
410    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid, or
411    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
412    pub async fn resume_orchestration(
413        &mut self,
414        instance_id: &str,
415        reason: Option<String>,
416    ) -> Result<()> {
417        internal::validate_identifier(
418            instance_id,
419            "instance ID",
420            self.options.max_identifier_length,
421        )?;
422        tracing::info!(instance_id = %instance_id, "Resuming orchestration");
423        let request = proto::ResumeRequest {
424            instance_id: instance_id.to_string(),
425            reason,
426        };
427        self.inner.resume_instance(request).await?;
428        Ok(())
429    }
430
431    /// Purge an orchestration's history and state by instance ID.
432    ///
433    /// Returns the number of deleted instances.
434    ///
435    /// # Errors
436    /// Returns [`DurableTaskError::Other`] if `instance_id` is invalid, or
437    /// [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
438    pub async fn purge_orchestration(&mut self, instance_id: &str, recursive: bool) -> Result<i32> {
439        internal::validate_identifier(
440            instance_id,
441            "instance ID",
442            self.options.max_identifier_length,
443        )?;
444        tracing::info!(instance_id = %instance_id, "Purging orchestration");
445        let request = proto::PurgeInstancesRequest {
446            request: Some(proto::purge_instances_request::Request::InstanceId(
447                instance_id.to_string(),
448            )),
449            recursive,
450            force: None,
451        };
452        let response = self.inner.purge_instances(request).await?;
453        let count = response.into_inner().deleted_instance_count;
454        tracing::debug!(instance_id = %instance_id, deleted = count, "Purge complete");
455        Ok(count)
456    }
457
458    /// Purge orchestrations matching the given filter criteria.
459    ///
460    /// Returns the number of deleted instances.
461    ///
462    /// # Examples
463    ///
464    /// ```rust,no_run
465    /// use dapr_durabletask::api::{OrchestrationStatus, PurgeInstanceFilter};
466    ///
467    /// # async fn example(mut client: dapr_durabletask::client::TaskHubGrpcClient) {
468    /// let filter = PurgeInstanceFilter::new()
469    ///     .with_created_time_from(chrono::Utc::now() - chrono::Duration::hours(24))
470    ///     .with_runtime_status([OrchestrationStatus::Completed, OrchestrationStatus::Failed]);
471    ///
472    /// let deleted = client.purge_orchestrations_by_filter(filter, false).await.unwrap();
473    /// println!("Deleted {deleted} orchestrations");
474    /// # }
475    /// ```
476    ///
477    /// # Errors
478    /// Returns [`DurableTaskError::GrpcError`] if the sidecar RPC fails.
479    pub async fn purge_orchestrations_by_filter(
480        &mut self,
481        filter: PurgeInstanceFilter,
482        recursive: bool,
483    ) -> Result<i32> {
484        tracing::info!(?filter, "Purging orchestrations by filter");
485        let request = proto::PurgeInstancesRequest {
486            request: Some(
487                proto::purge_instances_request::Request::PurgeInstanceFilter(filter.into_proto()),
488            ),
489            recursive,
490            force: None,
491        };
492        let response = self.inner.purge_instances(request).await?;
493        let count = response.into_inner().deleted_instance_count;
494        tracing::debug!(deleted = count, "Purge by filter complete");
495        Ok(count)
496    }
497}