dynamo_subscriber/
client.rs1use super::{
2 error::Error,
3 types::{GetRecordsOutput, GetShardsOutput, Shard},
4};
5
6use async_trait::async_trait;
7use aws_config::SdkConfig;
8use aws_sdk_dynamodb::Client as DbClient;
9use aws_sdk_dynamodbstreams::{
10 error::SdkError,
11 operation::{
12 get_records::{GetRecordsError, GetRecordsOutput as SdkGetRecordsOutput},
13 get_shard_iterator::GetShardIteratorError,
14 },
15 types::ShardIteratorType,
16 Client as StreamsClient,
17};
18use tracing::warn;
19
20#[derive(Debug, Clone)]
43pub struct Client {
44 db: DbClient,
45 streams: StreamsClient,
46}
47
48impl Client {
49 pub fn new(config: &SdkConfig) -> Self {
60 Self {
61 db: DbClient::new(config),
62 streams: StreamsClient::new(config),
63 }
64 }
65}
66
67#[async_trait]
69pub trait DynamodbClient: Clone + Send + Sync {
70 async fn get_stream_arn(&self, table_name: impl Into<String> + Send) -> Result<String, Error>;
73
74 async fn get_shards(
76 &self,
77 stream_arn: impl Into<String> + Send,
78 exclusive_start_shard_id: Option<String>,
79 ) -> Result<GetShardsOutput, Error>;
80
81 async fn get_shard_with_iterator(
85 &self,
86 stream_arn: impl Into<String> + Send,
87 shard: Shard,
88 shard_iterator_type: ShardIteratorType,
89 ) -> Result<Option<Shard>, Error>;
90
91 async fn get_records(&self, shard: Shard) -> Result<GetRecordsOutput, Error>;
94}
95
96#[async_trait]
97impl DynamodbClient for Client {
98 async fn get_stream_arn(&self, table_name: impl Into<String> + Send) -> Result<String, Error> {
99 let table_name: String = table_name.into();
100
101 self.db
102 .describe_table()
103 .table_name(&table_name)
104 .send()
105 .await
106 .map_err(|err| Error::SdkError(Box::new(err)))?
107 .table
108 .and_then(|table| table.latest_stream_arn)
109 .ok_or(Error::NotFoundStream(table_name))
110 }
111
112 async fn get_shards(
113 &self,
114 stream_arn: impl Into<String> + Send,
115 exclusive_start_shard_id: Option<String>,
116 ) -> Result<GetShardsOutput, Error> {
117 let stream_arn: String = stream_arn.into();
118
119 self.streams
120 .describe_stream()
121 .stream_arn(&stream_arn)
122 .set_exclusive_start_shard_id(exclusive_start_shard_id)
123 .send()
124 .await
125 .map_err(|err| Error::SdkError(Box::new(err)))?
126 .stream_description
127 .map(|description| {
128 let shards = description
129 .shards
130 .unwrap_or_default()
131 .into_iter()
132 .filter_map(Shard::new)
133 .collect::<Vec<Shard>>();
134 let last_shard_id = description.last_evaluated_shard_id;
135
136 GetShardsOutput {
137 shards,
138 last_shard_id,
139 }
140 })
141 .ok_or(Error::NotFoundStreamDescription(stream_arn))
142 }
143
144 async fn get_shard_with_iterator(
145 &self,
146 stream_arn: impl Into<String> + Send,
147 shard: Shard,
148 shard_iterator_type: ShardIteratorType,
149 ) -> Result<Option<Shard>, Error> {
150 let iterator = self
151 .streams
152 .get_shard_iterator()
153 .stream_arn(stream_arn)
154 .shard_id(shard.id())
155 .shard_iterator_type(shard_iterator_type)
156 .send()
157 .await
158 .map(|output| output.shard_iterator)
159 .or_else(empty_iterator)?;
160
161 Ok(shard.set_iterator(iterator))
162 }
163
164 async fn get_records(&self, shard: Shard) -> Result<GetRecordsOutput, Error> {
165 let iterator = shard.iterator().map(|val| val.to_string());
166
167 self.streams
168 .get_records()
169 .set_shard_iterator(iterator)
170 .send()
171 .await
172 .or_else(empty_records)
173 .map(|output| {
174 let shard = shard.set_iterator(output.next_shard_iterator);
175 let records = output.records.unwrap_or_default();
176
177 GetRecordsOutput { shard, records }
178 })
179 }
180}
181
182fn empty_iterator(err: SdkError<GetShardIteratorError>) -> Result<Option<String>, Error> {
183 use GetShardIteratorError::*;
184
185 match err {
186 SdkError::ServiceError(e) => {
187 let e = e.into_err();
188 match e {
189 ResourceNotFoundException(_) | TrimmedDataAccessException(_) => {
193 warn!("GetShardIterator operation failed due to {e}");
194 warn!("{:#?}", e);
195 Ok(None)
196 }
197 _ => Err(Error::SdkError(Box::new(e))),
198 }
199 }
200 _ => Err(Error::SdkError(Box::new(err))),
201 }
202}
203
204fn empty_records(err: SdkError<GetRecordsError>) -> Result<SdkGetRecordsOutput, Error> {
205 use GetRecordsError::*;
206
207 match err {
208 SdkError::ServiceError(e) => {
209 let e = e.into_err();
210 match e {
211 ExpiredIteratorException(_)
216 | LimitExceededException(_)
217 | ResourceNotFoundException(_)
218 | TrimmedDataAccessException(_) => {
219 warn!("GetRecords operation failed due to {e}");
220 warn!("{:#?}", e);
221 Ok(SdkGetRecordsOutput::builder().build())
222 }
223 _ => Err(Error::SdkError(Box::new(e))),
224 }
225 }
226 _ => Err(Error::SdkError(Box::new(err))),
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use aws_smithy_runtime_api::{
234 client::{orchestrator::HttpResponse, result::ServiceError},
235 http::StatusCode,
236 };
237 use aws_smithy_types::body::SdkBody;
238
239 #[test]
240 fn empty_iterator_converts_some_errors_to_ok() {
241 use aws_sdk_dynamodbstreams::types::error::*;
242
243 let e = ResourceNotFoundException::builder()
244 .message("error")
245 .build();
246 let err = service_error(GetShardIteratorError::ResourceNotFoundException(e));
247 assert!(empty_iterator(err).is_ok());
248
249 let e = InternalServerError::builder().message("error").build();
250 let err = service_error(GetShardIteratorError::InternalServerError(e));
251 assert!(empty_iterator(err).is_err());
252
253 let e = TrimmedDataAccessException::builder()
254 .message("error")
255 .build();
256 let err = service_error(GetShardIteratorError::TrimmedDataAccessException(e));
257 assert!(empty_iterator(err).is_ok());
258 }
259
260 #[test]
261 fn empty_records_converts_some_errors_to_ok() {
262 use aws_sdk_dynamodbstreams::types::error::*;
263
264 let e = ResourceNotFoundException::builder()
265 .message("error")
266 .build();
267 let err = service_error(GetRecordsError::ResourceNotFoundException(e));
268 assert!(empty_records(err).is_ok());
269
270 let e = InternalServerError::builder().message("error").build();
271 let err = service_error(GetRecordsError::InternalServerError(e));
272 assert!(empty_records(err).is_err());
273
274 let e = ExpiredIteratorException::builder().message("error").build();
275 let err = service_error(GetRecordsError::ExpiredIteratorException(e));
276 assert!(empty_records(err).is_ok());
277
278 let e = LimitExceededException::builder().message("error").build();
279 let err = service_error(GetRecordsError::LimitExceededException(e));
280 assert!(empty_records(err).is_ok());
281
282 let e = TrimmedDataAccessException::builder()
283 .message("error")
284 .build();
285 let err = service_error(GetRecordsError::TrimmedDataAccessException(e));
286 assert!(empty_records(err).is_ok());
287 }
288
289 fn service_error<E>(error: E) -> SdkError<E, HttpResponse> {
290 let resp = HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty());
291 let inner = ServiceError::builder().source(error).raw(resp).build();
292 SdkError::ServiceError(inner)
293 }
294}