redis_oxide/script.rs
1//! Lua scripting support for Redis
2//!
3//! This module provides functionality for executing Lua scripts on Redis servers
4//! using EVAL and EVALSHA commands. Scripts are automatically cached and managed
5//! for optimal performance.
6//!
7//! # Examples
8//!
9//! ## Basic Script Execution
10//!
11//! ```no_run
12//! use redis_oxide::{Client, ConnectionConfig};
13//!
14//! # #[tokio::main]
15//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//! let config = ConnectionConfig::new("redis://localhost:6379");
17//! let client = Client::connect(config).await?;
18//!
19//! // Execute a simple Lua script
20//! let script = "return redis.call('GET', KEYS[1])";
21//! let result: Option<String> = client.eval(script, vec!["mykey".to_string()], vec![]).await?;
22//! println!("Result: {:?}", result);
23//! # Ok(())
24//! # }
25//! ```
26//!
27//! ## Script with Arguments
28//!
29//! ```no_run
30//! use redis_oxide::{Client, ConnectionConfig};
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! let config = ConnectionConfig::new("redis://localhost:6379");
35//! let client = Client::connect(config).await?;
36//!
37//! // Script that increments a counter by a given amount
38//! let script = r#"
39//!     local current = redis.call('GET', KEYS[1]) or 0
40//!     local increment = tonumber(ARGV[1])
41//!     local new_value = tonumber(current) + increment
42//!     redis.call('SET', KEYS[1], new_value)
43//!     return new_value
44//! "#;
45//!
46//! let result: i64 = client.eval(
47//!     script,
48//!     vec!["counter".to_string()],
49//!     vec!["5".to_string()]
50//! ).await?;
51//! println!("New counter value: {}", result);
52//! # Ok(())
53//! # }
54//! ```
55//!
56//! ## Using Script Manager for Caching
57//!
58//! ```no_run
59//! use redis_oxide::{Client, ConnectionConfig, Script};
60//!
61//! # #[tokio::main]
62//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
63//! let config = ConnectionConfig::new("redis://localhost:6379");
64//! let client = Client::connect(config).await?;
65//!
66//! // Create a reusable script
67//! let script = Script::new(r#"
68//!     local key = KEYS[1]
69//!     local value = ARGV[1]
70//!     redis.call('SET', key, value)
71//!     return redis.call('GET', key)
72//! "#);
73//!
74//! // Execute the script (automatically uses EVALSHA if cached)
75//! let result: String = script.execute(
76//!     &client,
77//!     vec!["mykey".to_string()],
78//!     vec!["myvalue".to_string()]
79//! ).await?;
80//! println!("Result: {}", result);
81//! # Ok(())
82//! # }
83//! ```
84
85use crate::core::{
86    error::{RedisError, RedisResult},
87    value::RespValue,
88};
89use sha1::{Digest, Sha1};
90use std::collections::HashMap;
91use std::convert::TryFrom;
92use std::sync::Arc;
93use tokio::sync::RwLock;
94
95/// A Lua script that can be executed on Redis
96#[derive(Debug, Clone)]
97pub struct Script {
98    /// The Lua script source code
99    source: String,
100    /// SHA1 hash of the script (for EVALSHA)
101    sha: String,
102}
103
104impl Script {
105    /// Create a new script from Lua source code
106    ///
107    /// The script is automatically hashed for use with EVALSHA.
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// use redis_oxide::Script;
113    ///
114    /// let script = Script::new("return redis.call('GET', KEYS[1])");
115    /// println!("Script SHA: {}", script.sha());
116    /// ```
117    pub fn new(source: impl Into<String>) -> Self {
118        let source = source.into();
119        let sha = calculate_sha1(&source);
120
121        Self { source, sha }
122    }
123
124    /// Get the SHA1 hash of the script
125    #[must_use]
126    pub fn sha(&self) -> &str {
127        &self.sha
128    }
129
130    /// Get the source code of the script
131    #[must_use]
132    pub fn source(&self) -> &str {
133        &self.source
134    }
135
136    /// Execute the script on the given client
137    ///
138    /// This method will first try to use EVALSHA (if the script is cached on the server),
139    /// and fall back to EVAL if the script is not cached.
140    ///
141    /// # Arguments
142    ///
143    /// * `client` - The Redis client to execute the script on
144    /// * `keys` - List of Redis keys that the script will access (KEYS array in Lua)
145    /// * `args` - List of arguments to pass to the script (ARGV array in Lua)
146    ///
147    /// # Examples
148    ///
149    /// ```no_run
150    /// # use redis_oxide::{Client, ConnectionConfig, Script};
151    /// # #[tokio::main]
152    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
153    /// # let config = ConnectionConfig::new("redis://localhost:6379");
154    /// # let client = Client::connect(config).await?;
155    /// let script = Script::new("return KEYS[1] .. ':' .. ARGV[1]");
156    ///
157    /// let result: String = script.execute(
158    ///     &client,
159    ///     vec!["user".to_string()],
160    ///     vec!["123".to_string()]
161    /// ).await?;
162    ///
163    /// assert_eq!(result, "user:123");
164    /// # Ok(())
165    /// # }
166    /// ```
167    pub async fn execute<T>(
168        &self,
169        client: &crate::Client,
170        keys: Vec<String>,
171        args: Vec<String>,
172    ) -> RedisResult<T>
173    where
174        T: TryFrom<RespValue>,
175        T::Error: Into<RedisError>,
176    {
177        // First try EVALSHA
178        match client.evalsha(&self.sha, keys.clone(), args.clone()).await {
179            Ok(result) => Ok(result),
180            Err(RedisError::Protocol(msg)) if msg.contains("NOSCRIPT") => {
181                // Script not cached, use EVAL
182                client.eval(&self.source, keys, args).await
183            }
184            Err(e) => Err(e),
185        }
186    }
187
188    /// Load the script into Redis cache
189    ///
190    /// This sends the script to Redis using SCRIPT LOAD, which caches it
191    /// for future EVALSHA calls.
192    ///
193    /// # Examples
194    ///
195    /// ```no_run
196    /// # use redis_oxide::{Client, ConnectionConfig, Script};
197    /// # #[tokio::main]
198    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
199    /// # let config = ConnectionConfig::new("redis://localhost:6379");
200    /// # let client = Client::connect(config).await?;
201    /// let script = Script::new("return 'Hello, World!'");
202    ///
203    /// // Preload the script
204    /// let sha = script.load(&client).await?;
205    /// println!("Script loaded with SHA: {}", sha);
206    /// # Ok(())
207    /// # }
208    /// ```
209    pub async fn load(&self, client: &crate::Client) -> RedisResult<String> {
210        client.script_load(&self.source).await
211    }
212}
213
214/// Script manager for caching and managing multiple scripts
215#[derive(Debug)]
216pub struct ScriptManager {
217    scripts: Arc<RwLock<HashMap<String, Script>>>,
218}
219
220impl ScriptManager {
221    /// Create a new script manager
222    #[must_use]
223    pub fn new() -> Self {
224        Self {
225            scripts: Arc::new(RwLock::new(HashMap::new())),
226        }
227    }
228
229    /// Register a script with the manager
230    ///
231    /// # Examples
232    ///
233    /// ```no_run
234    /// use redis_oxide::{Script, ScriptManager};
235    ///
236    /// # #[tokio::main]
237    /// # async fn main() {
238    /// let mut manager = ScriptManager::new();
239    /// let script = Script::new("return 'Hello'");
240    ///
241    /// manager.register("greeting", script).await;
242    /// # }
243    /// ```
244    pub async fn register(&self, name: impl Into<String>, script: Script) {
245        let mut scripts = self.scripts.write().await;
246        scripts.insert(name.into(), script);
247    }
248
249    /// Get a script by name
250    ///
251    /// # Examples
252    ///
253    /// ```
254    /// # use redis_oxide::{Script, ScriptManager};
255    /// # #[tokio::main]
256    /// # async fn main() {
257    /// let manager = ScriptManager::new();
258    /// let script = Script::new("return 'Hello'");
259    /// manager.register("greeting", script).await;
260    ///
261    /// if let Some(script) = manager.get("greeting").await {
262    ///     println!("Found script with SHA: {}", script.sha());
263    /// }
264    /// # }
265    /// ```
266    pub async fn get(&self, name: &str) -> Option<Script> {
267        let scripts = self.scripts.read().await;
268        scripts.get(name).cloned()
269    }
270
271    /// Execute a script by name
272    ///
273    /// # Examples
274    ///
275    /// ```no_run
276    /// # use redis_oxide::{Client, ConnectionConfig, Script, ScriptManager};
277    /// # #[tokio::main]
278    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
279    /// # let config = ConnectionConfig::new("redis://localhost:6379");
280    /// # let client = Client::connect(config).await?;
281    /// let manager = ScriptManager::new();
282    /// let script = Script::new("return KEYS[1]");
283    /// manager.register("get_key", script).await;
284    ///
285    /// let result: String = manager.execute(
286    ///     "get_key",
287    ///     &client,
288    ///     vec!["mykey".to_string()],
289    ///     vec![]
290    /// ).await?;
291    /// # Ok(())
292    /// # }
293    /// ```
294    pub async fn execute<T>(
295        &self,
296        name: &str,
297        client: &crate::Client,
298        keys: Vec<String>,
299        args: Vec<String>,
300    ) -> RedisResult<T>
301    where
302        T: TryFrom<RespValue>,
303        T::Error: Into<RedisError>,
304    {
305        let script = self
306            .get(name)
307            .await
308            .ok_or_else(|| RedisError::Protocol(format!("Script '{}' not found", name)))?;
309
310        script.execute(client, keys, args).await
311    }
312
313    /// Load all registered scripts into Redis cache
314    ///
315    /// # Examples
316    ///
317    /// ```no_run
318    /// # use redis_oxide::{Client, ConnectionConfig, Script, ScriptManager};
319    /// # #[tokio::main]
320    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
321    /// # let config = ConnectionConfig::new("redis://localhost:6379");
322    /// # let client = Client::connect(config).await?;
323    /// let manager = ScriptManager::new();
324    ///
325    /// // Register some scripts
326    /// manager.register("script1", Script::new("return 1")).await;
327    /// manager.register("script2", Script::new("return 2")).await;
328    ///
329    /// // Load all scripts at once
330    /// let results = manager.load_all(&client).await?;
331    /// println!("Loaded {} scripts", results.len());
332    /// # Ok(())
333    /// # }
334    /// ```
335    pub async fn load_all(&self, client: &crate::Client) -> RedisResult<HashMap<String, String>> {
336        let scripts = self.scripts.read().await;
337        let mut results = HashMap::new();
338
339        for (name, script) in scripts.iter() {
340            let sha = script.load(client).await?;
341            results.insert(name.clone(), sha);
342        }
343
344        Ok(results)
345    }
346
347    /// Get the number of registered scripts
348    #[must_use]
349    pub async fn len(&self) -> usize {
350        let scripts = self.scripts.read().await;
351        scripts.len()
352    }
353
354    /// Check if the manager has any scripts
355    #[must_use]
356    pub async fn is_empty(&self) -> bool {
357        let scripts = self.scripts.read().await;
358        scripts.is_empty()
359    }
360
361    /// List all registered script names
362    pub async fn list_scripts(&self) -> Vec<String> {
363        let scripts = self.scripts.read().await;
364        scripts.keys().cloned().collect()
365    }
366
367    /// Remove a script from the manager
368    pub async fn remove(&self, name: &str) -> Option<Script> {
369        let mut scripts = self.scripts.write().await;
370        scripts.remove(name)
371    }
372
373    /// Clear all scripts from the manager
374    pub async fn clear(&self) {
375        let mut scripts = self.scripts.write().await;
376        scripts.clear();
377    }
378}
379
380impl Default for ScriptManager {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386/// Calculate SHA1 hash of a string
387fn calculate_sha1(input: &str) -> String {
388    let mut hasher = Sha1::new();
389    hasher.update(input.as_bytes());
390    let result = hasher.finalize();
391    hex::encode(result)
392}
393
394/// Common Lua script patterns
395pub mod patterns {
396    use super::Script;
397
398    /// Atomic increment with expiration
399    ///
400    /// # Arguments
401    /// - KEYS[1]: The key to increment
402    /// - ARGV\[1\]: Increment amount
403    /// - ARGV\[2\]: Expiration time in seconds
404    pub fn atomic_increment_with_expiration() -> Script {
405        Script::new(
406            r"
407            local key = KEYS[1]
408            local increment = tonumber(ARGV[1])
409            local expiration = tonumber(ARGV[2])
410            
411            local current = redis.call('GET', key)
412            local new_value
413            
414            if current == false then
415                new_value = increment
416            else
417                new_value = tonumber(current) + increment
418            end
419            
420            redis.call('SET', key, new_value)
421            redis.call('EXPIRE', key, expiration)
422            
423            return new_value
424        ",
425        )
426    }
427
428    /// Conditional set (SET if value matches)
429    ///
430    /// # Arguments
431    /// - KEYS\[1\]: The key to set
432    /// - ARGV\[1\]: Expected current value
433    /// - ARGV\[2\]: New value to set
434    pub fn conditional_set() -> Script {
435        Script::new(
436            r"
437            local key = KEYS[1]
438            local expected = ARGV[1]
439            local new_value = ARGV[2]
440            
441            local current = redis.call('GET', key)
442            
443            if current == expected then
444                redis.call('SET', key, new_value)
445                return 1
446            else
447                return 0
448            end
449        ",
450        )
451    }
452
453    /// Rate limiting with sliding window
454    ///
455    /// # Arguments
456    /// - KEYS\[1\]: The rate limit key
457    /// - ARGV\[1\]: Window size in seconds
458    /// - ARGV\[2\]: Maximum requests per window
459    pub fn sliding_window_rate_limit() -> Script {
460        Script::new(
461            r#"
462            local key = KEYS[1]
463            local window = tonumber(ARGV[1])
464            local limit = tonumber(ARGV[2])
465            local now = redis.call('TIME')[1]
466            
467            -- Remove old entries
468            redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
469            
470            -- Count current entries
471            local current = redis.call('ZCARD', key)
472            
473            if current < limit then
474                -- Add current request
475                redis.call('ZADD', key, now, now)
476                redis.call('EXPIRE', key, window)
477                return { 1, limit - current - 1 }
478            else
479                return { 0, 0 }
480            end
481        "#,
482        )
483    }
484
485    /// Distributed lock with expiration
486    ///
487    /// # Arguments
488    /// - KEYS\[1\]: The lock key
489    /// - ARGV\[1\]: Lock identifier (unique per client)
490    /// - ARGV\[2\]: Lock expiration in seconds
491    pub fn distributed_lock() -> Script {
492        Script::new(
493            r#"
494            local key = KEYS[1]
495            local identifier = ARGV[1]
496            local expiration = tonumber(ARGV[2])
497            
498            if redis.call('SET', key, identifier, 'NX', 'EX', expiration) then
499                return 1
500            else
501                return 0
502            end
503        "#,
504        )
505    }
506
507    /// Release distributed lock
508    ///
509    /// # Arguments
510    /// - KEYS\[1\]: The lock key
511    /// - ARGV\[1\]: Lock identifier (must match)
512    pub fn release_lock() -> Script {
513        Script::new(
514            r#"
515            local key = KEYS[1]
516            local identifier = ARGV[1]
517            
518            if redis.call('GET', key) == identifier then
519                return redis.call('DEL', key)
520            else
521                return 0
522            end
523        "#,
524        )
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_script_creation() {
534        let script = Script::new("return 'hello'");
535        assert_eq!(script.source(), "return 'hello'");
536        assert!(!script.sha().is_empty());
537        assert_eq!(script.sha().len(), 40); // SHA1 is 40 hex characters
538    }
539
540    #[test]
541    fn test_sha1_calculation() {
542        let sha = calculate_sha1("hello world");
543        assert_eq!(sha, "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed");
544    }
545
546    #[test]
547    fn test_script_sha_consistency() {
548        let script1 = Script::new("return 1");
549        let script2 = Script::new("return 1");
550        assert_eq!(script1.sha(), script2.sha());
551    }
552
553    #[test]
554    fn test_script_sha_uniqueness() {
555        let script1 = Script::new("return 1");
556        let script2 = Script::new("return 2");
557        assert_ne!(script1.sha(), script2.sha());
558    }
559
560    #[tokio::test]
561    async fn test_script_manager_creation() {
562        let manager = ScriptManager::new();
563        assert!(manager.is_empty().await);
564        assert_eq!(manager.len().await, 0);
565    }
566
567    #[tokio::test]
568    async fn test_script_manager_register_and_get() {
569        let manager = ScriptManager::new();
570        let script = Script::new("return 'test'");
571        let sha = script.sha().to_string();
572
573        manager.register("test_script", script).await;
574
575        assert!(!manager.is_empty().await);
576        assert_eq!(manager.len().await, 1);
577
578        let retrieved = manager.get("test_script").await.unwrap();
579        assert_eq!(retrieved.sha(), sha);
580        assert_eq!(retrieved.source(), "return 'test'");
581    }
582
583    #[tokio::test]
584    async fn test_script_manager_remove() {
585        let manager = ScriptManager::new();
586        let script = Script::new("return 'test'");
587
588        manager.register("test_script", script).await;
589        assert_eq!(manager.len().await, 1);
590
591        let removed = manager.remove("test_script").await;
592        assert!(removed.is_some());
593        assert_eq!(manager.len().await, 0);
594
595        let not_found = manager.remove("nonexistent").await;
596        assert!(not_found.is_none());
597    }
598
599    #[tokio::test]
600    async fn test_script_manager_clear() {
601        let manager = ScriptManager::new();
602
603        manager.register("script1", Script::new("return 1")).await;
604        manager.register("script2", Script::new("return 2")).await;
605        assert_eq!(manager.len().await, 2);
606
607        manager.clear().await;
608        assert_eq!(manager.len().await, 0);
609        assert!(manager.is_empty().await);
610    }
611
612    #[tokio::test]
613    async fn test_script_manager_list_scripts() {
614        let manager = ScriptManager::new();
615
616        manager
617            .register("script_a", Script::new("return 'a'"))
618            .await;
619        manager
620            .register("script_b", Script::new("return 'b'"))
621            .await;
622
623        let mut scripts = manager.list_scripts().await;
624        scripts.sort();
625
626        assert_eq!(scripts, vec!["script_a", "script_b"]);
627    }
628
629    #[test]
630    fn test_pattern_scripts() {
631        // Test that pattern scripts can be created without panicking
632        let _increment = patterns::atomic_increment_with_expiration();
633        let _conditional = patterns::conditional_set();
634        let _rate_limit = patterns::sliding_window_rate_limit();
635        let _lock = patterns::distributed_lock();
636        let _unlock = patterns::release_lock();
637    }
638}