1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
use crate::commands::auth::AuthCommand;
use crate::commands::builder::{CommandBuilder, ToStringOption};
use crate::commands::hello::HelloCommand;
use crate::commands::ping::PingCommand;
use crate::commands::Command;
use crate::network::buffer::Network;
use crate::network::client::{Client, CommandErrors};
use crate::network::handler::ConnectionError::{TcpConnectionFailed, TcpSocketError};
use crate::network::protocol::{Protocol, Resp2, Resp3};
use crate::network::response::MemoryParameters;
use alloc::string::{String, ToString};
use core::cell::RefCell;
use embedded_nal::{SocketAddr, TcpClientStack};
use embedded_time::duration::Extensions;
use embedded_time::duration::Microseconds;
use embedded_time::Clock;

/// Error handling for connection management
#[derive(Debug, Eq, PartialEq)]
pub enum ConnectionError {
    /// Unable to get a socket from network layer
    TcpSocketError,

    /// TCP Connect failed
    TcpConnectionFailed,

    /// Authentication failed with the given sub error
    AuthenticationError(CommandErrors),

    /// Protocol switch (switch to RESP3) failed with the given sub error
    ProtocolSwitchError(CommandErrors),
}

/// Authentication credentials
#[derive(Clone)]
pub struct Credentials {
    pub(crate) username: Option<String>,
    pub(crate) password: String,
}

impl Credentials {
    /// Uses ACL based authentication
    /// Required Redis version >= 6 + ACL enabled
    pub fn acl(username: &str, password: &str) -> Self {
        Credentials {
            username: Some(username.to_string()),
            password: password.to_string(),
        }
    }

    /// Uses password-only authentication.
    /// This form just authenticates against the password set with requirepass (Redis server conf)
    pub fn password_only(password: &str) -> Self {
        Self {
            username: None,
            password: password.to_string(),
        }
    }
}

/// Connection handler for Redis client
///
/// While the Client is not Send, the connection handler is.
/// The handler is designed with the approach that the creation of new clients is cheap.
/// Thus, the use of short-lived clients in concurrent applications is not a problem.
pub struct ConnectionHandler<N: TcpClientStack, P: Protocol>
where
    HelloCommand: Command<<P as Protocol>::FrameType>,
{
    /// Network details of Redis server
    remote: SocketAddr,

    /// Authentication credentials. None in case of no authentication.
    authentication: Option<Credentials>,

    /// Cached socket
    socket: Option<N::TcpSocket>,

    /// Previous authentication try failed, so socket gets closed on next connect()
    auth_failed: bool,

    /// Optional timeout
    /// Max. duration waiting for Redis responses
    timeout: Microseconds,

    /// Parameters for memory allocation
    memory: MemoryParameters,

    /// Redis protocol
    /// RESP3 requires Redis version >= 6.0
    protocol: P,

    /// Use PING command for testing connection
    use_ping: bool,

    /// Response to HELLO command, only used for RESP3
    pub(crate) hello_response: Option<<HelloCommand as Command<<P as Protocol>::FrameType>>::Response>,
}

impl<N: TcpClientStack> ConnectionHandler<N, Resp2> {
    /// Creates a new connection handler using RESP2 protocol
    pub fn resp2(remote: SocketAddr) -> ConnectionHandler<N, Resp2> {
        ConnectionHandler::new(remote, Resp2 {})
    }
}

impl<N: TcpClientStack> ConnectionHandler<N, Resp3> {
    /// Creates a new connection handler using RESP3 protocol
    pub fn resp3(remote: SocketAddr) -> ConnectionHandler<N, Resp3> {
        ConnectionHandler::new(remote, Resp3 {})
    }
}

