1pub mod error;
2use crate::error::DynamoDbMutexError as Error;
3use chrono::prelude::*;
4use rusoto_core::Region;
5pub use rusoto_core;
6use rusoto_dynamodb::{
7 AttributeDefinition, AttributeValue, CreateTableInput, DynamoDb, DynamoDbClient,
8 KeySchemaElement, ProvisionedThroughput, UpdateItemInput, DeleteItemInput,
9};
10use std::{collections::HashMap, str::FromStr};
11
12static TABLE_NAME: &str = "mutexes";
13static TABLE_KEY: &str = "mutex_code";
14
15fn make_key(mutex_code: &str) -> HashMap<String, AttributeValue> {
16 let mut map = HashMap::new();
17 insert_str_attribute(&mut map, "mutex_code", mutex_code);
18 map
19}
20
21async fn update(
22 client: &DynamoDbClient,
23 update_item_inupt: UpdateItemInput,
24) -> Result<DynamoDbMutexResult, Error> {
25 match client.update_item(update_item_inupt).await {
26 Ok(res) => match res.attributes {
27 Some(value) => {
28 let mutex_status = value.get("mutex_status").unwrap().s.as_ref().unwrap();
29 let updated_at = value.get("updated_at").unwrap().n.as_ref().unwrap();
30 let status = DynamoDbMutexStatus::from_str(&mutex_status)?;
31 Ok(DynamoDbMutexResult::Success(
32 Some(status),
33 updated_at.parse().unwrap(),
34 ))
35 }
36 None => Ok(DynamoDbMutexResult::Success(None, 0)),
37 },
38 Err(rusoto_core::RusotoError::Service(
39 rusoto_dynamodb::UpdateItemError::ConditionalCheckFailed(_),
40 )) => Ok(DynamoDbMutexResult::Failure),
41 Err(err) => Err(err.into()),
42 }
43}
44
45fn insert_str_attribute(map: &mut HashMap<String, AttributeValue>, key: &str, value: &str) {
46 map.insert(
47 key.to_owned(),
48 AttributeValue {
49 s: Some(value.to_owned()),
50 ..Default::default()
51 },
52 );
53}
54
55fn insert_num_attribute(map: &mut HashMap<String, AttributeValue>, key: &str, value: i64) {
56 map.insert(
57 key.to_owned(),
58 AttributeValue {
59 n: Some(value.to_string()),
60 ..Default::default()
61 },
62 );
63}
64
65pub struct DynamoDbMutex {
67 table_name: String,
68 client: DynamoDbClient,
69 done_after_milli_seconds: u64,
70 failed_after_milli_seconds: u64,
71 running_after_milli_seconds: u64,
72}
73
74impl DynamoDbMutex {
75 pub fn new(
76 region: Region,
77 done_after_milli_seconds: u64,
78 failed_after_milli_seconds: u64,
79 running_after_milli_seconds: u64,
80 table_name: Option<&str>,
81 ) -> Self {
82 Self {
83 table_name: table_name.unwrap_or(TABLE_NAME).to_owned(),
84 client: DynamoDbClient::new(region),
85 done_after_milli_seconds,
86 failed_after_milli_seconds,
87 running_after_milli_seconds,
88 }
89 }
90
91 pub async fn make_table(&self) -> Result<(), Error> {
93 let input = CreateTableInput {
94 attribute_definitions: vec![AttributeDefinition {
95 attribute_name: TABLE_KEY.to_owned(),
96 attribute_type: "S".to_owned(),
97 }],
98 billing_mode: Some("PROVISIONED".to_owned()),
99 provisioned_throughput: Some(ProvisionedThroughput {
100 read_capacity_units: 1,
101 write_capacity_units: 1,
102 }),
103 table_name: self.table_name.clone(),
104 key_schema: vec![KeySchemaElement {
105 attribute_name: TABLE_KEY.to_owned(),
106 key_type: "HASH".to_owned(),
107 }],
108 ..Default::default()
109 };
110 let _ = self.client.create_table(input).await?;
111 Ok(())
112 }
113
114 pub async fn lock(&self, mutex_code: &str) -> Result<DynamoDbMutexResult, Error> {
117 let now: DateTime<Utc> = Utc::now();
118 let now_millis = now.timestamp_millis();
119
120 let mut map = HashMap::new();
121 insert_str_attribute(
122 &mut map,
123 ":condion_done_status",
124 &DynamoDbMutexStatus::Done.to_string(),
125 );
126 insert_num_attribute(
127 &mut map,
128 ":condion_done_millis",
129 now_millis - self.done_after_milli_seconds as i64,
130 );
131 insert_str_attribute(
132 &mut map,
133 ":condion_failed_status",
134 &DynamoDbMutexStatus::Failed.to_string(),
135 );
136 insert_num_attribute(
137 &mut map,
138 ":condion_failed_millis",
139 now_millis - self.failed_after_milli_seconds as i64,
140 );
141 insert_str_attribute(
142 &mut map,
143 ":condion_running_status",
144 &DynamoDbMutexStatus::Running.to_string(),
145 );
146 insert_num_attribute(
147 &mut map,
148 ":condion_running_millis",
149 now_millis - self.running_after_milli_seconds as i64,
150 );
151 insert_str_attribute(
152 &mut map,
153 ":update_status",
154 &DynamoDbMutexStatus::Running.to_string(),
155 );
156 insert_num_attribute(&mut map, ":now", now_millis);
157
158 let condition = String::from("attribute_not_exists(mutex_status) OR mutex_status = :condion_done_status AND updated_at <= :condion_done_millis OR mutex_status = :condion_failed_status AND updated_at <= :condion_failed_millis OR mutex_status = :condion_running_status AND updated_at <= :condion_running_millis");
159
160 let input = UpdateItemInput {
161 key: make_key(mutex_code),
162 table_name: self.table_name.clone(),
163 condition_expression: Some(condition),
164 update_expression: Some(
165 "SET mutex_status = :update_status, updated_at = :now".to_owned(),
166 ),
167 expression_attribute_values: Some(map),
168 return_values: Some(String::from("ALL_OLD")),
169 ..Default::default()
170 };
171 update(&self.client, input).await
172 }
173
174 pub async fn unlock(&self, mutex_code: &str, is_success: bool) -> Result<(), Error> {
178 let now: DateTime<Utc> = Utc::now();
179 let now_millis = now.timestamp_millis();
180 let status = if is_success {DynamoDbMutexStatus::Done} else {DynamoDbMutexStatus::Failed};
181
182 let mut map = HashMap::new();
183 insert_str_attribute(
184 &mut map,
185 ":condion_status",
186 &DynamoDbMutexStatus::Running.to_string(),
187 );
188 insert_str_attribute(&mut map, ":update_status", &status.to_string());
189 insert_num_attribute(&mut map, ":now", now_millis);
190
191 let condition = String::from("mutex_status = :condion_status");
192
193 let input = UpdateItemInput {
194 key: make_key(mutex_code),
195 table_name: self.table_name.clone(),
196 condition_expression: Some(condition),
197 update_expression: Some(
198 "SET mutex_status = :update_status, updated_at = :now".to_owned(),
199 ),
200 expression_attribute_values: Some(map),
201 return_values: Some(String::from("NONE")),
202 ..Default::default()
203 };
204 let _ = update(&self.client, input).await?;
205 Ok(())
206 }
207
208 pub async fn remove(&self, mutex_code: &str) -> Result<(), Error> {
211 let input = DeleteItemInput {
212 key: make_key(mutex_code),
213 table_name: self.table_name.clone(),
214 return_values: Some(String::from("NONE")),
215 ..Default::default()
216 };
217 let _ = &self.client.delete_item(input).await?;
218 Ok(())
219 }
220}
221
222#[derive(Debug)]
224pub enum DynamoDbMutexResult {
225 Success(Option<DynamoDbMutexStatus>, u64),
226 Failure,
227}
228
229#[derive(Debug)]
231pub enum DynamoDbMutexStatus {
232 Running,
233 Done,
234 Failed,
235}
236
237impl ToString for DynamoDbMutexStatus {
238 fn to_string(&self) -> String {
239 match self {
240 Self::Running => "RUNNING",
241 Self::Done => "DONE",
242 Self::Failed => "FAILED",
243 }
244 .to_owned()
245 }
246}
247
248impl FromStr for DynamoDbMutexStatus {
249 type Err = Error;
250
251 fn from_str(s: &str) -> Result<Self, Self::Err> {
252 match s {
253 "RUNNING" => Ok(Self::Running),
254 "DONE" => Ok(Self::Done),
255 "FAILED" => Ok(Self::Failed),
256 _ => Err(Error::FailDbValue),
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use futures::{stream, StreamExt};
265 use std::sync::Arc;
266
267 #[tokio::test]
268 async fn it_works() -> Result<(), Error> {
269 let mutex = DynamoDbMutex::new(Region::UsEast1, 10000, 10000, 10000, None);
270 let res = mutex.lock("test3").await?;
272 println!("{:?}", res);
273 let _ = mutex.remove("test3").await?;
274 Ok(())
276 }
277
278 #[tokio::test]
279 async fn check_async() -> Result<(), Error> {
280 let ary = (0..10).collect::<Vec<u32>>();
281 let size = ary.len();
282 let list = stream::iter(ary);
283 let mutex = Arc::new(DynamoDbMutex::new(
284 Region::UsEast1,
285 10000,
286 10000,
287 10000,
288 None,
289 ));
290 let res = list
291 .map(|id| {
292 let mutex = Arc::clone(&mutex);
293 tokio::spawn(async move {
294 let res = match mutex.lock("test").await {
295 Ok(DynamoDbMutexResult::Success(_, _)) => 1,
296 Ok(DynamoDbMutexResult::Failure) => 0,
297 _ => -1,
298 };
299 format!("{}:{}", id, res)
300 })
301 })
302 .buffer_unordered(size);
303 res.for_each(|res| async move {
304 match res {
305 Ok(res) => println!("{}", res),
306 Err(e) => eprintln!("Got a tokio::JoinError: {}", e),
307 }
308 })
309 .await;
310 let _ = mutex.remove("test").await?;
311 Ok(())
312 }
313}