1use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
2
3use crate::api::{DurableTaskError, OrchestrationState, PurgeInstanceFilter, Result};
4use crate::internal;
5use crate::proto;
6use crate::proto::task_hub_sidecar_service_client::TaskHubSidecarServiceClient;
7
8use super::options::ClientOptions;
9
10pub struct TaskHubGrpcClient {
12 inner: TaskHubSidecarServiceClient<Channel>,
13 options: ClientOptions,
14}
15
16async fn build_channel(host_address: &str, options: &ClientOptions) -> Result<Channel> {
21 let mut builder = Channel::from_shared(host_address.to_string())
22 .map_err(|e| DurableTaskError::Other(e.to_string()))?;
23
24 if let Some(tls) = &options.tls {
25 if tls.skip_verify {
26 return Err(DurableTaskError::Other(
27 "skip_verify is not supported; connect without TLS for development".into(),
28 ));
29 }
30
31 let mut tls_config = ClientTlsConfig::new();
32
33 if let Some(ca_pem) = &tls.ca_cert_pem {
34 tls_config = tls_config.ca_certificate(Certificate::from_pem(ca_pem));
35 }
36
37 match (&tls.client_cert_pem, &tls.client_key_pem) {
38 (Some(cert), Some(key)) => {
39 tls_config = tls_config.identity(Identity::from_pem(cert, key));
40 }
41 (None, None) => {}
42 _ => {
43 return Err(DurableTaskError::Other(
44 "client_cert_pem and client_key_pem must both be set for mutual TLS".into(),
45 ));
46 }
47 }
48
49 if let Some(domain) = &tls.domain_name {
50 tls_config = tls_config.domain_name(domain.clone());
51 }
52
53 builder = builder
54 .tls_config(tls_config)
55 .map_err(|e| DurableTaskError::Other(e.to_string()))?;
56 }
57
58 if let Some(timeout) = options.connect_timeout {
59 builder = builder.connect_timeout(timeout);
60 }
61
62 if let Some(interval) = options.keepalive_interval {
63 builder = builder.tcp_keepalive(Some(interval));
64 }
65
66 builder
67 .connect()
68 .await
69 .map_err(|e| DurableTaskError::Other(e.to_string()))
70}
71
72fn make_stub(channel: Channel, options: &ClientOptions) -> TaskHubSidecarServiceClient<Channel> {
74 let mut stub = TaskHubSidecarServiceClient::new(channel);
75 if let Some(size) = options.max_grpc_message_size {
76 stub = stub.max_decoding_message_size(size);
77 }
78 stub
79}
80
81impl TaskHubGrpcClient {
84 pub async fn new(host_address: &str) -> Result<Self> {
88 Self::with_options(host_address, ClientOptions::default()).await
89 }
90
91 pub async fn with_options(host_address: &str, options: ClientOptions) -> Result<Self> {
93 tracing::info!(address = %host_address, "Connecting to sidecar");
94 let channel = build_channel(host_address, &options).await?;
95 tracing::info!(address = %host_address, "Client connected");
96 let inner = make_stub(channel, &options);
97 Ok(Self { inner, options })
98 }
99
100 pub fn from_channel(channel: Channel) -> Self {
102 let options = ClientOptions::default();
103 let inner = make_stub(channel, &options);
104 Self { inner, options }
105 }
106
107 pub fn from_channel_with_options(channel: Channel, options: ClientOptions) -> Self {
109 let inner = make_stub(channel, &options);
110 Self { inner, options }
111 }
112
113 pub fn close(self) {
118 drop(self);
119 }
120
121 pub async fn schedule_new_orchestration(
123 &mut self,
124 orchestrator_name: &str,
125 input: Option<String>,
126 instance_id: Option<String>,
127 start_at: Option<chrono::DateTime<chrono::Utc>>,
128 ) -> Result<String> {
129 internal::validate_identifier(
130 orchestrator_name,
131 "orchestrator name",
132 self.options.max_identifier_length,
133 )?;
134 let instance_id = instance_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
135 internal::validate_identifier(
136 &instance_id,
137 "instance ID",
138 self.options.max_identifier_length,
139 )?;
140
141 tracing::info!(
142 instance_id = %instance_id,
143 orchestrator = %orchestrator_name,
144 "Scheduling new orchestration"
145 );
146
147 #[cfg(feature = "opentelemetry")]
148 let (parent_trace_context, _otel_ctx) = {
149 let parent_ctx = opentelemetry::Context::current();
150 let ctx = internal::otel::start_create_orchestration_span(
151 &parent_ctx,
152 orchestrator_name,
153 &instance_id,
154 );
155 let sc = opentelemetry::trace::TraceContextExt::span(&ctx)
156 .span_context()
157 .clone();
158 let tc = internal::otel::trace_context_from_span_context(&sc);
159 (tc, ctx)
160 };
161 #[cfg(not(feature = "opentelemetry"))]
162 let parent_trace_context: Option<proto::TraceContext> = None;
163
164 let request = proto::CreateInstanceRequest {
165 instance_id: instance_id.clone(),
166 name: orchestrator_name.to_string(),
167 input,
168 scheduled_start_timestamp: start_at.map(internal::to_timestamp),
169 version: None,
170 execution_id: None,
171 tags: std::collections::HashMap::new(),
172 parent_trace_context,
173 };
174
175 let response = self.inner.start_instance(request).await?;
176 let result_id = response.into_inner().instance_id;
177
178 #[cfg(feature = "opentelemetry")]
179 internal::otel::end_span(&_otel_ctx);
180
181 tracing::debug!(instance_id = %result_id, "Orchestration scheduled");
182 Ok(result_id)
183 }
184
185 pub async fn get_orchestration_state(
187 &mut self,
188 instance_id: &str,
189 fetch_payloads: bool,
190 ) -> Result<Option<OrchestrationState>> {
191 internal::validate_identifier(
192 instance_id,
193 "instance ID",
194 self.options.max_identifier_length,
195 )?;
196 let request = proto::GetInstanceRequest {
197 instance_id: instance_id.to_string(),
198 get_inputs_and_outputs: fetch_payloads,
199 };
200 let response = self.inner.get_instance(request).await?;
201 Ok(OrchestrationState::try_from(&response.into_inner()).ok())
202 }
203
204 pub async fn wait_for_orchestration_start(
206 &mut self,
207 instance_id: &str,
208 fetch_payloads: bool,
209 timeout: Option<std::time::Duration>,
210 ) -> Result<Option<OrchestrationState>> {
211 internal::validate_identifier(
212 instance_id,
213 "instance ID",
214 self.options.max_identifier_length,
215 )?;
216 tracing::debug!(instance_id = %instance_id, "Waiting for orchestration to start");
217
218 let request = proto::GetInstanceRequest {
219 instance_id: instance_id.to_string(),
220 get_inputs_and_outputs: fetch_payloads,
221 };
222
223 let fut = self.inner.wait_for_instance_start(request);
224
225 let response = if let Some(timeout_dur) = timeout {
226 tokio::time::timeout(timeout_dur, fut)
227 .await
228 .map_err(|_| DurableTaskError::Timeout)??
229 } else {
230 fut.await?
231 };
232
233 let state = OrchestrationState::try_from(&response.into_inner()).ok();
234 tracing::debug!(
235 instance_id = %instance_id,
236 status = ?state.as_ref().map(|s| &s.runtime_status),
237 "Orchestration started"
238 );
239 Ok(state)
240 }
241
242 pub async fn wait_for_orchestration_completion(
244 &mut self,
245 instance_id: &str,
246 fetch_payloads: bool,
247 timeout: Option<std::time::Duration>,
248 ) -> Result<Option<OrchestrationState>> {
249 internal::validate_identifier(
250 instance_id,
251 "instance ID",
252 self.options.max_identifier_length,
253 )?;
254 tracing::debug!(instance_id = %instance_id, "Waiting for orchestration completion");
255
256 let request = proto::GetInstanceRequest {
257 instance_id: instance_id.to_string(),
258 get_inputs_and_outputs: fetch_payloads,
259 };
260
261 let fut = self.inner.wait_for_instance_completion(request);
262
263 let response = if let Some(timeout_dur) = timeout {
264 tokio::time::timeout(timeout_dur, fut)
265 .await
266 .map_err(|_| DurableTaskError::Timeout)??
267 } else {
268 fut.await?
269 };
270
271 let state = OrchestrationState::try_from(&response.into_inner()).ok();
272 tracing::debug!(
273 instance_id = %instance_id,
274 status = ?state.as_ref().map(|s| &s.runtime_status),
275 "Orchestration completed"
276 );
277 Ok(state)
278 }
279
280 pub async fn raise_orchestration_event(
282 &mut self,
283 instance_id: &str,
284 event_name: &str,
285 data: Option<String>,
286 ) -> Result<()> {
287 internal::validate_identifier(
288 instance_id,
289 "instance ID",
290 self.options.max_identifier_length,
291 )?;
292 internal::validate_identifier(
293 event_name,
294 "event name",
295 self.options.max_identifier_length,
296 )?;
297 tracing::info!(
298 instance_id = %instance_id,
299 event_name = %event_name,
300 "Raising orchestration event"
301 );
302 let request = proto::RaiseEventRequest {
303 instance_id: instance_id.to_string(),
304 name: event_name.to_string(),
305 input: data,
306 };
307 self.inner.raise_event(request).await?;
308 Ok(())
309 }
310
311 pub async fn terminate_orchestration(
313 &mut self,
314 instance_id: &str,
315 output: Option<String>,
316 recursive: bool,
317 ) -> Result<()> {
318 internal::validate_identifier(
319 instance_id,
320 "instance ID",
321 self.options.max_identifier_length,
322 )?;
323 tracing::info!(
324 instance_id = %instance_id,
325 recursive = recursive,
326 "Terminating orchestration"
327 );
328 let request = proto::TerminateRequest {
329 instance_id: instance_id.to_string(),
330 output,
331 recursive,
332 };
333 self.inner.terminate_instance(request).await?;
334 Ok(())
335 }
336
337 pub async fn suspend_orchestration(
339 &mut self,
340 instance_id: &str,
341 reason: Option<String>,
342 ) -> Result<()> {
343 internal::validate_identifier(
344 instance_id,
345 "instance ID",
346 self.options.max_identifier_length,
347 )?;
348 tracing::info!(instance_id = %instance_id, "Suspending orchestration");
349 let request = proto::SuspendRequest {
350 instance_id: instance_id.to_string(),
351 reason,
352 };
353 self.inner.suspend_instance(request).await?;
354 Ok(())
355 }
356
357 pub async fn resume_orchestration(
359 &mut self,
360 instance_id: &str,
361 reason: Option<String>,
362 ) -> Result<()> {
363 internal::validate_identifier(
364 instance_id,
365 "instance ID",
366 self.options.max_identifier_length,
367 )?;
368 tracing::info!(instance_id = %instance_id, "Resuming orchestration");
369 let request = proto::ResumeRequest {
370 instance_id: instance_id.to_string(),
371 reason,
372 };
373 self.inner.resume_instance(request).await?;
374 Ok(())
375 }
376
377 pub async fn purge_orchestration(&mut self, instance_id: &str, recursive: bool) -> Result<i32> {
381 internal::validate_identifier(
382 instance_id,
383 "instance ID",
384 self.options.max_identifier_length,
385 )?;
386 tracing::info!(instance_id = %instance_id, "Purging orchestration");
387 let request = proto::PurgeInstancesRequest {
388 request: Some(proto::purge_instances_request::Request::InstanceId(
389 instance_id.to_string(),
390 )),
391 recursive,
392 force: None,
393 };
394 let response = self.inner.purge_instances(request).await?;
395 let count = response.into_inner().deleted_instance_count;
396 tracing::debug!(instance_id = %instance_id, deleted = count, "Purge complete");
397 Ok(count)
398 }
399
400 pub async fn purge_orchestrations_by_filter(
419 &mut self,
420 filter: PurgeInstanceFilter,
421 recursive: bool,
422 ) -> Result<i32> {
423 tracing::info!(?filter, "Purging orchestrations by filter");
424 let request = proto::PurgeInstancesRequest {
425 request: Some(
426 proto::purge_instances_request::Request::PurgeInstanceFilter(filter.into_proto()),
427 ),
428 recursive,
429 force: None,
430 };
431 let response = self.inner.purge_instances(request).await?;
432 let count = response.into_inner().deleted_instance_count;
433 tracing::debug!(deleted = count, "Purge by filter complete");
434 Ok(count)
435 }
436}