1use std::time::Duration;
2
3use hyper_util::rt::TokioIo;
4use serde::Serialize;
5use tokio::net::UnixStream;
6use tonic::Request;
7use tonic::metadata::MetadataValue;
8use tonic::service::Interceptor;
9use tonic::service::interceptor::InterceptedService;
10use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
11use tower::service_fn;
12
13use crate::OperationResult;
14use crate::generated::v1::{
15 self as pb, plugin_invoker_client::PluginInvokerClient as ProtoPluginInvokerClient,
16};
17
18type PluginInvokerTransport = InterceptedService<Channel, RelayTokenInterceptor>;
19
20pub const ENV_PLUGIN_INVOKER_SOCKET: &str = "GESTALT_PLUGIN_INVOKER_SOCKET";
22pub const ENV_PLUGIN_INVOKER_SOCKET_TOKEN: &str = "GESTALT_PLUGIN_INVOKER_SOCKET_TOKEN";
24const PLUGIN_INVOKER_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
25
26#[derive(Debug, thiserror::Error)]
27pub enum PluginInvokerError {
29 #[error("plugin invoker: invocation token is not available")]
31 MissingInvocationToken,
32 #[error("{0}")]
34 Transport(#[from] tonic::transport::Error),
35 #[error("{0}")]
37 Status(#[from] tonic::Status),
38 #[error("{0}")]
40 Env(String),
41 #[error("{0}")]
43 Json(#[from] serde_json::Error),
44 #[error("{0}")]
46 Protocol(String),
47}
48
49#[derive(Clone, Debug, Default, Eq, PartialEq)]
50pub struct InvocationGrant {
52 pub plugin: String,
54 pub operations: Vec<String>,
56 pub surfaces: Vec<String>,
58 pub all_operations: bool,
60}
61
62#[derive(Clone, Debug, Default, Eq, PartialEq)]
63pub struct InvokeOptions {
65 pub connection: String,
67 pub instance: String,
69 pub idempotency_key: String,
71}
72
73pub struct PluginInvoker {
75 client: ProtoPluginInvokerClient<PluginInvokerTransport>,
76 invocation_token: String,
77}
78
79impl PluginInvoker {
80 pub async fn connect(
82 invocation_token: impl AsRef<str>,
83 ) -> std::result::Result<Self, PluginInvokerError> {
84 let invocation_token = invocation_token.as_ref().trim().to_owned();
85 if invocation_token.is_empty() {
86 return Err(PluginInvokerError::MissingInvocationToken);
87 }
88
89 let socket_path = std::env::var(ENV_PLUGIN_INVOKER_SOCKET).map_err(|_| {
90 PluginInvokerError::Env(format!("{ENV_PLUGIN_INVOKER_SOCKET} is not set"))
91 })?;
92 let relay_token = std::env::var(ENV_PLUGIN_INVOKER_SOCKET_TOKEN).unwrap_or_default();
93
94 let channel = match parse_plugin_invoker_target(&socket_path)? {
95 PluginInvokerTarget::Unix(path) => {
96 Endpoint::try_from("http://[::]:50051")?
97 .connect_with_connector(service_fn(move |_: Uri| {
98 let path = path.clone();
99 async move { UnixStream::connect(path).await.map(TokioIo::new) }
100 }))
101 .await?
102 }
103 PluginInvokerTarget::Tcp(address) => {
104 Endpoint::from_shared(format!("http://{address}"))?
105 .connect()
106 .await?
107 }
108 PluginInvokerTarget::Tls(address) => {
109 Endpoint::from_shared(format!("https://{address}"))?
110 .tls_config(ClientTlsConfig::new().with_native_roots())?
111 .connect()
112 .await?
113 }
114 };
115
116 Ok(Self {
117 client: ProtoPluginInvokerClient::with_interceptor(
118 channel,
119 relay_token_interceptor(relay_token.trim())?,
120 ),
121 invocation_token,
122 })
123 }
124
125 pub async fn invoke<P>(
127 &mut self,
128 plugin: &str,
129 operation: &str,
130 params: P,
131 options: Option<InvokeOptions>,
132 ) -> std::result::Result<OperationResult, PluginInvokerError>
133 where
134 P: Serialize,
135 {
136 let response = self
137 .client
138 .invoke(pb::PluginInvokeRequest {
139 plugin: plugin.to_string(),
140 operation: operation.to_string(),
141 params: Some(json_to_struct(serde_json::to_value(params)?)?),
142 connection: options
143 .as_ref()
144 .map(|opts| opts.connection.clone())
145 .unwrap_or_default(),
146 instance: options
147 .as_ref()
148 .map(|opts| opts.instance.clone())
149 .unwrap_or_default(),
150 invocation_token: self.invocation_token.clone(),
151 idempotency_key: options
152 .as_ref()
153 .map(|opts| opts.idempotency_key.trim().to_string())
154 .unwrap_or_default(),
155 })
156 .await?
157 .into_inner();
158
159 let status = u16::try_from(response.status).map_err(|_| {
160 PluginInvokerError::Protocol(format!(
161 "plugin invoker: invalid response status {}",
162 response.status
163 ))
164 })?;
165
166 Ok(OperationResult {
167 status,
168 body: response.body,
169 })
170 }
171
172 pub async fn invoke_graphql<V>(
174 &mut self,
175 plugin: &str,
176 document: &str,
177 variables: Option<V>,
178 options: Option<InvokeOptions>,
179 ) -> std::result::Result<OperationResult, PluginInvokerError>
180 where
181 V: Serialize,
182 {
183 let document = document.trim();
184 if document.is_empty() {
185 return Err(PluginInvokerError::Protocol(
186 "plugin invoker: graphql document is required".to_string(),
187 ));
188 }
189
190 let response = self
191 .client
192 .invoke_graph_ql(pb::PluginInvokeGraphQlRequest {
193 plugin: plugin.to_string(),
194 document: document.to_string(),
195 variables: variables
196 .map(serde_json::to_value)
197 .transpose()?
198 .map(|value| json_to_optional_struct(value, "variables"))
199 .transpose()?
200 .flatten(),
201 connection: options
202 .as_ref()
203 .map(|opts| opts.connection.clone())
204 .unwrap_or_default(),
205 instance: options
206 .as_ref()
207 .map(|opts| opts.instance.clone())
208 .unwrap_or_default(),
209 invocation_token: self.invocation_token.clone(),
210 idempotency_key: options
211 .as_ref()
212 .map(|opts| opts.idempotency_key.trim().to_string())
213 .unwrap_or_default(),
214 })
215 .await?
216 .into_inner();
217
218 let status = u16::try_from(response.status).map_err(|_| {
219 PluginInvokerError::Protocol(format!(
220 "plugin invoker: invalid response status {}",
221 response.status
222 ))
223 })?;
224
225 Ok(OperationResult {
226 status,
227 body: response.body,
228 })
229 }
230
231 pub async fn exchange_invocation_token(
233 &mut self,
234 grants: &[InvocationGrant],
235 ttl: Option<Duration>,
236 ) -> std::result::Result<String, PluginInvokerError> {
237 let ttl_seconds = ttl
238 .map(duration_to_ttl_seconds)
239 .transpose()?
240 .unwrap_or_default();
241 let response = self
242 .client
243 .exchange_invocation_token(pb::ExchangeInvocationTokenRequest {
244 parent_invocation_token: self.invocation_token.clone(),
245 grants: encode_invocation_grants(grants),
246 ttl_seconds,
247 })
248 .await?
249 .into_inner();
250
251 Ok(response.invocation_token)
252 }
253}
254
255enum PluginInvokerTarget {
256 Unix(String),
257 Tcp(String),
258 Tls(String),
259}
260
261fn parse_plugin_invoker_target(
262 raw_target: &str,
263) -> Result<PluginInvokerTarget, PluginInvokerError> {
264 let target = raw_target.trim();
265 if target.is_empty() {
266 return Err(PluginInvokerError::Env(
267 "plugin invoker: transport target is required".to_string(),
268 ));
269 }
270 if let Some(address) = target.strip_prefix("tcp://") {
271 let address = address.trim();
272 if address.is_empty() {
273 return Err(PluginInvokerError::Env(format!(
274 "plugin invoker: tcp target {raw_target:?} is missing host:port"
275 )));
276 }
277 return Ok(PluginInvokerTarget::Tcp(address.to_string()));
278 }
279 if let Some(address) = target.strip_prefix("tls://") {
280 let address = address.trim();
281 if address.is_empty() {
282 return Err(PluginInvokerError::Env(format!(
283 "plugin invoker: tls target {raw_target:?} is missing host:port"
284 )));
285 }
286 return Ok(PluginInvokerTarget::Tls(address.to_string()));
287 }
288 if let Some(path) = target.strip_prefix("unix://") {
289 let path = path.trim();
290 if path.is_empty() {
291 return Err(PluginInvokerError::Env(format!(
292 "plugin invoker: unix target {raw_target:?} is missing a socket path"
293 )));
294 }
295 return Ok(PluginInvokerTarget::Unix(path.to_string()));
296 }
297 if target.contains("://") {
298 let scheme = target.split("://").next().unwrap_or_default();
299 return Err(PluginInvokerError::Env(format!(
300 "plugin invoker: unsupported target scheme {scheme:?}"
301 )));
302 }
303 Ok(PluginInvokerTarget::Unix(target.to_string()))
304}
305
306fn encode_invocation_grants(grants: &[InvocationGrant]) -> Vec<pb::PluginInvocationGrant> {
307 grants
308 .iter()
309 .filter_map(|grant| {
310 let plugin = grant.plugin.trim();
311 if plugin.is_empty() {
312 return None;
313 }
314 let operations = grant
315 .operations
316 .iter()
317 .map(|operation| operation.trim())
318 .filter(|operation| !operation.is_empty())
319 .map(ToOwned::to_owned)
320 .collect();
321 let surfaces = grant
322 .surfaces
323 .iter()
324 .map(|surface| surface.trim())
325 .filter(|surface| !surface.is_empty())
326 .map(|surface| surface.to_ascii_lowercase())
327 .collect();
328
329 Some(pb::PluginInvocationGrant {
330 plugin: plugin.to_owned(),
331 operations,
332 surfaces,
333 all_operations: grant.all_operations,
334 })
335 })
336 .collect()
337}
338
339fn duration_to_ttl_seconds(ttl: Duration) -> std::result::Result<i64, PluginInvokerError> {
340 if ttl.is_zero() {
341 return Ok(0);
342 }
343
344 let ttl_seconds = ttl.as_secs().max(1);
345 i64::try_from(ttl_seconds).map_err(|_| {
346 PluginInvokerError::Protocol(
347 "plugin invoker: exchange token ttl exceeds supported range".to_string(),
348 )
349 })
350}
351
352fn relay_token_interceptor(token: &str) -> Result<RelayTokenInterceptor, PluginInvokerError> {
353 let header = if token.trim().is_empty() {
354 None
355 } else {
356 Some(MetadataValue::try_from(token.to_string()).map_err(|err| {
357 PluginInvokerError::Env(format!(
358 "invalid plugin invoker relay token metadata: {err}"
359 ))
360 })?)
361 };
362 Ok(RelayTokenInterceptor { header })
363}
364
365#[derive(Clone)]
366struct RelayTokenInterceptor {
367 header: Option<MetadataValue<tonic::metadata::Ascii>>,
368}
369
370impl Interceptor for RelayTokenInterceptor {
371 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, tonic::Status> {
372 if let Some(header) = self.header.clone() {
373 request
374 .metadata_mut()
375 .insert(PLUGIN_INVOKER_RELAY_TOKEN_HEADER, header);
376 }
377 Ok(request)
378 }
379}
380
381fn json_to_struct(
382 value: serde_json::Value,
383) -> std::result::Result<prost_types::Struct, PluginInvokerError> {
384 Ok(json_to_optional_struct(value, "params")?.unwrap_or_default())
385}
386
387fn json_to_optional_struct(
388 value: serde_json::Value,
389 field_name: &str,
390) -> std::result::Result<Option<prost_types::Struct>, PluginInvokerError> {
391 let serde_json::Value::Object(fields) = value else {
392 if value.is_null() {
393 return Ok(None);
394 }
395 return Err(PluginInvokerError::Protocol(format!(
396 "plugin invoker: {field_name} must serialize to a JSON object"
397 )));
398 };
399
400 Ok(Some(prost_types::Struct {
401 fields: fields
402 .into_iter()
403 .map(|(key, value)| (key, json_value_to_prost(value)))
404 .collect(),
405 }))
406}
407
408fn json_value_to_prost(value: serde_json::Value) -> prost_types::Value {
409 use prost_types::value::Kind;
410
411 let kind = match value {
412 serde_json::Value::Null => Kind::NullValue(0),
413 serde_json::Value::Bool(boolean) => Kind::BoolValue(boolean),
414 serde_json::Value::Number(number) => Kind::NumberValue(number.as_f64().unwrap_or_default()),
415 serde_json::Value::String(string) => Kind::StringValue(string),
416 serde_json::Value::Array(items) => Kind::ListValue(prost_types::ListValue {
417 values: items.into_iter().map(json_value_to_prost).collect(),
418 }),
419 serde_json::Value::Object(fields) => Kind::StructValue(prost_types::Struct {
420 fields: fields
421 .into_iter()
422 .map(|(key, value)| (key, json_value_to_prost(value)))
423 .collect(),
424 }),
425 };
426
427 prost_types::Value { kind: Some(kind) }
428}