1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_channel::Receiver;
8use tokio::sync::oneshot;
9use tokio::sync::Mutex;
10use tokio::task::JoinHandle;
11use tokio::time::timeout;
12
13use google_cloud_gax::grpc::Status;
14use google_cloud_gax::retry::RetrySetting;
15use google_cloud_googleapis::pubsub::v1::{PublishRequest, PubsubMessage};
16
17use crate::apiv1::publisher_client::PublisherClient;
18use crate::util::ToUsize;
19
20pub(crate) struct ReservedMessage {
21 pub producer: oneshot::Sender<Result<String, Status>>,
22 pub message: PubsubMessage,
23}
24
25pub(crate) enum Reserved {
26 Single(ReservedMessage),
27 Multi(Vec<ReservedMessage>),
28}
29
30#[derive(Debug, Clone)]
31pub struct PublisherConfig {
32 pub workers: usize,
34 pub flush_interval: Duration,
36 pub bundle_size: usize,
38 pub retry_setting: Option<RetrySetting>,
39}
40
41impl Default for PublisherConfig {
42 fn default() -> Self {
43 Self {
44 workers: 3,
45 flush_interval: Duration::from_millis(100),
46 bundle_size: 3,
47 retry_setting: None,
48 }
49 }
50}
51
52pub struct Awaiter {
53 consumer: oneshot::Receiver<Result<String, Status>>,
54}
55
56impl Awaiter {
57 pub(crate) fn new(consumer: oneshot::Receiver<Result<String, Status>>) -> Self {
58 Self { consumer }
59 }
60 pub async fn get(self) -> Result<String, Status> {
61 match self.consumer.await {
62 Ok(v) => v,
63 Err(_e) => Err(Status::cancelled("closed")),
64 }
65 }
66}
67
68#[derive(Clone, Debug)]
73pub struct Publisher {
74 ordering_senders: Arc<Vec<async_channel::Sender<Reserved>>>,
75 sender: async_channel::Sender<Reserved>,
76 tasks: Arc<Mutex<Tasks>>,
77 fqtn: String,
78 pubc: PublisherClient,
79}
80
81impl Publisher {
82 pub(crate) fn new(fqtn: String, pubc: PublisherClient, config: Option<PublisherConfig>) -> Self {
83 let config = config.unwrap_or_default();
84 let (sender, receiver) = async_channel::unbounded::<Reserved>();
85 let mut receivers = Vec::with_capacity(config.workers * 2);
86 let mut ordering_senders = Vec::with_capacity(config.workers);
87
88 for _ in 0..config.workers {
90 tracing::trace!("start non-ordering publisher : {}", fqtn.clone());
91 receivers.push(receiver.clone());
92 }
93
94 for _ in 0..config.workers {
96 tracing::trace!("start ordering publisher : {}", fqtn.clone());
97 let (sender, receiver) = async_channel::unbounded::<Reserved>();
98 receivers.push(receiver);
99 ordering_senders.push(sender);
100 }
101
102 Self {
103 sender,
104 ordering_senders: Arc::new(ordering_senders),
105 tasks: Arc::new(Mutex::new(Tasks::new(fqtn.clone(), pubc.clone(), receivers, config))),
106 fqtn,
107 pubc,
108 }
109 }
110
111 pub async fn publish_immediately(
113 &self,
114 messages: Vec<PubsubMessage>,
115 retry: Option<RetrySetting>,
116 ) -> Result<Vec<String>, Status> {
117 self.pubc
118 .publish(
119 PublishRequest {
120 topic: self.fqtn.clone(),
121 messages,
122 },
123 retry,
124 )
125 .await
126 .map(|v| v.into_inner().message_ids)
127 }
128
129 pub async fn publish(&self, message: PubsubMessage) -> Awaiter {
135 let (producer, consumer) = oneshot::channel();
136 if message.ordering_key.is_empty() {
137 let _ = self
138 .sender
139 .send(Reserved::Single(ReservedMessage { producer, message }))
140 .await;
141 } else {
142 let key = message.ordering_key.as_str().to_usize();
143 let index = key % self.ordering_senders.len();
144 let _ = self.ordering_senders[index]
145 .send(Reserved::Single(ReservedMessage { producer, message }))
146 .await;
147 }
148 Awaiter::new(consumer)
149 }
150
151 pub fn publish_blocking(&self, message: PubsubMessage) -> Awaiter {
155 let (producer, consumer) = oneshot::channel();
156 if message.ordering_key.is_empty() {
157 let _ = self
158 .sender
159 .send_blocking(Reserved::Single(ReservedMessage { producer, message }));
160 } else {
161 let key = message.ordering_key.as_str().to_usize();
162 let index = key % self.ordering_senders.len();
163 let _ = self.ordering_senders[index].send_blocking(Reserved::Single(ReservedMessage { producer, message }));
164 }
165 Awaiter::new(consumer)
166 }
167
168 pub async fn publish_bulk(&self, messages: Vec<PubsubMessage>) -> Vec<Awaiter> {
174 let mut awaiters = Vec::with_capacity(messages.len());
175 let mut split_by_key = HashMap::<String, Vec<ReservedMessage>>::with_capacity(messages.len());
176 for message in messages {
177 let (producer, consumer) = oneshot::channel();
178 awaiters.push(Awaiter::new(consumer));
179 split_by_key
180 .entry(message.ordering_key.clone())
181 .or_default()
182 .push(ReservedMessage { producer, message });
183 }
184
185 for e in split_by_key {
186 if e.0.is_empty() {
187 let _ = self.sender.send(Reserved::Multi(e.1)).await;
188 } else {
189 let key = e.0.as_str().to_usize();
190 let index = key % self.ordering_senders.len();
191 let _ = self.ordering_senders[index].send(Reserved::Multi(e.1)).await;
192 }
193 }
194 awaiters
195 }
196
197 pub async fn shutdown(&mut self) {
198 self.sender.close();
199 for s in self.ordering_senders.iter() {
200 s.close();
201 }
202 self.tasks.lock().await.done().await;
203 }
204}
205
206#[derive(Debug)]
207struct Tasks {
208 inner: Option<Vec<JoinHandle<()>>>,
209}
210
211impl Tasks {
212 pub fn new(
213 topic: String,
214 pubc: PublisherClient,
215 receivers: Vec<async_channel::Receiver<Reserved>>,
216 config: PublisherConfig,
217 ) -> Self {
218 let tasks = receivers
219 .into_iter()
220 .map(|receiver| {
221 Self::run_task(
222 receiver,
223 pubc.clone(),
224 topic.clone(),
225 config.retry_setting.clone(),
226 config.flush_interval,
227 config.bundle_size,
228 )
229 })
230 .collect();
231
232 Self { inner: Some(tasks) }
233 }
234
235 fn run_task(
236 receiver: Receiver<Reserved>,
237 mut client: PublisherClient,
238 topic: String,
239 retry: Option<RetrySetting>,
240 flush_interval: Duration,
241 bundle_size: usize,
242 ) -> JoinHandle<()> {
243 tokio::spawn(async move {
244 let mut bundle = MessageBundle::new();
246 while !receiver.is_closed() {
247 let result = match timeout(flush_interval, &mut receiver.recv()).await {
248 Ok(result) => result,
249 Err(_e) => {
251 if !bundle.is_empty() {
252 tracing::trace!("elapsed: flush buffer : {}", topic);
253 for value in bundle.key_by() {
254 Self::flush(&mut client, topic.as_str(), value, retry.clone()).await;
255 }
256 bundle = MessageBundle::new();
257 }
258 continue;
259 }
260 };
261 match result {
262 Ok(reserved) => {
263 match reserved {
264 Reserved::Single(message) => bundle.push(message),
265 Reserved::Multi(messages) => bundle.extend(messages),
266 }
267 if bundle.len() >= bundle_size {
268 tracing::trace!("bundle size max: {}", topic);
269 for value in bundle.key_by() {
270 Self::flush(&mut client, topic.as_str(), value, retry.clone()).await;
271 }
272 bundle = MessageBundle::new();
273 }
274 }
275 Err(_e) => break,
277 };
278 }
279
280 tracing::trace!("stop publisher : {}", topic);
281 if !bundle.is_empty() {
282 tracing::trace!("flush rest buffer : {}", topic);
283 for value in bundle.key_by() {
284 Self::flush(&mut client, topic.as_str(), value, retry.clone()).await;
285 }
286 }
287 })
288 }
289
290 async fn flush(
292 client: &mut PublisherClient,
293 topic: &str,
294 bundle: Vec<ReservedMessage>,
295 retry_setting: Option<RetrySetting>,
296 ) {
297 let mut data = Vec::<PubsubMessage>::with_capacity(bundle.len());
298 let mut callback = Vec::<oneshot::Sender<Result<String, Status>>>::with_capacity(bundle.len());
299 bundle.into_iter().for_each(|r| {
300 data.push(r.message);
301 callback.push(r.producer);
302 });
303 let req = PublishRequest {
304 topic: topic.to_string(),
305 messages: data,
306 };
307 let result = client
308 .publish(req, retry_setting)
309 .await
310 .map(|v| v.into_inner().message_ids);
311
312 match result {
314 Ok(message_ids) => {
315 for (i, p) in callback.into_iter().enumerate() {
316 let message_id = &message_ids[i];
317 if p.send(Ok(message_id.to_string())).is_err() {
318 tracing::error!("failed to notify : id={message_id}");
319 }
320 }
321 }
322 Err(status) => {
323 for p in callback.into_iter() {
324 let code = status.code();
325 let status = Status::new(code, (*status.message()).to_string());
326 if p.send(Err(status)).is_err() {
327 tracing::error!("failed to notify : status={}", code);
328 }
329 }
330 }
331 };
332 }
333
334 pub async fn done(&mut self) {
336 if let Some(tasks) = self.inner.take() {
337 for task in tasks {
338 let _ = task.await;
339 }
340 }
341 }
342}
343
344struct MessageBundle {
345 inner: Vec<ReservedMessage>,
346}
347
348impl MessageBundle {
349 fn new() -> Self {
350 Self { inner: vec![] }
351 }
352
353 fn key_by(self) -> Vec<Vec<ReservedMessage>> {
354 let mut values = HashMap::<String, Vec<ReservedMessage>>::new();
355 for v in self.inner {
356 let key = v.message.ordering_key.to_string();
357 match values.get_mut(&key) {
358 Some(e) => {
359 e.push(v);
360 }
361 None => {
362 values.insert(key, vec![v]);
363 }
364 }
365 }
366 let mut result = Vec::with_capacity(values.len());
367 for (_, v) in values.into_iter() {
368 result.push(v);
369 }
370 result
371 }
372}
373
374impl Deref for MessageBundle {
375 type Target = Vec<ReservedMessage>;
376
377 fn deref(&self) -> &Self::Target {
378 &self.inner
379 }
380}
381
382impl DerefMut for MessageBundle {
383 fn deref_mut(&mut self) -> &mut Self::Target {
384 &mut self.inner
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::publisher::{MessageBundle, ReservedMessage};
391 use google_cloud_googleapis::pubsub::v1::PubsubMessage;
392 use tokio::sync::oneshot;
393
394 fn msg(key: &str) -> ReservedMessage {
395 let (sender, _) = oneshot::channel();
396 ReservedMessage {
397 producer: sender,
398 message: PubsubMessage {
399 ordering_key: key.to_string(),
400 ..Default::default()
401 },
402 }
403 }
404
405 #[test]
406 fn test_message_bundle_key_by() {
407 let mut bundle = MessageBundle::new();
408 for key in ["", "a", "b", "c", "A", "", "D", "a"] {
409 bundle.push(msg(key));
410 }
411 let msgs = bundle.key_by();
412 assert_eq!(6, msgs.len());
413 for msg in msgs {
414 let key = msg.first().unwrap().message.ordering_key.clone();
415 if key == "a" || key.is_empty() {
416 assert_eq!(2, msg.len());
417 } else {
418 assert_eq!(1, msg.len());
419 }
420 }
421 }
422}