Skip to main content

dapr_durabletask/client/
grpc_client.rs

1use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
2
3use crate::api::{DurableTaskError, OrchestrationState, PurgeInstanceFilter, Result};
4use crate::internal;
5use crate::proto;
6use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
7
8use super::options::ClientOptions;
9
10/// Client for managing orchestrations via a gRPC connection to a sidecar.
11pub struct TaskHubGrpcClient {
12    inner: TaskHubSidecarServiceClient<Channel>,
13    options: ClientOptions,
14}
15
16// ─── Channel construction ────────────────────────────────────────────────────
17
18/// Build a tonic [`Channel`] from the host address and client options,
19/// applying TLS, keepalive, connect timeout, and message-size limits.
20async fn build_channel(host_address: &str, options: &ClientOptions) -> Result<Channel> {
21    let mut builder = Channel::from_shared(host_address.to_string())
22        .map_err(|e| DurableTaskError::Other(e.to_string()))?;
23
24    if let Some(tls) = &options.tls {
25        if tls.skip_verify {
26            return Err(DurableTaskError::Other(
27                "skip_verify is not supported; connect without TLS for development".into(),
28            ));
29        }
30
31        let mut tls_config = ClientTlsConfig::new();
32
33        if let Some(ca_pem) = &tls.ca_cert_pem {
34            tls_config = tls_config.ca_certificate(Certificate::from_pem(ca_pem));
35        }
36
37        match (&tls.client_cert_pem, &tls.client_key_pem) {
38            (Some(cert), Some(key)) => {
39                tls_config = tls_config.identity(Identity::from_pem(cert, key));
40            }
41            (None, None) => {}
42            _ => {
43                return Err(DurableTaskError::Other(
44                    "client_cert_pem and client_key_pem must both be set for mutual TLS".into(),
45                ));
46            }
47        }
48
49        if let Some(domain) = &tls.domain_name {
50            tls_config = tls_config.domain_name(domain.clone());
51        }
52
53        builder = builder
54            .tls_config(tls_config)
55            .map_err(|e| DurableTaskError::Other(e.to_string()))?;
56    }
57
58    if let Some(timeout) = options.connect_timeout {
59        builder = builder.connect_timeout(timeout);
60    }
61
62    if let Some(interval) = options.keepalive_interval {
63        builder = builder.tcp_keepalive(Some(interval));
64    }
65
66    builder
67        .connect()
68        .await
69        .map_err(|e| DurableTaskError::Other(e.to_string()))
70}
71
72/// Wrap a channel in the gRPC client stub, applying the max-message-size limit.
73fn make_stub(channel: Channel, options: &ClientOptions) -> TaskHubSidecarServiceClient<Channel> {
74    let mut stub = TaskHubSidecarServiceClient::new(channel);
75    if let Some(size) = options.max_grpc_message_size {
76        stub = stub.max_decoding_message_size(size);
77    }
78    stub
79}
80
81// ─── TaskHubGrpcClient ───────────────────────────────────────────────────────
82
83impl TaskHubGrpcClient {
84    /// Create a new client connected to the given host address.
85    ///
86    /// The default address is `http://localhost:4001`.
87    pub async fn new(host_address: &str) -> Result<Self> {
88        Self::with_options(host_address, ClientOptions::default()).await
89    }
90
91    /// Create a new client connected to the given host address with custom options.
92    pub async fn with_options(host_address: &str, options: ClientOptions) -> Result<Self> {
93        tracing::info!(address = %host_address, "Connecting to sidecar");
94        let channel = build_channel(host_address, &options).await?;
95        tracing::info!(address = %host_address, "Client connected");
96        let inner = make_stub(channel, &options);
97        Ok(Self { inner, options })
98    }
99
100    /// Create a client from an existing tonic [`Channel`].
101    pub fn from_channel(channel: Channel) -> Self {
102        let options = ClientOptions::default();
103        let inner = make_stub(channel, &options);
104        Self { inner, options }
105    }
106
107    /// Create a client from an existing tonic [`Channel`] with custom options.
108    pub fn from_channel_with_options(channel: Channel, options: ClientOptions) -> Self {
109        let inner = make_stub(channel, &options);
110        Self { inner, options }
111    }
112
113    /// Close the client, releasing the underlying gRPC channel.
114    ///
115    /// The channel is also released when the client is dropped. This method
116    /// provides an explicit, named alternative.
117    pub fn close(self) {
118        drop(self);
119    }
120
121    /// Schedule a new orchestration instance and return its instance ID.
122    pub async fn schedule_new_orchestration(
123        &mut self,
124        orchestrator_name: &str,
125        input: Option<String>,
126        instance_id: Option<String>,
127        start_at: Option<chrono::DateTime<chrono::Utc>>,
128    ) -> Result<String> {
129        internal::validate_identifier(
130            orchestrator_name,
131            "orchestrator name",
132            self.options.max_identifier_length,
133        )?;
134        let instance_id = instance_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
135        internal::validate_identifier(
136            &instance_id,
137            "instance ID",
138            self.options.max_identifier_length,
139        )?;
140
141        tracing::info!(
142            instance_id = %instance_id,
143            orchestrator = %orchestrator_name,
144            "Scheduling new orchestration"
145        );
146
147        #[cfg(feature = "opentelemetry")]
148        let (parent_trace_context, _otel_ctx) = {
149            let parent_ctx = opentelemetry::Context::current();
150            let ctx = internal::otel::start_create_orchestration_span(
151                &parent_ctx,
152                orchestrator_name,
153                &instance_id,
154            );
155            let sc = opentelemetry::trace::TraceContextExt::span(&ctx)
156                .span_context()
157                .clone();
158            let tc = internal::otel::trace_context_from_span_context(&sc);
159            (tc, ctx)
160        };
161        #[cfg(not(feature = "opentelemetry"))]
162        let parent_trace_context: Option<proto::TraceContext> = None;
163
164        let request = proto::CreateInstanceRequest {
165            instance_id: instance_id.clone(),
166            name: orchestrator_name.to_string(),
167            input,
168            scheduled_start_timestamp: start_at.map(internal::to_timestamp),
169            version: None,
170            execution_id: None,
171            tags: std::collections::HashMap::new(),
172            parent_trace_context,
173        };
174
175        let response = self.inner.start_instance(request).await?;
176        let result_id = response.into_inner().instance_id;
177
178        #[cfg(feature = "opentelemetry")]
179        internal::otel::end_span(&_otel_ctx);
180
181        tracing::debug!(instance_id = %result_id, "Orchestration scheduled");
182        Ok(result_id)
183    }
184
185    /// Get the current state of an orchestration.
186    pub async fn get_orchestration_state(
187        &mut self,
188        instance_id: &str,
189        fetch_payloads: bool,
190    ) -> Result<Option<OrchestrationState>> {
191        internal::validate_identifier(
192            instance_id,
193            "instance ID",
194            self.options.max_identifier_length,
195        )?;
196        let request = proto::GetInstanceRequest {
197            instance_id: instance_id.to_string(),
198            get_inputs_and_outputs: fetch_payloads,
199        };
200        let response = self.inner.get_instance(request).await?;
201        Ok(OrchestrationState::try_from(&response.into_inner()).ok())
202    }
203
204    /// Wait for an orchestration to start running.
205    pub async fn wait_for_orchestration_start(
206        &mut self,
207        instance_id: &str,
208        fetch_payloads: bool,
209        timeout: Option<std::time::Duration>,
210    ) -> Result<Option<OrchestrationState>> {
211        internal::validate_identifier(
212            instance_id,
213            "instance ID",
214            self.options.max_identifier_length,
215        )?;
216        tracing::debug!(instance_id = %instance_id, "Waiting for orchestration to start");
217
218        let request = proto::GetInstanceRequest {
219            instance_id: instance_id.to_string(),
220            get_inputs_and_outputs: fetch_payloads,
221        };
222
223        let fut = self.inner.wait_for_instance_start(request);
224
225        let response = if let Some(timeout_dur) = timeout {
226            tokio::time::timeout(timeout_dur, fut)
227                .await
228                .map_err(|_| DurableTaskError::Timeout)??
229        } else {
230            fut.await?
231        };
232
233        let state = OrchestrationState::try_from(&response.into_inner()).ok();
234        tracing::debug!(
235            instance_id = %instance_id,
236            status = ?state.as_ref().map(|s| &s.runtime_status),
237            "Orchestration started"
238        );
239        Ok(state)
240    }
241
242    /// Wait for an orchestration to reach a terminal state.
243    pub async fn wait_for_orchestration_completion(
244        &mut self,
245        instance_id: &str,
246        fetch_payloads: bool,
247        timeout: Option<std::time::Duration>,
248    ) -> Result<Option<OrchestrationState>> {
249        internal::validate_identifier(
250            instance_id,
251            "instance ID",
252            self.options.max_identifier_length,
253        )?;
254        tracing::debug!(instance_id = %instance_id, "Waiting for orchestration completion");
255
256        let request = proto::GetInstanceRequest {
257            instance_id: instance_id.to_string(),
258            get_inputs_and_outputs: fetch_payloads,
259        };
260
261        let fut = self.inner.wait_for_instance_completion(request);
262
263        let response = if let Some(timeout_dur) = timeout {
264            tokio::time::timeout(timeout_dur, fut)
265                .await
266                .map_err(|_| DurableTaskError::Timeout)??
267        } else {
268            fut.await?
269        };
270
271        let state = OrchestrationState::try_from(&response.into_inner()).ok();
272        tracing::debug!(
273            instance_id = %instance_id,
274            status = ?state.as_ref().map(|s| &s.runtime_status),
275            "Orchestration completed"
276        );
277        Ok(state)
278    }
279
280    /// Raise an event to an orchestration instance.
281    pub async fn raise_orchestration_event(
282        &mut self,
283        instance_id: &str,
284        event_name: &str,
285        data: Option<String>,
286    ) -> Result<()> {
287        internal::validate_identifier(
288            instance_id,
289            "instance ID",
290            self.options.max_identifier_length,
291        )?;
292        internal::validate_identifier(
293            event_name,
294            "event name",
295            self.options.max_identifier_length,
296        )?;
297        tracing::info!(
298            instance_id = %instance_id,
299            event_name = %event_name,
300            "Raising orchestration event"
301        );
302        let request = proto::RaiseEventRequest {
303            instance_id: instance_id.to_string(),
304            name: event_name.to_string(),
305            input: data,
306        };
307        self.inner.raise_event(request).await?;
308        Ok(())
309    }
310
311    /// Terminate a running orchestration.
312    pub async fn terminate_orchestration(
313        &mut self,
314        instance_id: &str,
315        output: Option<String>,
316        recursive: bool,
317    ) -> Result<()> {
318        internal::validate_identifier(
319            instance_id,
320            "instance ID",
321            self.options.max_identifier_length,
322        )?;
323        tracing::info!(
324            instance_id = %instance_id,
325            recursive = recursive,
326            "Terminating orchestration"
327        );
328        let request = proto::TerminateRequest {
329            instance_id: instance_id.to_string(),
330            output,
331            recursive,
332        };
333        self.inner.terminate_instance(request).await?;
334        Ok(())
335    }
336
337    /// Suspend a running orchestration.
338    pub async fn suspend_orchestration(
339        &mut self,
340        instance_id: &str,
341        reason: Option<String>,
342    ) -> Result<()> {
343        internal::validate_identifier(
344            instance_id,
345            "instance ID",
346            self.options.max_identifier_length,
347        )?;
348        tracing::info!(instance_id = %instance_id, "Suspending orchestration");
349        let request = proto::SuspendRequest {
350            instance_id: instance_id.to_string(),
351            reason,
352        };
353        self.inner.suspend_instance(request).await?;
354        Ok(())
355    }
356
357    /// Resume a suspended orchestration.
358    pub async fn resume_orchestration(
359        &mut self,
360        instance_id: &str,
361        reason: Option<String>,
362    ) -> Result<()> {
363        internal::validate_identifier(
364            instance_id,
365            "instance ID",
366            self.options.max_identifier_length,
367        )?;
368        tracing::info!(instance_id = %instance_id, "Resuming orchestration");
369        let request = proto::ResumeRequest {
370            instance_id: instance_id.to_string(),
371            reason,
372        };
373        self.inner.resume_instance(request).await?;
374        Ok(())
375    }
376
377    /// Purge an orchestration's history and state by instance ID.
378    ///
379    /// Returns the number of deleted instances.
380    pub async fn purge_orchestration(&mut self, instance_id: &str, recursive: bool) -> Result<i32> {
381        internal::validate_identifier(
382            instance_id,
383            "instance ID",
384            self.options.max_identifier_length,
385        )?;
386        tracing::info!(instance_id = %instance_id, "Purging orchestration");
387        let request = proto::PurgeInstancesRequest {
388            request: Some(proto::purge_instances_request::Request::InstanceId(
389                instance_id.to_string(),
390            )),
391            recursive,
392            force: None,
393        };
394        let response = self.inner.purge_instances(request).await?;
395        let count = response.into_inner().deleted_instance_count;
396        tracing::debug!(instance_id = %instance_id, deleted = count, "Purge complete");
397        Ok(count)
398    }
399
400    /// Purge orchestrations matching the given filter criteria.
401    ///
402    /// Returns the number of deleted instances.
403    ///
404    /// # Examples
405    ///
406    /// ```rust,no_run
407    /// use dapr_durabletask::api::{OrchestrationStatus, PurgeInstanceFilter};
408    ///
409    /// # async fn example(mut client: dapr_durabletask::client::TaskHubGrpcClient) {
410    /// let filter = PurgeInstanceFilter::new()
411    ///     .with_created_time_from(chrono::Utc::now() - chrono::Duration::hours(24))
412    ///     .with_runtime_status([OrchestrationStatus::Completed, OrchestrationStatus::Failed]);
413    ///
414    /// let deleted = client.purge_orchestrations_by_filter(filter, false).await.unwrap();
415    /// println!("Deleted {deleted} orchestrations");
416    /// # }
417    /// ```
418    pub async fn purge_orchestrations_by_filter(
419        &mut self,
420        filter: PurgeInstanceFilter,
421        recursive: bool,
422    ) -> Result<i32> {
423        tracing::info!(?filter, "Purging orchestrations by filter");
424        let request = proto::PurgeInstancesRequest {
425            request: Some(
426                proto::purge_instances_request::Request::PurgeInstanceFilter(filter.into_proto()),
427            ),
428            recursive,
429            force: None,
430        };
431        let response = self.inner.purge_instances(request).await?;
432        let count = response.into_inner().deleted_instance_count;
433        tracing::debug!(deleted = count, "Purge by filter complete");
434        Ok(count)
435    }
436}