use std::pin::Pin;
use std::sync::Arc;
use tonic::transport::{Channel, ClientTlsConfig};
use force::auth::Authenticator;
use force::session::Session;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::OnceCell;
use tokio_stream::Stream;
use crate::config::{PubSubConfig, ReplayPreset};
use crate::error::{PubSubError, Result};
use crate::interceptor;
use crate::publish_sink::{PublishSink, open_publish_stream};
use crate::publisher::publish_unary;
use crate::schema_cache::SchemaCache;
use crate::subscriber::{subscribe_dynamic, subscribe_typed_dynamic};
use crate::types::{PubSubEvent, PublishResponse};
use crate::proto::eventbus_v1::{SchemaRequest, TopicRequest, pub_sub_client::PubSubClient};
#[derive(serde::Deserialize)]
struct UserInfo {
organization_id: String,
}
async fn fetch_tenant_id<A: Authenticator>(session: &Arc<Session<A>>) -> Result<String> {
let token = session.token_manager().token().await?;
let userinfo_url = format!("{}/services/oauth2/userinfo", token.instance_url());
let resp = reqwest::Client::new()
.get(&userinfo_url)
.bearer_auth(token.as_str())
.send()
.await
.map_err(|e| PubSubError::Config(format!("userinfo request failed: {e}")))?;
if !resp.status().is_success() {
return Err(PubSubError::Config(format!(
"userinfo returned status {}",
resp.status()
)));
}
let body = force::http::read_capped_body(resp, 1024 * 1024)
.await
.map_err(|e| PubSubError::Config(format!("userinfo parse failed: {e}")))?;
let info: UserInfo = serde_json::from_str(&body)
.map_err(|e| PubSubError::Config(format!("userinfo parse failed: {e}")))?;
Ok(info.organization_id)
}
#[derive(Debug, Clone)]
pub struct TopicInfo {
pub topic_name: String,
pub topic_uri: String,
pub can_publish: bool,
pub can_subscribe: bool,
pub schema_id: String,
}
#[derive(Debug, Clone)]
pub struct SchemaInfo {
pub schema_id: String,
pub schema_json: String,
}
#[derive(Clone)]
pub struct PubSubHandler<A: Authenticator> {
pub(crate) session: Arc<Session<A>>,
pub(crate) config: PubSubConfig,
pub schema_cache: SchemaCache,
pub(crate) channel: Channel,
tenant_id: Arc<OnceCell<String>>,
}
impl<A: Authenticator> PubSubHandler<A> {
pub async fn connect(session: Arc<Session<A>>, config: PubSubConfig) -> Result<Self> {
if config.batch_size < 1 || config.batch_size > 100 {
return Err(PubSubError::Config(
"batch_size must be between 1 and 100".to_string(),
));
}
let endpoint = Channel::from_shared(config.endpoint.clone())
.map_err(|e| PubSubError::Config(format!("invalid endpoint: {e}")))?;
let endpoint = if endpoint.uri().scheme_str() == Some("https") {
endpoint.tls_config(ClientTlsConfig::new().with_webpki_roots())?
} else {
endpoint
};
let channel = endpoint.connect().await?;
Ok(Self {
session,
config,
schema_cache: SchemaCache::new(),
channel,
tenant_id: Arc::new(OnceCell::new()),
})
}
fn grpc_client(&self) -> PubSubClient<Channel> {
PubSubClient::new(self.channel.clone())
}
pub(crate) async fn get_tenant_id(&self) -> Result<&str> {
self.tenant_id
.get_or_try_init(|| fetch_tenant_id(&self.session))
.await
.map(String::as_str)
}
async fn auth_request<T>(&self, message: T) -> Result<tonic::Request<T>> {
let token = self.session.token_manager().token().await?;
let tenant_id = self.get_tenant_id().await?.to_string();
let meta = interceptor::build_metadata(&token, token.instance_url(), &tenant_id)?;
let mut req = tonic::Request::new(message);
*req.metadata_mut() = meta;
Ok(req)
}
pub async fn get_topic(&self, topic_name: &str) -> Result<TopicInfo> {
let req = self
.auth_request(TopicRequest {
topic_name: topic_name.to_string(),
})
.await?;
let resp = self.grpc_client().get_topic(req).await?;
let info = resp.into_inner();
Ok(TopicInfo {
topic_name: info.topic_name,
topic_uri: info.topic_uri,
can_publish: info.can_publish,
can_subscribe: info.can_subscribe,
schema_id: info.schema_id,
})
}
pub async fn get_schema(&self, schema_id: &str) -> Result<SchemaInfo> {
let req = self
.auth_request(SchemaRequest {
schema_id: schema_id.to_string(),
})
.await?;
let resp = self.grpc_client().get_schema(req).await?;
let info = resp.into_inner();
Ok(SchemaInfo {
schema_id: info.schema_id,
schema_json: info.schema_json,
})
}
}
impl<A: Authenticator + Send + Sync + 'static> PubSubHandler<A> {
pub async fn publish<T: Serialize + Send>(
&self,
topic: &str,
events: Vec<T>,
) -> Result<PublishResponse> {
let topic_info = self.get_topic(topic).await?;
let schema_id = &topic_info.schema_id;
let token = self.session.token_manager().token().await?;
let tenant_id = self.get_tenant_id().await?.to_string();
let meta = interceptor::build_metadata(&token, token.instance_url(), &tenant_id)?;
self.schema_cache
.get_or_fetch(schema_id, &self.channel, meta)
.await?;
publish_unary(
&self.session,
&self.channel,
&self.schema_cache,
schema_id,
topic,
events,
&tenant_id,
)
.await
}
pub async fn subscribe(
&self,
topic: &str,
replay: ReplayPreset,
) -> Result<Pin<Box<dyn Stream<Item = Result<PubSubEvent<Value>>> + Send>>> {
let tenant_id = self.get_tenant_id().await?.to_string();
Ok(subscribe_dynamic(
Arc::clone(&self.session),
self.config.clone(),
self.schema_cache.clone(),
self.channel.clone(),
topic.to_string(),
replay,
tenant_id,
))
}
pub async fn subscribe_typed<T>(
&self,
topic: &str,
replay: ReplayPreset,
) -> Result<Pin<Box<dyn Stream<Item = Result<PubSubEvent<T>>> + Send>>>
where
T: DeserializeOwned + Send + 'static,
{
let tenant_id = self.get_tenant_id().await?.to_string();
Ok(subscribe_typed_dynamic(
Arc::clone(&self.session),
self.config.clone(),
self.schema_cache.clone(),
self.channel.clone(),
topic.to_string(),
replay,
tenant_id,
))
}
pub async fn publish_stream<T: Serialize + Send + 'static>(
&self,
topic: &str,
) -> Result<PublishSink<T>> {
let token = self.session.token_manager().token().await?;
let tenant_id = self.get_tenant_id().await?.to_string();
open_publish_stream(
Arc::clone(&self.session),
self.channel.clone(),
self.schema_cache.clone(),
tenant_id,
topic.to_string(),
&token,
)
.await
}
}