impl<N: TcpClientStack, P: Protocol> ConnectionHandler<N, P>
where
    AuthCommand: Command<<P as Protocol>::FrameType>,
    HelloCommand: Command<<P as Protocol>::FrameType>,
    PingCommand: Command<<P as Protocol>::FrameType>,
    <P as Protocol>::FrameType: ToStringOption,
    <P as Protocol>::FrameType: From<CommandBuilder>,
{
    fn new(remote: SocketAddr, protocol: P) -> Self {
        ConnectionHandler {
            remote,
            authentication: None,
            socket: None,
            auth_failed: false,
            timeout: 0.microseconds(),
            memory: MemoryParameters::default(),
            protocol,
            use_ping: false,
            hello_response: None,
        }
    }

    /// Returns a Redis client. Caches the connection for future reuse.
    /// The client has the same lifetime as the network reference.
    ///
    /// As the connection is cached, later calls are cheap.
    /// So a new client may be created when switching threads, RISC tasks, etc.
    ///
    /// *Authentication*
    /// Authentication is done automatically when creating a new connection. So the caller can
    /// expect a already authenticated and read2use client
    ///
    /// # Arguments
    ///
    /// * `network`: Mutable borrow of embedded-nal network stack
    /// * `clock`: Borrow of embedded-time clock
    ///
    /// returns: Result<Client<N, C, P>, ConnectionError>
    pub fn connect<'a, C: Clock>(
        &'a mut self,
        network: &'a mut N,
        clock: Option<&'a C>,
    ) -> Result<Client<'a, N, C, P>, ConnectionError> {
        // Previous socket is maybe faulty, so we are closing it here
        if self.auth_failed {
            self.disconnect(network);
        }

        // Check if cached socket is still connected
        self.test_socket(network, clock);

        // Reuse existing connection
        if self.socket.is_some() {
            return Ok(self.create_client(network, clock));
        }

        self.new_client(network, clock)
    }

    /// Creates and authenticates a new client
    fn new_client<'a, C: Clock>(
        &'a mut self,
        network: &'a mut N,
        clock: Option<&'a C>,
    ) -> Result<Client<'a, N, C, P>, ConnectionError> {
        self.connect_socket(network)?;
        let credentials = self.authentication.clone();
        let client = self.create_client(network, clock);

        match client.init(credentials) {
            Ok(response) => {
                self.hello_response = response;
                Ok(self.create_client(network, clock))
            }
            Err(error) => {
                self.auth_failed = true;
                Err(error)
            }
        }
    }

    /// Tests if the cached socket is still connected, if not it's closed
    fn test_socket<'a, C: Clock>(&'a mut self, network: &'a mut N, clock: Option<&'a C>) {
        if self.socket.is_none() {
            return;
        }

        if !network.is_connected(self.socket.as_ref().unwrap()).unwrap_or(false) {
            return self.disconnect(network);
        }

        if self.use_ping && self.ping(network, clock).is_err() {
            self.disconnect(network);
        }
    }

    /// Sends ping command for testing the socket
    fn ping<'a, C: Clock>(
        &'a mut self,
        network: &'a mut N,
        clock: Option<&'a C>,
    ) -> Result<(), CommandErrors> {
        self.create_client(network, clock).ping()?.wait()?;
        Ok(())
    }

    /// Disconnects the connection
    pub fn disconnect(&mut self, network: &mut N) {
        if self.socket.is_none() {
            return;
        }

        let _ = network.close(self.socket.take().unwrap());
        self.auth_failed = false;
    }

    /// Creates a new TCP connection
    fn connect_socket(&mut self, network: &mut N) -> Result<(), ConnectionError> {
        let socket_result = network.socket();
        if socket_result.is_err() {
            return Err(TcpSocketError);
        }

        let mut socket = socket_result.unwrap();
        if network.connect(&mut socket, self.remote).is_err() {
            let _ = network.close(socket);
            return Err(TcpConnectionFailed);
        };

        self.socket = Some(socket);
        Ok(())
    }

    /// Creates a new client instance
    fn create_client<'a, C: Clock>(
        &'a mut self,
        stack: &'a mut N,
        clock: Option<&'a C>,
    ) -> Client<'a, N, C, P> {
        Client {
            network: Network::new(
                RefCell::new(stack),
                RefCell::new(self.socket.as_mut().unwrap()),
                self.protocol.clone(),
                self.memory.clone(),
            ),
            timeout_duration: self.timeout,
            clock,
            hello_response: self.hello_response.as_ref(),
        }
    }
}

impl<N: TcpClientStack, P: Protocol> ConnectionHandler<N, P>
where
    HelloCommand: Command<<P as Protocol>::FrameType>,
{
    /// Sets the max. duration waiting for Redis responses
    pub fn timeout(&mut self, timeout: Microseconds) -> &mut Self {
        self.timeout = timeout;
        self
    }

    /// Sets the authentication credentials
    pub fn auth(&mut self, credentials: Credentials) -> &mut Self {
        self.authentication = Some(credentials);
        self
    }

    /// Using PING command for testing connections
    pub fn use_ping(&mut self) -> &mut Self {
        self.use_ping = true;
        self
    }

    /// Sets memory allocation parameters
    pub fn memory(&mut self, parameters: MemoryParameters) -> &mut Self {
        self.memory = parameters;
        self
    }
}