1use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use serde::Serialize;
12use tokio::sync::mpsc;
13use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
14use tonic::transport::Channel;
15
16use force::auth::Authenticator;
17use force::session::Session;
18
19use crate::codec::encode_avro;
20use crate::error::{PubSubError, Result};
21use crate::interceptor;
22use crate::proto::eventbus_v1::{ProducerEvent, PublishRequest, pub_sub_client::PubSubClient};
23use crate::schema_cache::SchemaCache;
24use crate::types::{PublishResponse, PublishResult, ReplayId};
25
26#[async_trait::async_trait]
29trait TokenGetter {
30 async fn get_token(&self) -> Result<force::auth::AccessToken>;
31}
32
33struct SessionTokenGetter<A: Authenticator> {
34 session: Arc<Session<A>>,
35}
36
37#[async_trait::async_trait]
38impl<A: Authenticator + Send + Sync + 'static> TokenGetter for SessionTokenGetter<A> {
39 async fn get_token(&self) -> Result<force::auth::AccessToken> {
40 self.session
41 .token_manager()
42 .token()
43 .await
44 .map_err(PubSubError::Auth)
45 }
46}
47
48pub struct PublishSink<T> {
76 sender: mpsc::Sender<PublishRequest>,
78 resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
80 schema_cache: SchemaCache,
82 channel: Channel,
84 session_token_getter: Arc<dyn TokenGetter + Send + Sync>,
86 tenant_id: String,
88 topic: String,
90 _phantom: PhantomData<T>,
91}
92
93impl<T: Serialize + Send + 'static> PublishSink<T> {
94 #[allow(clippy::too_many_arguments)]
98 pub(crate) fn new<A: Authenticator + Send + Sync + 'static>(
99 sender: mpsc::Sender<PublishRequest>,
100 resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
101 schema_cache: SchemaCache,
102 channel: Channel,
103 session: Arc<Session<A>>,
104 tenant_id: String,
105 topic: String,
106 ) -> Self {
107 Self {
108 sender,
109 resp_stream,
110 schema_cache,
111 channel,
112 session_token_getter: Arc::new(SessionTokenGetter { session }),
113 tenant_id,
114 topic,
115 _phantom: PhantomData,
116 }
117 }
118
119 pub async fn send(&mut self, schema_id: &str, events: Vec<T>) -> Result<()> {
130 let token = self.session_token_getter.get_token().await?;
132 let meta = interceptor::build_metadata(&token, token.instance_url(), &self.tenant_id)?;
133 let schema = self
134 .schema_cache
135 .get_or_fetch(schema_id, &self.channel, meta)
136 .await?;
137
138 let mut producer_events = Vec::with_capacity(events.len());
140 for event in &events {
141 let payload = encode_avro(&schema, event)?;
142 producer_events.push(ProducerEvent {
143 schema_id: schema_id.to_string(),
144 payload,
145 });
146 }
147
148 let request = PublishRequest {
149 topic_name: self.topic.clone(),
150 events: producer_events,
151 };
152
153 self.sender.send(request).await.map_err(|_| {
154 PubSubError::Config(
155 "PublishStream channel closed — server may have terminated the stream".to_string(),
156 )
157 })
158 }
159
160 pub fn responses(&mut self) -> &mut (impl Stream<Item = Result<PublishResponse>> + '_) {
170 &mut self.resp_stream
171 }
172
173 pub async fn close(mut self) -> Result<()> {
184 drop(self.sender);
186
187 while let Some(item) = self.resp_stream.next().await {
189 item?;
190 }
191
192 Ok(())
193 }
194}
195
196fn map_proto_response(proto_resp: crate::proto::eventbus_v1::PublishResponse) -> PublishResponse {
198 let results = proto_resp
199 .results
200 .into_iter()
201 .map(|r| PublishResult {
202 replay_id: if r.replay_id.is_empty() {
203 None
204 } else {
205 Some(ReplayId::from_bytes(r.replay_id))
206 },
207 error: r.error.and_then(|e| {
208 if e.code == 0 && e.msg.is_empty() {
209 None
210 } else {
211 Some(e.msg)
212 }
213 }),
214 })
215 .collect();
216 PublishResponse {
217 topic_name: proto_resp.topic_name,
218 results,
219 }
220}
221
222pub async fn open_publish_stream<A, T>(
227 session: Arc<Session<A>>,
228 channel: Channel,
229 schema_cache: SchemaCache,
230 tenant_id: String,
231 topic: String,
232 token: &force::auth::AccessToken,
233) -> Result<PublishSink<T>>
234where
235 A: Authenticator + Send + Sync + 'static,
236 T: Serialize + Send + 'static,
237{
238 let (tx, rx) = mpsc::channel::<PublishRequest>(32);
239 let meta = interceptor::build_metadata(token, token.instance_url(), &tenant_id)?;
240
241 let mut req = tonic::Request::new(ReceiverStream::new(rx));
242 *req.metadata_mut() = meta;
243
244 let streaming = PubSubClient::new(channel.clone())
245 .publish_stream(req)
246 .await?
247 .into_inner();
248
249 let resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>> =
252 Box::pin(streaming.map(|item| match item {
253 Ok(proto_resp) => Ok(map_proto_response(proto_resp)),
254 Err(status) => Err(PubSubError::Transport(status)),
255 }));
256
257 Ok(PublishSink::new(
258 tx,
259 resp_stream,
260 schema_cache,
261 channel,
262 session,
263 tenant_id,
264 topic,
265 ))
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_map_proto_response_success() {
274 use crate::proto::eventbus_v1::{
275 PublishResponse as ProtoResp, PublishResult as ProtoResult,
276 };
277 let proto = ProtoResp {
278 topic_name: "/event/Test__e".to_string(),
279 results: vec![ProtoResult {
280 replay_id: vec![1, 2, 3],
281 error: None,
282 }],
283 rpc_id: None,
284 };
285 let resp = map_proto_response(proto);
286 assert_eq!(resp.topic_name, "/event/Test__e");
287 assert_eq!(resp.results.len(), 1);
288 assert!(resp.results[0].is_success());
289 let Some(replay_id) = resp.results[0].replay_id.as_ref() else {
290 panic!("expected replay_id")
291 };
292 assert_eq!(replay_id.as_bytes(), &[1, 2, 3]);
293 }
294
295 #[test]
296 fn test_map_proto_response_error_result() {
297 use crate::proto::eventbus_v1::{
298 PubSubError as ProtoErr, PublishResponse as ProtoResp, PublishResult as ProtoResult,
299 };
300 let proto = ProtoResp {
301 topic_name: "/event/Test__e".to_string(),
302 results: vec![ProtoResult {
303 replay_id: vec![],
304 error: Some(ProtoErr {
305 code: 1,
306 msg: "INVALID_PAYLOAD".to_string(),
307 key: None,
308 }),
309 }],
310 rpc_id: None,
311 };
312 let resp = map_proto_response(proto);
313 assert!(!resp.results[0].is_success());
314 assert_eq!(resp.results[0].error.as_deref(), Some("INVALID_PAYLOAD"));
315 }
316
317 #[test]
318 fn test_publish_result_success_is_success() {
319 let r = PublishResult {
320 replay_id: Some(ReplayId::from_bytes(vec![1, 2, 3])),
321 error: None,
322 };
323 assert!(r.is_success());
324 }
325
326 #[test]
327 fn test_publish_result_error_is_not_success() {
328 let r = PublishResult {
329 replay_id: None,
330 error: Some("PUBLISH_ERROR".to_string()),
331 };
332 assert!(!r.is_success());
333 }
334}