1use celers_core::{CelersError, Result};
10use redis::{Client, Script};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15use tracing::{debug, info, warn};
16
17pub const POP_WITH_VISIBILITY: &str = r#"
32local queue = KEYS[1]
33local unacked_set = KEYS[2]
34local timeout_at = ARGV[1]
35
36-- Pop from queue (non-blocking)
37local msg = redis.call('RPOP', queue)
38
39if msg then
40 -- Add to unacked set with timeout score
41 redis.call('ZADD', unacked_set, timeout_at, msg)
42 return msg
43end
44
45return nil
46"#;
47
48pub const BRPOP_WITH_VISIBILITY: &str = r#"
59local queue = KEYS[1]
60local unacked_set = KEYS[2]
61local timeout_at = ARGV[1]
62local block_timeout = ARGV[2]
63
64-- Blocking pop from queue
65local result = redis.call('BRPOP', queue, block_timeout)
66
67if result then
68 local msg = result[2] -- BRPOP returns [queue_name, message]
69 -- Add to unacked set with timeout score
70 redis.call('ZADD', unacked_set, timeout_at, msg)
71 return msg
72end
73
74return nil
75"#;
76
77pub const ACK_MESSAGE: &str = r#"
86local unacked_set = KEYS[1]
87local msg = ARGV[1]
88
89return redis.call('ZREM', unacked_set, msg)
90"#;
91
92pub const NACK_MESSAGE: &str = r#"
102local unacked_set = KEYS[1]
103local queue = KEYS[2]
104local dlq = KEYS[3]
105local msg = ARGV[1]
106local requeue = ARGV[2]
107
108-- Remove from unacked set
109redis.call('ZREM', unacked_set, msg)
110
111if requeue == "1" then
112 -- Requeue to original queue
113 redis.call('LPUSH', queue, msg)
114 return "requeued"
115else
116 -- Send to dead letter queue
117 redis.call('LPUSH', dlq, msg)
118 return "dlq"
119end
120"#;
121
122pub const RECOVER_TIMED_OUT: &str = r#"
134local unacked_set = KEYS[1]
135local queue = KEYS[2]
136local current_time = ARGV[1]
137local max_count = ARGV[2]
138
139-- Get messages with score (timeout) less than current time
140local messages = redis.call('ZRANGEBYSCORE', unacked_set, '-inf', current_time, 'LIMIT', 0, max_count)
141
142if #messages > 0 then
143 -- Remove from unacked set
144 for i, msg in ipairs(messages) do
145 redis.call('ZREM', unacked_set, msg)
146 -- Requeue
147 redis.call('LPUSH', queue, msg)
148 end
149 return #messages
150end
151
152return 0
153"#;
154
155pub const POP_PRIORITY_WITH_VISIBILITY: &str = r#"
165local unacked_set = table.remove(KEYS)
166local timeout_at = ARGV[1]
167
168-- Try each queue in order (high priority first)
169for i, queue in ipairs(KEYS) do
170 local msg = redis.call('RPOP', queue)
171 if msg then
172 -- Add to unacked set
173 redis.call('ZADD', unacked_set, timeout_at, msg)
174 return {queue, msg}
175 end
176end
177
178return nil
179"#;
180
181pub const ENQUEUE_WITH_PRIORITY: &str = r#"
191local base_queue = KEYS[1]
192local priority = tonumber(ARGV[1])
193local msg = ARGV[2]
194
195local queue_name
196if priority and priority > 0 then
197 -- Kombu priority queue naming convention
198 queue_name = base_queue .. '\x06\x16' .. priority
199else
200 queue_name = base_queue
201end
202
203redis.call('LPUSH', queue_name, msg)
204return queue_name
205"#;
206
207pub const SCRIPT_VERSION: u32 = 1;
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
212pub enum ScriptId {
213 PopWithVisibility,
215 BrpopWithVisibility,
217 AckMessage,
219 NackMessage,
221 RecoverTimedOut,
223 PopPriorityWithVisibility,
225 EnqueueWithPriority,
227}
228
229impl ScriptId {
230 pub fn source(&self) -> &'static str {
232 match self {
233 ScriptId::PopWithVisibility => POP_WITH_VISIBILITY,
234 ScriptId::BrpopWithVisibility => BRPOP_WITH_VISIBILITY,
235 ScriptId::AckMessage => ACK_MESSAGE,
236 ScriptId::NackMessage => NACK_MESSAGE,
237 ScriptId::RecoverTimedOut => RECOVER_TIMED_OUT,
238 ScriptId::PopPriorityWithVisibility => POP_PRIORITY_WITH_VISIBILITY,
239 ScriptId::EnqueueWithPriority => ENQUEUE_WITH_PRIORITY,
240 }
241 }
242
243 pub fn name(&self) -> &'static str {
245 match self {
246 ScriptId::PopWithVisibility => "pop_with_visibility",
247 ScriptId::BrpopWithVisibility => "brpop_with_visibility",
248 ScriptId::AckMessage => "ack_message",
249 ScriptId::NackMessage => "nack_message",
250 ScriptId::RecoverTimedOut => "recover_timed_out",
251 ScriptId::PopPriorityWithVisibility => "pop_priority_with_visibility",
252 ScriptId::EnqueueWithPriority => "enqueue_with_priority",
253 }
254 }
255
256 pub fn all() -> Vec<ScriptId> {
258 vec![
259 ScriptId::PopWithVisibility,
260 ScriptId::BrpopWithVisibility,
261 ScriptId::AckMessage,
262 ScriptId::NackMessage,
263 ScriptId::RecoverTimedOut,
264 ScriptId::PopPriorityWithVisibility,
265 ScriptId::EnqueueWithPriority,
266 ]
267 }
268}
269
270#[derive(Debug, Clone, Default)]
272pub struct ScriptPerformance {
273 pub execution_count: u64,
275 pub total_duration: Duration,
277 pub min_duration: Option<Duration>,
279 pub max_duration: Option<Duration>,
281 pub last_execution: Option<Instant>,
283}
284
285impl ScriptPerformance {
286 pub fn avg_duration(&self) -> Option<Duration> {
288 if self.execution_count > 0 {
289 Some(self.total_duration / self.execution_count as u32)
290 } else {
291 None
292 }
293 }
294
295 pub fn record(&mut self, duration: Duration) {
297 self.execution_count += 1;
298 self.total_duration += duration;
299 self.last_execution = Some(Instant::now());
300
301 match self.min_duration {
302 None => self.min_duration = Some(duration),
303 Some(min) if duration < min => self.min_duration = Some(duration),
304 _ => {}
305 }
306
307 match self.max_duration {
308 None => self.max_duration = Some(duration),
309 Some(max) if duration > max => self.max_duration = Some(duration),
310 _ => {}
311 }
312 }
313
314 pub fn reset(&mut self) {
316 *self = Self::default();
317 }
318}
319
320pub struct ScriptManager {
322 client: Client,
323 sha_cache: Arc<RwLock<HashMap<ScriptId, String>>>,
325 script_cache: Arc<RwLock<HashMap<ScriptId, Script>>>,
327 performance: Arc<RwLock<HashMap<ScriptId, ScriptPerformance>>>,
329 version: u32,
331}
332
333impl ScriptManager {
334 pub fn new(client: Client) -> Self {
336 Self {
337 client,
338 sha_cache: Arc::new(RwLock::new(HashMap::new())),
339 script_cache: Arc::new(RwLock::new(HashMap::new())),
340 performance: Arc::new(RwLock::new(HashMap::new())),
341 version: SCRIPT_VERSION,
342 }
343 }
344
345 pub fn version(&self) -> u32 {
347 self.version
348 }
349
350 pub async fn record_execution(&self, script_id: ScriptId, duration: Duration) {
352 let mut perf = self.performance.write().await;
353 perf.entry(script_id).or_default().record(duration);
354
355 if duration.as_millis() > 100 {
357 warn!(
358 "Slow script execution: {} took {}ms",
359 script_id.name(),
360 duration.as_millis()
361 );
362 }
363 }
364
365 pub async fn get_performance(&self, script_id: ScriptId) -> Option<ScriptPerformance> {
367 self.performance.read().await.get(&script_id).cloned()
368 }
369
370 pub async fn get_all_performance(&self) -> HashMap<ScriptId, ScriptPerformance> {
372 self.performance.read().await.clone()
373 }
374
375 pub async fn reset_performance(&self, script_id: ScriptId) {
377 let mut perf = self.performance.write().await;
378 if let Some(p) = perf.get_mut(&script_id) {
379 p.reset();
380 }
381 }
382
383 pub async fn reset_all_performance(&self) {
385 let mut perf = self.performance.write().await;
386 for p in perf.values_mut() {
387 p.reset();
388 }
389 }
390
391 pub async fn load_all(&self) -> Result<()> {
393 let mut conn = self
394 .client
395 .get_multiplexed_async_connection()
396 .await
397 .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
398
399 let mut sha_cache = self.sha_cache.write().await;
400 let mut script_cache = self.script_cache.write().await;
401
402 for script_id in ScriptId::all() {
403 let source = script_id.source();
404 let script = Script::new(source);
405
406 let sha: String = redis::cmd("SCRIPT")
408 .arg("LOAD")
409 .arg(source)
410 .query_async(&mut conn)
411 .await
412 .map_err(|e| {
413 CelersError::Broker(format!(
414 "Failed to load script {}: {}",
415 script_id.name(),
416 e
417 ))
418 })?;
419
420 debug!("Loaded script {} with SHA: {}", script_id.name(), sha);
421
422 sha_cache.insert(script_id, sha);
423 script_cache.insert(script_id, script);
424 }
425
426 info!("Loaded {} Lua scripts into Redis", ScriptId::all().len());
427
428 Ok(())
429 }
430
431 pub async fn get_sha(&self, script_id: ScriptId) -> Option<String> {
433 self.sha_cache.read().await.get(&script_id).cloned()
434 }
435
436 pub async fn get_script(&self, script_id: ScriptId) -> Option<Script> {
438 self.script_cache.read().await.get(&script_id).cloned()
439 }
440
441 pub async fn load_script(&self, script_id: ScriptId) -> Result<String> {
443 let mut conn = self
444 .client
445 .get_multiplexed_async_connection()
446 .await
447 .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
448
449 let source = script_id.source();
450 let script = Script::new(source);
451
452 let sha: String = redis::cmd("SCRIPT")
454 .arg("LOAD")
455 .arg(source)
456 .query_async(&mut conn)
457 .await
458 .map_err(|e| {
459 CelersError::Broker(format!("Failed to load script {}: {}", script_id.name(), e))
460 })?;
461
462 debug!("Loaded script {} with SHA: {}", script_id.name(), sha);
463
464 let mut sha_cache = self.sha_cache.write().await;
466 let mut script_cache = self.script_cache.write().await;
467
468 sha_cache.insert(script_id, sha.clone());
469 script_cache.insert(script_id, script);
470
471 Ok(sha)
472 }
473
474 pub async fn is_loaded(&self, script_id: ScriptId) -> Result<bool> {
476 let sha = match self.get_sha(script_id).await {
477 Some(sha) => sha,
478 None => return Ok(false),
479 };
480
481 let mut conn = self
482 .client
483 .get_multiplexed_async_connection()
484 .await
485 .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
486
487 let exists: Vec<bool> = redis::cmd("SCRIPT")
488 .arg("EXISTS")
489 .arg(&sha)
490 .query_async(&mut conn)
491 .await
492 .map_err(|e| CelersError::Broker(format!("Failed to check script: {}", e)))?;
493
494 Ok(exists.first().copied().unwrap_or(false))
495 }
496
497 pub async fn clear_cache(&self) {
499 let mut sha_cache = self.sha_cache.write().await;
500 let mut script_cache = self.script_cache.write().await;
501
502 sha_cache.clear();
503 script_cache.clear();
504
505 debug!("Cleared script cache");
506 }
507
508 pub async fn stats(&self) -> ScriptStats {
510 let sha_cache = self.sha_cache.read().await;
511 let script_cache = self.script_cache.read().await;
512 let perf = self.performance.read().await;
513
514 let total_executions: u64 = perf.values().map(|p| p.execution_count).sum();
515
516 ScriptStats {
517 total_scripts: ScriptId::all().len(),
518 loaded_scripts: sha_cache.len(),
519 cached_scripts: script_cache.len(),
520 version: self.version,
521 total_executions,
522 }
523 }
524}
525
526#[derive(Debug, Clone)]
528pub struct ScriptStats {
529 pub total_scripts: usize,
531 pub loaded_scripts: usize,
533 pub cached_scripts: usize,
535 pub version: u32,
537 pub total_executions: u64,
539}
540
541impl ScriptStats {
542 pub fn all_loaded(&self) -> bool {
544 self.loaded_scripts == self.total_scripts
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 #[allow(clippy::const_is_empty)]
554 fn test_scripts_are_valid() {
555 assert!(!POP_WITH_VISIBILITY.is_empty());
557 assert!(!BRPOP_WITH_VISIBILITY.is_empty());
558 assert!(!ACK_MESSAGE.is_empty());
559 assert!(!NACK_MESSAGE.is_empty());
560 assert!(!RECOVER_TIMED_OUT.is_empty());
561 assert!(!POP_PRIORITY_WITH_VISIBILITY.is_empty());
562 assert!(!ENQUEUE_WITH_PRIORITY.is_empty());
563 }
564
565 #[test]
566 fn test_script_syntax() {
567 assert!(POP_WITH_VISIBILITY.contains("RPOP"));
569 assert!(POP_WITH_VISIBILITY.contains("ZADD"));
570
571 assert!(ACK_MESSAGE.contains("ZREM"));
572
573 assert!(NACK_MESSAGE.contains("LPUSH"));
574
575 assert!(RECOVER_TIMED_OUT.contains("ZRANGEBYSCORE"));
576 }
577
578 #[test]
579 fn test_script_id_source() {
580 assert_eq!(ScriptId::PopWithVisibility.source(), POP_WITH_VISIBILITY);
581 assert_eq!(
582 ScriptId::BrpopWithVisibility.source(),
583 BRPOP_WITH_VISIBILITY
584 );
585 assert_eq!(ScriptId::AckMessage.source(), ACK_MESSAGE);
586 assert_eq!(ScriptId::NackMessage.source(), NACK_MESSAGE);
587 assert_eq!(ScriptId::RecoverTimedOut.source(), RECOVER_TIMED_OUT);
588 assert_eq!(
589 ScriptId::PopPriorityWithVisibility.source(),
590 POP_PRIORITY_WITH_VISIBILITY
591 );
592 assert_eq!(
593 ScriptId::EnqueueWithPriority.source(),
594 ENQUEUE_WITH_PRIORITY
595 );
596 }
597
598 #[test]
599 fn test_script_id_name() {
600 assert_eq!(ScriptId::PopWithVisibility.name(), "pop_with_visibility");
601 assert_eq!(
602 ScriptId::BrpopWithVisibility.name(),
603 "brpop_with_visibility"
604 );
605 assert_eq!(ScriptId::AckMessage.name(), "ack_message");
606 assert_eq!(ScriptId::NackMessage.name(), "nack_message");
607 assert_eq!(ScriptId::RecoverTimedOut.name(), "recover_timed_out");
608 assert_eq!(
609 ScriptId::PopPriorityWithVisibility.name(),
610 "pop_priority_with_visibility"
611 );
612 assert_eq!(
613 ScriptId::EnqueueWithPriority.name(),
614 "enqueue_with_priority"
615 );
616 }
617
618 #[test]
619 fn test_script_id_all() {
620 let all_scripts = ScriptId::all();
621 assert_eq!(all_scripts.len(), 7);
622 assert!(all_scripts.contains(&ScriptId::PopWithVisibility));
623 assert!(all_scripts.contains(&ScriptId::BrpopWithVisibility));
624 assert!(all_scripts.contains(&ScriptId::AckMessage));
625 assert!(all_scripts.contains(&ScriptId::NackMessage));
626 assert!(all_scripts.contains(&ScriptId::RecoverTimedOut));
627 assert!(all_scripts.contains(&ScriptId::PopPriorityWithVisibility));
628 assert!(all_scripts.contains(&ScriptId::EnqueueWithPriority));
629 }
630
631 #[test]
632 fn test_script_stats() {
633 let stats = ScriptStats {
634 total_scripts: 7,
635 loaded_scripts: 7,
636 cached_scripts: 7,
637 version: SCRIPT_VERSION,
638 total_executions: 0,
639 };
640
641 assert!(stats.all_loaded());
642 assert_eq!(stats.version, SCRIPT_VERSION);
643
644 let stats_incomplete = ScriptStats {
645 total_scripts: 7,
646 loaded_scripts: 5,
647 cached_scripts: 5,
648 version: SCRIPT_VERSION,
649 total_executions: 0,
650 };
651
652 assert!(!stats_incomplete.all_loaded());
653 }
654
655 #[test]
656 fn test_script_performance() {
657 let mut perf = ScriptPerformance::default();
658 assert_eq!(perf.execution_count, 0);
659 assert_eq!(perf.avg_duration(), None);
660
661 perf.record(Duration::from_millis(10));
662 assert_eq!(perf.execution_count, 1);
663 assert_eq!(perf.avg_duration(), Some(Duration::from_millis(10)));
664 assert_eq!(perf.min_duration, Some(Duration::from_millis(10)));
665 assert_eq!(perf.max_duration, Some(Duration::from_millis(10)));
666
667 perf.record(Duration::from_millis(20));
668 assert_eq!(perf.execution_count, 2);
669 assert_eq!(perf.avg_duration(), Some(Duration::from_millis(15)));
670 assert_eq!(perf.min_duration, Some(Duration::from_millis(10)));
671 assert_eq!(perf.max_duration, Some(Duration::from_millis(20)));
672
673 perf.reset();
674 assert_eq!(perf.execution_count, 0);
675 assert_eq!(perf.avg_duration(), None);
676 }
677
678 #[test]
679 fn test_script_version() {
680 assert_eq!(SCRIPT_VERSION, 1);
681 }
682}