use hyper_util::rt::TokioIo;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::net::UnixStream;
use tonic::Request as GrpcRequest;
use tonic::codegen::async_trait;
use tonic::metadata::MetadataValue;
use tonic::service::Interceptor;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
use tower::service_fn;
use crate::OperationResult;
use crate::api::{Request, current_request_context};
use crate::app_decode::{decode_app_result, decode_graphql_result};
use crate::env::{ENV_HOST_SERVICE_SOCKET, ENV_HOST_SERVICE_TOKEN};
use crate::generated::v1::{self as pb, app_client::AppClient as ProtoAppClient};
use crate::protocol;
type AppTransport = InterceptedService<Channel, RelayTokenInterceptor>;
const APP_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("{0}")]
Transport(#[from] tonic::transport::Error),
#[error("{0}")]
Status(#[from] tonic::Status),
#[error("{0}")]
Env(String),
#[error("{0}")]
Json(#[from] serde_json::Error),
#[error("{0}")]
Protocol(String),
#[error("{0}")]
Invoke(#[source] Box<crate::InvokeError>),
}
impl From<crate::InvokeError> for AppError {
fn from(error: crate::InvokeError) -> Self {
Self::Invoke(Box::new(error))
}
}
impl From<Box<crate::InvokeError>> for AppError {
fn from(error: Box<crate::InvokeError>) -> Self {
Self::Invoke(error)
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct InvokeOptions {
pub connection: String,
pub instance: String,
pub idempotency_key: String,
pub credential_mode: String,
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct InvokeGraphQLOptions {
pub connection: String,
pub instance: String,
pub idempotency_key: String,
}
#[async_trait]
pub trait AppContract: Send {
async fn invoke(
&mut self,
plugin: String,
operation: String,
params: serde_json::Value,
options: Option<InvokeOptions>,
) -> std::result::Result<serde_json::Value, AppError>;
async fn invoke_raw(
&mut self,
plugin: String,
operation: String,
params: serde_json::Value,
options: Option<InvokeOptions>,
) -> std::result::Result<OperationResult, AppError>;
async fn invoke_graphql(
&mut self,
plugin: String,
document: String,
variables: Option<serde_json::Value>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<serde_json::Value, AppError>;
async fn invoke_graphql_raw(
&mut self,
plugin: String,
document: String,
variables: Option<serde_json::Value>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<OperationResult, AppError>;
}
pub struct App {
client: ProtoAppClient<AppTransport>,
context: Option<pb::RequestContext>,
}
impl App {
pub async fn connect(_request: &Request) -> std::result::Result<Self, AppError> {
let context = current_request_context();
let socket_path = std::env::var(ENV_HOST_SERVICE_SOCKET)
.map_err(|_| AppError::Env(format!("{ENV_HOST_SERVICE_SOCKET} is not set")))?;
let relay_token = std::env::var(ENV_HOST_SERVICE_TOKEN).unwrap_or_default();
let channel = match parse_app_target(&socket_path)? {
AppTarget::Unix(path) => {
Endpoint::try_from("http://[::]:50051")?
.connect_with_connector(service_fn(move |_: Uri| {
let path = path.clone();
async move { UnixStream::connect(path).await.map(TokioIo::new) }
}))
.await?
}
AppTarget::Tcp(address) => {
Endpoint::from_shared(format!("http://{address}"))?
.connect()
.await?
}
AppTarget::Tls(address) => {
Endpoint::from_shared(format!("https://{address}"))?
.tls_config(ClientTlsConfig::new().with_native_roots())?
.connect()
.await?
}
};
Ok(Self {
client: ProtoAppClient::with_interceptor(
channel,
relay_token_interceptor(relay_token.trim())?,
),
context,
})
}
pub async fn invoke<P, T>(
&mut self,
plugin: &str,
operation: &str,
params: P,
options: Option<InvokeOptions>,
) -> std::result::Result<T, AppError>
where
P: Serialize,
T: DeserializeOwned,
{
let result = self.invoke_raw(plugin, operation, params, options).await?;
let decoded = decode_app_result(plugin, operation, &result)?;
Ok(serde_json::from_value(decoded)?)
}
pub async fn invoke_raw<P>(
&mut self,
plugin: &str,
operation: &str,
params: P,
options: Option<InvokeOptions>,
) -> std::result::Result<OperationResult, AppError>
where
P: Serialize,
{
let response = self
.client
.invoke(pb::AppInvokeRequest {
app: plugin.to_string(),
operation: operation.to_string(),
params: Some(serializable_to_struct(params, "params")?),
connection: options
.as_ref()
.map(|opts| opts.connection.clone())
.unwrap_or_default(),
instance: options
.as_ref()
.map(|opts| opts.instance.clone())
.unwrap_or_default(),
idempotency_key: options
.as_ref()
.map(|opts| opts.idempotency_key.trim().to_string())
.unwrap_or_default(),
credential_mode: options
.as_ref()
.map(|opts| opts.credential_mode.trim().to_string())
.unwrap_or_default(),
context: self.context.clone(),
})
.await?
.into_inner();
let status = u16::try_from(response.status).map_err(|_| {
AppError::Protocol(format!("app: invalid response status {}", response.status))
})?;
Ok(OperationResult {
status,
headers: protocol::string_lists_from_proto(&response.headers),
body: response.body,
})
}
pub async fn invoke_graphql<V, T>(
&mut self,
plugin: &str,
document: &str,
variables: Option<V>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<T, AppError>
where
V: Serialize,
T: DeserializeOwned,
{
let result = self
.invoke_graphql_raw(plugin, document, variables, options)
.await?;
let decoded = decode_graphql_result(plugin, &result)?;
Ok(serde_json::from_value(decoded)?)
}
pub async fn invoke_graphql_raw<V>(
&mut self,
plugin: &str,
document: &str,
variables: Option<V>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<OperationResult, AppError>
where
V: Serialize,
{
let document = document.trim();
if document.is_empty() {
return Err(AppError::Protocol(
"app: graphql document is required".to_string(),
));
}
let response = self
.client
.invoke_graph_ql(pb::AppInvokeGraphQlRequest {
app: plugin.to_string(),
document: document.to_string(),
variables: variables
.map(|value| serializable_to_optional_struct(value, "variables"))
.transpose()?
.flatten(),
connection: options
.as_ref()
.map(|opts| opts.connection.clone())
.unwrap_or_default(),
instance: options
.as_ref()
.map(|opts| opts.instance.clone())
.unwrap_or_default(),
idempotency_key: options
.as_ref()
.map(|opts| opts.idempotency_key.trim().to_string())
.unwrap_or_default(),
context: self.context.clone(),
})
.await?
.into_inner();
let status = u16::try_from(response.status).map_err(|_| {
AppError::Protocol(format!("app: invalid response status {}", response.status))
})?;
Ok(OperationResult {
status,
headers: protocol::string_lists_from_proto(&response.headers),
body: response.body,
})
}
}
#[async_trait]
impl AppContract for App {
async fn invoke(
&mut self,
plugin: String,
operation: String,
params: serde_json::Value,
options: Option<InvokeOptions>,
) -> std::result::Result<serde_json::Value, AppError> {
App::invoke::<_, serde_json::Value>(self, &plugin, &operation, params, options).await
}
async fn invoke_raw(
&mut self,
plugin: String,
operation: String,
params: serde_json::Value,
options: Option<InvokeOptions>,
) -> std::result::Result<OperationResult, AppError> {
App::invoke_raw(self, &plugin, &operation, params, options).await
}
async fn invoke_graphql(
&mut self,
plugin: String,
document: String,
variables: Option<serde_json::Value>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<serde_json::Value, AppError> {
App::invoke_graphql::<_, serde_json::Value>(self, &plugin, &document, variables, options)
.await
}
async fn invoke_graphql_raw(
&mut self,
plugin: String,
document: String,
variables: Option<serde_json::Value>,
options: Option<InvokeGraphQLOptions>,
) -> std::result::Result<OperationResult, AppError> {
App::invoke_graphql_raw(self, &plugin, &document, variables, options).await
}
}
enum AppTarget {
Unix(String),
Tcp(String),
Tls(String),
}
fn parse_app_target(raw_target: &str) -> Result<AppTarget, AppError> {
let target = raw_target.trim();
if target.is_empty() {
return Err(AppError::Env(
"app: transport target is required".to_string(),
));
}
if let Some(address) = target.strip_prefix("tcp://") {
let address = address.trim();
if address.is_empty() {
return Err(AppError::Env(format!(
"app: tcp target {raw_target:?} is missing host:port"
)));
}
return Ok(AppTarget::Tcp(address.to_string()));
}
if let Some(address) = target.strip_prefix("tls://") {
let address = address.trim();
if address.is_empty() {
return Err(AppError::Env(format!(
"app: tls target {raw_target:?} is missing host:port"
)));
}
return Ok(AppTarget::Tls(address.to_string()));
}
if let Some(path) = target.strip_prefix("unix://") {
let path = path.trim();
if path.is_empty() {
return Err(AppError::Env(format!(
"app: unix target {raw_target:?} is missing a socket path"
)));
}
return Ok(AppTarget::Unix(path.to_string()));
}
if target.contains("://") {
let scheme = target.split("://").next().unwrap_or_default();
return Err(AppError::Env(format!(
"app: unsupported target scheme {scheme:?}"
)));
}
Ok(AppTarget::Unix(target.to_string()))
}
fn relay_token_interceptor(token: &str) -> Result<RelayTokenInterceptor, AppError> {
let header = if token.trim().is_empty() {
None
} else {
Some(
MetadataValue::try_from(token.to_string())
.map_err(|err| AppError::Env(format!("invalid app relay token metadata: {err}")))?,
)
};
Ok(RelayTokenInterceptor { header })
}
#[derive(Clone)]
struct RelayTokenInterceptor {
header: Option<MetadataValue<tonic::metadata::Ascii>>,
}
impl Interceptor for RelayTokenInterceptor {
fn call(&mut self, mut request: GrpcRequest<()>) -> Result<GrpcRequest<()>, tonic::Status> {
if let Some(header) = self.header.clone() {
request
.metadata_mut()
.insert(APP_RELAY_TOKEN_HEADER, header);
}
Ok(request)
}
}
fn serializable_to_struct<T: Serialize>(
value: T,
field_name: &str,
) -> std::result::Result<prost_types::Struct, AppError> {
let value = protocol::json_value_from_serializable(value)?;
Ok(json_to_optional_struct(value, field_name)?.unwrap_or_default())
}
fn json_to_optional_struct(
value: serde_json::Value,
field_name: &str,
) -> std::result::Result<Option<prost_types::Struct>, AppError> {
if value.is_null() {
return Ok(None);
}
let serde_json::Value::Object(_) = &value else {
return Err(AppError::Protocol(format!(
"app: {field_name} must serialize to a JSON object"
)));
};
protocol::struct_from_json(value)
.map(Some)
.map_err(|err| AppError::Protocol(err.to_string()))
}
fn serializable_to_optional_struct<T: Serialize>(
value: T,
field_name: &str,
) -> std::result::Result<Option<prost_types::Struct>, AppError> {
let value = protocol::json_value_from_serializable(value)?;
json_to_optional_struct(value, field_name)
}