Skip to main content

we_trust_redis/
lib.rs

1#![warn(missing_docs)]
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::sync::Arc;
5use tokio_util::codec::{Decoder, Encoder};
6use uuid::Uuid;
7use yykv_types::{DsError, DsValue, Redundancy};
8
9use std::path::PathBuf;
10
11type Result<T> = std::result::Result<T, DsError>;
12use yykv_wal::{OpType, WalManager};
13
14pub mod adapter;
15pub mod connection;
16pub mod transaction;
17
18pub use adapter::RedisAdapter;
19pub use connection::RedisConnection;
20pub use transaction::RedisTransaction;
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum RedisFrame {
24    SimpleString(String),
25    Error(String),
26    Integer(i64),
27    BulkString(Option<Bytes>),
28    Array(Option<Vec<RedisFrame>>),
29}
30
31use futures::{SinkExt, StreamExt};
32use tokio::net::TcpStream;
33use tokio_util::codec::Framed;
34use tracing::{error, info};
35
36pub async fn handle_connection(stream: TcpStream, service: Arc<RedisService>) -> Result<()> {
37    let mut framed = Framed::new(stream, RedisCodec);
38
39    while let Some(result) = framed.next().await {
40        let frame = match result {
41            Ok(f) => f,
42            Err(e) => {
43                error!("Decode error: {}", e);
44                break;
45            }
46        };
47
48        // Parse command from frame
49        let cmd = match RedisCommand::from_frame(frame) {
50            Ok(c) => c,
51            Err(e) => {
52                let err_frame = RedisFrame::Error(format!("ERR {}", e));
53                let _ = framed.send(err_frame).await;
54                continue;
55            }
56        };
57
58        // Handle command via RedisService (YY ecosystem)
59        let response = service.handle_command(cmd).await?;
60
61        // Send response back
62        framed.send(response).await?;
63    }
64
65    info!("Connection closed");
66    Ok(())
67}
68
69pub struct RedisCodec;
70
71impl Decoder for RedisCodec {
72    type Item = RedisFrame;
73    type Error = DsError;
74
75    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
76        if src.is_empty() {
77            return Ok(None);
78        }
79
80        match src[0] {
81            b'+' => self.decode_simple_string(src),
82            b'-' => self.decode_error(src),
83            b':' => self.decode_integer(src),
84            b'$' => self.decode_bulk_string(src),
85            b'*' => self.decode_array(src),
86            _ => Err(DsError::protocol(format!(
87                "Invalid Redis frame type: {}",
88                src[0] as char
89            ))),
90        }
91    }
92}
93
94impl RedisCodec {
95    fn decode_simple_string(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
96        if let Some(i) = self.find_crlf(src) {
97            let line = src.split_to(i + 2);
98            let s = String::from_utf8(line[1..i].to_vec())?;
99            Ok(Some(RedisFrame::SimpleString(s)))
100        } else {
101            Ok(None)
102        }
103    }
104
105    fn decode_error(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
106        if let Some(i) = self.find_crlf(src) {
107            let line = src.split_to(i + 2);
108            let s = String::from_utf8(line[1..i].to_vec())?;
109            Ok(Some(RedisFrame::Error(s)))
110        } else {
111            Ok(None)
112        }
113    }
114
115    fn decode_integer(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
116        if let Some(i) = self.find_crlf(src) {
117            let line = src.split_to(i + 2);
118            let s = std::str::from_utf8(&line[1..i])?;
119            let n = s.parse::<i64>()?;
120            Ok(Some(RedisFrame::Integer(n)))
121        } else {
122            Ok(None)
123        }
124    }
125
126    fn decode_bulk_string(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
127        if let Some(i) = self.find_crlf(src) {
128            let line = &src[..i];
129            let len = std::str::from_utf8(&line[1..])?.parse::<isize>()?;
130
131            if len == -1 {
132                src.advance(i + 2);
133                return Ok(Some(RedisFrame::BulkString(None)));
134            }
135
136            let bulk_len = len as usize;
137            if src.len() < i + 2 + bulk_len + 2 {
138                return Ok(None);
139            }
140
141            src.advance(i + 2);
142            let data = src.split_to(bulk_len).freeze();
143            src.advance(2); // CRLF
144            Ok(Some(RedisFrame::BulkString(Some(data))))
145        } else {
146            Ok(None)
147        }
148    }
149
150    fn decode_array(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
151        if let Some(i) = self.find_crlf(src) {
152            let line = &src[..i];
153            let len = std::str::from_utf8(&line[1..])?.parse::<isize>()?;
154
155            if len == -1 {
156                src.advance(i + 2);
157                return Ok(Some(RedisFrame::Array(None)));
158            }
159
160            let array_len = len as usize;
161            src.advance(i + 2);
162
163            let mut frames = Vec::with_capacity(array_len);
164            for _ in 0..array_len {
165                match self.decode(src)? {
166                    Some(frame) => frames.push(frame),
167                    None => return Ok(None), // Should technically buffer or error
168                }
169            }
170            Ok(Some(RedisFrame::Array(Some(frames))))
171        } else {
172            Ok(None)
173        }
174    }
175
176    fn find_crlf(&self, src: &[u8]) -> Option<usize> {
177        src.windows(2).position(|w| w == b"\r\n")
178    }
179}
180
181impl Encoder<RedisFrame> for RedisCodec {
182    type Error = DsError;
183
184    fn encode(&mut self, item: RedisFrame, dst: &mut BytesMut) -> Result<()> {
185        match item {
186            RedisFrame::SimpleString(s) => {
187                dst.reserve(1 + s.len() + 2);
188                dst.put_u8(b'+');
189                dst.put_slice(s.as_bytes());
190                dst.put_slice(b"\r\n");
191            }
192            RedisFrame::Error(s) => {
193                dst.reserve(1 + s.len() + 2);
194                dst.put_u8(b'-');
195                dst.put_slice(s.as_bytes());
196                dst.put_slice(b"\r\n");
197            }
198            RedisFrame::Integer(n) => {
199                let s = n.to_string();
200                dst.reserve(1 + s.len() + 2);
201                dst.put_u8(b':');
202                dst.put_slice(s.as_bytes());
203                dst.put_slice(b"\r\n");
204            }
205            RedisFrame::BulkString(opt_data) => match opt_data {
206                Some(data) => {
207                    let len_s = data.len().to_string();
208                    dst.reserve(1 + len_s.len() + 2 + data.len() + 2);
209                    dst.put_u8(b'$');
210                    dst.put_slice(len_s.as_bytes());
211                    dst.put_slice(b"\r\n");
212                    dst.put_slice(&data);
213                    dst.put_slice(b"\r\n");
214                }
215                None => {
216                    dst.put_slice(b"$-1\r\n");
217                }
218            },
219            RedisFrame::Array(opt_frames) => match opt_frames {
220                Some(frames) => {
221                    let len_s = frames.len().to_string();
222                    dst.reserve(1 + len_s.len() + 2);
223                    dst.put_u8(b'*');
224                    dst.put_slice(len_s.as_bytes());
225                    dst.put_slice(b"\r\n");
226                    for frame in frames {
227                        self.encode(frame, dst)?;
228                    }
229                }
230                None => {
231                    dst.put_slice(b"*-1\r\n");
232                }
233            },
234        }
235        Ok(())
236    }
237}
238
239#[derive(Debug, Clone)]
240pub enum RedisCommand {
241    Ping,
242    Get(String),
243    Set(String, Bytes),
244    Del(Vec<String>),
245    Exists(Vec<String>),
246    // Hash commands
247    HSet(String, String, Bytes),
248    HGet(String, String),
249    HDel(String, Vec<String>),
250    // Atomic increments
251    Incr(String),
252    Decr(String),
253    IncrBy(String, i64),
254    // TTL commands
255    Expire(String, u64),
256    Ttl(String),
257    // Pattern matching
258    Keys(String),
259}
260
261impl RedisCommand {
262    pub fn from_frame(frame: RedisFrame) -> Result<Self> {
263        match frame {
264            RedisFrame::Array(Some(frames)) => {
265                if frames.is_empty() {
266                    return Err(DsError::protocol("Empty command array"));
267                }
268
269                let cmd_name = match &frames[0] {
270                    RedisFrame::BulkString(Some(data)) => {
271                        String::from_utf8_lossy(data).to_uppercase()
272                    }
273                    _ => {
274                        return Err(DsError::protocol("Command name must be a bulk string"));
275                    }
276                };
277
278                match cmd_name.as_str() {
279                    "PING" => Ok(RedisCommand::Ping),
280                    "GET" => {
281                        if frames.len() != 2 {
282                            return Err(DsError::protocol("GET requires exactly 1 argument"));
283                        }
284                        let key = match &frames[1] {
285                            RedisFrame::BulkString(Some(data)) => {
286                                String::from_utf8_lossy(data).to_string()
287                            }
288                            _ => {
289                                return Err(DsError::protocol("GET key must be a bulk string"));
290                            }
291                        };
292                        Ok(RedisCommand::Get(key))
293                    }
294                    "SET" => {
295                        if frames.len() != 3 {
296                            return Err(DsError::protocol("SET requires exactly 2 arguments"));
297                        }
298                        let key = match &frames[1] {
299                            RedisFrame::BulkString(Some(data)) => {
300                                String::from_utf8_lossy(data).to_string()
301                            }
302                            _ => {
303                                return Err(DsError::protocol("SET key must be a bulk string"));
304                            }
305                        };
306                        let value = match &frames[2] {
307                            RedisFrame::BulkString(Some(data)) => data.clone(),
308                            _ => {
309                                return Err(DsError::protocol("SET value must be a bulk string"));
310                            }
311                        };
312                        Ok(RedisCommand::Set(key, value))
313                    }
314                    "DEL" => {
315                        let mut keys = Vec::new();
316                        for frame in frames.iter().skip(1) {
317                            match frame {
318                                RedisFrame::BulkString(Some(data)) => {
319                                    keys.push(String::from_utf8_lossy(data).to_string())
320                                }
321                                _ => {
322                                    return Err(DsError::protocol("DEL key must be a bulk string"));
323                                }
324                            }
325                        }
326                        Ok(RedisCommand::Del(keys))
327                    }
328                    "EXISTS" => {
329                        let mut keys = Vec::new();
330                        for frame in frames.iter().skip(1) {
331                            match frame {
332                                RedisFrame::BulkString(Some(data)) => {
333                                    keys.push(String::from_utf8_lossy(data).to_string())
334                                }
335                                _ => {
336                                    return Err(DsError::protocol(
337                                        "EXISTS key must be a bulk string",
338                                    ));
339                                }
340                            }
341                        }
342                        Ok(RedisCommand::Exists(keys))
343                    }
344                    "HSET" => {
345                        if frames.len() != 4 {
346                            return Err(DsError::protocol("HSET requires exactly 3 arguments"));
347                        }
348                        let key = match &frames[1] {
349                            RedisFrame::BulkString(Some(data)) => {
350                                String::from_utf8_lossy(data).to_string()
351                            }
352                            _ => {
353                                return Err(DsError::protocol("HSET key must be a bulk string"));
354                            }
355                        };
356                        let field = match &frames[2] {
357                            RedisFrame::BulkString(Some(data)) => {
358                                String::from_utf8_lossy(data).to_string()
359                            }
360                            _ => {
361                                return Err(DsError::protocol("HSET field must be a bulk string"));
362                            }
363                        };
364                        let value = match &frames[3] {
365                            RedisFrame::BulkString(Some(data)) => data.clone(),
366                            _ => {
367                                return Err(DsError::protocol("HSET value must be a bulk string"));
368                            }
369                        };
370                        Ok(RedisCommand::HSet(key, field, value))
371                    }
372                    "HGET" => {
373                        if frames.len() != 3 {
374                            return Err(DsError::protocol("HGET requires exactly 2 arguments"));
375                        }
376                        let key = match &frames[1] {
377                            RedisFrame::BulkString(Some(data)) => {
378                                String::from_utf8_lossy(data).to_string()
379                            }
380                            _ => {
381                                return Err(DsError::protocol("HGET key must be a bulk string"));
382                            }
383                        };
384                        let field = match &frames[2] {
385                            RedisFrame::BulkString(Some(data)) => {
386                                String::from_utf8_lossy(data).to_string()
387                            }
388                            _ => {
389                                return Err(DsError::protocol("HGET field must be a bulk string"));
390                            }
391                        };
392                        Ok(RedisCommand::HGet(key, field))
393                    }
394                    "HDEL" => {
395                        if frames.len() < 3 {
396                            return Err(DsError::protocol("HDEL requires at least 2 arguments"));
397                        }
398                        let key = match &frames[1] {
399                            RedisFrame::BulkString(Some(data)) => {
400                                String::from_utf8_lossy(data).to_string()
401                            }
402                            _ => {
403                                return Err(DsError::protocol("HDEL key must be a bulk string"));
404                            }
405                        };
406                        let mut fields = Vec::new();
407                        for frame in frames.iter().skip(2) {
408                            match frame {
409                                RedisFrame::BulkString(Some(data)) => {
410                                    fields.push(String::from_utf8_lossy(data).to_string())
411                                }
412                                _ => {
413                                    return Err(DsError::protocol(
414                                        "HDEL field must be a bulk string",
415                                    ));
416                                }
417                            }
418                        }
419                        Ok(RedisCommand::HDel(key, fields))
420                    }
421                    "INCR" => {
422                        if frames.len() != 2 {
423                            return Err(DsError::protocol("INCR requires exactly 1 argument"));
424                        }
425                        let key = match &frames[1] {
426                            RedisFrame::BulkString(Some(data)) => {
427                                String::from_utf8_lossy(data).to_string()
428                            }
429                            _ => {
430                                return Err(DsError::protocol("INCR key must be a bulk string"));
431                            }
432                        };
433                        Ok(RedisCommand::Incr(key))
434                    }
435                    "DECR" => {
436                        if frames.len() != 2 {
437                            return Err(DsError::protocol("DECR requires exactly 1 argument"));
438                        }
439                        let key = match &frames[1] {
440                            RedisFrame::BulkString(Some(data)) => {
441                                String::from_utf8_lossy(data).to_string()
442                            }
443                            _ => {
444                                return Err(DsError::protocol("DECR key must be a bulk string"));
445                            }
446                        };
447                        Ok(RedisCommand::Decr(key))
448                    }
449                    "INCRBY" => {
450                        if frames.len() != 3 {
451                            return Err(DsError::protocol("INCRBY requires exactly 2 arguments"));
452                        }
453                        let key = match &frames[1] {
454                            RedisFrame::BulkString(Some(data)) => {
455                                String::from_utf8_lossy(data).to_string()
456                            }
457                            _ => {
458                                return Err(DsError::protocol("INCRBY key must be a bulk string"));
459                            }
460                        };
461                        let amount = match &frames[2] {
462                            RedisFrame::BulkString(Some(data)) => {
463                                String::from_utf8_lossy(data).parse::<i64>().map_err(|e| {
464                                    DsError::protocol(format!("Invalid INCRBY amount: {}", e))
465                                })?
466                            }
467                            _ => {
468                                return Err(DsError::protocol(
469                                    "INCRBY amount must be a bulk string",
470                                ));
471                            }
472                        };
473                        Ok(RedisCommand::IncrBy(key, amount))
474                    }
475                    "EXPIRE" => {
476                        if frames.len() != 3 {
477                            return Err(DsError::protocol("EXPIRE requires exactly 2 arguments"));
478                        }
479                        let key = match &frames[1] {
480                            RedisFrame::BulkString(Some(data)) => {
481                                String::from_utf8_lossy(data).to_string()
482                            }
483                            _ => {
484                                return Err(DsError::protocol("EXPIRE key must be a bulk string"));
485                            }
486                        };
487                        let seconds = match &frames[2] {
488                            RedisFrame::BulkString(Some(data)) => {
489                                String::from_utf8_lossy(data).parse::<u64>().map_err(|e| {
490                                    DsError::protocol(format!("Invalid EXPIRE seconds: {}", e))
491                                })?
492                            }
493                            _ => {
494                                return Err(DsError::protocol(
495                                    "EXPIRE seconds must be a bulk string",
496                                ));
497                            }
498                        };
499                        Ok(RedisCommand::Expire(key, seconds))
500                    }
501                    "TTL" => {
502                        if frames.len() != 2 {
503                            return Err(DsError::protocol("TTL requires exactly 1 argument"));
504                        }
505                        let key = match &frames[1] {
506                            RedisFrame::BulkString(Some(data)) => {
507                                String::from_utf8_lossy(data).to_string()
508                            }
509                            _ => {
510                                return Err(DsError::protocol("TTL key must be a bulk string"));
511                            }
512                        };
513                        Ok(RedisCommand::Ttl(key))
514                    }
515                    "KEYS" => {
516                        if frames.len() != 2 {
517                            return Err(DsError::protocol("KEYS requires exactly 1 argument"));
518                        }
519                        let pattern = match &frames[1] {
520                            RedisFrame::BulkString(Some(data)) => {
521                                String::from_utf8_lossy(data).to_string()
522                            }
523                            _ => {
524                                return Err(DsError::protocol(
525                                    "KEYS pattern must be a bulk string",
526                                ));
527                            }
528                        };
529                        Ok(RedisCommand::Keys(pattern))
530                    }
531                    _ => Err(DsError::protocol(format!(
532                        "Unsupported command: {}",
533                        cmd_name
534                    ))),
535                }
536            }
537            _ => Err(DsError::protocol("Invalid command frame: must be an array")),
538        }
539    }
540}
541
542pub struct RedisService {
543    wal: Arc<WalManager>,
544    tenant_id: Uuid,
545    // In-memory state for the example to actually work
546    kv: Arc<tokio::sync::RwLock<std::collections::HashMap<String, Bytes>>>,
547    hash: Arc<
548        tokio::sync::RwLock<
549            std::collections::HashMap<String, std::collections::HashMap<String, Bytes>>,
550        >,
551    >,
552    ttl: Arc<tokio::sync::RwLock<std::collections::HashMap<String, u64>>>,
553}
554
555impl RedisService {
556    pub async fn new() -> Result<Self> {
557        let wal = Arc::new(
558            WalManager::new(PathBuf::from("wal_redis"))
559                .await
560                .map_err(|e| DsError::storage(e.to_string()))?,
561        );
562        // let index = Arc::new(SearchIndexManager::new_in_memory()?);
563
564        Ok(Self {
565            wal,
566            tenant_id: Uuid::new_v4(),
567            kv: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
568            hash: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
569            ttl: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
570        })
571    }
572
573    pub async fn handle_command(&self, cmd: RedisCommand) -> Result<RedisFrame> {
574        match cmd {
575            RedisCommand::Ping => Ok(RedisFrame::SimpleString("PONG".to_string())),
576            RedisCommand::Set(key, value) => {
577                // 1. Update in-memory state
578                {
579                    let mut kv = self.kv.write().await;
580                    kv.insert(key.clone(), value.clone());
581                }
582
583                // 2. Log to WAL (Simplified)
584                self.wal
585                    .write(
586                        self.tenant_id,
587                        "redis_kv".to_string(),
588                        key.clone(),
589                        OpType::Insert,
590                        Redundancy::SINGLE,
591                        DsValue::Binary(value),
592                    )
593                    .await?;
594
595                Ok(RedisFrame::SimpleString("OK".to_string()))
596            }
597            RedisCommand::Get(key) => {
598                let kv = self.kv.read().await;
599                match kv.get(&key) {
600                    Some(val) => Ok(RedisFrame::BulkString(Some(val.clone()))),
601                    None => Ok(RedisFrame::BulkString(None)),
602                }
603            }
604            RedisCommand::Del(keys) => {
605                let mut kv = self.kv.write().await;
606                let mut count = 0;
607                for key in keys {
608                    if kv.remove(&key).is_some() {
609                        count += 1;
610                        self.wal
611                            .write(
612                                self.tenant_id,
613                                "redis_kv".to_string(),
614                                key.clone(),
615                                OpType::Delete,
616                                Redundancy::SINGLE,
617                                DsValue::Text(key),
618                            )
619                            .await?;
620                    }
621                }
622                Ok(RedisFrame::Integer(count))
623            }
624            RedisCommand::Exists(keys) => {
625                let kv = self.kv.read().await;
626                let mut count = 0;
627                for key in keys {
628                    if kv.contains_key(&key) {
629                        count += 1;
630                    }
631                }
632                Ok(RedisFrame::Integer(count))
633            }
634            RedisCommand::HSet(key, field, value) => {
635                let mut hash = self.hash.write().await;
636                let entry = hash
637                    .entry(key)
638                    .or_insert_with(std::collections::HashMap::new);
639                entry.insert(field, value);
640                Ok(RedisFrame::Integer(1))
641            }
642            RedisCommand::HGet(key, field) => {
643                let hash = self.hash.read().await;
644                match hash.get(&key).and_then(|m| m.get(&field)) {
645                    Some(val) => Ok(RedisFrame::BulkString(Some(val.clone()))),
646                    None => Ok(RedisFrame::BulkString(None)),
647                }
648            }
649            RedisCommand::HDel(key, fields) => {
650                let mut hash = self.hash.write().await;
651                let mut count = 0;
652                if let Some(m) = hash.get_mut(&key) {
653                    for field in fields {
654                        if m.remove(&field).is_some() {
655                            count += 1;
656                        }
657                    }
658                }
659                Ok(RedisFrame::Integer(count))
660            }
661            RedisCommand::Incr(key) => {
662                let mut kv = self.kv.write().await;
663                let val = kv
664                    .get(&key)
665                    .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
666                    .unwrap_or(0);
667                let new_val = val + 1;
668                kv.insert(key, Bytes::from(new_val.to_string()));
669                Ok(RedisFrame::Integer(new_val))
670            }
671            RedisCommand::Decr(key) => {
672                let mut kv = self.kv.write().await;
673                let val = kv
674                    .get(&key)
675                    .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
676                    .unwrap_or(0);
677                let new_val = val - 1;
678                kv.insert(key, Bytes::from(new_val.to_string()));
679                Ok(RedisFrame::Integer(new_val))
680            }
681            RedisCommand::IncrBy(key, amount) => {
682                let mut kv = self.kv.write().await;
683                let val = kv
684                    .get(&key)
685                    .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
686                    .unwrap_or(0);
687                let new_val = val + amount;
688                kv.insert(key, Bytes::from(new_val.to_string()));
689                Ok(RedisFrame::Integer(new_val))
690            }
691            RedisCommand::Expire(key, seconds) => {
692                let mut ttl = self.ttl.write().await;
693                ttl.insert(key, seconds);
694                Ok(RedisFrame::Integer(1))
695            }
696            RedisCommand::Ttl(key) => {
697                let ttl = self.ttl.read().await;
698                match ttl.get(&key) {
699                    Some(s) => Ok(RedisFrame::Integer(*s as i64)),
700                    None => Ok(RedisFrame::Integer(-1)),
701                }
702            }
703            RedisCommand::Keys(pattern) => {
704                let kv = self.kv.read().await;
705                let keys: Vec<RedisFrame> = kv
706                    .keys()
707                    .filter(|k| k.contains(&pattern.replace("*", "")))
708                    .map(|k| RedisFrame::BulkString(Some(Bytes::copy_from_slice(k.as_bytes()))))
709                    .collect();
710                Ok(RedisFrame::Array(Some(keys)))
711            }
712        }
713    }
714}