dynamo_subscriber/stream/
dynamodb.rs

1use super::{
2    channel::{self, ConsumerChannel, ProducerChannel},
3    types::{GetShardsOutput, Lineages, Shard},
4    DynamodbClient, Error,
5};
6use aws_sdk_dynamodbstreams::types::{Record, ShardIteratorType};
7use std::{
8    cmp,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use tokio::{
14    sync::mpsc,
15    time::{sleep, Duration},
16};
17use tokio_stream::Stream;
18use tracing::error;
19
20/// The polling half of DynamoDB Streams.
21#[derive(Debug)]
22pub struct DynamodbStreamProducer<Client>
23where
24    Client: DynamodbClient + 'static,
25{
26    table_name: String,
27    stream_arn: String,
28    shards: Option<Vec<Shard>>,
29    channel: ProducerChannel,
30    client: Client,
31    shard_iterator_type: ShardIteratorType,
32    interval: Option<Duration>,
33    sender: mpsc::Sender<Vec<Record>>,
34}
35
36impl<Client> DynamodbStreamProducer<Client>
37where
38    Client: DynamodbClient + 'static,
39{
40    fn client(&self) -> Arc<Client> {
41        Arc::new(self.client.clone())
42    }
43
44    /// Get shards and shard iterator ids for first attempt to get records.
45    async fn init(&mut self) -> Result<(), Error> {
46        let stream_arn = self.client.get_stream_arn(&self.table_name).await?;
47        self.stream_arn = stream_arn;
48
49        let shards = self.get_all_shards().await?;
50        let shards = self
51            .get_shard_iterators(shards, self.shard_iterator_type.clone())
52            .await;
53
54        self.shards = Some(shards);
55        self.channel.send_init();
56
57        Ok(())
58    }
59
60    /// Get records and renew shards for next iteration.
61    async fn iterate(&mut self) -> Result<Vec<Record>, Error> {
62        let lineages: Lineages = self.shards.take().unwrap_or_default().into();
63        let (mut shards, records) = lineages.get_records(self.client()).await;
64
65        let new_shards = self
66            .get_all_shards()
67            .await?
68            .into_iter()
69            .filter(|shard| !shards.iter().any(|s| s.id() == shard.id()))
70            .collect::<Vec<Shard>>();
71        let mut new_shards = self
72            .get_shard_iterators(new_shards, ShardIteratorType::Latest)
73            .await;
74
75        shards.append(&mut new_shards);
76        self.shards = Some(shards);
77
78        Ok(records)
79    }
80
81    /// Poll the DynamoDB Streams.
82    async fn streaming(&mut self) {
83        ok_or_return!(self.init().await, |err| {
84            error!(
85                "Unexpected error during initialization: {err}. Skip polling {} table.",
86                self.table_name,
87            );
88        });
89
90        loop {
91            let records = ok_or_return!(self.iterate().await, |err| {
92                error!(
93                    "Unexpected error during iteration: {err}. Stop polling {} table.",
94                    self.table_name,
95                );
96            });
97
98            if self.channel.should_close() {
99                return;
100            }
101
102            if !records.is_empty() && self.sender.send(records).await.is_err() {
103                return;
104            }
105
106            if let Some(duration) = self.interval {
107                sleep(duration).await;
108            }
109        }
110    }
111
112    /// Get all shards from the DynamoDB table.
113    async fn get_all_shards(&self) -> Result<Vec<Shard>, Error> {
114        let GetShardsOutput {
115            mut shards,
116            mut last_shard_id,
117        } = self.client.get_shards(&self.stream_arn, None).await?;
118
119        while last_shard_id.is_some() {
120            let mut output = self
121                .client
122                .get_shards(&self.stream_arn, last_shard_id.take())
123                .await?;
124            shards.append(&mut output.shards);
125            last_shard_id = output.last_shard_id;
126        }
127
128        Ok(shards)
129    }
130
131    /// Get and set shard iterator.
132    async fn get_shard_iterators(
133        &self,
134        shards: Vec<Shard>,
135        shard_iterator_type: ShardIteratorType,
136    ) -> Vec<Shard> {
137        // The buffer size must be positive (not zero).
138        let buf = cmp::max(1, shards.len());
139        let (tx, mut rx) = mpsc::channel::<Shard>(buf);
140        let mut output: Vec<Shard> = vec![];
141        let client = self.client();
142
143        for shard in shards {
144            let tx = tx.clone();
145            let client = Arc::clone(&client);
146            let stream_arn = self.stream_arn.clone();
147            let shard_iterator_type = shard_iterator_type.clone();
148
149            tokio::spawn(async move {
150                let result = client.get_shard_with_iterator(stream_arn, shard, shard_iterator_type);
151                let shard_opt = ok_or_return!(result.await, |err| {
152                    error!("Unexpected error during getting shard iterator: {err}");
153                });
154
155                if let Some(shard) = shard_opt {
156                    if let Err(err) = tx.send(shard).await {
157                        error!("Unexpected error during sending shard: {err}");
158                    }
159                }
160            });
161        }
162
163        drop(tx);
164
165        while let Some(shard) = rx.recv().await {
166            output.push(shard);
167        }
168
169        output
170    }
171}
172
173/// Represent DynamoDB Stream.
174///
175/// This struct receives DynamoDB Stream records from polling half and emit them as Rust Stream.
176#[derive(Debug)]
177pub struct DynamodbStream {
178    receiver: mpsc::Receiver<Vec<Record>>,
179    channel: Option<ConsumerChannel>,
180}
181
182impl DynamodbStream {
183    /// Get [`ConsumerChannel`] as communication channel to the stream.
184    ///
185    /// Once you take a channel from this method, you can't take it anymore from the same channel
186    /// because this method also passes the ownership of the channel.
187    ///
188    /// ```rust,no_run
189    /// use aws_config::BehaviorVersion;
190    /// use dynamo_subscriber as subscriber;
191    ///
192    /// # async fn wrapper() {
193    /// # let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
194    /// # let client = subscriber::Client::new(&config);
195    /// let mut stream = subscriber::stream::builder()
196    ///     .client(client)
197    ///     .table_name("People")
198    ///     .build();
199    /// let channel = stream.take_channel();
200    /// assert!(channel.is_some());
201    ///
202    /// let channel = stream.take_channel();
203    /// assert!(channel.is_none());
204    /// # }
205    /// ```
206    pub fn take_channel(&mut self) -> Option<ConsumerChannel> {
207        self.channel.take()
208    }
209}
210
211impl Stream for DynamodbStream {
212    type Item = Vec<Record>;
213
214    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
215        self.receiver.poll_recv(cx)
216    }
217}
218
219impl Drop for DynamodbStream {
220    fn drop(&mut self) {
221        self.receiver.close();
222        if let Some(mut channel) = self.take_channel() {
223            channel.close(|| {});
224        }
225    }
226}
227
228impl AsRef<mpsc::Receiver<Vec<Record>>> for DynamodbStream {
229    fn as_ref(&self) -> &mpsc::Receiver<Vec<Record>> {
230        &self.receiver
231    }
232}
233
234impl AsMut<mpsc::Receiver<Vec<Record>>> for DynamodbStream {
235    fn as_mut(&mut self) -> &mut mpsc::Receiver<Vec<Record>> {
236        &mut self.receiver
237    }
238}
239
240/// A builder for [`DynamodbStream`].
241#[derive(Debug)]
242pub struct DynamodbStreamBuilder<Client>
243where
244    Client: DynamodbClient + 'static,
245{
246    table_name: Option<String>,
247    client: Option<Client>,
248    shard_iterator_type: ShardIteratorType,
249    interval: Option<Duration>,
250    buffer: usize,
251}
252
253impl<Client> DynamodbStreamBuilder<Client>
254where
255    Client: DynamodbClient + 'static,
256{
257    /// Create a new `DynamodbStreamBuilder`.
258    pub fn new() -> Self {
259        Self {
260            table_name: None,
261            client: None,
262            shard_iterator_type: ShardIteratorType::Latest,
263            interval: Some(Duration::from_secs(3)),
264            buffer: 100,
265        }
266    }
267
268    /// Set table name you want to retrieve records from.
269    ///
270    /// **Setting any table name is required** before the build method is called.
271    pub fn table_name(self, table_name: impl Into<String>) -> Self {
272        Self {
273            table_name: Some(table_name.into()),
274            ..self
275        }
276    }
277
278    /// Set client to call AWS APIs.
279    ///
280    /// **Setting any client is required** before the build method is called.
281    pub fn client(self, client: Client) -> Self {
282        Self {
283            client: Some(client),
284            ..self
285        }
286    }
287
288    /// Set [`ShardIteratorType`] to get records for the first time.
289    /// After the first time, the DynamodbStream uses the shard iterator from the previous
290    /// `get records` operation outputs.
291    ///
292    /// Setting any shard iterator type is optional. If you omit calling this method,
293    /// `ShardIteratorType::Latest` is used as default value.
294    pub fn shard_iterator_type(self, shard_iterator_type: ShardIteratorType) -> Self {
295        Self {
296            shard_iterator_type,
297            ..self
298        }
299    }
300
301    /// Set interval between polling attempts. When None is provided there are no intervals between
302    /// polling iterations.
303    ///
304    /// Setting any interval is optional. If you omit calling this method,
305    /// `3 seconds` is used as default value.
306    pub fn interval(self, interval: Option<Duration>) -> Self {
307        Self { interval, ..self }
308    }
309
310    /// Set the buffer for [`tokio::sync::mpsc::channel`](tokio::sync::mpsc::channel).
311    ///
312    /// The stream records are stored up to the buffer size unless the records are consumed.
313    /// Once the buffer is full, attempts to receive records from the DynamoDB Streams will
314    /// wait until the records is consumed.
315    ///
316    /// This method will panic when given zero as buffer size.
317    ///
318    /// Setting buffer size is optional. If you omit calling this method,
319    /// `100` is used as default value.
320    pub fn buffer(self, buffer: usize) -> Self {
321        if buffer == 0 {
322            panic!("buffer must be positive.");
323        }
324
325        Self { buffer, ..self }
326    }
327
328    /// Consumes the builder and constructs a [`DynamodbStream`].
329    ///
330    /// This method will panic if no table name is set or no client is set.
331    pub fn build(self) -> DynamodbStream {
332        let (c_half, rx) = self.build_producer();
333
334        DynamodbStream {
335            receiver: rx,
336            channel: Some(c_half),
337        }
338    }
339
340    fn build_producer(self) -> (ConsumerChannel, mpsc::Receiver<Vec<Record>>) {
341        let table_name = self.table_name.expect("`table_name` is required");
342        let client = self.client.expect("`client` is required");
343
344        let (p_half, c_half) = channel::new();
345        let (tx_mpsc, rx_mpsc) = mpsc::channel::<Vec<Record>>(self.buffer);
346
347        let mut producer = DynamodbStreamProducer {
348            table_name,
349            stream_arn: "".to_string(),
350            shards: None,
351            channel: p_half,
352            client,
353            shard_iterator_type: self.shard_iterator_type,
354            interval: self.interval,
355            sender: tx_mpsc,
356        };
357
358        tokio::spawn(async move {
359            producer.streaming().await;
360        });
361
362        (c_half, rx_mpsc)
363    }
364}
365
366impl<Client> Default for DynamodbStreamBuilder<Client>
367where
368    Client: DynamodbClient + 'static,
369{
370    fn default() -> Self {
371        Self::new()
372    }
373}