Skip to main content

force_pubsub/
publish_sink.rs

1//! Bidirectional streaming publish sink.
2//!
3//! [`PublishSink`] wraps the `PublishStream` gRPC bidirectional streaming RPC,
4//! providing an ergonomic API for sending batches of Avro-encoded events and
5//! receiving per-batch publish responses asynchronously.
6
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use serde::Serialize;
12use tokio::sync::mpsc;
13use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
14use tonic::transport::Channel;
15
16use force::auth::Authenticator;
17use force::session::Session;
18
19use crate::codec::encode_avro;
20use crate::error::{PubSubError, Result};
21use crate::interceptor;
22use crate::proto::eventbus_v1::{ProducerEvent, PublishRequest, pub_sub_client::PubSubClient};
23use crate::schema_cache::SchemaCache;
24use crate::types::{PublishResponse, PublishResult, ReplayId};
25
26/// Internal trait to abstract token + instance_url retrieval without
27/// carrying the full generic `A: Authenticator` parameter on `PublishSink`.
28#[async_trait::async_trait]
29trait TokenGetter {
30    async fn get_token(&self) -> Result<force::auth::AccessToken>;
31}
32
33struct SessionTokenGetter<A: Authenticator> {
34    session: Arc<Session<A>>,
35}
36
37#[async_trait::async_trait]
38impl<A: Authenticator + Send + Sync + 'static> TokenGetter for SessionTokenGetter<A> {
39    async fn get_token(&self) -> Result<force::auth::AccessToken> {
40        self.session
41            .token_manager()
42            .token()
43            .await
44            .map_err(PubSubError::Auth)
45    }
46}
47
48/// A bidirectional streaming publish sink for the Salesforce Pub/Sub API.
49///
50/// Created by [`crate::handler::PubSubHandler::publish_stream`]. Holds an open
51/// gRPC `PublishStream` channel and allows callers to send multiple batches of
52/// events, streaming publish acknowledgements back.
53///
54/// # Type parameter
55///
56/// `T` is the event payload type. It must implement [`serde::Serialize`] so that
57/// payloads can be Avro-encoded before transmission.
58///
59/// # Example
60///
61/// ```ignore
62/// let mut sink = handler.publish_stream::<MyEvent>("/event/MyEvent__e").await?;
63///
64/// sink.send("schema-id", vec![MyEvent { id: "e1".into() }]).await?;
65/// sink.send("schema-id", vec![MyEvent { id: "e2".into() }]).await?;
66///
67/// // Drain acknowledgement responses
68/// let mut acks = sink.responses();
69/// while let Some(resp) = acks.next().await {
70///     let r = resp?;
71///     println!("acked {} event(s) on {}", r.results.len(), r.topic_name);
72/// }
73/// sink.close().await?;
74/// ```
75pub struct PublishSink<T> {
76    /// Sender half of the mpsc channel feeding the gRPC request stream.
77    sender: mpsc::Sender<PublishRequest>,
78    /// Boxed response stream, mapping proto messages to domain types.
79    resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
80    /// Schema cache shared with the handler.
81    schema_cache: SchemaCache,
82    /// gRPC channel for fetching schemas on demand.
83    channel: Channel,
84    /// Token manager reference for fresh auth tokens.
85    session_token_getter: Arc<dyn TokenGetter + Send + Sync>,
86    /// Pre-fetched 18-char org ID.
87    tenant_id: String,
88    /// The topic name this sink is publishing to.
89    topic: String,
90    _phantom: PhantomData<T>,
91}
92
93impl<T: Serialize + Send + 'static> PublishSink<T> {
94    /// Build a `PublishSink` from its constituent parts.
95    ///
96    /// Called exclusively by [`crate::handler::PubSubHandler::publish_stream`].
97    #[allow(clippy::too_many_arguments)]
98    pub(crate) fn new<A: Authenticator + Send + Sync + 'static>(
99        sender: mpsc::Sender<PublishRequest>,
100        resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>>,
101        schema_cache: SchemaCache,
102        channel: Channel,
103        session: Arc<Session<A>>,
104        tenant_id: String,
105        topic: String,
106    ) -> Self {
107        Self {
108            sender,
109            resp_stream,
110            schema_cache,
111            channel,
112            session_token_getter: Arc::new(SessionTokenGetter { session }),
113            tenant_id,
114            topic,
115            _phantom: PhantomData,
116        }
117    }
118
119    /// Encode a batch of events and send them to the open `PublishStream`.
120    ///
121    /// The Avro schema is resolved via the schema cache (or fetched from the
122    /// `GetSchema` RPC on a miss). Subsequent calls reuse the cached schema.
123    ///
124    /// # Errors
125    ///
126    /// - [`PubSubError::Avro`] if an event cannot be Avro-encoded.
127    /// - [`PubSubError::Transport`] if the `GetSchema` RPC fails on a cache miss.
128    /// - [`PubSubError::Config`] if the channel to the gRPC stream is closed.
129    pub async fn send(&mut self, schema_id: &str, events: Vec<T>) -> Result<()> {
130        // Fetch schema (cache hit is O(1)).
131        let token = self.session_token_getter.get_token().await?;
132        let meta = interceptor::build_metadata(&token, token.instance_url(), &self.tenant_id)?;
133        let schema = self
134            .schema_cache
135            .get_or_fetch(schema_id, &self.channel, meta)
136            .await?;
137
138        // Encode each event to Avro bytes.
139        let mut producer_events = Vec::with_capacity(events.len());
140        for event in &events {
141            let payload = encode_avro(&schema, event)?;
142            producer_events.push(ProducerEvent {
143                schema_id: schema_id.to_string(),
144                payload,
145            });
146        }
147
148        let request = PublishRequest {
149            topic_name: self.topic.clone(),
150            events: producer_events,
151        };
152
153        self.sender.send(request).await.map_err(|_| {
154            PubSubError::Config(
155                "PublishStream channel closed — server may have terminated the stream".to_string(),
156            )
157        })
158    }
159
160    /// Return a reference to the server acknowledgement response stream.
161    ///
162    /// Each item is a [`PublishResponse`] containing per-event results for the
163    /// most recently acknowledged batch.
164    ///
165    /// # Errors
166    ///
167    /// Items may be `Err(PubSubError::Transport)` if the gRPC stream reports
168    /// an error.
169    pub fn responses(&mut self) -> &mut (impl Stream<Item = Result<PublishResponse>> + '_) {
170        &mut self.resp_stream
171    }
172
173    /// Close the sink.
174    ///
175    /// Drops the sender side of the mpsc channel, which signals to tonic that
176    /// the client input stream is complete. Then drains any remaining server
177    /// acknowledgement responses so the gRPC stream shuts down cleanly.
178    ///
179    /// # Errors
180    ///
181    /// Returns the first transport error encountered while draining responses,
182    /// if any.
183    pub async fn close(mut self) -> Result<()> {
184        // Drop sender → closes mpsc → tonic stream sees EOF.
185        drop(self.sender);
186
187        // Drain remaining responses.
188        while let Some(item) = self.resp_stream.next().await {
189            item?;
190        }
191
192        Ok(())
193    }
194}
195
196/// Map a single proto `PublishResponse` to the domain `PublishResponse`.
197fn map_proto_response(proto_resp: crate::proto::eventbus_v1::PublishResponse) -> PublishResponse {
198    let results = proto_resp
199        .results
200        .into_iter()
201        .map(|r| PublishResult {
202            replay_id: if r.replay_id.is_empty() {
203                None
204            } else {
205                Some(ReplayId::from_bytes(r.replay_id))
206            },
207            error: r.error.and_then(|e| {
208                if e.code == 0 && e.msg.is_empty() {
209                    None
210                } else {
211                    Some(e.msg)
212                }
213            }),
214        })
215        .collect();
216    PublishResponse {
217        topic_name: proto_resp.topic_name,
218        results,
219    }
220}
221
222/// Open a `PublishStream` RPC and return a [`PublishSink`].
223///
224/// Called by [`crate::handler::PubSubHandler::publish_stream`] after
225/// auth setup.
226pub async fn open_publish_stream<A, T>(
227    session: Arc<Session<A>>,
228    channel: Channel,
229    schema_cache: SchemaCache,
230    tenant_id: String,
231    topic: String,
232    token: &force::auth::AccessToken,
233) -> Result<PublishSink<T>>
234where
235    A: Authenticator + Send + Sync + 'static,
236    T: Serialize + Send + 'static,
237{
238    let (tx, rx) = mpsc::channel::<PublishRequest>(32);
239    let meta = interceptor::build_metadata(token, token.instance_url(), &tenant_id)?;
240
241    let mut req = tonic::Request::new(ReceiverStream::new(rx));
242    *req.metadata_mut() = meta;
243
244    let streaming = PubSubClient::new(channel.clone())
245        .publish_stream(req)
246        .await?
247        .into_inner();
248
249    // Convert the tonic Streaming<ProtoPublishResponse> into a pinned boxed
250    // domain stream so PublishSink doesn't need to be generic over Streaming.
251    let resp_stream: Pin<Box<dyn Stream<Item = Result<PublishResponse>> + Send>> =
252        Box::pin(streaming.map(|item| match item {
253            Ok(proto_resp) => Ok(map_proto_response(proto_resp)),
254            Err(status) => Err(PubSubError::Transport(status)),
255        }));
256
257    Ok(PublishSink::new(
258        tx,
259        resp_stream,
260        schema_cache,
261        channel,
262        session,
263        tenant_id,
264        topic,
265    ))
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_map_proto_response_success() {
274        use crate::proto::eventbus_v1::{
275            PublishResponse as ProtoResp, PublishResult as ProtoResult,
276        };
277        let proto = ProtoResp {
278            topic_name: "/event/Test__e".to_string(),
279            results: vec![ProtoResult {
280                replay_id: vec![1, 2, 3],
281                error: None,
282            }],
283            rpc_id: None,
284        };
285        let resp = map_proto_response(proto);
286        assert_eq!(resp.topic_name, "/event/Test__e");
287        assert_eq!(resp.results.len(), 1);
288        assert!(resp.results[0].is_success());
289        let Some(replay_id) = resp.results[0].replay_id.as_ref() else {
290            panic!("expected replay_id")
291        };
292        assert_eq!(replay_id.as_bytes(), &[1, 2, 3]);
293    }
294
295    #[test]
296    fn test_map_proto_response_error_result() {
297        use crate::proto::eventbus_v1::{
298            PubSubError as ProtoErr, PublishResponse as ProtoResp, PublishResult as ProtoResult,
299        };
300        let proto = ProtoResp {
301            topic_name: "/event/Test__e".to_string(),
302            results: vec![ProtoResult {
303                replay_id: vec![],
304                error: Some(ProtoErr {
305                    code: 1,
306                    msg: "INVALID_PAYLOAD".to_string(),
307                    key: None,
308                }),
309            }],
310            rpc_id: None,
311        };
312        let resp = map_proto_response(proto);
313        assert!(!resp.results[0].is_success());
314        assert_eq!(resp.results[0].error.as_deref(), Some("INVALID_PAYLOAD"));
315    }
316
317    #[test]
318    fn test_publish_result_success_is_success() {
319        let r = PublishResult {
320            replay_id: Some(ReplayId::from_bytes(vec![1, 2, 3])),
321            error: None,
322        };
323        assert!(r.is_success());
324    }
325
326    #[test]
327    fn test_publish_result_error_is_not_success() {
328        let r = PublishResult {
329            replay_id: None,
330            error: Some("PUBLISH_ERROR".to_string()),
331        };
332        assert!(!r.is_success());
333    }
334}