dynamo_subscriber/
client.rs

1use super::{
2    error::Error,
3    types::{GetRecordsOutput, GetShardsOutput, Shard},
4};
5
6use async_trait::async_trait;
7use aws_config::SdkConfig;
8use aws_sdk_dynamodb::Client as DbClient;
9use aws_sdk_dynamodbstreams::{
10    error::SdkError,
11    operation::{
12        get_records::{GetRecordsError, GetRecordsOutput as SdkGetRecordsOutput},
13        get_shard_iterator::GetShardIteratorError,
14    },
15    types::ShardIteratorType,
16    Client as StreamsClient,
17};
18use tracing::warn;
19
20/// Client for both Amazon DynamoDB and Amazon DynamoDB Streams.
21///
22/// A [`SdkConfig`] is required to construct a client.
23/// You can select any ways to get [`SdkConfig`] and pass it
24/// to a client.
25///
26/// For example, if you want to subscribe dynamodb streams from your dynamodb-local
27/// running on localhost:8000, set `endpoint_url` to your [`SdkConfig`].
28///
29/// ```rust,no_run
30/// use dynamo_subscriber::Client;
31///
32/// # async fn wrapper() {
33/// let config = aws_config::load_from_env()
34///     .await
35///     .into_builder()
36///     .endpoint_url("http://localhost:8000")
37///     .build();
38/// let client = Client::new(&config);
39/// # }
40/// ```
41/// See the [`aws-config` docs](aws_config) for more information on customizing configuration.
42#[derive(Debug, Clone)]
43pub struct Client {
44    db: DbClient,
45    streams: StreamsClient,
46}
47
48impl Client {
49    /// Create a new client using passed configuration.
50    ///
51    /// ```rust,no_run
52    /// use dynamo_subscriber::Client;
53    ///
54    /// # async fn wrapper() {
55    /// let config = aws_config::load_from_env().await;
56    /// let client = Client::new(&config);
57    /// # }
58    /// ```
59    pub fn new(config: &SdkConfig) -> Self {
60        Self {
61            db: DbClient::new(config),
62            streams: StreamsClient::new(config),
63        }
64    }
65}
66
67/// An interface to receive DynamoDB Streams records.
68#[async_trait]
69pub trait DynamodbClient: Clone + Send + Sync {
70    /// Return DynamoDB Stream Arn from DynamoDB
71    /// [`TableDescription`](aws_sdk_dynamodb::types::TableDescription).
72    async fn get_stream_arn(&self, table_name: impl Into<String> + Send) -> Result<String, Error>;
73
74    /// Return a vector of [`Shard`](crate::types::Shard) and shard id for next iteration.
75    async fn get_shards(
76        &self,
77        stream_arn: impl Into<String> + Send,
78        exclusive_start_shard_id: Option<String>,
79    ) -> Result<GetShardsOutput, Error>;
80
81    /// Return a [`Option<Shard>`](crate::types::Shard) that is the shard passed as an argument with shard
82    /// iterator id.
83    /// Return None if the aws sdk operation fails due to `ResourceNotFound` of `TrimmedDataAccess`.
84    async fn get_shard_with_iterator(
85        &self,
86        stream_arn: impl Into<String> + Send,
87        shard: Shard,
88        shard_iterator_type: ShardIteratorType,
89    ) -> Result<Option<Shard>, Error>;
90
91    /// Return a vector of [`Record`](aws_sdk_dynamodbstreams::types::Record) and a
92    /// [`Shard`](crate::types::Shard) with shard iterator id for next getting records call.
93    async fn get_records(&self, shard: Shard) -> Result<GetRecordsOutput, Error>;
94}
95
96#[async_trait]
97impl DynamodbClient for Client {
98    async fn get_stream_arn(&self, table_name: impl Into<String> + Send) -> Result<String, Error> {
99        let table_name: String = table_name.into();
100
101        self.db
102            .describe_table()
103            .table_name(&table_name)
104            .send()
105            .await
106            .map_err(|err| Error::SdkError(Box::new(err)))?
107            .table
108            .and_then(|table| table.latest_stream_arn)
109            .ok_or(Error::NotFoundStream(table_name))
110    }
111
112    async fn get_shards(
113        &self,
114        stream_arn: impl Into<String> + Send,
115        exclusive_start_shard_id: Option<String>,
116    ) -> Result<GetShardsOutput, Error> {
117        let stream_arn: String = stream_arn.into();
118
119        self.streams
120            .describe_stream()
121            .stream_arn(&stream_arn)
122            .set_exclusive_start_shard_id(exclusive_start_shard_id)
123            .send()
124            .await
125            .map_err(|err| Error::SdkError(Box::new(err)))?
126            .stream_description
127            .map(|description| {
128                let shards = description
129                    .shards
130                    .unwrap_or_default()
131                    .into_iter()
132                    .filter_map(Shard::new)
133                    .collect::<Vec<Shard>>();
134                let last_shard_id = description.last_evaluated_shard_id;
135
136                GetShardsOutput {
137                    shards,
138                    last_shard_id,
139                }
140            })
141            .ok_or(Error::NotFoundStreamDescription(stream_arn))
142    }
143
144    async fn get_shard_with_iterator(
145        &self,
146        stream_arn: impl Into<String> + Send,
147        shard: Shard,
148        shard_iterator_type: ShardIteratorType,
149    ) -> Result<Option<Shard>, Error> {
150        let iterator = self
151            .streams
152            .get_shard_iterator()
153            .stream_arn(stream_arn)
154            .shard_id(shard.id())
155            .shard_iterator_type(shard_iterator_type)
156            .send()
157            .await
158            .map(|output| output.shard_iterator)
159            .or_else(empty_iterator)?;
160
161        Ok(shard.set_iterator(iterator))
162    }
163
164    async fn get_records(&self, shard: Shard) -> Result<GetRecordsOutput, Error> {
165        let iterator = shard.iterator().map(|val| val.to_string());
166
167        self.streams
168            .get_records()
169            .set_shard_iterator(iterator)
170            .send()
171            .await
172            .or_else(empty_records)
173            .map(|output| {
174                let shard = shard.set_iterator(output.next_shard_iterator);
175                let records = output.records.unwrap_or_default();
176
177                GetRecordsOutput { shard, records }
178            })
179    }
180}
181
182fn empty_iterator(err: SdkError<GetShardIteratorError>) -> Result<Option<String>, Error> {
183    use GetShardIteratorError::*;
184
185    match err {
186        SdkError::ServiceError(e) => {
187            let e = e.into_err();
188            match e {
189                // Retrun Ok(None) if the response is either `ResourceNotFound` or `TrimmedDataAccess`
190                // This means the shard will drop silently because returning None as shard iterator
191                // id results in returning Ok(None) from `get_shard_with_iterator` method.
192                ResourceNotFoundException(_) | TrimmedDataAccessException(_) => {
193                    warn!("GetShardIterator operation failed due to {e}");
194                    warn!("{:#?}", e);
195                    Ok(None)
196                }
197                _ => Err(Error::SdkError(Box::new(e))),
198            }
199        }
200        _ => Err(Error::SdkError(Box::new(err))),
201    }
202}
203
204fn empty_records(err: SdkError<GetRecordsError>) -> Result<SdkGetRecordsOutput, Error> {
205    use GetRecordsError::*;
206
207    match err {
208        SdkError::ServiceError(e) => {
209            let e = e.into_err();
210            match e {
211                // Retrun Ok with default SdkGetRecordsOutput if the response is one of
212                // `ExpiredIterator`, `LimitExceeded`, `ResourceNotFound` and `TrimmedDataAccess`.
213                // This means the shard will drop silently because returning None as shard iterator
214                // id results in returning None as shard in GetRecordsOutput from `get_records` method.
215                ExpiredIteratorException(_)
216                | LimitExceededException(_)
217                | ResourceNotFoundException(_)
218                | TrimmedDataAccessException(_) => {
219                    warn!("GetRecords operation failed due to {e}");
220                    warn!("{:#?}", e);
221                    Ok(SdkGetRecordsOutput::builder().build())
222                }
223                _ => Err(Error::SdkError(Box::new(e))),
224            }
225        }
226        _ => Err(Error::SdkError(Box::new(err))),
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use aws_smithy_runtime_api::{
234        client::{orchestrator::HttpResponse, result::ServiceError},
235        http::StatusCode,
236    };
237    use aws_smithy_types::body::SdkBody;
238
239    #[test]
240    fn empty_iterator_converts_some_errors_to_ok() {
241        use aws_sdk_dynamodbstreams::types::error::*;
242
243        let e = ResourceNotFoundException::builder()
244            .message("error")
245            .build();
246        let err = service_error(GetShardIteratorError::ResourceNotFoundException(e));
247        assert!(empty_iterator(err).is_ok());
248
249        let e = InternalServerError::builder().message("error").build();
250        let err = service_error(GetShardIteratorError::InternalServerError(e));
251        assert!(empty_iterator(err).is_err());
252
253        let e = TrimmedDataAccessException::builder()
254            .message("error")
255            .build();
256        let err = service_error(GetShardIteratorError::TrimmedDataAccessException(e));
257        assert!(empty_iterator(err).is_ok());
258    }
259
260    #[test]
261    fn empty_records_converts_some_errors_to_ok() {
262        use aws_sdk_dynamodbstreams::types::error::*;
263
264        let e = ResourceNotFoundException::builder()
265            .message("error")
266            .build();
267        let err = service_error(GetRecordsError::ResourceNotFoundException(e));
268        assert!(empty_records(err).is_ok());
269
270        let e = InternalServerError::builder().message("error").build();
271        let err = service_error(GetRecordsError::InternalServerError(e));
272        assert!(empty_records(err).is_err());
273
274        let e = ExpiredIteratorException::builder().message("error").build();
275        let err = service_error(GetRecordsError::ExpiredIteratorException(e));
276        assert!(empty_records(err).is_ok());
277
278        let e = LimitExceededException::builder().message("error").build();
279        let err = service_error(GetRecordsError::LimitExceededException(e));
280        assert!(empty_records(err).is_ok());
281
282        let e = TrimmedDataAccessException::builder()
283            .message("error")
284            .build();
285        let err = service_error(GetRecordsError::TrimmedDataAccessException(e));
286        assert!(empty_records(err).is_ok());
287    }
288
289    fn service_error<E>(error: E) -> SdkError<E, HttpResponse> {
290        let resp = HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty());
291        let inner = ServiceError::builder().source(error).raw(resp).build();
292        SdkError::ServiceError(inner)
293    }
294}