1use std::time::{Duration, SystemTime, UNIX_EPOCH};
22
23use bytes::Bytes;
24use fred::clients::Pool;
25use fred::interfaces::{KeysInterface, SortedSetsInterface, StreamsInterface};
26use ruststream::runtime::RETRY_COUNT_HEADER;
27use ruststream::{AckError, Headers};
28
29use crate::convert::fields_for_publish;
30use crate::envelope::{frame, unframe};
31use crate::error::RedisError;
32
33const SWEEP_BATCH: i64 = 128;
36
37#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum DelayedRetry {
57 DurableZset {
63 key: String,
65 ttl: Option<Duration>,
67 },
68}
69
70#[derive(Debug, Clone)]
72pub(crate) struct DelayConfig {
73 zset_key: String,
74 ttl: Option<Duration>,
75}
76
77impl DelayConfig {
78 pub(crate) fn from_retry(retry: &DelayedRetry) -> Self {
79 match retry {
80 DelayedRetry::DurableZset { key, ttl } => Self {
81 zset_key: key.clone(),
82 ttl: *ttl,
83 },
84 }
85 }
86}
87
88fn now_ms() -> u64 {
90 SystemTime::now()
91 .duration_since(UNIX_EPOCH)
92 .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
93}
94
95fn delay_millis(delay: Duration) -> u64 {
96 u64::try_from(delay.as_millis()).unwrap_or(u64::MAX)
97}
98
99#[allow(
102 clippy::cast_precision_loss,
103 reason = "epoch-ms < 2^53 is exact in f64"
104)]
105fn as_score(ms: u64) -> f64 {
106 ms as f64
107}
108
109fn ttl_millis(ttl: Duration) -> i64 {
111 i64::try_from(ttl.as_millis()).unwrap_or(i64::MAX).max(1)
112}
113
114fn encode_member(id: &str, payload: &[u8], headers: &Headers) -> Vec<u8> {
118 let body = frame(None, payload, headers);
119 let id = id.as_bytes();
120 let id_len = u32::try_from(id.len()).unwrap_or(u32::MAX);
121 let mut buf = Vec::with_capacity(4 + id.len() + body.len());
122 buf.extend_from_slice(&id_len.to_be_bytes());
123 buf.extend_from_slice(id);
124 buf.extend_from_slice(&body);
125 buf
126}
127
128fn decode_member(member: &[u8]) -> Option<(Bytes, Headers)> {
130 let id_len = u32::from_be_bytes(member.get(0..4)?.try_into().ok()?) as usize;
131 let body = member.get(4usize.checked_add(id_len)?..)?;
132 Some(unframe(None, body))
133}
134
135fn next_retry_count(headers: &Headers) -> u64 {
136 headers
137 .get_str(RETRY_COUNT_HEADER)
138 .and_then(|v| v.parse::<u64>().ok())
139 .unwrap_or(0)
140 + 1
141}
142
143fn broker_err(err: fred::error::Error) -> AckError {
144 AckError::Broker(Box::new(err))
145}
146
147pub(crate) async fn schedule(
151 pool: &Pool,
152 cfg: &DelayConfig,
153 id: &str,
154 payload: &[u8],
155 headers: &Headers,
156 delay: Duration,
157) -> Result<(), AckError> {
158 let fire_at = now_ms().saturating_add(delay_millis(delay));
159
160 let mut next = headers.clone();
161 next.insert(RETRY_COUNT_HEADER, next_retry_count(headers).to_string());
162 let member = encode_member(id, payload, &next);
163
164 let _: i64 = pool
165 .zadd(
166 cfg.zset_key.as_str(),
167 None,
168 None,
169 false,
170 false,
171 (as_score(fire_at), member),
172 )
173 .await
174 .map_err(broker_err)?;
175 if let Some(ttl) = cfg.ttl {
176 let _: i64 = pool
177 .pexpire(cfg.zset_key.as_str(), ttl_millis(ttl), None)
178 .await
179 .map_err(broker_err)?;
180 }
181 Ok(())
182}
183
184pub(crate) async fn sweep_due(
190 pool: &Pool,
191 cfg: &DelayConfig,
192 stream_key: &str,
193) -> Result<(), RedisError> {
194 let now = as_score(now_ms());
195 let due: Vec<Bytes> = pool
196 .zrangebyscore(
197 cfg.zset_key.as_str(),
198 0.0,
199 now,
200 false,
201 Some((0, SWEEP_BATCH)),
202 )
203 .await
204 .map_err(RedisError::stream)?;
205
206 for member in due {
207 let removed: i64 = pool
208 .zrem(cfg.zset_key.as_str(), member.clone())
209 .await
210 .map_err(RedisError::stream)?;
211 if removed != 1 {
213 continue;
214 }
215 let Some((payload, headers)) = decode_member(&member) else {
216 continue;
217 };
218 let fields = fields_for_publish(&payload, &headers);
219 let _: String = pool
220 .xadd(stream_key, false, None::<()>, "*", fields)
221 .await
222 .map_err(RedisError::stream)?;
223 }
224 Ok(())
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn member_round_trips_payload_and_headers() {
233 let mut headers = Headers::new();
234 headers.insert("content-type", "application/json");
235 headers.insert(RETRY_COUNT_HEADER, "2");
236
237 let member = encode_member("1700000000000-0", b"{}", &headers);
238 let (payload, decoded) = decode_member(&member).expect("decodes");
239 assert_eq!(payload.as_ref(), b"{}");
240 assert_eq!(decoded.content_type(), Some("application/json"));
241 assert_eq!(decoded.get_str(RETRY_COUNT_HEADER), Some("2"));
242 }
243
244 #[test]
245 fn distinct_ids_yield_distinct_members_for_equal_payloads() {
246 let headers = Headers::new();
247 let a = encode_member("1-0", b"dup", &headers);
248 let b = encode_member("2-0", b"dup", &headers);
249 assert_ne!(
250 a, b,
251 "the id salt must keep equal payloads from colliding in the ZSET"
252 );
253 }
254
255 #[test]
256 fn next_retry_count_starts_at_one_and_increments() {
257 let mut headers = Headers::new();
258 assert_eq!(next_retry_count(&headers), 1);
259 headers.insert(RETRY_COUNT_HEADER, "4");
260 assert_eq!(next_retry_count(&headers), 5);
261 headers.insert(RETRY_COUNT_HEADER, "not-a-number");
263 assert_eq!(next_retry_count(&headers), 1);
264 }
265
266 #[test]
267 fn ttl_millis_clamps_sub_millisecond_to_one() {
268 assert_eq!(ttl_millis(Duration::from_secs(30)), 30_000);
269 assert_eq!(ttl_millis(Duration::ZERO), 1);
270 }
271}