redis_oxide/
connection.rs1use crate::core::{
7 config::{ConnectionConfig, TopologyMode},
8 error::{RedisError, RedisResult},
9 value::RespValue,
10};
11use crate::protocol::{ProtocolConnection, RespDecoder, RespEncoder};
12use bytes::{Buf, BytesMut};
13use std::io::Cursor;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpStream;
16use tokio::time::timeout;
17use tracing::{debug, info, warn};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum TopologyType {
22 Standalone,
24 Cluster,
26}
27
28pub struct RedisConnection {
30 stream: TcpStream,
31 read_buffer: BytesMut,
32 config: ConnectionConfig,
33}
34
35impl RedisConnection {
36 pub async fn connect(host: &str, port: u16, config: ConnectionConfig) -> RedisResult<Self> {
38 let addr = format!("{}:{}", host, port);
39 debug!("Connecting to Redis at {}", addr);
40
41 let stream = timeout(config.connect_timeout, TcpStream::connect(&addr))
42 .await
43 .map_err(|_| RedisError::Timeout)?
44 .map_err(|e| RedisError::Connection(format!("Failed to connect to {}: {}", addr, e)))?;
45
46 if let Some(keepalive_duration) = config.tcp_keepalive {
48 let socket = socket2::Socket::from(stream.into_std()?);
49 let keepalive = socket2::TcpKeepalive::new().with_time(keepalive_duration);
50 socket.set_tcp_keepalive(&keepalive).map_err(|e| {
51 RedisError::Connection(format!("Failed to set TCP keepalive: {}", e))
52 })?;
53 let stream = TcpStream::from_std(socket.into())?;
54
55 let mut conn = Self {
56 stream,
57 read_buffer: BytesMut::with_capacity(8192),
58 config: config.clone(),
59 };
60
61 if let Some(ref password) = config.password {
63 conn.authenticate(password).await?;
64 }
65
66 Ok(conn)
67 } else {
68 let mut conn = Self {
69 stream,
70 read_buffer: BytesMut::with_capacity(8192),
71 config: config.clone(),
72 };
73
74 if let Some(ref password) = config.password {
76 conn.authenticate(password).await?;
77 }
78
79 Ok(conn)
80 }
81 }
82
83 async fn authenticate(&mut self, password: &str) -> RedisResult<()> {
85 debug!("Authenticating with Redis server");
86 let response = self
87 .execute_command("AUTH", &[RespValue::from(password)])
88 .await?;
89
90 match response {
91 RespValue::SimpleString(ref s) if s == "OK" => Ok(()),
92 RespValue::Error(e) => Err(RedisError::Auth(e)),
93 _ => Err(RedisError::Auth(
94 "Unexpected authentication response".to_string(),
95 )),
96 }
97 }
98
99 pub async fn send_command(&mut self, command: &RespValue) -> RedisResult<()> {
101 let mut buffer = BytesMut::new();
102 RespEncoder::encode(command, &mut buffer)?;
103 self.stream.write_all(&buffer).await?;
104 Ok(())
105 }
106
107 pub async fn execute_command(
109 &mut self,
110 command: &str,
111 args: &[RespValue],
112 ) -> RedisResult<RespValue> {
113 let encoded = RespEncoder::encode_command(command, args)?;
115
116 timeout(
118 self.config.operation_timeout,
119 self.stream.write_all(&encoded),
120 )
121 .await
122 .map_err(|_| RedisError::Timeout)?
123 .map_err(RedisError::Io)?;
124
125 let response = timeout(self.config.operation_timeout, self.read_response())
127 .await
128 .map_err(|_| RedisError::Timeout)??;
129
130 if let RespValue::Error(ref msg) = response {
132 if let Some(redirect_error) = RedisError::parse_redirect(msg) {
133 return Err(redirect_error);
134 }
135 return Err(RedisError::Server(msg.clone()));
136 }
137
138 Ok(response)
139 }
140
141 pub async fn read_response(&mut self) -> RedisResult<RespValue> {
143 loop {
144 let mut cursor = Cursor::new(&self.read_buffer[..]);
146 if let Some(value) = RespDecoder::decode(&mut cursor)? {
147 let pos = cursor.position() as usize;
148 self.read_buffer.advance(pos);
149 return Ok(value);
150 }
151
152 let n = self.stream.read_buf(&mut self.read_buffer).await?;
154 if n == 0 {
155 return Err(RedisError::Connection(
156 "Connection closed by server".to_string(),
157 ));
158 }
159 }
160 }
161
162 pub async fn detect_topology(&mut self) -> RedisResult<TopologyType> {
164 info!("Detecting Redis topology");
165
166 match self
168 .execute_command("CLUSTER", &[RespValue::from("INFO")])
169 .await
170 {
171 Ok(RespValue::BulkString(data)) => {
172 let info_str = String::from_utf8(data.to_vec())
173 .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8: {}", e)))?;
174
175 if info_str.contains("cluster_enabled:1") || info_str.contains("cluster_state:ok") {
177 info!("Detected Redis Cluster");
178 return Ok(TopologyType::Cluster);
179 }
180 }
181 Ok(RespValue::SimpleString(info_str)) => {
182 if info_str.contains("cluster_enabled:1") || info_str.contains("cluster_state:ok") {
184 info!("Detected Redis Cluster");
185 return Ok(TopologyType::Cluster);
186 }
187 }
188 Ok(RespValue::Error(ref e))
189 if e.contains("command not supported")
190 || e.contains("unknown command")
191 || e.contains("disabled") =>
192 {
193 info!("Detected Standalone Redis (CLUSTER command not available)");
195 return Ok(TopologyType::Standalone);
196 }
197 Err(RedisError::Server(ref e))
198 if e.contains("command not supported")
199 || e.contains("unknown command")
200 || e.contains("disabled") =>
201 {
202 info!("Detected Standalone Redis (CLUSTER command not available)");
203 return Ok(TopologyType::Standalone);
204 }
205 Err(e) => {
206 warn!("Error detecting topology: {:?}, assuming standalone", e);
207 return Ok(TopologyType::Standalone);
208 }
209 _ => {}
210 }
211
212 info!("Detected Standalone Redis");
213 Ok(TopologyType::Standalone)
214 }
215
216 pub async fn select_database(&mut self, db: u8) -> RedisResult<()> {
218 let response = self
219 .execute_command("SELECT", &[RespValue::from(db as i64)])
220 .await?;
221
222 match response {
223 RespValue::SimpleString(ref s) if s == "OK" => Ok(()),
224 RespValue::Error(e) => Err(RedisError::Server(e)),
225 _ => Err(RedisError::UnexpectedResponse(format!("{:?}", response))),
226 }
227 }
228}
229
230#[async_trait::async_trait]
231impl ProtocolConnection for RedisConnection {
232 async fn send_command(&mut self, command: &RespValue) -> RedisResult<()> {
233 self.send_command(command).await
234 }
235
236 async fn read_response(&mut self) -> RedisResult<RespValue> {
237 self.read_response().await
238 }
239}
240
241pub struct ConnectionManager {
243 config: ConnectionConfig,
244 topology: Option<TopologyType>,
245}
246
247impl ConnectionManager {
248 pub fn new(config: ConnectionConfig) -> Self {
250 Self {
251 config,
252 topology: None,
253 }
254 }
255
256 pub async fn get_topology(&mut self) -> RedisResult<TopologyType> {
258 if let Some(topology) = self.topology {
259 return Ok(topology);
260 }
261
262 match self.config.topology_mode {
264 TopologyMode::Standalone => {
265 self.topology = Some(TopologyType::Standalone);
266 Ok(TopologyType::Standalone)
267 }
268 TopologyMode::Cluster => {
269 self.topology = Some(TopologyType::Cluster);
270 Ok(TopologyType::Cluster)
271 }
272 TopologyMode::Auto => {
273 let endpoints = self.config.parse_endpoints();
275 if endpoints.is_empty() {
276 return Err(RedisError::Config("No endpoints specified".to_string()));
277 }
278
279 let (host, port) = &endpoints[0];
280 let mut conn = RedisConnection::connect(host, *port, self.config.clone()).await?;
281 let topology = conn.detect_topology().await?;
282 self.topology = Some(topology);
283 Ok(topology)
284 }
285 }
286 }
287
288 pub async fn create_connection(&self, host: &str, port: u16) -> RedisResult<RedisConnection> {
290 RedisConnection::connect(host, port, self.config.clone()).await
291 }
292
293 pub fn config(&self) -> &ConnectionConfig {
295 &self.config
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_connection_manager_creation() {
305 let config = ConnectionConfig::new("redis://localhost:6379");
306 let manager = ConnectionManager::new(config);
307 assert!(manager.topology.is_none());
308 }
309
310 #[test]
311 fn test_forced_topology() {
312 let config = ConnectionConfig::new("redis://localhost:6379")
313 .with_topology_mode(TopologyMode::Standalone);
314 let manager = ConnectionManager::new(config);
315
316 assert_eq!(manager.config.topology_mode, TopologyMode::Standalone);
318 }
319}