redis_oxide/
transaction.rs1use crate::commands::Command;
32use crate::core::{
33    error::{RedisError, RedisResult},
34    value::RespValue,
35};
36use std::collections::VecDeque;
37use std::sync::Arc;
38use tokio::sync::Mutex;
39
40pub struct Transaction {
42    commands: VecDeque<Box<dyn TransactionCommand>>,
43    connection: Arc<Mutex<dyn TransactionExecutor + Send + Sync>>,
44    watched_keys: Vec<String>,
45    is_started: bool,
46}
47
48pub trait TransactionCommand: Send + Sync {
50    fn name(&self) -> &str;
52
53    fn args(&self) -> Vec<RespValue>;
55
56    fn key(&self) -> Option<String>;
58}
59
60#[async_trait::async_trait]
62pub trait TransactionExecutor {
63    async fn multi(&mut self) -> RedisResult<()>;
65
66    async fn queue_command(&mut self, command: Box<dyn TransactionCommand>) -> RedisResult<()>;
68
69    async fn exec(&mut self) -> RedisResult<Vec<RespValue>>;
71
72    async fn discard(&mut self) -> RedisResult<()>;
74
75    async fn watch(&mut self, keys: Vec<String>) -> RedisResult<()>;
77
78    async fn unwatch(&mut self) -> RedisResult<()>;
80}
81
82impl Transaction {
83    pub fn new(connection: Arc<Mutex<dyn TransactionExecutor + Send + Sync>>) -> Self {
85        Self {
86            commands: VecDeque::new(),
87            connection,
88            watched_keys: Vec::new(),
89            is_started: false,
90        }
91    }
92
93    pub async fn watch(&mut self, keys: Vec<String>) -> RedisResult<()> {
120        if self.is_started {
121            return Err(RedisError::Protocol("Cannot WATCH after MULTI".to_string()));
122        }
123
124        let mut connection = self.connection.lock().await;
125        connection.watch(keys.clone()).await?;
126        self.watched_keys.extend(keys);
127        Ok(())
128    }
129
130    pub async fn unwatch(&mut self) -> RedisResult<()> {
132        let mut connection = self.connection.lock().await;
133        connection.unwatch().await?;
134        self.watched_keys.clear();
135        Ok(())
136    }
137
138    pub fn add_command(&mut self, command: Box<dyn TransactionCommand>) -> &mut Self {
140        self.commands.push_back(command);
141        self
142    }
143
144    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
146        use crate::commands::SetCommand;
147        let cmd = SetCommand::new(key.into(), value.into());
148        self.add_command(Box::new(cmd))
149    }
150
151    pub fn get(&mut self, key: impl Into<String>) -> &mut Self {
153        use crate::commands::GetCommand;
154        let cmd = GetCommand::new(key.into());
155        self.add_command(Box::new(cmd))
156    }
157
158    pub fn del(&mut self, keys: Vec<String>) -> &mut Self {
160        use crate::commands::DelCommand;
161        let cmd = DelCommand::new(keys);
162        self.add_command(Box::new(cmd))
163    }
164
165    pub fn incr(&mut self, key: impl Into<String>) -> &mut Self {
167        use crate::commands::IncrCommand;
168        let cmd = IncrCommand::new(key.into());
169        self.add_command(Box::new(cmd))
170    }
171
172    pub fn decr(&mut self, key: impl Into<String>) -> &mut Self {
174        use crate::commands::DecrCommand;
175        let cmd = DecrCommand::new(key.into());
176        self.add_command(Box::new(cmd))
177    }
178
179    pub fn incr_by(&mut self, key: impl Into<String>, increment: i64) -> &mut Self {
181        use crate::commands::IncrByCommand;
182        let cmd = IncrByCommand::new(key.into(), increment);
183        self.add_command(Box::new(cmd))
184    }
185
186    pub fn decr_by(&mut self, key: impl Into<String>, decrement: i64) -> &mut Self {
188        use crate::commands::DecrByCommand;
189        let cmd = DecrByCommand::new(key.into(), decrement);
190        self.add_command(Box::new(cmd))
191    }
192
193    pub fn exists(&mut self, keys: Vec<String>) -> &mut Self {
195        use crate::commands::ExistsCommand;
196        let cmd = ExistsCommand::new(keys);
197        self.add_command(Box::new(cmd))
198    }
199
200    pub fn expire(&mut self, key: impl Into<String>, seconds: std::time::Duration) -> &mut Self {
202        use crate::commands::ExpireCommand;
203        let cmd = ExpireCommand::new(key.into(), seconds);
204        self.add_command(Box::new(cmd))
205    }
206
207    pub fn ttl(&mut self, key: impl Into<String>) -> &mut Self {
209        use crate::commands::TtlCommand;
210        let cmd = TtlCommand::new(key.into());
211        self.add_command(Box::new(cmd))
212    }
213
214    pub fn hget(&mut self, key: impl Into<String>, field: impl Into<String>) -> &mut Self {
218        use crate::commands::HGetCommand;
219        let cmd = HGetCommand::new(key.into(), field.into());
220        self.add_command(Box::new(cmd))
221    }
222
223    pub fn hset(
225        &mut self,
226        key: impl Into<String>,
227        field: impl Into<String>,
228        value: impl Into<String>,
229    ) -> &mut Self {
230        use crate::commands::HSetCommand;
231        let cmd = HSetCommand::new(key.into(), field.into(), value.into());
232        self.add_command(Box::new(cmd))
233    }
234
235    #[must_use]
237    pub fn len(&self) -> usize {
238        self.commands.len()
239    }
240
241    #[must_use]
243    pub fn is_empty(&self) -> bool {
244        self.commands.is_empty()
245    }
246
247    pub fn clear(&mut self) {
249        self.commands.clear();
250    }
251
252    pub async fn exec(&mut self) -> RedisResult<Vec<RespValue>> {
283        if self.commands.is_empty() {
284            return Err(RedisError::Protocol("Transaction is empty".to_string()));
285        }
286
287        let mut connection = self.connection.lock().await;
288
289        if !self.is_started {
291            connection.multi().await?;
292            self.is_started = true;
293        }
294
295        let commands: Vec<Box<dyn TransactionCommand>> = self.commands.drain(..).collect();
297        for command in commands {
298            connection.queue_command(command).await?;
299        }
300
301        let results = connection.exec().await?;
303        self.is_started = false;
304
305        Ok(results)
306    }
307
308    pub async fn discard(&mut self) -> RedisResult<()> {
330        let mut connection = self.connection.lock().await;
331        connection.discard().await?;
332        self.commands.clear();
333        self.is_started = false;
334        Ok(())
335    }
336}
337
338#[derive(Debug, Clone)]
340pub struct TransactionResult {
341    results: Vec<RespValue>,
342    index: usize,
343}
344
345impl TransactionResult {
346    #[must_use]
348    pub fn new(results: Vec<RespValue>) -> Self {
349        Self { results, index: 0 }
350    }
351
352    pub fn next<T>(&mut self) -> RedisResult<T>
358    where
359        T: TryFrom<RespValue>,
360        T::Error: Into<RedisError>,
361    {
362        if self.index >= self.results.len() {
363            return Err(RedisError::Protocol(
364                "No more results in transaction".to_string(),
365            ));
366        }
367
368        let result = self.results[self.index].clone();
369        self.index += 1;
370
371        T::try_from(result).map_err(Into::into)
372    }
373
374    pub fn get<T>(&self, index: usize) -> RedisResult<T>
380    where
381        T: TryFrom<RespValue>,
382        T::Error: Into<RedisError>,
383    {
384        if index >= self.results.len() {
385            return Err(RedisError::Protocol(format!(
386                "Index {} out of bounds",
387                index
388            )));
389        }
390
391        let result = self.results[index].clone();
392        T::try_from(result).map_err(Into::into)
393    }
394
395    #[must_use]
397    pub fn len(&self) -> usize {
398        self.results.len()
399    }
400
401    #[must_use]
403    pub fn is_empty(&self) -> bool {
404        self.results.is_empty()
405    }
406
407    #[must_use]
409    pub fn into_results(self) -> Vec<RespValue> {
410        self.results
411    }
412}
413
414impl TransactionCommand for crate::commands::GetCommand {
416    fn name(&self) -> &str {
417        self.command_name()
418    }
419
420    fn args(&self) -> Vec<RespValue> {
421        <Self as Command>::args(self)
422    }
423
424    fn key(&self) -> Option<String> {
425        Some(self.keys()[0].iter().map(|&b| b as char).collect())
426    }
427}
428
429impl TransactionCommand for crate::commands::SetCommand {
430    fn name(&self) -> &str {
431        self.command_name()
432    }
433
434    fn args(&self) -> Vec<RespValue> {
435        <Self as Command>::args(self)
436    }
437
438    fn key(&self) -> Option<String> {
439        Some(self.keys()[0].iter().map(|&b| b as char).collect())
440    }
441}
442
443impl TransactionCommand for crate::commands::DelCommand {
444    fn name(&self) -> &str {
445        self.command_name()
446    }
447
448    fn args(&self) -> Vec<RespValue> {
449        <Self as Command>::args(self)
450    }
451
452    fn key(&self) -> Option<String> {
453        if let Some(first_key) = self.keys().first() {
454            Some(first_key.iter().map(|&b| b as char).collect())
455        } else {
456            None
457        }
458    }
459}
460
461impl TransactionCommand for crate::commands::IncrCommand {
462    fn name(&self) -> &str {
463        self.command_name()
464    }
465
466    fn args(&self) -> Vec<RespValue> {
467        <Self as Command>::args(self)
468    }
469
470    fn key(&self) -> Option<String> {
471        Some(self.keys()[0].iter().map(|&b| b as char).collect())
472    }
473}
474
475impl TransactionCommand for crate::commands::DecrCommand {
476    fn name(&self) -> &str {
477        self.command_name()
478    }
479
480    fn args(&self) -> Vec<RespValue> {
481        <Self as Command>::args(self)
482    }
483
484    fn key(&self) -> Option<String> {
485        Some(self.keys()[0].iter().map(|&b| b as char).collect())
486    }
487}
488
489impl TransactionCommand for crate::commands::IncrByCommand {
490    fn name(&self) -> &str {
491        self.command_name()
492    }
493
494    fn args(&self) -> Vec<RespValue> {
495        <Self as Command>::args(self)
496    }
497
498    fn key(&self) -> Option<String> {
499        Some(self.keys()[0].iter().map(|&b| b as char).collect())
500    }
501}
502
503impl TransactionCommand for crate::commands::DecrByCommand {
504    fn name(&self) -> &str {
505        self.command_name()
506    }
507
508    fn args(&self) -> Vec<RespValue> {
509        <Self as Command>::args(self)
510    }
511
512    fn key(&self) -> Option<String> {
513        Some(self.keys()[0].iter().map(|&b| b as char).collect())
514    }
515}
516
517impl TransactionCommand for crate::commands::ExistsCommand {
518    fn name(&self) -> &str {
519        self.command_name()
520    }
521
522    fn args(&self) -> Vec<RespValue> {
523        <Self as Command>::args(self)
524    }
525
526    fn key(&self) -> Option<String> {
527        if let Some(first_key) = self.keys().first() {
528            Some(first_key.iter().map(|&b| b as char).collect())
529        } else {
530            None
531        }
532    }
533}
534
535impl TransactionCommand for crate::commands::ExpireCommand {
536    fn name(&self) -> &str {
537        self.command_name()
538    }
539
540    fn args(&self) -> Vec<RespValue> {
541        <Self as Command>::args(self)
542    }
543
544    fn key(&self) -> Option<String> {
545        Some(self.keys()[0].iter().map(|&b| b as char).collect())
546    }
547}
548
549impl TransactionCommand for crate::commands::TtlCommand {
550    fn name(&self) -> &str {
551        self.command_name()
552    }
553
554    fn args(&self) -> Vec<RespValue> {
555        <Self as Command>::args(self)
556    }
557
558    fn key(&self) -> Option<String> {
559        Some(self.keys()[0].iter().map(|&b| b as char).collect())
560    }
561}
562
563impl TransactionCommand for crate::commands::HGetCommand {
564    fn name(&self) -> &str {
565        self.command_name()
566    }
567
568    fn args(&self) -> Vec<RespValue> {
569        <Self as Command>::args(self)
570    }
571
572    fn key(&self) -> Option<String> {
573        Some(self.keys()[0].iter().map(|&b| b as char).collect())
574    }
575}
576
577impl TransactionCommand for crate::commands::HSetCommand {
578    fn name(&self) -> &str {
579        self.command_name()
580    }
581
582    fn args(&self) -> Vec<RespValue> {
583        <Self as Command>::args(self)
584    }
585
586    fn key(&self) -> Option<String> {
587        Some(self.keys()[0].iter().map(|&b| b as char).collect())
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use std::sync::Arc;
595    use tokio::sync::Mutex;
596
597    struct MockTransactionExecutor {
598        commands: Vec<String>,
599        multi_called: bool,
600        exec_called: bool,
601    }
602
603    impl MockTransactionExecutor {
604        fn new() -> Self {
605            Self {
606                commands: Vec::new(),
607                multi_called: false,
608                exec_called: false,
609            }
610        }
611    }
612
613    #[async_trait::async_trait]
614    impl TransactionExecutor for MockTransactionExecutor {
615        async fn multi(&mut self) -> RedisResult<()> {
616            self.multi_called = true;
617            Ok(())
618        }
619
620        async fn queue_command(&mut self, command: Box<dyn TransactionCommand>) -> RedisResult<()> {
621            self.commands.push(command.name().to_string());
622            Ok(())
623        }
624
625        async fn exec(&mut self) -> RedisResult<Vec<RespValue>> {
626            self.exec_called = true;
627            let mut results = Vec::new();
628            for _ in 0..self.commands.len() {
629                results.push(RespValue::SimpleString("OK".to_string()));
630            }
631            Ok(results)
632        }
633
634        async fn discard(&mut self) -> RedisResult<()> {
635            self.commands.clear();
636            self.multi_called = false;
637            Ok(())
638        }
639
640        async fn watch(&mut self, _keys: Vec<String>) -> RedisResult<()> {
641            Ok(())
642        }
643
644        async fn unwatch(&mut self) -> RedisResult<()> {
645            Ok(())
646        }
647    }
648
649    #[tokio::test]
650    async fn test_transaction_creation() {
651        let executor = MockTransactionExecutor::new();
652        let transaction = Transaction::new(Arc::new(Mutex::new(executor)));
653
654        assert!(transaction.is_empty());
655        assert_eq!(transaction.len(), 0);
656    }
657
658    #[tokio::test]
659    async fn test_transaction_add_commands() {
660        let executor = MockTransactionExecutor::new();
661        let mut transaction = Transaction::new(Arc::new(Mutex::new(executor)));
662
663        transaction.set("key1", "value1");
664        transaction.get("key1");
665
666        assert_eq!(transaction.len(), 2);
667        assert!(!transaction.is_empty());
668    }
669
670    #[tokio::test]
671    async fn test_transaction_exec() {
672        let executor = MockTransactionExecutor::new();
673        let mut transaction = Transaction::new(Arc::new(Mutex::new(executor)));
674
675        transaction.set("key1", "value1");
676        transaction.get("key1");
677
678        let results = transaction.exec().await.unwrap();
679        assert_eq!(results.len(), 2);
680        assert!(transaction.is_empty()); }
682
683    #[tokio::test]
684    async fn test_transaction_discard() {
685        let executor = MockTransactionExecutor::new();
686        let mut transaction = Transaction::new(Arc::new(Mutex::new(executor)));
687
688        transaction.set("key1", "value1");
689        transaction.get("key1");
690        assert_eq!(transaction.len(), 2);
691
692        transaction.discard().await.unwrap();
693        assert!(transaction.is_empty());
694    }
695
696    #[tokio::test]
697    async fn test_transaction_result() {
698        let results = vec![
699            RespValue::SimpleString("OK".to_string()),
700            RespValue::BulkString(b"value1".to_vec().into()),
701            RespValue::Integer(42),
702        ];
703
704        let mut transaction_result = TransactionResult::new(results);
705
706        assert_eq!(transaction_result.len(), 3);
707        assert!(!transaction_result.is_empty());
708
709        let first: String = transaction_result.next().unwrap();
710        assert_eq!(first, "OK");
711
712        let second: String = transaction_result.get(1).unwrap();
713        assert_eq!(second, "value1");
714    }
715}