Skip to main content

celers_broker_redis/
lua_scripts.rs

1//! Lua scripts for atomic Redis operations (Kombu compatibility)
2//!
3//! These scripts ensure atomicity for complex operations that cannot be
4//! achieved with single Redis commands.
5//!
6//! The ScriptManager handles script loading and caching using SCRIPT LOAD
7//! for optimal performance.
8
9use celers_core::{CelersError, Result};
10use redis::{Client, Script};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15use tracing::{debug, info, warn};
16
17/// Pop from queue with visibility timeout (Kombu-compatible)
18///
19/// This script atomically:
20/// 1. Pops a message from the queue (RPOP)
21/// 2. Adds it to the unacked sorted set with timeout score (ZADD)
22///
23/// This ensures that if a worker crashes, the message can be recovered
24/// after the visibility timeout expires.
25///
26/// `KEYS[1]`: queue name (e.g., "celery")
27/// `KEYS[2]`: unacked set name (e.g., "celery:unacked")
28/// `ARGV[1]`: visibility timeout (Unix timestamp)
29///
30/// Returns: message data or nil
31pub const POP_WITH_VISIBILITY: &str = r#"
32local queue = KEYS[1]
33local unacked_set = KEYS[2]
34local timeout_at = ARGV[1]
35
36-- Pop from queue (non-blocking)
37local msg = redis.call('RPOP', queue)
38
39if msg then
40    -- Add to unacked set with timeout score
41    redis.call('ZADD', unacked_set, timeout_at, msg)
42    return msg
43end
44
45return nil
46"#;
47
48/// Blocking pop with visibility timeout
49///
50/// Like POP_WITH_VISIBILITY but uses BRPOP for blocking behavior.
51///
52/// `KEYS[1]`: queue name
53/// `KEYS[2]`: unacked set name
54/// `ARGV[1]`: visibility timeout (Unix timestamp)
55/// `ARGV[2]`: block timeout in seconds
56///
57/// Returns: message data or nil
58pub const BRPOP_WITH_VISIBILITY: &str = r#"
59local queue = KEYS[1]
60local unacked_set = KEYS[2]
61local timeout_at = ARGV[1]
62local block_timeout = ARGV[2]
63
64-- Blocking pop from queue
65local result = redis.call('BRPOP', queue, block_timeout)
66
67if result then
68    local msg = result[2]  -- BRPOP returns [queue_name, message]
69    -- Add to unacked set with timeout score
70    redis.call('ZADD', unacked_set, timeout_at, msg)
71    return msg
72end
73
74return nil
75"#;
76
77/// Acknowledge (ACK) a message
78///
79/// Removes the message from the unacked set.
80///
81/// `KEYS[1]`: unacked set name
82/// `ARGV[1]`: message data
83///
84/// Returns: 1 if removed, 0 if not found
85pub const ACK_MESSAGE: &str = r#"
86local unacked_set = KEYS[1]
87local msg = ARGV[1]
88
89return redis.call('ZREM', unacked_set, msg)
90"#;
91
92/// Reject a message (NACK) with optional requeue
93///
94/// `KEYS[1]`: unacked set name
95/// `KEYS[2]`: queue name (for requeue)
96/// `KEYS[3]`: dead letter queue name
97/// `ARGV[1]`: message data
98/// `ARGV[2]`: requeue flag (1 = requeue, 0 = send to DLQ)
99///
100/// Returns: "requeued", "dlq", or "removed"
101pub const NACK_MESSAGE: &str = r#"
102local unacked_set = KEYS[1]
103local queue = KEYS[2]
104local dlq = KEYS[3]
105local msg = ARGV[1]
106local requeue = ARGV[2]
107
108-- Remove from unacked set
109redis.call('ZREM', unacked_set, msg)
110
111if requeue == "1" then
112    -- Requeue to original queue
113    redis.call('LPUSH', queue, msg)
114    return "requeued"
115else
116    -- Send to dead letter queue
117    redis.call('LPUSH', dlq, msg)
118    return "dlq"
119end
120"#;
121
122/// Recover timed-out messages
123///
124/// Moves messages from unacked set back to queue if they've exceeded
125/// the visibility timeout.
126///
127/// `KEYS[1]`: unacked set name
128/// `KEYS[2]`: queue name
129/// `ARGV[1]`: current time (Unix timestamp)
130/// `ARGV[2]`: max messages to recover
131///
132/// Returns: number of messages recovered
133pub const RECOVER_TIMED_OUT: &str = r#"
134local unacked_set = KEYS[1]
135local queue = KEYS[2]
136local current_time = ARGV[1]
137local max_count = ARGV[2]
138
139-- Get messages with score (timeout) less than current time
140local messages = redis.call('ZRANGEBYSCORE', unacked_set, '-inf', current_time, 'LIMIT', 0, max_count)
141
142if #messages > 0 then
143    -- Remove from unacked set
144    for i, msg in ipairs(messages) do
145        redis.call('ZREM', unacked_set, msg)
146        -- Requeue
147        redis.call('LPUSH', queue, msg)
148    end
149    return #messages
150end
151
152return 0
153"#;
154
155/// Priority queue pop
156///
157/// Pops from multiple queues in priority order.
158///
159/// `KEYS[1..N]`: queue names in priority order (high to low)
160/// `KEYS[N+1]`: unacked set name
161/// `ARGV[1]`: visibility timeout
162///
163/// Returns: [queue_name, message] or nil
164pub const POP_PRIORITY_WITH_VISIBILITY: &str = r#"
165local unacked_set = table.remove(KEYS)
166local timeout_at = ARGV[1]
167
168-- Try each queue in order (high priority first)
169for i, queue in ipairs(KEYS) do
170    local msg = redis.call('RPOP', queue)
171    if msg then
172        -- Add to unacked set
173        redis.call('ZADD', unacked_set, timeout_at, msg)
174        return {queue, msg}
175    end
176end
177
178return nil
179"#;
180
181/// Enqueue with priority
182///
183/// Adds a message to the appropriate priority queue.
184///
185/// `KEYS[1]`: base queue name
186/// `ARGV[1]`: priority (0-9, higher = more priority)
187/// `ARGV[2]`: message data
188///
189/// Returns: queue name used
190pub const ENQUEUE_WITH_PRIORITY: &str = r#"
191local base_queue = KEYS[1]
192local priority = tonumber(ARGV[1])
193local msg = ARGV[2]
194
195local queue_name
196if priority and priority > 0 then
197    -- Kombu priority queue naming convention
198    queue_name = base_queue .. '\x06\x16' .. priority
199else
200    queue_name = base_queue
201end
202
203redis.call('LPUSH', queue_name, msg)
204return queue_name
205"#;
206
207/// Current script version (increment when scripts change)
208pub const SCRIPT_VERSION: u32 = 1;
209
210/// Script identifier for easy lookup
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
212pub enum ScriptId {
213    /// Pop with visibility timeout
214    PopWithVisibility,
215    /// Blocking pop with visibility timeout
216    BrpopWithVisibility,
217    /// Acknowledge message
218    AckMessage,
219    /// Reject message
220    NackMessage,
221    /// Recover timed-out messages
222    RecoverTimedOut,
223    /// Priority queue pop with visibility
224    PopPriorityWithVisibility,
225    /// Enqueue with priority
226    EnqueueWithPriority,
227}
228
229impl ScriptId {
230    /// Get the script source code
231    pub fn source(&self) -> &'static str {
232        match self {
233            ScriptId::PopWithVisibility => POP_WITH_VISIBILITY,
234            ScriptId::BrpopWithVisibility => BRPOP_WITH_VISIBILITY,
235            ScriptId::AckMessage => ACK_MESSAGE,
236            ScriptId::NackMessage => NACK_MESSAGE,
237            ScriptId::RecoverTimedOut => RECOVER_TIMED_OUT,
238            ScriptId::PopPriorityWithVisibility => POP_PRIORITY_WITH_VISIBILITY,
239            ScriptId::EnqueueWithPriority => ENQUEUE_WITH_PRIORITY,
240        }
241    }
242
243    /// Get a human-readable name
244    pub fn name(&self) -> &'static str {
245        match self {
246            ScriptId::PopWithVisibility => "pop_with_visibility",
247            ScriptId::BrpopWithVisibility => "brpop_with_visibility",
248            ScriptId::AckMessage => "ack_message",
249            ScriptId::NackMessage => "nack_message",
250            ScriptId::RecoverTimedOut => "recover_timed_out",
251            ScriptId::PopPriorityWithVisibility => "pop_priority_with_visibility",
252            ScriptId::EnqueueWithPriority => "enqueue_with_priority",
253        }
254    }
255
256    /// Get all script IDs
257    pub fn all() -> Vec<ScriptId> {
258        vec![
259            ScriptId::PopWithVisibility,
260            ScriptId::BrpopWithVisibility,
261            ScriptId::AckMessage,
262            ScriptId::NackMessage,
263            ScriptId::RecoverTimedOut,
264            ScriptId::PopPriorityWithVisibility,
265            ScriptId::EnqueueWithPriority,
266        ]
267    }
268}
269
270/// Performance metrics for a single script
271#[derive(Debug, Clone, Default)]
272pub struct ScriptPerformance {
273    /// Number of times the script was executed
274    pub execution_count: u64,
275    /// Total execution time
276    pub total_duration: Duration,
277    /// Minimum execution time
278    pub min_duration: Option<Duration>,
279    /// Maximum execution time
280    pub max_duration: Option<Duration>,
281    /// Last execution time
282    pub last_execution: Option<Instant>,
283}
284
285impl ScriptPerformance {
286    /// Get average execution time
287    pub fn avg_duration(&self) -> Option<Duration> {
288        if self.execution_count > 0 {
289            Some(self.total_duration / self.execution_count as u32)
290        } else {
291            None
292        }
293    }
294
295    /// Record a new execution
296    pub fn record(&mut self, duration: Duration) {
297        self.execution_count += 1;
298        self.total_duration += duration;
299        self.last_execution = Some(Instant::now());
300
301        match self.min_duration {
302            None => self.min_duration = Some(duration),
303            Some(min) if duration < min => self.min_duration = Some(duration),
304            _ => {}
305        }
306
307        match self.max_duration {
308            None => self.max_duration = Some(duration),
309            Some(max) if duration > max => self.max_duration = Some(duration),
310            _ => {}
311        }
312    }
313
314    /// Reset all statistics
315    pub fn reset(&mut self) {
316        *self = Self::default();
317    }
318}
319
320/// Manages Lua script loading and caching
321pub struct ScriptManager {
322    client: Client,
323    /// Maps script ID to SHA1 hash
324    sha_cache: Arc<RwLock<HashMap<ScriptId, String>>>,
325    /// Maps script ID to Script object
326    script_cache: Arc<RwLock<HashMap<ScriptId, Script>>>,
327    /// Performance tracking for each script
328    performance: Arc<RwLock<HashMap<ScriptId, ScriptPerformance>>>,
329    /// Script version
330    version: u32,
331}
332
333impl ScriptManager {
334    /// Create a new script manager
335    pub fn new(client: Client) -> Self {
336        Self {
337            client,
338            sha_cache: Arc::new(RwLock::new(HashMap::new())),
339            script_cache: Arc::new(RwLock::new(HashMap::new())),
340            performance: Arc::new(RwLock::new(HashMap::new())),
341            version: SCRIPT_VERSION,
342        }
343    }
344
345    /// Get the script version
346    pub fn version(&self) -> u32 {
347        self.version
348    }
349
350    /// Record script execution performance
351    pub async fn record_execution(&self, script_id: ScriptId, duration: Duration) {
352        let mut perf = self.performance.write().await;
353        perf.entry(script_id).or_default().record(duration);
354
355        // Warn if execution is slow
356        if duration.as_millis() > 100 {
357            warn!(
358                "Slow script execution: {} took {}ms",
359                script_id.name(),
360                duration.as_millis()
361            );
362        }
363    }
364
365    /// Get performance metrics for a script
366    pub async fn get_performance(&self, script_id: ScriptId) -> Option<ScriptPerformance> {
367        self.performance.read().await.get(&script_id).cloned()
368    }
369
370    /// Get all performance metrics
371    pub async fn get_all_performance(&self) -> HashMap<ScriptId, ScriptPerformance> {
372        self.performance.read().await.clone()
373    }
374
375    /// Reset performance metrics for a specific script
376    pub async fn reset_performance(&self, script_id: ScriptId) {
377        let mut perf = self.performance.write().await;
378        if let Some(p) = perf.get_mut(&script_id) {
379            p.reset();
380        }
381    }
382
383    /// Reset all performance metrics
384    pub async fn reset_all_performance(&self) {
385        let mut perf = self.performance.write().await;
386        for p in perf.values_mut() {
387            p.reset();
388        }
389    }
390
391    /// Load all scripts into Redis and cache their SHA1 hashes
392    pub async fn load_all(&self) -> Result<()> {
393        let mut conn = self
394            .client
395            .get_multiplexed_async_connection()
396            .await
397            .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
398
399        let mut sha_cache = self.sha_cache.write().await;
400        let mut script_cache = self.script_cache.write().await;
401
402        for script_id in ScriptId::all() {
403            let source = script_id.source();
404            let script = Script::new(source);
405
406            // Load script and get SHA1
407            let sha: String = redis::cmd("SCRIPT")
408                .arg("LOAD")
409                .arg(source)
410                .query_async(&mut conn)
411                .await
412                .map_err(|e| {
413                    CelersError::Broker(format!(
414                        "Failed to load script {}: {}",
415                        script_id.name(),
416                        e
417                    ))
418                })?;
419
420            debug!("Loaded script {} with SHA: {}", script_id.name(), sha);
421
422            sha_cache.insert(script_id, sha);
423            script_cache.insert(script_id, script);
424        }
425
426        info!("Loaded {} Lua scripts into Redis", ScriptId::all().len());
427
428        Ok(())
429    }
430
431    /// Get the SHA1 hash for a script
432    pub async fn get_sha(&self, script_id: ScriptId) -> Option<String> {
433        self.sha_cache.read().await.get(&script_id).cloned()
434    }
435
436    /// Get a Script object for execution
437    pub async fn get_script(&self, script_id: ScriptId) -> Option<Script> {
438        self.script_cache.read().await.get(&script_id).cloned()
439    }
440
441    /// Load a single script
442    pub async fn load_script(&self, script_id: ScriptId) -> Result<String> {
443        let mut conn = self
444            .client
445            .get_multiplexed_async_connection()
446            .await
447            .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
448
449        let source = script_id.source();
450        let script = Script::new(source);
451
452        // Load script and get SHA1
453        let sha: String = redis::cmd("SCRIPT")
454            .arg("LOAD")
455            .arg(source)
456            .query_async(&mut conn)
457            .await
458            .map_err(|e| {
459                CelersError::Broker(format!("Failed to load script {}: {}", script_id.name(), e))
460            })?;
461
462        debug!("Loaded script {} with SHA: {}", script_id.name(), sha);
463
464        // Update caches
465        let mut sha_cache = self.sha_cache.write().await;
466        let mut script_cache = self.script_cache.write().await;
467
468        sha_cache.insert(script_id, sha.clone());
469        script_cache.insert(script_id, script);
470
471        Ok(sha)
472    }
473
474    /// Check if a script is loaded in Redis
475    pub async fn is_loaded(&self, script_id: ScriptId) -> Result<bool> {
476        let sha = match self.get_sha(script_id).await {
477            Some(sha) => sha,
478            None => return Ok(false),
479        };
480
481        let mut conn = self
482            .client
483            .get_multiplexed_async_connection()
484            .await
485            .map_err(|e| CelersError::Broker(format!("Failed to get connection: {}", e)))?;
486
487        let exists: Vec<bool> = redis::cmd("SCRIPT")
488            .arg("EXISTS")
489            .arg(&sha)
490            .query_async(&mut conn)
491            .await
492            .map_err(|e| CelersError::Broker(format!("Failed to check script: {}", e)))?;
493
494        Ok(exists.first().copied().unwrap_or(false))
495    }
496
497    /// Clear the script cache (useful for testing or after Redis restart)
498    pub async fn clear_cache(&self) {
499        let mut sha_cache = self.sha_cache.write().await;
500        let mut script_cache = self.script_cache.write().await;
501
502        sha_cache.clear();
503        script_cache.clear();
504
505        debug!("Cleared script cache");
506    }
507
508    /// Get statistics about loaded scripts
509    pub async fn stats(&self) -> ScriptStats {
510        let sha_cache = self.sha_cache.read().await;
511        let script_cache = self.script_cache.read().await;
512        let perf = self.performance.read().await;
513
514        let total_executions: u64 = perf.values().map(|p| p.execution_count).sum();
515
516        ScriptStats {
517            total_scripts: ScriptId::all().len(),
518            loaded_scripts: sha_cache.len(),
519            cached_scripts: script_cache.len(),
520            version: self.version,
521            total_executions,
522        }
523    }
524}
525
526/// Script manager statistics
527#[derive(Debug, Clone)]
528pub struct ScriptStats {
529    /// Total number of available scripts
530    pub total_scripts: usize,
531    /// Number of scripts loaded in Redis
532    pub loaded_scripts: usize,
533    /// Number of scripts cached in memory
534    pub cached_scripts: usize,
535    /// Script version
536    pub version: u32,
537    /// Total number of script executions
538    pub total_executions: u64,
539}
540
541impl ScriptStats {
542    /// Check if all scripts are loaded
543    pub fn all_loaded(&self) -> bool {
544        self.loaded_scripts == self.total_scripts
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    #[allow(clippy::const_is_empty)]
554    fn test_scripts_are_valid() {
555        // Verify scripts are non-empty
556        assert!(!POP_WITH_VISIBILITY.is_empty());
557        assert!(!BRPOP_WITH_VISIBILITY.is_empty());
558        assert!(!ACK_MESSAGE.is_empty());
559        assert!(!NACK_MESSAGE.is_empty());
560        assert!(!RECOVER_TIMED_OUT.is_empty());
561        assert!(!POP_PRIORITY_WITH_VISIBILITY.is_empty());
562        assert!(!ENQUEUE_WITH_PRIORITY.is_empty());
563    }
564
565    #[test]
566    fn test_script_syntax() {
567        // Basic syntax validation (contains essential keywords)
568        assert!(POP_WITH_VISIBILITY.contains("RPOP"));
569        assert!(POP_WITH_VISIBILITY.contains("ZADD"));
570
571        assert!(ACK_MESSAGE.contains("ZREM"));
572
573        assert!(NACK_MESSAGE.contains("LPUSH"));
574
575        assert!(RECOVER_TIMED_OUT.contains("ZRANGEBYSCORE"));
576    }
577
578    #[test]
579    fn test_script_id_source() {
580        assert_eq!(ScriptId::PopWithVisibility.source(), POP_WITH_VISIBILITY);
581        assert_eq!(
582            ScriptId::BrpopWithVisibility.source(),
583            BRPOP_WITH_VISIBILITY
584        );
585        assert_eq!(ScriptId::AckMessage.source(), ACK_MESSAGE);
586        assert_eq!(ScriptId::NackMessage.source(), NACK_MESSAGE);
587        assert_eq!(ScriptId::RecoverTimedOut.source(), RECOVER_TIMED_OUT);
588        assert_eq!(
589            ScriptId::PopPriorityWithVisibility.source(),
590            POP_PRIORITY_WITH_VISIBILITY
591        );
592        assert_eq!(
593            ScriptId::EnqueueWithPriority.source(),
594            ENQUEUE_WITH_PRIORITY
595        );
596    }
597
598    #[test]
599    fn test_script_id_name() {
600        assert_eq!(ScriptId::PopWithVisibility.name(), "pop_with_visibility");
601        assert_eq!(
602            ScriptId::BrpopWithVisibility.name(),
603            "brpop_with_visibility"
604        );
605        assert_eq!(ScriptId::AckMessage.name(), "ack_message");
606        assert_eq!(ScriptId::NackMessage.name(), "nack_message");
607        assert_eq!(ScriptId::RecoverTimedOut.name(), "recover_timed_out");
608        assert_eq!(
609            ScriptId::PopPriorityWithVisibility.name(),
610            "pop_priority_with_visibility"
611        );
612        assert_eq!(
613            ScriptId::EnqueueWithPriority.name(),
614            "enqueue_with_priority"
615        );
616    }
617
618    #[test]
619    fn test_script_id_all() {
620        let all_scripts = ScriptId::all();
621        assert_eq!(all_scripts.len(), 7);
622        assert!(all_scripts.contains(&ScriptId::PopWithVisibility));
623        assert!(all_scripts.contains(&ScriptId::BrpopWithVisibility));
624        assert!(all_scripts.contains(&ScriptId::AckMessage));
625        assert!(all_scripts.contains(&ScriptId::NackMessage));
626        assert!(all_scripts.contains(&ScriptId::RecoverTimedOut));
627        assert!(all_scripts.contains(&ScriptId::PopPriorityWithVisibility));
628        assert!(all_scripts.contains(&ScriptId::EnqueueWithPriority));
629    }
630
631    #[test]
632    fn test_script_stats() {
633        let stats = ScriptStats {
634            total_scripts: 7,
635            loaded_scripts: 7,
636            cached_scripts: 7,
637            version: SCRIPT_VERSION,
638            total_executions: 0,
639        };
640
641        assert!(stats.all_loaded());
642        assert_eq!(stats.version, SCRIPT_VERSION);
643
644        let stats_incomplete = ScriptStats {
645            total_scripts: 7,
646            loaded_scripts: 5,
647            cached_scripts: 5,
648            version: SCRIPT_VERSION,
649            total_executions: 0,
650        };
651
652        assert!(!stats_incomplete.all_loaded());
653    }
654
655    #[test]
656    fn test_script_performance() {
657        let mut perf = ScriptPerformance::default();
658        assert_eq!(perf.execution_count, 0);
659        assert_eq!(perf.avg_duration(), None);
660
661        perf.record(Duration::from_millis(10));
662        assert_eq!(perf.execution_count, 1);
663        assert_eq!(perf.avg_duration(), Some(Duration::from_millis(10)));
664        assert_eq!(perf.min_duration, Some(Duration::from_millis(10)));
665        assert_eq!(perf.max_duration, Some(Duration::from_millis(10)));
666
667        perf.record(Duration::from_millis(20));
668        assert_eq!(perf.execution_count, 2);
669        assert_eq!(perf.avg_duration(), Some(Duration::from_millis(15)));
670        assert_eq!(perf.min_duration, Some(Duration::from_millis(10)));
671        assert_eq!(perf.max_duration, Some(Duration::from_millis(20)));
672
673        perf.reset();
674        assert_eq!(perf.execution_count, 0);
675        assert_eq!(perf.avg_duration(), None);
676    }
677
678    #[test]
679    fn test_script_version() {
680        assert_eq!(SCRIPT_VERSION, 1);
681    }
682}