use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
use tonic::transport::Channel;
use force::auth::Authenticator;
use force::session::Session;
use crate::codec::encode_avro;
use crate::error::{PubSubError, Result};
use crate::interceptor;
use crate::proto::eventbus_v1::{ProducerEvent, PublishRequest, pub_sub_client::PubSubClient};
use crate::schema_cache::SchemaCache;
use crate::types::{PublishResponse, PublishResult, ReplayId};
#[async_trait::async_trait]
trait TokenGetter {
async fn get_token(&self) -> Result<force::auth::AccessToken>;
}
struct SessionTokenGetter<A: Authenticator> {
session: Arc<Session<A>>,
}
#[async_trait::async_trait]
impl<A: Authenticator + Send + Sync + 'static> TokenGetter for SessionTokenGetter<A> {
async fn get_token(&self) -> Result<force::auth::AccessToken> {
self.session
.token_manager()
.token()
.await
.map_err(PubSubError::Auth)
}
}
pub struct PublishSink<T> {
sender: mpsc::Sender<PublishRequest>,
resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
schema_cache: SchemaCache,
channel: Channel,
session_token_getter: Arc<dyn TokenGetter + Send + Sync>,
tenant_id: String,
topic: String,
_phantom: PhantomData<T>,
}
impl<T: Serialize + Send + 'static> PublishSink<T> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new<A: Authenticator + Send + Sync + 'static>(
sender: mpsc::Sender<PublishRequest>,
resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
schema_cache: SchemaCache,
channel: Channel,
session: Arc<Session<A>>,
tenant_id: String,
topic: String,
) -> Self {
Self {
sender,
resp_stream,
schema_cache,
channel,
session_token_getter: Arc::new(SessionTokenGetter { session }),
tenant_id,
topic,
_phantom: PhantomData,
}
}
pub async fn send(&mut self, schema_id: &str, events: Vec<T>) -> Result<()> {
let token = self.session_token_getter.get_token().await?;
let meta = interceptor::build_metadata(&token, token.instance_url(), &self.tenant_id)?;
let schema = self
.schema_cache
.get_or_fetch(schema_id, &self.channel, meta)
.await?;
let mut producer_events = Vec::with_capacity(events.len());
for event in &events {
let payload = encode_avro(&schema, event)?;
producer_events.push(ProducerEvent {
schema_id: schema_id.to_string(),
payload,
});
}
let request = PublishRequest {
topic_name: self.topic.clone(),
events: producer_events,
};
self.sender.send(request).await.map_err(|_| {
PubSubError::Config(
"PublishStream channel closed — server may have terminated the stream".to_string(),
)
})
}
pub fn responses(&mut self) -> &mut (impl Stream<Item = Result<PublishResponse>> + '_) {
&mut self.resp_stream
}
pub async fn close(mut self) -> Result<()> {
drop(self.sender);
while let Some(item) = self.resp_stream.next().await {
item?;
}
Ok(())
}
}
fn map_proto_response(proto_resp: crate::proto::eventbus_v1::PublishResponse) -> PublishResponse {
let results = proto_resp
.results
.into_iter()
.map(|r| PublishResult {
replay_id: if r.replay_id.is_empty() {
None
} else {
Some(ReplayId::from_bytes(r.replay_id))
},
error: r.error.and_then(|e| {
if e.code == 0 && e.msg.is_empty() {
None
} else {
Some(e.msg)
}
}),
})
.collect();
PublishResponse {
topic_name: proto_resp.topic_name,
results,
}
}
pub async fn open_publish_stream<A, T>(
session: Arc<Session<A>>,
channel: Channel,
schema_cache: SchemaCache,
tenant_id: String,
topic: String,
token: &force::auth::AccessToken,
) -> Result<PublishSink<T>>
where
A: Authenticator + Send + Sync + 'static,
T: Serialize + Send + 'static,
{
let (tx, rx) = mpsc::channel::<PublishRequest>(32);
let meta = interceptor::build_metadata(token, token.instance_url(), &tenant_id)?;
let mut req = tonic::Request::new(ReceiverStream::new(rx));
*req.metadata_mut() = meta;
let streaming = PubSubClient::new(channel.clone())
.publish_stream(req)
.await?
.into_inner();
let resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>> =
Box::pin(streaming.map(|item| match item {
Ok(proto_resp) => Ok(map_proto_response(proto_resp)),
Err(status) => Err(PubSubError::Transport(status)),
}));
Ok(PublishSink::new(
tx,
resp_stream,
schema_cache,
channel,
session,
tenant_id,
topic,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map_proto_response_success() {
use crate::proto::eventbus_v1::{
PublishResponse as ProtoResp, PublishResult as ProtoResult,
};
let proto = ProtoResp {
topic_name: "/event/Test__e".to_string(),
results: vec![ProtoResult {
replay_id: vec![1, 2, 3],
error: None,
}],
rpc_id: None,
};
let resp = map_proto_response(proto);
assert_eq!(resp.topic_name, "/event/Test__e");
assert_eq!(resp.results.len(), 1);
assert!(resp.results[0].is_success());
let Some(replay_id) = resp.results[0].replay_id.as_ref() else {
panic!("expected replay_id")
};
assert_eq!(replay_id.as_bytes(), &[1, 2, 3]);
}
#[test]
fn test_map_proto_response_error_result() {
use crate::proto::eventbus_v1::{
PubSubError as ProtoErr, PublishResponse as ProtoResp, PublishResult as ProtoResult,
};
let proto = ProtoResp {
topic_name: "/event/Test__e".to_string(),
results: vec![ProtoResult {
replay_id: vec![],
error: Some(ProtoErr {
code: 1,
msg: "INVALID_PAYLOAD".to_string(),
key: None,
}),
}],
rpc_id: None,
};
let resp = map_proto_response(proto);
assert!(!resp.results[0].is_success());
assert_eq!(resp.results[0].error.as_deref(), Some("INVALID_PAYLOAD"));
}
#[test]
fn test_publish_result_success_is_success() {
let r = PublishResult {
replay_id: Some(ReplayId::from_bytes(vec![1, 2, 3])),
error: None,
};
assert!(r.is_success());
}
#[test]
fn test_publish_result_error_is_not_success() {
let r = PublishResult {
replay_id: None,
error: Some("PUBLISH_ERROR".to_string()),
};
assert!(!r.is_success());
}
}