1#![warn(missing_docs)]
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::sync::Arc;
5use tokio_util::codec::{Decoder, Encoder};
6use uuid::Uuid;
7use yykv_types::{DsError, DsValue, Redundancy};
8
9use std::path::PathBuf;
10
11type Result<T> = std::result::Result<T, DsError>;
12use yykv_wal::{OpType, WalManager};
13
14pub mod adapter;
15pub mod connection;
16pub mod transaction;
17
18pub use adapter::RedisAdapter;
19pub use connection::RedisConnection;
20pub use transaction::RedisTransaction;
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum RedisFrame {
24 SimpleString(String),
25 Error(String),
26 Integer(i64),
27 BulkString(Option<Bytes>),
28 Array(Option<Vec<RedisFrame>>),
29}
30
31use futures::{SinkExt, StreamExt};
32use tokio::net::TcpStream;
33use tokio_util::codec::Framed;
34use tracing::{error, info};
35
36pub async fn handle_connection(stream: TcpStream, service: Arc<RedisService>) -> Result<()> {
37 let mut framed = Framed::new(stream, RedisCodec);
38
39 while let Some(result) = framed.next().await {
40 let frame = match result {
41 Ok(f) => f,
42 Err(e) => {
43 error!("Decode error: {}", e);
44 break;
45 }
46 };
47
48 let cmd = match RedisCommand::from_frame(frame) {
50 Ok(c) => c,
51 Err(e) => {
52 let err_frame = RedisFrame::Error(format!("ERR {}", e));
53 let _ = framed.send(err_frame).await;
54 continue;
55 }
56 };
57
58 let response = service.handle_command(cmd).await?;
60
61 framed.send(response).await?;
63 }
64
65 info!("Connection closed");
66 Ok(())
67}
68
69pub struct RedisCodec;
70
71impl Decoder for RedisCodec {
72 type Item = RedisFrame;
73 type Error = DsError;
74
75 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
76 if src.is_empty() {
77 return Ok(None);
78 }
79
80 match src[0] {
81 b'+' => self.decode_simple_string(src),
82 b'-' => self.decode_error(src),
83 b':' => self.decode_integer(src),
84 b'$' => self.decode_bulk_string(src),
85 b'*' => self.decode_array(src),
86 _ => Err(DsError::protocol(format!(
87 "Invalid Redis frame type: {}",
88 src[0] as char
89 ))),
90 }
91 }
92}
93
94impl RedisCodec {
95 fn decode_simple_string(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
96 if let Some(i) = self.find_crlf(src) {
97 let line = src.split_to(i + 2);
98 let s = String::from_utf8(line[1..i].to_vec())?;
99 Ok(Some(RedisFrame::SimpleString(s)))
100 } else {
101 Ok(None)
102 }
103 }
104
105 fn decode_error(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
106 if let Some(i) = self.find_crlf(src) {
107 let line = src.split_to(i + 2);
108 let s = String::from_utf8(line[1..i].to_vec())?;
109 Ok(Some(RedisFrame::Error(s)))
110 } else {
111 Ok(None)
112 }
113 }
114
115 fn decode_integer(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
116 if let Some(i) = self.find_crlf(src) {
117 let line = src.split_to(i + 2);
118 let s = std::str::from_utf8(&line[1..i])?;
119 let n = s.parse::<i64>()?;
120 Ok(Some(RedisFrame::Integer(n)))
121 } else {
122 Ok(None)
123 }
124 }
125
126 fn decode_bulk_string(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
127 if let Some(i) = self.find_crlf(src) {
128 let line = &src[..i];
129 let len = std::str::from_utf8(&line[1..])?.parse::<isize>()?;
130
131 if len == -1 {
132 src.advance(i + 2);
133 return Ok(Some(RedisFrame::BulkString(None)));
134 }
135
136 let bulk_len = len as usize;
137 if src.len() < i + 2 + bulk_len + 2 {
138 return Ok(None);
139 }
140
141 src.advance(i + 2);
142 let data = src.split_to(bulk_len).freeze();
143 src.advance(2); Ok(Some(RedisFrame::BulkString(Some(data))))
145 } else {
146 Ok(None)
147 }
148 }
149
150 fn decode_array(&mut self, src: &mut BytesMut) -> Result<Option<RedisFrame>> {
151 if let Some(i) = self.find_crlf(src) {
152 let line = &src[..i];
153 let len = std::str::from_utf8(&line[1..])?.parse::<isize>()?;
154
155 if len == -1 {
156 src.advance(i + 2);
157 return Ok(Some(RedisFrame::Array(None)));
158 }
159
160 let array_len = len as usize;
161 src.advance(i + 2);
162
163 let mut frames = Vec::with_capacity(array_len);
164 for _ in 0..array_len {
165 match self.decode(src)? {
166 Some(frame) => frames.push(frame),
167 None => return Ok(None), }
169 }
170 Ok(Some(RedisFrame::Array(Some(frames))))
171 } else {
172 Ok(None)
173 }
174 }
175
176 fn find_crlf(&self, src: &[u8]) -> Option<usize> {
177 src.windows(2).position(|w| w == b"\r\n")
178 }
179}
180
181impl Encoder<RedisFrame> for RedisCodec {
182 type Error = DsError;
183
184 fn encode(&mut self, item: RedisFrame, dst: &mut BytesMut) -> Result<()> {
185 match item {
186 RedisFrame::SimpleString(s) => {
187 dst.reserve(1 + s.len() + 2);
188 dst.put_u8(b'+');
189 dst.put_slice(s.as_bytes());
190 dst.put_slice(b"\r\n");
191 }
192 RedisFrame::Error(s) => {
193 dst.reserve(1 + s.len() + 2);
194 dst.put_u8(b'-');
195 dst.put_slice(s.as_bytes());
196 dst.put_slice(b"\r\n");
197 }
198 RedisFrame::Integer(n) => {
199 let s = n.to_string();
200 dst.reserve(1 + s.len() + 2);
201 dst.put_u8(b':');
202 dst.put_slice(s.as_bytes());
203 dst.put_slice(b"\r\n");
204 }
205 RedisFrame::BulkString(opt_data) => match opt_data {
206 Some(data) => {
207 let len_s = data.len().to_string();
208 dst.reserve(1 + len_s.len() + 2 + data.len() + 2);
209 dst.put_u8(b'$');
210 dst.put_slice(len_s.as_bytes());
211 dst.put_slice(b"\r\n");
212 dst.put_slice(&data);
213 dst.put_slice(b"\r\n");
214 }
215 None => {
216 dst.put_slice(b"$-1\r\n");
217 }
218 },
219 RedisFrame::Array(opt_frames) => match opt_frames {
220 Some(frames) => {
221 let len_s = frames.len().to_string();
222 dst.reserve(1 + len_s.len() + 2);
223 dst.put_u8(b'*');
224 dst.put_slice(len_s.as_bytes());
225 dst.put_slice(b"\r\n");
226 for frame in frames {
227 self.encode(frame, dst)?;
228 }
229 }
230 None => {
231 dst.put_slice(b"*-1\r\n");
232 }
233 },
234 }
235 Ok(())
236 }
237}
238
239#[derive(Debug, Clone)]
240pub enum RedisCommand {
241 Ping,
242 Get(String),
243 Set(String, Bytes),
244 Del(Vec<String>),
245 Exists(Vec<String>),
246 HSet(String, String, Bytes),
248 HGet(String, String),
249 HDel(String, Vec<String>),
250 Incr(String),
252 Decr(String),
253 IncrBy(String, i64),
254 Expire(String, u64),
256 Ttl(String),
257 Keys(String),
259}
260
261impl RedisCommand {
262 pub fn from_frame(frame: RedisFrame) -> Result<Self> {
263 match frame {
264 RedisFrame::Array(Some(frames)) => {
265 if frames.is_empty() {
266 return Err(DsError::protocol("Empty command array"));
267 }
268
269 let cmd_name = match &frames[0] {
270 RedisFrame::BulkString(Some(data)) => {
271 String::from_utf8_lossy(data).to_uppercase()
272 }
273 _ => {
274 return Err(DsError::protocol("Command name must be a bulk string"));
275 }
276 };
277
278 match cmd_name.as_str() {
279 "PING" => Ok(RedisCommand::Ping),
280 "GET" => {
281 if frames.len() != 2 {
282 return Err(DsError::protocol("GET requires exactly 1 argument"));
283 }
284 let key = match &frames[1] {
285 RedisFrame::BulkString(Some(data)) => {
286 String::from_utf8_lossy(data).to_string()
287 }
288 _ => {
289 return Err(DsError::protocol("GET key must be a bulk string"));
290 }
291 };
292 Ok(RedisCommand::Get(key))
293 }
294 "SET" => {
295 if frames.len() != 3 {
296 return Err(DsError::protocol("SET requires exactly 2 arguments"));
297 }
298 let key = match &frames[1] {
299 RedisFrame::BulkString(Some(data)) => {
300 String::from_utf8_lossy(data).to_string()
301 }
302 _ => {
303 return Err(DsError::protocol("SET key must be a bulk string"));
304 }
305 };
306 let value = match &frames[2] {
307 RedisFrame::BulkString(Some(data)) => data.clone(),
308 _ => {
309 return Err(DsError::protocol("SET value must be a bulk string"));
310 }
311 };
312 Ok(RedisCommand::Set(key, value))
313 }
314 "DEL" => {
315 let mut keys = Vec::new();
316 for frame in frames.iter().skip(1) {
317 match frame {
318 RedisFrame::BulkString(Some(data)) => {
319 keys.push(String::from_utf8_lossy(data).to_string())
320 }
321 _ => {
322 return Err(DsError::protocol("DEL key must be a bulk string"));
323 }
324 }
325 }
326 Ok(RedisCommand::Del(keys))
327 }
328 "EXISTS" => {
329 let mut keys = Vec::new();
330 for frame in frames.iter().skip(1) {
331 match frame {
332 RedisFrame::BulkString(Some(data)) => {
333 keys.push(String::from_utf8_lossy(data).to_string())
334 }
335 _ => {
336 return Err(DsError::protocol(
337 "EXISTS key must be a bulk string",
338 ));
339 }
340 }
341 }
342 Ok(RedisCommand::Exists(keys))
343 }
344 "HSET" => {
345 if frames.len() != 4 {
346 return Err(DsError::protocol("HSET requires exactly 3 arguments"));
347 }
348 let key = match &frames[1] {
349 RedisFrame::BulkString(Some(data)) => {
350 String::from_utf8_lossy(data).to_string()
351 }
352 _ => {
353 return Err(DsError::protocol("HSET key must be a bulk string"));
354 }
355 };
356 let field = match &frames[2] {
357 RedisFrame::BulkString(Some(data)) => {
358 String::from_utf8_lossy(data).to_string()
359 }
360 _ => {
361 return Err(DsError::protocol("HSET field must be a bulk string"));
362 }
363 };
364 let value = match &frames[3] {
365 RedisFrame::BulkString(Some(data)) => data.clone(),
366 _ => {
367 return Err(DsError::protocol("HSET value must be a bulk string"));
368 }
369 };
370 Ok(RedisCommand::HSet(key, field, value))
371 }
372 "HGET" => {
373 if frames.len() != 3 {
374 return Err(DsError::protocol("HGET requires exactly 2 arguments"));
375 }
376 let key = match &frames[1] {
377 RedisFrame::BulkString(Some(data)) => {
378 String::from_utf8_lossy(data).to_string()
379 }
380 _ => {
381 return Err(DsError::protocol("HGET key must be a bulk string"));
382 }
383 };
384 let field = match &frames[2] {
385 RedisFrame::BulkString(Some(data)) => {
386 String::from_utf8_lossy(data).to_string()
387 }
388 _ => {
389 return Err(DsError::protocol("HGET field must be a bulk string"));
390 }
391 };
392 Ok(RedisCommand::HGet(key, field))
393 }
394 "HDEL" => {
395 if frames.len() < 3 {
396 return Err(DsError::protocol("HDEL requires at least 2 arguments"));
397 }
398 let key = match &frames[1] {
399 RedisFrame::BulkString(Some(data)) => {
400 String::from_utf8_lossy(data).to_string()
401 }
402 _ => {
403 return Err(DsError::protocol("HDEL key must be a bulk string"));
404 }
405 };
406 let mut fields = Vec::new();
407 for frame in frames.iter().skip(2) {
408 match frame {
409 RedisFrame::BulkString(Some(data)) => {
410 fields.push(String::from_utf8_lossy(data).to_string())
411 }
412 _ => {
413 return Err(DsError::protocol(
414 "HDEL field must be a bulk string",
415 ));
416 }
417 }
418 }
419 Ok(RedisCommand::HDel(key, fields))
420 }
421 "INCR" => {
422 if frames.len() != 2 {
423 return Err(DsError::protocol("INCR requires exactly 1 argument"));
424 }
425 let key = match &frames[1] {
426 RedisFrame::BulkString(Some(data)) => {
427 String::from_utf8_lossy(data).to_string()
428 }
429 _ => {
430 return Err(DsError::protocol("INCR key must be a bulk string"));
431 }
432 };
433 Ok(RedisCommand::Incr(key))
434 }
435 "DECR" => {
436 if frames.len() != 2 {
437 return Err(DsError::protocol("DECR requires exactly 1 argument"));
438 }
439 let key = match &frames[1] {
440 RedisFrame::BulkString(Some(data)) => {
441 String::from_utf8_lossy(data).to_string()
442 }
443 _ => {
444 return Err(DsError::protocol("DECR key must be a bulk string"));
445 }
446 };
447 Ok(RedisCommand::Decr(key))
448 }
449 "INCRBY" => {
450 if frames.len() != 3 {
451 return Err(DsError::protocol("INCRBY requires exactly 2 arguments"));
452 }
453 let key = match &frames[1] {
454 RedisFrame::BulkString(Some(data)) => {
455 String::from_utf8_lossy(data).to_string()
456 }
457 _ => {
458 return Err(DsError::protocol("INCRBY key must be a bulk string"));
459 }
460 };
461 let amount = match &frames[2] {
462 RedisFrame::BulkString(Some(data)) => {
463 String::from_utf8_lossy(data).parse::<i64>().map_err(|e| {
464 DsError::protocol(format!("Invalid INCRBY amount: {}", e))
465 })?
466 }
467 _ => {
468 return Err(DsError::protocol(
469 "INCRBY amount must be a bulk string",
470 ));
471 }
472 };
473 Ok(RedisCommand::IncrBy(key, amount))
474 }
475 "EXPIRE" => {
476 if frames.len() != 3 {
477 return Err(DsError::protocol("EXPIRE requires exactly 2 arguments"));
478 }
479 let key = match &frames[1] {
480 RedisFrame::BulkString(Some(data)) => {
481 String::from_utf8_lossy(data).to_string()
482 }
483 _ => {
484 return Err(DsError::protocol("EXPIRE key must be a bulk string"));
485 }
486 };
487 let seconds = match &frames[2] {
488 RedisFrame::BulkString(Some(data)) => {
489 String::from_utf8_lossy(data).parse::<u64>().map_err(|e| {
490 DsError::protocol(format!("Invalid EXPIRE seconds: {}", e))
491 })?
492 }
493 _ => {
494 return Err(DsError::protocol(
495 "EXPIRE seconds must be a bulk string",
496 ));
497 }
498 };
499 Ok(RedisCommand::Expire(key, seconds))
500 }
501 "TTL" => {
502 if frames.len() != 2 {
503 return Err(DsError::protocol("TTL requires exactly 1 argument"));
504 }
505 let key = match &frames[1] {
506 RedisFrame::BulkString(Some(data)) => {
507 String::from_utf8_lossy(data).to_string()
508 }
509 _ => {
510 return Err(DsError::protocol("TTL key must be a bulk string"));
511 }
512 };
513 Ok(RedisCommand::Ttl(key))
514 }
515 "KEYS" => {
516 if frames.len() != 2 {
517 return Err(DsError::protocol("KEYS requires exactly 1 argument"));
518 }
519 let pattern = match &frames[1] {
520 RedisFrame::BulkString(Some(data)) => {
521 String::from_utf8_lossy(data).to_string()
522 }
523 _ => {
524 return Err(DsError::protocol(
525 "KEYS pattern must be a bulk string",
526 ));
527 }
528 };
529 Ok(RedisCommand::Keys(pattern))
530 }
531 _ => Err(DsError::protocol(format!(
532 "Unsupported command: {}",
533 cmd_name
534 ))),
535 }
536 }
537 _ => Err(DsError::protocol("Invalid command frame: must be an array")),
538 }
539 }
540}
541
542pub struct RedisService {
543 wal: Arc<WalManager>,
544 tenant_id: Uuid,
545 kv: Arc<tokio::sync::RwLock<std::collections::HashMap<String, Bytes>>>,
547 hash: Arc<
548 tokio::sync::RwLock<
549 std::collections::HashMap<String, std::collections::HashMap<String, Bytes>>,
550 >,
551 >,
552 ttl: Arc<tokio::sync::RwLock<std::collections::HashMap<String, u64>>>,
553}
554
555impl RedisService {
556 pub async fn new() -> Result<Self> {
557 let wal = Arc::new(
558 WalManager::new(PathBuf::from("wal_redis"))
559 .await
560 .map_err(|e| DsError::storage(e.to_string()))?,
561 );
562 Ok(Self {
565 wal,
566 tenant_id: Uuid::new_v4(),
567 kv: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
568 hash: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
569 ttl: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
570 })
571 }
572
573 pub async fn handle_command(&self, cmd: RedisCommand) -> Result<RedisFrame> {
574 match cmd {
575 RedisCommand::Ping => Ok(RedisFrame::SimpleString("PONG".to_string())),
576 RedisCommand::Set(key, value) => {
577 {
579 let mut kv = self.kv.write().await;
580 kv.insert(key.clone(), value.clone());
581 }
582
583 self.wal
585 .write(
586 self.tenant_id,
587 "redis_kv".to_string(),
588 key.clone(),
589 OpType::Insert,
590 Redundancy::SINGLE,
591 DsValue::Binary(value),
592 )
593 .await?;
594
595 Ok(RedisFrame::SimpleString("OK".to_string()))
596 }
597 RedisCommand::Get(key) => {
598 let kv = self.kv.read().await;
599 match kv.get(&key) {
600 Some(val) => Ok(RedisFrame::BulkString(Some(val.clone()))),
601 None => Ok(RedisFrame::BulkString(None)),
602 }
603 }
604 RedisCommand::Del(keys) => {
605 let mut kv = self.kv.write().await;
606 let mut count = 0;
607 for key in keys {
608 if kv.remove(&key).is_some() {
609 count += 1;
610 self.wal
611 .write(
612 self.tenant_id,
613 "redis_kv".to_string(),
614 key.clone(),
615 OpType::Delete,
616 Redundancy::SINGLE,
617 DsValue::Text(key),
618 )
619 .await?;
620 }
621 }
622 Ok(RedisFrame::Integer(count))
623 }
624 RedisCommand::Exists(keys) => {
625 let kv = self.kv.read().await;
626 let mut count = 0;
627 for key in keys {
628 if kv.contains_key(&key) {
629 count += 1;
630 }
631 }
632 Ok(RedisFrame::Integer(count))
633 }
634 RedisCommand::HSet(key, field, value) => {
635 let mut hash = self.hash.write().await;
636 let entry = hash
637 .entry(key)
638 .or_insert_with(std::collections::HashMap::new);
639 entry.insert(field, value);
640 Ok(RedisFrame::Integer(1))
641 }
642 RedisCommand::HGet(key, field) => {
643 let hash = self.hash.read().await;
644 match hash.get(&key).and_then(|m| m.get(&field)) {
645 Some(val) => Ok(RedisFrame::BulkString(Some(val.clone()))),
646 None => Ok(RedisFrame::BulkString(None)),
647 }
648 }
649 RedisCommand::HDel(key, fields) => {
650 let mut hash = self.hash.write().await;
651 let mut count = 0;
652 if let Some(m) = hash.get_mut(&key) {
653 for field in fields {
654 if m.remove(&field).is_some() {
655 count += 1;
656 }
657 }
658 }
659 Ok(RedisFrame::Integer(count))
660 }
661 RedisCommand::Incr(key) => {
662 let mut kv = self.kv.write().await;
663 let val = kv
664 .get(&key)
665 .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
666 .unwrap_or(0);
667 let new_val = val + 1;
668 kv.insert(key, Bytes::from(new_val.to_string()));
669 Ok(RedisFrame::Integer(new_val))
670 }
671 RedisCommand::Decr(key) => {
672 let mut kv = self.kv.write().await;
673 let val = kv
674 .get(&key)
675 .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
676 .unwrap_or(0);
677 let new_val = val - 1;
678 kv.insert(key, Bytes::from(new_val.to_string()));
679 Ok(RedisFrame::Integer(new_val))
680 }
681 RedisCommand::IncrBy(key, amount) => {
682 let mut kv = self.kv.write().await;
683 let val = kv
684 .get(&key)
685 .map(|v| String::from_utf8_lossy(v).parse::<i64>().unwrap_or(0))
686 .unwrap_or(0);
687 let new_val = val + amount;
688 kv.insert(key, Bytes::from(new_val.to_string()));
689 Ok(RedisFrame::Integer(new_val))
690 }
691 RedisCommand::Expire(key, seconds) => {
692 let mut ttl = self.ttl.write().await;
693 ttl.insert(key, seconds);
694 Ok(RedisFrame::Integer(1))
695 }
696 RedisCommand::Ttl(key) => {
697 let ttl = self.ttl.read().await;
698 match ttl.get(&key) {
699 Some(s) => Ok(RedisFrame::Integer(*s as i64)),
700 None => Ok(RedisFrame::Integer(-1)),
701 }
702 }
703 RedisCommand::Keys(pattern) => {
704 let kv = self.kv.read().await;
705 let keys: Vec<RedisFrame> = kv
706 .keys()
707 .filter(|k| k.contains(&pattern.replace("*", "")))
708 .map(|k| RedisFrame::BulkString(Some(Bytes::copy_from_slice(k.as_bytes()))))
709 .collect();
710 Ok(RedisFrame::Array(Some(keys)))
711 }
712 }
713 }
714}