embedded_redis/network/
client.rs1use crate::commands::auth::AuthCommand;
2use crate::commands::builder::CommandBuilder;
3use crate::commands::hello::{HelloCommand, HelloResponse};
4use crate::commands::Command;
5use crate::network::buffer::Network;
6use crate::network::future::Future;
7use crate::network::handler::{ConnectionError, Credentials};
8use crate::network::protocol::{Protocol, Resp3};
9use crate::network::timeout::{Timeout, TimeoutError};
10use crate::subscription::client::{Error, Subscription};
11use crate::subscription::messages::ToPushMessage;
12use alloc::string::String;
13use bytes::Bytes;
14use core::fmt::{Debug, Formatter};
15use embedded_nal::TcpClientStack;
16use embedded_time::duration::Microseconds;
17use embedded_time::Clock;
18
19#[derive(Debug, Eq, PartialEq, Clone)]
21pub enum CommandErrors {
22 Timeout,
25 EncodingCommandFailed,
27 ProtocolViolation,
34 InvalidFuture,
38 TcpError,
40 TimerError,
42 CommandResponseViolation,
44 ErrorResponse(String),
46 MemoryFull,
49}
50
51pub struct Client<'a, N: TcpClientStack, C: Clock, P: Protocol>
55where
56 HelloCommand: Command<<P as Protocol>::FrameType>,
57{
58 pub(crate) network: Network<'a, N, P>,
59 pub(crate) clock: Option<&'a C>,
60
61 pub(crate) timeout_duration: Microseconds,
63
64 pub(crate) hello_response: Option<&'a <HelloCommand as Command<<P as Protocol>::FrameType>>::Response>,
66}
67
68impl<'a, N: TcpClientStack, C: Clock, P: Protocol> Client<'a, N, C, P>
69where
70 AuthCommand: Command<<P as Protocol>::FrameType>,
71 HelloCommand: Command<<P as Protocol>::FrameType>,
72{
73 pub fn send<Cmd>(&'a self, command: Cmd) -> Result<Future<'a, N, C, P, Cmd>, CommandErrors>
75 where
76 Cmd: Command<P::FrameType>,
77 {
78 let id = self.network.send(command.encode())?;
79
80 Ok(Future::new(
81 id,
82 command,
83 self.network.get_protocol(),
84 &self.network,
85 Timeout::new(self.clock, self.timeout_duration)?,
86 ))
87 }
88
89 pub fn subscribe<const L: usize>(
94 self,
95 channels: [Bytes; L],
96 ) -> Result<Subscription<'a, N, C, P, L>, Error>
97 where
98 <P as Protocol>::FrameType: ToPushMessage,
99 <P as Protocol>::FrameType: From<CommandBuilder>,
100 {
101 Subscription::new(self, channels).subscribe()
102 }
103
104 pub(crate) fn auth(&'a self, credentials: Option<Credentials>) -> Result<(), ConnectionError> {
106 if credentials.is_some() {
107 self.send(AuthCommand::from(credentials.as_ref().unwrap()))
108 .map_err(auth_error)?
109 .wait()
110 .map_err(auth_error)?;
111 }
112
113 Ok(())
114 }
115
116 pub(crate) fn init(
118 &'a self,
119 credentials: Option<Credentials>,
120 ) -> Result<Option<<HelloCommand as Command<<P as Protocol>::FrameType>>::Response>, ConnectionError>
121 {
122 self.auth(credentials)?;
123 if self.network.get_protocol().requires_hello() {
124 return Ok(Some(
125 self.send(HelloCommand {}).map_err(hello_error)?.wait().map_err(hello_error)?,
126 ));
127 }
128
129 Ok(None)
130 }
131
132 pub fn close(&self) {
134 if !self.network.remaining_dropped_futures() {
135 return;
136 }
137
138 let timer = match Timeout::new(self.clock, self.timeout_duration) {
139 Ok(timer) => timer,
140 Err(_) => {
141 return;
142 }
143 };
144
145 while self.network.remaining_dropped_futures() && !timer.expired().unwrap_or(true) {
146 self.network.handle_dropped_futures();
147 }
148 }
149}
150
151impl<N: TcpClientStack, C: Clock> Client<'_, N, C, Resp3> {
152 pub fn get_hello_response(&self) -> &HelloResponse {
155 self.hello_response.as_ref().unwrap()
156 }
157}
158
159impl From<TimeoutError> for CommandErrors {
160 fn from(_: TimeoutError) -> Self {
161 CommandErrors::TimerError
162 }
163}
164
165fn auth_error(error: CommandErrors) -> ConnectionError {
166 ConnectionError::AuthenticationError(error)
167}
168
169#[allow(dead_code)]
170fn hello_error(error: CommandErrors) -> ConnectionError {
171 ConnectionError::ProtocolSwitchError(error)
172}
173
174impl<N: TcpClientStack, C: Clock, P: Protocol> Debug for Client<'_, N, C, P>
175where
176 HelloCommand: Command<<P as Protocol>::FrameType>,
177{
178 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
179 f.debug_struct("Client")
180 .field("network", &self.network)
181 .field("timeout_duration", &self.timeout_duration)
182 .finish()
183 }
184}