redis_oxide/
connection.rs1use crate::core::{
7    config::{ConnectionConfig, TopologyMode},
8    error::{RedisError, RedisResult},
9    value::RespValue,
10};
11use crate::protocol::{ProtocolConnection, RespDecoder, RespEncoder};
12use bytes::{Buf, BytesMut};
13use std::io::Cursor;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpStream;
16use tokio::time::timeout;
17use tracing::{debug, info, warn};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum TopologyType {
22    Standalone,
24    Cluster,
26}
27
28pub struct RedisConnection {
30    stream: TcpStream,
31    read_buffer: BytesMut,
32    config: ConnectionConfig,
33}
34
35impl RedisConnection {
36    pub async fn connect(host: &str, port: u16, config: ConnectionConfig) -> RedisResult<Self> {
38        let addr = format!("{}:{}", host, port);
39        debug!("Connecting to Redis at {}", addr);
40
41        let stream = timeout(config.connect_timeout, TcpStream::connect(&addr))
42            .await
43            .map_err(|_| RedisError::Timeout)?
44            .map_err(|e| RedisError::Connection(format!("Failed to connect to {}: {}", addr, e)))?;
45
46        if let Some(keepalive_duration) = config.tcp_keepalive {
48            let socket = socket2::Socket::from(stream.into_std()?);
49            let keepalive = socket2::TcpKeepalive::new().with_time(keepalive_duration);
50            socket.set_tcp_keepalive(&keepalive).map_err(|e| {
51                RedisError::Connection(format!("Failed to set TCP keepalive: {}", e))
52            })?;
53            let stream = TcpStream::from_std(socket.into())?;
54
55            let mut conn = Self {
56                stream,
57                read_buffer: BytesMut::with_capacity(8192),
58                config: config.clone(),
59            };
60
61            if let Some(ref password) = config.password {
63                conn.authenticate(password).await?;
64            }
65
66            Ok(conn)
67        } else {
68            let mut conn = Self {
69                stream,
70                read_buffer: BytesMut::with_capacity(8192),
71                config: config.clone(),
72            };
73
74            if let Some(ref password) = config.password {
76                conn.authenticate(password).await?;
77            }
78
79            Ok(conn)
80        }
81    }
82
83    async fn authenticate(&mut self, password: &str) -> RedisResult<()> {
85        debug!("Authenticating with Redis server");
86        let response = self
87            .execute_command("AUTH", &[RespValue::from(password)])
88            .await?;
89
90        match response {
91            RespValue::SimpleString(ref s) if s == "OK" => Ok(()),
92            RespValue::Error(e) => Err(RedisError::Auth(e)),
93            _ => Err(RedisError::Auth(
94                "Unexpected authentication response".to_string(),
95            )),
96        }
97    }
98
99    pub async fn send_command(&mut self, command: &RespValue) -> RedisResult<()> {
101        let mut buffer = BytesMut::new();
102        RespEncoder::encode(command, &mut buffer)?;
103        self.stream.write_all(&buffer).await?;
104        Ok(())
105    }
106
107    pub async fn execute_command(
109        &mut self,
110        command: &str,
111        args: &[RespValue],
112    ) -> RedisResult<RespValue> {
113        let encoded = RespEncoder::encode_command(command, args)?;
115
116        timeout(
118            self.config.operation_timeout,
119            self.stream.write_all(&encoded),
120        )
121        .await
122        .map_err(|_| RedisError::Timeout)?
123        .map_err(RedisError::Io)?;
124
125        let response = timeout(self.config.operation_timeout, self.read_response())
127            .await
128            .map_err(|_| RedisError::Timeout)??;
129
130        if let RespValue::Error(ref msg) = response {
132            if let Some(redirect_error) = RedisError::parse_redirect(msg) {
133                return Err(redirect_error);
134            }
135            return Err(RedisError::Server(msg.clone()));
136        }
137
138        Ok(response)
139    }
140
141    pub async fn read_response(&mut self) -> RedisResult<RespValue> {
143        loop {
144            let mut cursor = Cursor::new(&self.read_buffer[..]);
146            if let Some(value) = RespDecoder::decode(&mut cursor)? {
147                let pos = cursor.position() as usize;
148                self.read_buffer.advance(pos);
149                return Ok(value);
150            }
151
152            let n = self.stream.read_buf(&mut self.read_buffer).await?;
154            if n == 0 {
155                return Err(RedisError::Connection(
156                    "Connection closed by server".to_string(),
157                ));
158            }
159        }
160    }
161
162    pub async fn detect_topology(&mut self) -> RedisResult<TopologyType> {
164        info!("Detecting Redis topology");
165
166        match self
168            .execute_command("CLUSTER", &[RespValue::from("INFO")])
169            .await
170        {
171            Ok(RespValue::BulkString(data)) => {
172                let info_str = String::from_utf8(data.to_vec())
173                    .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8: {}", e)))?;
174
175                if info_str.contains("cluster_enabled:1") || info_str.contains("cluster_state:ok") {
177                    info!("Detected Redis Cluster");
178                    return Ok(TopologyType::Cluster);
179                }
180            }
181            Ok(RespValue::SimpleString(info_str)) => {
182                if info_str.contains("cluster_enabled:1") || info_str.contains("cluster_state:ok") {
184                    info!("Detected Redis Cluster");
185                    return Ok(TopologyType::Cluster);
186                }
187            }
188            Ok(RespValue::Error(ref e))
189                if e.contains("command not supported")
190                    || e.contains("unknown command")
191                    || e.contains("disabled") =>
192            {
193                info!("Detected Standalone Redis (CLUSTER command not available)");
195                return Ok(TopologyType::Standalone);
196            }
197            Err(RedisError::Server(ref e))
198                if e.contains("command not supported")
199                    || e.contains("unknown command")
200                    || e.contains("disabled") =>
201            {
202                info!("Detected Standalone Redis (CLUSTER command not available)");
203                return Ok(TopologyType::Standalone);
204            }
205            Err(e) => {
206                warn!("Error detecting topology: {:?}, assuming standalone", e);
207                return Ok(TopologyType::Standalone);
208            }
209            _ => {}
210        }
211
212        info!("Detected Standalone Redis");
213        Ok(TopologyType::Standalone)
214    }
215
216    pub async fn select_database(&mut self, db: u8) -> RedisResult<()> {
218        let response = self
219            .execute_command("SELECT", &[RespValue::from(db as i64)])
220            .await?;
221
222        match response {
223            RespValue::SimpleString(ref s) if s == "OK" => Ok(()),
224            RespValue::Error(e) => Err(RedisError::Server(e)),
225            _ => Err(RedisError::UnexpectedResponse(format!("{:?}", response))),
226        }
227    }
228}
229
230#[async_trait::async_trait]
231impl ProtocolConnection for RedisConnection {
232    async fn send_command(&mut self, command: &RespValue) -> RedisResult<()> {
233        self.send_command(command).await
234    }
235
236    async fn read_response(&mut self) -> RedisResult<RespValue> {
237        self.read_response().await
238    }
239}
240
241pub struct ConnectionManager {
243    config: ConnectionConfig,
244    topology: Option<TopologyType>,
245}
246
247impl ConnectionManager {
248    pub fn new(config: ConnectionConfig) -> Self {
250        Self {
251            config,
252            topology: None,
253        }
254    }
255
256    pub async fn get_topology(&mut self) -> RedisResult<TopologyType> {
258        if let Some(topology) = self.topology {
259            return Ok(topology);
260        }
261
262        match self.config.topology_mode {
264            TopologyMode::Standalone => {
265                self.topology = Some(TopologyType::Standalone);
266                Ok(TopologyType::Standalone)
267            }
268            TopologyMode::Cluster => {
269                self.topology = Some(TopologyType::Cluster);
270                Ok(TopologyType::Cluster)
271            }
272            TopologyMode::Auto => {
273                let endpoints = self.config.parse_endpoints();
275                if endpoints.is_empty() {
276                    return Err(RedisError::Config("No endpoints specified".to_string()));
277                }
278
279                let (host, port) = &endpoints[0];
280                let mut conn = RedisConnection::connect(host, *port, self.config.clone()).await?;
281                let topology = conn.detect_topology().await?;
282                self.topology = Some(topology);
283                Ok(topology)
284            }
285        }
286    }
287
288    pub async fn create_connection(&self, host: &str, port: u16) -> RedisResult<RedisConnection> {
290        RedisConnection::connect(host, port, self.config.clone()).await
291    }
292
293    pub fn config(&self) -> &ConnectionConfig {
295        &self.config
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_connection_manager_creation() {
305        let config = ConnectionConfig::new("redis://localhost:6379");
306        let manager = ConnectionManager::new(config);
307        assert!(manager.topology.is_none());
308    }
309
310    #[test]
311    fn test_forced_topology() {
312        let config = ConnectionConfig::new("redis://localhost:6379")
313            .with_topology_mode(TopologyMode::Standalone);
314        let manager = ConnectionManager::new(config);
315
316        assert_eq!(manager.config.topology_mode, TopologyMode::Standalone);
318    }
319}