1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::iter::zip;

use crate::{
    resp::{cmd, Array, BulkString, Command, FromValue, ResultValueExt, Value},
    BitmapCommands, Error, GenericCommands, GeoCommands, HashCommands, HyperLogLogCommands,
    InnerClient, ListCommands, PipelinePreparedCommand, PreparedCommand, Result, ScriptingCommands,
    ServerCommands, SetCommands, SortedSetCommands, StreamCommands, StringCommands,
};

/// Represents an on-going [`transaction`](https://redis.io/docs/manual/transactions/) on a specific client instance.
pub struct Transaction {
    client: InnerClient,
    commands: Vec<Command>,
    forget_flags: Vec<bool>,
}

impl Transaction {
    pub(crate) fn new(client: InnerClient) -> Transaction {
        let mut transaction = Transaction {
            client,
            commands: Vec::new(),
            forget_flags: Vec::new(),
        };

        transaction.queue(cmd("MULTI"));
        transaction
    }

    /// Queue a command into the transaction.
    pub fn queue(&mut self, command: Command) {
        self.commands.push(command);
        self.forget_flags.push(false);
    }

    /// Queue a command into the transaction and forget its response.
    pub fn forget(&mut self, command: Command) {
        self.commands.push(command);
        self.forget_flags.push(true);
    }

    pub async fn execute<T: FromValue>(mut self) -> Result<T> {
        self.queue(cmd("EXEC"));

        let num_commands = self.commands.len();

        let values: Vec<Value> = self.client.send_batch(self.commands).await?.into()?;
        let mut iter = values.into_iter();

        // MULTI + QUEUED commands
        for _ in 0..num_commands - 1 {
            if let Some(Value::Error(e)) = iter.next() {
                return Err(Error::Redis(e));
            }
        }

        // EXEC
        if let Some(result) = iter.next() {
            match result {
                Value::Array(Array::Vec(results)) => {
                    let mut filtered_results = zip(results, self.forget_flags.iter().skip(1))
                        .filter_map(
                            |(value, forget_flag)| if *forget_flag { None } else { Some(value) },
                        )
                        .collect::<Vec<_>>();

                    if filtered_results.len() == 1 {
                        let value = filtered_results.pop().unwrap();
                        Ok(value).into_result()?.into()
                    } else {
                        Value::Array(Array::Vec(filtered_results)).into()
                    }
                }
                Value::Array(Array::Nil) | Value::BulkString(BulkString::Nil) => Err(Error::Aborted),
                _ => Err(Error::Client("Unexpected transaction reply".to_owned())),
            }
        } else {
            Err(Error::Client(
                "Unexpected result for transaction".to_owned(),
            ))
        }
    }
}

impl<'a, R> PipelinePreparedCommand<'a, R> for PreparedCommand<'a, Transaction, R>
where
    R: FromValue + Send + 'a,
{
    /// Queue a command into the transaction.
    fn queue(self) {
        self.executor.queue(self.command)
    }

    /// Queue a command into the transaction and forget its response.
    fn forget(self) {
        self.executor.forget(self.command)
    }
}

impl BitmapCommands for Transaction {}
impl GenericCommands for Transaction {}
impl GeoCommands for Transaction {}
impl HashCommands for Transaction {}
impl HyperLogLogCommands for Transaction {}
impl ListCommands for Transaction {}
impl SetCommands for Transaction {}
impl ScriptingCommands for Transaction {}
impl ServerCommands for Transaction {}
impl SortedSetCommands for Transaction {}
impl StreamCommands for Transaction {}
impl StringCommands for Transaction {}