use std::path::Path;
use std::fmt;
use crate::auth::{
decode_jwt, is_token_expired, retrieve_auth_token_client_credentials, set_jwks, AuthInterceptor,
};
use crate::query::QueryEntitiesReturn;
use crate::types::error::HstpError;
use crate::upsert::upsert;
use crate::utils::read_hsml_json;
use crate::{query::query_t, types::entity::HSMLEntity};
use kortex_gen_grpc::hstp::v1::hstp_service_client::HstpServiceClient;
use kortex_gen_grpc::hstp::v1::CollisionStrategy;
use serde_json::Value;
use tonic::codegen::InterceptedService;
use tonic::transport::{Channel, Endpoint};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
#[cfg(feature = "napi")]
use napi_derive::napi;
use serde::de::DeserializeOwned;
#[derive(Debug, Default)]
pub struct ClientCredentials {
pub client_id: String,
pub client_secret: String,
pub auth_domain: String,
pub audience: String,
}
pub struct TimeoutAndRetries {
pub timeout: tokio::time::Duration,
pub retries: u32,
}
impl Default for TimeoutAndRetries {
fn default() -> Self {
Self {
timeout: tokio::time::Duration::from_secs(30),
retries: 3,
}
}
}
pub(crate) type InternalClient = HstpServiceClient<InterceptedService<Channel, AuthInterceptor>>;
#[cfg(feature = "pyo3")]
#[pyclass]
pub struct Client {
client: InternalClient,
client_credentials: Option<ClientCredentials>,
token: String,
retries: u32,
}
#[cfg(feature = "napi")]
#[napi]
pub struct Client {
client: InternalClient,
client_credentials: Option<ClientCredentials>,
token: String,
retries: u32,
}
#[cfg(not(any(feature = "pyo3", feature = "napi")))]
pub struct Client {
client: InternalClient,
client_credentials: Option<ClientCredentials>,
token: String,
retries: u32,
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("client_credentials", &self.client_credentials)
.field("token", &self.token)
.field("retries", &self.retries)
.finish()
}
}
#[derive(Default)]
pub enum Protocol {
HTTP,
#[default]
HTTPS,
}
#[cfg(feature = "napi")]
#[napi]
impl Protocol {}
impl From<Protocol> for &str {
fn from(value: Protocol) -> Self {
match value {
Protocol::HTTP => "http",
Protocol::HTTPS => "https",
}
}
}
impl From<&str> for Protocol {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"http" => Protocol::HTTP,
"https" => Protocol::HTTPS,
_ => panic!("Invalid protocol"),
}
}
}
#[derive(Default)]
pub struct ClientConfig {
pub protocol: Protocol,
pub host: String,
pub port: String,
pub client_id: String,
pub client_secret: String,
pub auth_domain: String,
pub audience: String,
}
#[cfg(not(feature = "napi"))]
impl Client {
pub fn set_token<S: Into<String>>(&mut self, token: S) {
self.token = token.into();
}
pub fn get_token(&self) -> &str {
&self.token
}
#[deprecated(since = "0.3.0-rc12", note = "Please use `new_with_oauth2_token` instead")]
pub async fn new_client_credentials(
config: ClientConfig,
timeout_and_retries: Option<TimeoutAndRetries>,
) -> Result<Self, HstpError> {
let timeout = timeout_and_retries.unwrap_or_default();
let auth_domain_string: String = config.auth_domain;
set_jwks(&auth_domain_string).await?;
let client_credentials = ClientCredentials {
client_id: config.client_id,
client_secret: config.client_secret,
auth_domain: auth_domain_string,
audience: config.audience,
};
let token = retrieve_auth_token_client_credentials(&client_credentials).await?;
let client = Self::construct_internal_client(
config.protocol.into(),
config.host,
config.port,
token.clone(),
&timeout,
)
.await?;
Ok(Self {
client,
client_credentials: Some(client_credentials),
token,
retries: timeout.retries,
})
}
pub async fn new_with_oauth2_token<S: Into<String>, T: Into<String>, U: Into<String>>(
protocol: Protocol,
host: S,
port: T,
token: U,
timeout_and_retries: Option<TimeoutAndRetries>,
) -> Result<Self, HstpError> {
let timeout = timeout_and_retries.unwrap_or_default();
let token = token.into();
let client = Self::construct_internal_client(
protocol.into(),
host.into(),
port.into(),
token.clone(),
&timeout,
)
.await?;
Ok(Self {
client,
client_credentials: None,
retries: timeout.retries,
token,
})
}
async fn construct_internal_client(
protocol: &str,
host: String,
port: String,
token: String,
timeout_and_retries: &TimeoutAndRetries,
) -> Result<InternalClient, tonic::transport::Error> {
let genius_core_endpoint = format!("{}://{}:{}", protocol, host, port);
let connection_timeout_seconds = timeout_and_retries.timeout.as_secs();
let channel = Endpoint::from_shared(genius_core_endpoint)?
.tls_config(tonic::transport::ClientTlsConfig::default())?
.timeout(std::time::Duration::from_secs(connection_timeout_seconds))
.connect()
.await?;
let client = HstpServiceClient::with_interceptor(channel, AuthInterceptor { token });
let max_size = 2 * 1024 * 1024 * 1024; Ok(client
.max_decoding_message_size(max_size)
.max_encoding_message_size(max_size))
}
async fn refresh_token(&mut self) -> Result<(), HstpError> {
if let Some(client_credentials) = &self.client_credentials {
let claims = decode_jwt(&self.token).await?;
if is_token_expired(&claims) {
let token = retrieve_auth_token_client_credentials(client_credentials).await?;
self.token = token;
}
}
Ok(())
}
pub async fn get_user_id(&mut self) -> Result<String, HstpError> {
let claims = decode_jwt(&self.token).await?;
let user_id = claims.sub;
Ok(user_id.clone())
}
pub async fn query_for_entity_array<S: Into<String>>(
&mut self,
query: S,
) -> Result<QueryEntitiesReturn, HstpError> {
self.refresh_token().await?;
query_t::<QueryEntitiesReturn>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_entity<S: Into<String>>(
&mut self,
query: S,
) -> Result<HSMLEntity, HstpError> {
self.refresh_token().await?;
query_t::<HSMLEntity>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_value_array<S: Into<String>>(
&mut self,
query: S,
) -> Result<Vec<Value>, HstpError> {
self.refresh_token().await?;
query_t::<Vec<Value>>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_value<S: Into<String>>(&mut self, query: S) -> Result<Value, HstpError> {
self.refresh_token().await?;
query_t::<Value>(&mut self.client, query.into(), self.retries).await
}
pub async fn query<S: Into<String>>(
&mut self,
query: S,
) -> Result<Value, HstpError> {
self.refresh_token().await?;
query_t::<Value>(&mut self.client, query.into(), self.retries).await
}
pub async fn upsert<V: AsRef<[HSMLEntity]>>(
&mut self,
entities: V,
collision_strategy: CollisionStrategy,
) -> Result<Vec<HSMLEntity>, HstpError> {
self.refresh_token().await?;
upsert(
&mut self.client,
entities.as_ref(),
collision_strategy,
self.retries,
)
.await
}
pub async fn upsert_one(
&mut self,
entity: &HSMLEntity,
collision_strategy: CollisionStrategy,
) -> Result<HSMLEntity, HstpError> {
self.refresh_token().await?;
let entities = vec![entity.clone()];
let entities = upsert(
&mut self.client,
&entities,
collision_strategy,
self.retries,
)
.await?;
Ok(entities[0].clone())
}
pub async fn upsert_hsml_json<P: AsRef<Path>>(
&mut self,
path: P,
collision_strategy: CollisionStrategy,
) -> Result<Vec<HSMLEntity>, HstpError> {
let entities = read_hsml_json(path)?;
self.upsert(entities, collision_strategy).await
}
pub async fn create_listener(&mut self) -> Result<crate::listen::Listener, HstpError> {
self.refresh_token().await?;
crate::listen::Listener::new(&mut self.client).await
}
}
#[cfg(feature = "napi")]
#[napi]
impl Client {
pub fn set_token<S: Into<String>>(&mut self, token: S) {
self.token = token.into();
}
pub fn get_token(&self) -> &str {
&self.token
}
pub async fn new_client_credentials(
config: ClientConfig,
timeout_and_retries: Option<TimeoutAndRetries>,
) -> Result<Self, HstpError> {
let timeout = timeout_and_retries.unwrap_or_default();
let auth_domain_string: String = config.auth_domain;
set_jwks(&auth_domain_string).await?;
let client_credentials = ClientCredentials {
client_id: config.client_id,
client_secret: config.client_secret,
auth_domain: auth_domain_string,
audience: config.audience,
};
let token = retrieve_auth_token_client_credentials(&client_credentials).await?;
let client = Self::construct_internal_client(
config.protocol.into(),
config.host,
config.port,
token.clone(),
&timeout,
)
.await?;
Ok(Self {
client,
client_credentials: Some(client_credentials),
token,
retries: timeout.retries,
})
}
pub async fn new_with_perpetual_token<S: Into<String>, T: Into<String>, U: Into<String>>(
protocol: Protocol,
host: S,
port: T,
token: U,
timeout_and_retries: Option<TimeoutAndRetries>,
) -> Result<Self, HstpError> {
let timeout = timeout_and_retries.unwrap_or_default();
let token = token.into();
let client = Self::construct_internal_client(
protocol.into(),
host.into(),
port.into(),
token.clone(),
&timeout,
)
.await?;
Ok(Self {
client,
client_credentials: None,
retries: timeout.retries,
token,
})
}
async fn construct_internal_client(
protocol: &str,
host: String,
port: String,
token: String,
timeout_and_retries: &TimeoutAndRetries,
) -> Result<InternalClient, tonic::transport::Error> {
let genius_core_endpoint = format!("{}://{}:{}", protocol, host, port);
let connection_timeout_seconds = timeout_and_retries.timeout.as_secs();
let channel = Endpoint::from_shared(genius_core_endpoint)?
.tls_config(tonic::transport::ClientTlsConfig::default())?
.timeout(std::time::Duration::from_secs(connection_timeout_seconds))
.connect()
.await?;
let client = HstpServiceClient::with_interceptor(channel, AuthInterceptor { token });
let max_size = 2 * 1024 * 1024 * 1024; Ok(client
.max_decoding_message_size(max_size)
.max_encoding_message_size(max_size))
}
async fn refresh_token(&mut self) -> Result<(), HstpError> {
if let Some(client_credentials) = &self.client_credentials {
let claims = decode_jwt(&self.token).await?;
if is_token_expired(&claims) {
let token = retrieve_auth_token_client_credentials(client_credentials).await?;
self.token = token;
}
}
Ok(())
}
pub async fn get_user_id(&mut self) -> Result<String, HstpError> {
let claims = decode_jwt(&self.token).await?;
let user_id = claims.sub;
Ok(user_id.clone())
}
pub async fn query_for_entity_array<S: Into<String>>(
&mut self,
query: S,
) -> Result<QueryEntitiesReturn, HstpError> {
self.refresh_token().await?;
query_t::<QueryEntitiesReturn>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_entity<S: Into<String>>(
&mut self,
query: S,
) -> Result<HSMLEntity, HstpError> {
self.refresh_token().await?;
query_t::<HSMLEntity>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_value_array<S: Into<String>>(
&mut self,
query: S,
) -> Result<Vec<Value>, HstpError> {
self.refresh_token().await?;
query_t::<Vec<Value>>(&mut self.client, query.into(), self.retries).await
}
pub async fn query_for_value<S: Into<String>>(&mut self, query: S) -> Result<Value, HstpError> {
self.refresh_token().await?;
query_t::<Value>(&mut self.client, query.into(), self.retries).await
}
pub async fn query<S: Into<String>, T: DeserializeOwned>(
&mut self,
query: S,
) -> Result<T, HstpError> {
self.refresh_token().await?;
query_t::<T>(&mut self.client, query.into(), self.retries).await
}
pub async fn upsert<V: AsRef<[HSMLEntity]>>(
&mut self,
entities: V,
collision_strategy: CollisionStrategy,
) -> Result<Vec<HSMLEntity>, HstpError> {
self.refresh_token().await?;
upsert(
&mut self.client,
entities.as_ref(),
collision_strategy,
self.retries,
)
.await
}
pub async fn upsert_one(
&mut self,
entity: &HSMLEntity,
collision_strategy: CollisionStrategy,
) -> Result<HSMLEntity, HstpError> {
self.refresh_token().await?;
let entities = vec![entity.clone()];
let entities = upsert(
&mut self.client,
&entities,
collision_strategy,
self.retries,
)
.await?;
Ok(entities[0].clone())
}
pub async fn upsert_hsml_json<P: AsRef<Path>>(
&mut self,
path: P,
collision_strategy: CollisionStrategy,
) -> Result<Vec<HSMLEntity>, HstpError> {
let entities = read_hsml_json(path)?;
self.upsert(entities, collision_strategy).await
}
pub async fn create_listener(&mut self) -> Result<crate::listen::Listener, HstpError> {
self.refresh_token().await?;
crate::listen::Listener::new(&mut self.client).await
}
}