dynamo_subscriber/stream/
dynamodb.rs1use super::{
2 channel::{self, ConsumerChannel, ProducerChannel},
3 types::{GetShardsOutput, Lineages, Shard},
4 DynamodbClient, Error,
5};
6use aws_sdk_dynamodbstreams::types::{Record, ShardIteratorType};
7use std::{
8 cmp,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use tokio::{
14 sync::mpsc,
15 time::{sleep, Duration},
16};
17use tokio_stream::Stream;
18use tracing::error;
19
20#[derive(Debug)]
22pub struct DynamodbStreamProducer<Client>
23where
24 Client: DynamodbClient + 'static,
25{
26 table_name: String,
27 stream_arn: String,
28 shards: Option<Vec<Shard>>,
29 channel: ProducerChannel,
30 client: Client,
31 shard_iterator_type: ShardIteratorType,
32 interval: Option<Duration>,
33 sender: mpsc::Sender<Vec<Record>>,
34}
35
36impl<Client> DynamodbStreamProducer<Client>
37where
38 Client: DynamodbClient + 'static,
39{
40 fn client(&self) -> Arc<Client> {
41 Arc::new(self.client.clone())
42 }
43
44 async fn init(&mut self) -> Result<(), Error> {
46 let stream_arn = self.client.get_stream_arn(&self.table_name).await?;
47 self.stream_arn = stream_arn;
48
49 let shards = self.get_all_shards().await?;
50 let shards = self
51 .get_shard_iterators(shards, self.shard_iterator_type.clone())
52 .await;
53
54 self.shards = Some(shards);
55 self.channel.send_init();
56
57 Ok(())
58 }
59
60 async fn iterate(&mut self) -> Result<Vec<Record>, Error> {
62 let lineages: Lineages = self.shards.take().unwrap_or_default().into();
63 let (mut shards, records) = lineages.get_records(self.client()).await;
64
65 let new_shards = self
66 .get_all_shards()
67 .await?
68 .into_iter()
69 .filter(|shard| !shards.iter().any(|s| s.id() == shard.id()))
70 .collect::<Vec<Shard>>();
71 let mut new_shards = self
72 .get_shard_iterators(new_shards, ShardIteratorType::Latest)
73 .await;
74
75 shards.append(&mut new_shards);
76 self.shards = Some(shards);
77
78 Ok(records)
79 }
80
81 async fn streaming(&mut self) {
83 ok_or_return!(self.init().await, |err| {
84 error!(
85 "Unexpected error during initialization: {err}. Skip polling {} table.",
86 self.table_name,
87 );
88 });
89
90 loop {
91 let records = ok_or_return!(self.iterate().await, |err| {
92 error!(
93 "Unexpected error during iteration: {err}. Stop polling {} table.",
94 self.table_name,
95 );
96 });
97
98 if self.channel.should_close() {
99 return;
100 }
101
102 if !records.is_empty() && self.sender.send(records).await.is_err() {
103 return;
104 }
105
106 if let Some(duration) = self.interval {
107 sleep(duration).await;
108 }
109 }
110 }
111
112 async fn get_all_shards(&self) -> Result<Vec<Shard>, Error> {
114 let GetShardsOutput {
115 mut shards,
116 mut last_shard_id,
117 } = self.client.get_shards(&self.stream_arn, None).await?;
118
119 while last_shard_id.is_some() {
120 let mut output = self
121 .client
122 .get_shards(&self.stream_arn, last_shard_id.take())
123 .await?;
124 shards.append(&mut output.shards);
125 last_shard_id = output.last_shard_id;
126 }
127
128 Ok(shards)
129 }
130
131 async fn get_shard_iterators(
133 &self,
134 shards: Vec<Shard>,
135 shard_iterator_type: ShardIteratorType,
136 ) -> Vec<Shard> {
137 let buf = cmp::max(1, shards.len());
139 let (tx, mut rx) = mpsc::channel::<Shard>(buf);
140 let mut output: Vec<Shard> = vec![];
141 let client = self.client();
142
143 for shard in shards {
144 let tx = tx.clone();
145 let client = Arc::clone(&client);
146 let stream_arn = self.stream_arn.clone();
147 let shard_iterator_type = shard_iterator_type.clone();
148
149 tokio::spawn(async move {
150 let result = client.get_shard_with_iterator(stream_arn, shard, shard_iterator_type);
151 let shard_opt = ok_or_return!(result.await, |err| {
152 error!("Unexpected error during getting shard iterator: {err}");
153 });
154
155 if let Some(shard) = shard_opt {
156 if let Err(err) = tx.send(shard).await {
157 error!("Unexpected error during sending shard: {err}");
158 }
159 }
160 });
161 }
162
163 drop(tx);
164
165 while let Some(shard) = rx.recv().await {
166 output.push(shard);
167 }
168
169 output
170 }
171}
172
173#[derive(Debug)]
177pub struct DynamodbStream {
178 receiver: mpsc::Receiver<Vec<Record>>,
179 channel: Option<ConsumerChannel>,
180}
181
182impl DynamodbStream {
183 pub fn take_channel(&mut self) -> Option<ConsumerChannel> {
207 self.channel.take()
208 }
209}
210
211impl Stream for DynamodbStream {
212 type Item = Vec<Record>;
213
214 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
215 self.receiver.poll_recv(cx)
216 }
217}
218
219impl Drop for DynamodbStream {
220 fn drop(&mut self) {
221 self.receiver.close();
222 if let Some(mut channel) = self.take_channel() {
223 channel.close(|| {});
224 }
225 }
226}
227
228impl AsRef<mpsc::Receiver<Vec<Record>>> for DynamodbStream {
229 fn as_ref(&self) -> &mpsc::Receiver<Vec<Record>> {
230 &self.receiver
231 }
232}
233
234impl AsMut<mpsc::Receiver<Vec<Record>>> for DynamodbStream {
235 fn as_mut(&mut self) -> &mut mpsc::Receiver<Vec<Record>> {
236 &mut self.receiver
237 }
238}
239
240#[derive(Debug)]
242pub struct DynamodbStreamBuilder<Client>
243where
244 Client: DynamodbClient + 'static,
245{
246 table_name: Option<String>,
247 client: Option<Client>,
248 shard_iterator_type: ShardIteratorType,
249 interval: Option<Duration>,
250 buffer: usize,
251}
252
253impl<Client> DynamodbStreamBuilder<Client>
254where
255 Client: DynamodbClient + 'static,
256{
257 pub fn new() -> Self {
259 Self {
260 table_name: None,
261 client: None,
262 shard_iterator_type: ShardIteratorType::Latest,
263 interval: Some(Duration::from_secs(3)),
264 buffer: 100,
265 }
266 }
267
268 pub fn table_name(self, table_name: impl Into<String>) -> Self {
272 Self {
273 table_name: Some(table_name.into()),
274 ..self
275 }
276 }
277
278 pub fn client(self, client: Client) -> Self {
282 Self {
283 client: Some(client),
284 ..self
285 }
286 }
287
288 pub fn shard_iterator_type(self, shard_iterator_type: ShardIteratorType) -> Self {
295 Self {
296 shard_iterator_type,
297 ..self
298 }
299 }
300
301 pub fn interval(self, interval: Option<Duration>) -> Self {
307 Self { interval, ..self }
308 }
309
310 pub fn buffer(self, buffer: usize) -> Self {
321 if buffer == 0 {
322 panic!("buffer must be positive.");
323 }
324
325 Self { buffer, ..self }
326 }
327
328 pub fn build(self) -> DynamodbStream {
332 let (c_half, rx) = self.build_producer();
333
334 DynamodbStream {
335 receiver: rx,
336 channel: Some(c_half),
337 }
338 }
339
340 fn build_producer(self) -> (ConsumerChannel, mpsc::Receiver<Vec<Record>>) {
341 let table_name = self.table_name.expect("`table_name` is required");
342 let client = self.client.expect("`client` is required");
343
344 let (p_half, c_half) = channel::new();
345 let (tx_mpsc, rx_mpsc) = mpsc::channel::<Vec<Record>>(self.buffer);
346
347 let mut producer = DynamodbStreamProducer {
348 table_name,
349 stream_arn: "".to_string(),
350 shards: None,
351 channel: p_half,
352 client,
353 shard_iterator_type: self.shard_iterator_type,
354 interval: self.interval,
355 sender: tx_mpsc,
356 };
357
358 tokio::spawn(async move {
359 producer.streaming().await;
360 });
361
362 (c_half, rx_mpsc)
363 }
364}
365
366impl<Client> Default for DynamodbStreamBuilder<Client>
367where
368 Client: DynamodbClient + 'static,
369{
370 fn default() -> Self {
371 Self::new()
372 }
373}