1use std::{
7 fmt::{Display, Formatter},
8 sync::Arc,
9 time::Duration as StdDuration,
10};
11
12use anyhow::{Context, Result};
13use async_nats::jetstream::{
14 Context as JsContext,
15 context::KeyValueErrorKind,
16 kv::{self, CreateErrorKind},
17};
18use async_trait::async_trait;
19use gsm_telemetry::{TelemetryLabels, record_counter};
20use serde::{Deserialize, Serialize};
21use time::{Duration, OffsetDateTime};
22use tokio::sync::RwLock;
23use tracing::{instrument, warn};
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
27pub struct IdKey {
28 pub tenant: String,
29 pub platform: String,
30 pub msg_id: String,
31}
32
33impl Display for IdKey {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}:{}:{}", self.tenant, self.platform, self.msg_id)
36 }
37}
38
39#[async_trait]
41pub trait IdemStore: Send + Sync {
42 async fn put_if_absent(&self, key: &str, ttl_s: u64) -> Result<bool>;
46}
47
48pub type SharedIdemStore = Arc<dyn IdemStore>;
50
51#[derive(Clone, Default)]
53pub struct InMemoryIdemStore {
54 inner: Arc<RwLock<std::collections::HashMap<String, OffsetDateTime>>>,
55}
56
57impl InMemoryIdemStore {
58 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub async fn purge_expired(&self, now: OffsetDateTime) {
63 let mut guard = self.inner.write().await;
64 guard.retain(|_, expires| *expires > now);
65 }
66}
67
68#[async_trait]
69impl IdemStore for InMemoryIdemStore {
70 async fn put_if_absent(&self, key: &str, ttl_s: u64) -> Result<bool> {
71 let ttl = Duration::seconds(ttl_s as i64);
72 let now = OffsetDateTime::now_utc();
73 let mut guard = self.inner.write().await;
74 match guard.get(key) {
75 Some(exp) if *exp > now => Ok(false),
76 _ => {
77 guard.insert(key.to_string(), now + ttl);
78 Ok(true)
79 }
80 }
81 }
82}
83
84pub struct NatsKvIdemStore {
86 bucket: kv::Store,
87}
88
89impl NatsKvIdemStore {
90 pub async fn new(js: &JsContext, namespace: &str) -> Result<Self> {
92 let bucket = match js.get_key_value(namespace).await {
93 Ok(store) => store,
94 Err(err) if err.kind() == KeyValueErrorKind::GetBucket => js
95 .create_key_value(kv::Config {
96 bucket: namespace.to_string(),
97 history: 1,
98 max_age: StdDuration::from_secs(0),
99 ..Default::default()
100 })
101 .await
102 .with_context(|| format!("create JetStream KV bucket {namespace}"))?,
103 Err(err) => anyhow::bail!("idempotency kv init failed: {err}"),
104 };
105 Ok(Self { bucket })
106 }
107}
108
109#[async_trait]
110impl IdemStore for NatsKvIdemStore {
111 #[instrument(name = "idempotency.put_if_absent", skip(self), fields(key = %key))]
112 async fn put_if_absent(&self, key: &str, ttl_s: u64) -> Result<bool> {
113 let ttl = StdDuration::from_secs(ttl_s.max(1));
114 let seen_at = OffsetDateTime::now_utc()
115 .format(&time::format_description::well_known::Rfc3339)
116 .unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string());
117 let payload = serde_json::to_vec(&serde_json::json!({ "seen_at": seen_at }))?;
118
119 match self.bucket.create_with_ttl(key, payload.into(), ttl).await {
120 Ok(_) => Ok(true),
121 Err(err) if err.kind() == CreateErrorKind::AlreadyExists => Ok(false),
122 Err(err) => Err(anyhow::anyhow!(err)
123 .context(format!("put idempotency key {key} with ttl {ttl_s}s"))),
124 }
125 }
126}
127
128#[derive(Clone)]
130pub struct IdempotencyConfig {
131 pub ttl_hours: u64,
132 pub namespace: String,
133}
134
135impl Default for IdempotencyConfig {
136 fn default() -> Self {
137 Self {
138 ttl_hours: 36,
139 namespace: "idempotency".to_string(),
140 }
141 }
142}
143
144impl IdempotencyConfig {
145 pub fn from_settings(ttl_hours: Option<u64>, namespace: Option<String>) -> Self {
146 let mut cfg = Self::default();
147 if let Some(parsed) = ttl_hours {
148 cfg.ttl_hours = parsed.max(1);
149 }
150 if let Some(ns) = namespace
151 && !ns.trim().is_empty()
152 {
153 cfg.namespace = ns;
154 }
155 cfg
156 }
157}
158
159#[derive(Clone)]
161pub struct IdempotencyGuard {
162 ttl_secs: u64,
163 store: SharedIdemStore,
164}
165
166impl IdempotencyGuard {
167 pub fn new(store: SharedIdemStore, ttl_hours: u64) -> Self {
168 Self {
169 store,
170 ttl_secs: ttl_hours.saturating_mul(3600).max(60),
171 }
172 }
173
174 pub async fn should_process(&self, key: &IdKey) -> Result<bool> {
196 let inserted = self
197 .store
198 .put_if_absent(&key.to_string(), self.ttl_secs)
199 .await?;
200 if !inserted {
201 warn!(tenant = %key.tenant, platform = %key.platform, msg_id = %key.msg_id, "duplicate message dropped");
202 let labels = TelemetryLabels {
203 tenant: key.tenant.clone(),
204 platform: None,
205 chat_id: None,
206 msg_id: None,
207 extra: Vec::new(),
208 };
209 record_counter("idempotency_hit", 1, &labels);
210 }
211 Ok(inserted)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use time::Duration;
219
220 #[tokio::test]
221 async fn memory_store_dedupes() {
222 let store = InMemoryIdemStore::new();
223 assert!(store.put_if_absent("k", 10).await.unwrap());
224 assert!(!store.put_if_absent("k", 10).await.unwrap());
225 store.inner.write().await.insert(
226 "expired".into(),
227 OffsetDateTime::now_utc() - Duration::seconds(5),
228 );
229 assert!(store.put_if_absent("expired", 1).await.unwrap());
230 }
231
232 #[tokio::test]
233 async fn guard_should_process() {
234 let store: SharedIdemStore = Arc::new(InMemoryIdemStore::new());
235 let guard = IdempotencyGuard::new(store, 1);
236 let key = IdKey {
237 tenant: "t1".into(),
238 platform: "slack".into(),
239 msg_id: "abc".into(),
240 };
241 assert!(guard.should_process(&key).await.unwrap());
242 assert!(!guard.should_process(&key).await.unwrap());
243 }
244}