Skip to main content

gestalt/
invoker.rs

1use std::time::Duration;
2
3use hyper_util::rt::TokioIo;
4use serde::Serialize;
5use tokio::net::UnixStream;
6use tonic::Request;
7use tonic::metadata::MetadataValue;
8use tonic::service::Interceptor;
9use tonic::service::interceptor::InterceptedService;
10use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
11use tower::service_fn;
12
13use crate::OperationResult;
14use crate::generated::v1::{
15    self as pb, plugin_invoker_client::PluginInvokerClient as ProtoPluginInvokerClient,
16};
17
18type PluginInvokerTransport = InterceptedService<Channel, RelayTokenInterceptor>;
19
20/// Environment variable containing the plugin-invoker host-service target.
21pub const ENV_PLUGIN_INVOKER_SOCKET: &str = "GESTALT_PLUGIN_INVOKER_SOCKET";
22/// Environment variable containing the optional plugin-invoker relay token.
23pub const ENV_PLUGIN_INVOKER_SOCKET_TOKEN: &str = "GESTALT_PLUGIN_INVOKER_SOCKET_TOKEN";
24const PLUGIN_INVOKER_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
25
26#[derive(Debug, thiserror::Error)]
27/// Errors returned by [`PluginInvoker`].
28pub enum PluginInvokerError {
29    /// The invocation token was empty.
30    #[error("plugin invoker: invocation token is not available")]
31    MissingInvocationToken,
32    /// The host-service transport could not be created.
33    #[error("{0}")]
34    Transport(#[from] tonic::transport::Error),
35    /// The host-service RPC returned a gRPC status.
36    #[error("{0}")]
37    Status(#[from] tonic::Status),
38    /// Required environment or target configuration was invalid.
39    #[error("{0}")]
40    Env(String),
41    /// Invocation parameters or variables could not be serialized.
42    #[error("{0}")]
43    Json(#[from] serde_json::Error),
44    /// The host returned a protocol value the SDK could not represent.
45    #[error("{0}")]
46    Protocol(String),
47}
48
49#[derive(Clone, Debug, Default, Eq, PartialEq)]
50/// Grant included when exchanging an invocation token for a child token.
51pub struct InvocationGrant {
52    /// Plugin name that the child token may invoke.
53    pub plugin: String,
54    /// Specific operation ids allowed by the child token.
55    pub operations: Vec<String>,
56    /// Surface names allowed by the child token.
57    pub surfaces: Vec<String>,
58    /// Whether the child token may invoke every operation on the plugin.
59    pub all_operations: bool,
60}
61
62#[derive(Clone, Debug, Default, Eq, PartialEq)]
63/// Options that select the target connection for a plugin invocation.
64pub struct InvokeOptions {
65    /// Connected account id or name to invoke against.
66    pub connection: String,
67    /// Provider instance id or name to invoke against.
68    pub instance: String,
69    /// Idempotency key forwarded to the target operation.
70    pub idempotency_key: String,
71}
72
73/// Client for invoking sibling plugin operations through the host.
74pub struct PluginInvoker {
75    client: ProtoPluginInvokerClient<PluginInvokerTransport>,
76    invocation_token: String,
77}
78
79impl PluginInvoker {
80    /// Connects to the plugin invoker with an invocation token from the host.
81    pub async fn connect(
82        invocation_token: impl AsRef<str>,
83    ) -> std::result::Result<Self, PluginInvokerError> {
84        let invocation_token = invocation_token.as_ref().trim().to_owned();
85        if invocation_token.is_empty() {
86            return Err(PluginInvokerError::MissingInvocationToken);
87        }
88
89        let socket_path = std::env::var(ENV_PLUGIN_INVOKER_SOCKET).map_err(|_| {
90            PluginInvokerError::Env(format!("{ENV_PLUGIN_INVOKER_SOCKET} is not set"))
91        })?;
92        let relay_token = std::env::var(ENV_PLUGIN_INVOKER_SOCKET_TOKEN).unwrap_or_default();
93
94        let channel = match parse_plugin_invoker_target(&socket_path)? {
95            PluginInvokerTarget::Unix(path) => {
96                Endpoint::try_from("http://[::]:50051")?
97                    .connect_with_connector(service_fn(move |_: Uri| {
98                        let path = path.clone();
99                        async move { UnixStream::connect(path).await.map(TokioIo::new) }
100                    }))
101                    .await?
102            }
103            PluginInvokerTarget::Tcp(address) => {
104                Endpoint::from_shared(format!("http://{address}"))?
105                    .connect()
106                    .await?
107            }
108            PluginInvokerTarget::Tls(address) => {
109                Endpoint::from_shared(format!("https://{address}"))?
110                    .tls_config(ClientTlsConfig::new().with_native_roots())?
111                    .connect()
112                    .await?
113            }
114        };
115
116        Ok(Self {
117            client: ProtoPluginInvokerClient::with_interceptor(
118                channel,
119                relay_token_interceptor(relay_token.trim())?,
120            ),
121            invocation_token,
122        })
123    }
124
125    /// Invokes one operation on another plugin.
126    pub async fn invoke<P>(
127        &mut self,
128        plugin: &str,
129        operation: &str,
130        params: P,
131        options: Option<InvokeOptions>,
132    ) -> std::result::Result<OperationResult, PluginInvokerError>
133    where
134        P: Serialize,
135    {
136        let response = self
137            .client
138            .invoke(pb::PluginInvokeRequest {
139                plugin: plugin.to_string(),
140                operation: operation.to_string(),
141                params: Some(json_to_struct(serde_json::to_value(params)?)?),
142                connection: options
143                    .as_ref()
144                    .map(|opts| opts.connection.clone())
145                    .unwrap_or_default(),
146                instance: options
147                    .as_ref()
148                    .map(|opts| opts.instance.clone())
149                    .unwrap_or_default(),
150                invocation_token: self.invocation_token.clone(),
151                idempotency_key: options
152                    .as_ref()
153                    .map(|opts| opts.idempotency_key.trim().to_string())
154                    .unwrap_or_default(),
155            })
156            .await?
157            .into_inner();
158
159        let status = u16::try_from(response.status).map_err(|_| {
160            PluginInvokerError::Protocol(format!(
161                "plugin invoker: invalid response status {}",
162                response.status
163            ))
164        })?;
165
166        Ok(OperationResult {
167            status,
168            body: response.body,
169        })
170    }
171
172    /// Invokes another plugin's GraphQL surface.
173    pub async fn invoke_graphql<V>(
174        &mut self,
175        plugin: &str,
176        document: &str,
177        variables: Option<V>,
178        options: Option<InvokeOptions>,
179    ) -> std::result::Result<OperationResult, PluginInvokerError>
180    where
181        V: Serialize,
182    {
183        let document = document.trim();
184        if document.is_empty() {
185            return Err(PluginInvokerError::Protocol(
186                "plugin invoker: graphql document is required".to_string(),
187            ));
188        }
189
190        let response = self
191            .client
192            .invoke_graph_ql(pb::PluginInvokeGraphQlRequest {
193                plugin: plugin.to_string(),
194                document: document.to_string(),
195                variables: variables
196                    .map(serde_json::to_value)
197                    .transpose()?
198                    .map(|value| json_to_optional_struct(value, "variables"))
199                    .transpose()?
200                    .flatten(),
201                connection: options
202                    .as_ref()
203                    .map(|opts| opts.connection.clone())
204                    .unwrap_or_default(),
205                instance: options
206                    .as_ref()
207                    .map(|opts| opts.instance.clone())
208                    .unwrap_or_default(),
209                invocation_token: self.invocation_token.clone(),
210                idempotency_key: options
211                    .as_ref()
212                    .map(|opts| opts.idempotency_key.trim().to_string())
213                    .unwrap_or_default(),
214            })
215            .await?
216            .into_inner();
217
218        let status = u16::try_from(response.status).map_err(|_| {
219            PluginInvokerError::Protocol(format!(
220                "plugin invoker: invalid response status {}",
221                response.status
222            ))
223        })?;
224
225        Ok(OperationResult {
226            status,
227            body: response.body,
228        })
229    }
230
231    /// Exchanges this invocation token for a narrower child token.
232    pub async fn exchange_invocation_token(
233        &mut self,
234        grants: &[InvocationGrant],
235        ttl: Option<Duration>,
236    ) -> std::result::Result<String, PluginInvokerError> {
237        let ttl_seconds = ttl
238            .map(duration_to_ttl_seconds)
239            .transpose()?
240            .unwrap_or_default();
241        let response = self
242            .client
243            .exchange_invocation_token(pb::ExchangeInvocationTokenRequest {
244                parent_invocation_token: self.invocation_token.clone(),
245                grants: encode_invocation_grants(grants),
246                ttl_seconds,
247            })
248            .await?
249            .into_inner();
250
251        Ok(response.invocation_token)
252    }
253}
254
255enum PluginInvokerTarget {
256    Unix(String),
257    Tcp(String),
258    Tls(String),
259}
260
261fn parse_plugin_invoker_target(
262    raw_target: &str,
263) -> Result<PluginInvokerTarget, PluginInvokerError> {
264    let target = raw_target.trim();
265    if target.is_empty() {
266        return Err(PluginInvokerError::Env(
267            "plugin invoker: transport target is required".to_string(),
268        ));
269    }
270    if let Some(address) = target.strip_prefix("tcp://") {
271        let address = address.trim();
272        if address.is_empty() {
273            return Err(PluginInvokerError::Env(format!(
274                "plugin invoker: tcp target {raw_target:?} is missing host:port"
275            )));
276        }
277        return Ok(PluginInvokerTarget::Tcp(address.to_string()));
278    }
279    if let Some(address) = target.strip_prefix("tls://") {
280        let address = address.trim();
281        if address.is_empty() {
282            return Err(PluginInvokerError::Env(format!(
283                "plugin invoker: tls target {raw_target:?} is missing host:port"
284            )));
285        }
286        return Ok(PluginInvokerTarget::Tls(address.to_string()));
287    }
288    if let Some(path) = target.strip_prefix("unix://") {
289        let path = path.trim();
290        if path.is_empty() {
291            return Err(PluginInvokerError::Env(format!(
292                "plugin invoker: unix target {raw_target:?} is missing a socket path"
293            )));
294        }
295        return Ok(PluginInvokerTarget::Unix(path.to_string()));
296    }
297    if target.contains("://") {
298        let scheme = target.split("://").next().unwrap_or_default();
299        return Err(PluginInvokerError::Env(format!(
300            "plugin invoker: unsupported target scheme {scheme:?}"
301        )));
302    }
303    Ok(PluginInvokerTarget::Unix(target.to_string()))
304}
305
306fn encode_invocation_grants(grants: &[InvocationGrant]) -> Vec<pb::PluginInvocationGrant> {
307    grants
308        .iter()
309        .filter_map(|grant| {
310            let plugin = grant.plugin.trim();
311            if plugin.is_empty() {
312                return None;
313            }
314            let operations = grant
315                .operations
316                .iter()
317                .map(|operation| operation.trim())
318                .filter(|operation| !operation.is_empty())
319                .map(ToOwned::to_owned)
320                .collect();
321            let surfaces = grant
322                .surfaces
323                .iter()
324                .map(|surface| surface.trim())
325                .filter(|surface| !surface.is_empty())
326                .map(|surface| surface.to_ascii_lowercase())
327                .collect();
328
329            Some(pb::PluginInvocationGrant {
330                plugin: plugin.to_owned(),
331                operations,
332                surfaces,
333                all_operations: grant.all_operations,
334            })
335        })
336        .collect()
337}
338
339fn duration_to_ttl_seconds(ttl: Duration) -> std::result::Result<i64, PluginInvokerError> {
340    if ttl.is_zero() {
341        return Ok(0);
342    }
343
344    let ttl_seconds = ttl.as_secs().max(1);
345    i64::try_from(ttl_seconds).map_err(|_| {
346        PluginInvokerError::Protocol(
347            "plugin invoker: exchange token ttl exceeds supported range".to_string(),
348        )
349    })
350}
351
352fn relay_token_interceptor(token: &str) -> Result<RelayTokenInterceptor, PluginInvokerError> {
353    let header = if token.trim().is_empty() {
354        None
355    } else {
356        Some(MetadataValue::try_from(token.to_string()).map_err(|err| {
357            PluginInvokerError::Env(format!(
358                "invalid plugin invoker relay token metadata: {err}"
359            ))
360        })?)
361    };
362    Ok(RelayTokenInterceptor { header })
363}
364
365#[derive(Clone)]
366struct RelayTokenInterceptor {
367    header: Option<MetadataValue<tonic::metadata::Ascii>>,
368}
369
370impl Interceptor for RelayTokenInterceptor {
371    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, tonic::Status> {
372        if let Some(header) = self.header.clone() {
373            request
374                .metadata_mut()
375                .insert(PLUGIN_INVOKER_RELAY_TOKEN_HEADER, header);
376        }
377        Ok(request)
378    }
379}
380
381fn json_to_struct(
382    value: serde_json::Value,
383) -> std::result::Result<prost_types::Struct, PluginInvokerError> {
384    Ok(json_to_optional_struct(value, "params")?.unwrap_or_default())
385}
386
387fn json_to_optional_struct(
388    value: serde_json::Value,
389    field_name: &str,
390) -> std::result::Result<Option<prost_types::Struct>, PluginInvokerError> {
391    let serde_json::Value::Object(fields) = value else {
392        if value.is_null() {
393            return Ok(None);
394        }
395        return Err(PluginInvokerError::Protocol(format!(
396            "plugin invoker: {field_name} must serialize to a JSON object"
397        )));
398    };
399
400    Ok(Some(prost_types::Struct {
401        fields: fields
402            .into_iter()
403            .map(|(key, value)| (key, json_value_to_prost(value)))
404            .collect(),
405    }))
406}
407
408fn json_value_to_prost(value: serde_json::Value) -> prost_types::Value {
409    use prost_types::value::Kind;
410
411    let kind = match value {
412        serde_json::Value::Null => Kind::NullValue(0),
413        serde_json::Value::Bool(boolean) => Kind::BoolValue(boolean),
414        serde_json::Value::Number(number) => Kind::NumberValue(number.as_f64().unwrap_or_default()),
415        serde_json::Value::String(string) => Kind::StringValue(string),
416        serde_json::Value::Array(items) => Kind::ListValue(prost_types::ListValue {
417            values: items.into_iter().map(json_value_to_prost).collect(),
418        }),
419        serde_json::Value::Object(fields) => Kind::StructValue(prost_types::Struct {
420            fields: fields
421                .into_iter()
422                .map(|(key, value)| (key, json_value_to_prost(value)))
423                .collect(),
424        }),
425    };
426
427    prost_types::Value { kind: Some(kind) }
428}