reredis/
rdb.rs

1use crate::server::Server;
2use std::io;
3use std::io::{Write, BufWriter, BufReader, Read};
4use std::fs::{File, OpenOptions, rename};
5use crate::db::DB;
6use crate::object::{RobjPtr, RobjEncoding, RobjType, Robj};
7use std::time::SystemTime;
8use crate::util::{unix_timestamp, to_system_time};
9use std::rc::Rc;
10use crate::object::linked_list::LinkedList;
11use crate::object::dict::Dict;
12use crate::hash;
13use rand::Rng;
14use std::os::unix::io::AsRawFd;
15use nix::unistd::{fork, ForkResult, Pid};
16use nix::sys::ptrace::kill;
17use std::process::exit;
18use std::error::Error;
19
20const RDB_DB_SELECT_FLAG: u8 = 0xFE;
21const RDB_DB_END_FLAG: u8 = 0xFF;
22const RDB_KV_EXPIRE_FLAG: u8 = 0xFC;
23
24const RDB_STRING_FLAG: u8 = 0;
25const RDB_LIST_FLAG: u8 = 1;
26const RDB_SET_FLAG: u8 = 2;
27const RDB_ZSET_FLAG: u8 = 3;
28const RDB_HASH_FLAG: u8 = 4;
29const RDB_ZIPMAP_FLAG: u8 = 9;
30const RDB_ZIPLIST_FLAG: u8 = 10;
31const RDB_INTSET_FLAG: u8 = 11;
32const RDB_ZSET_ZIPLIST_FLAG: u8 = 12;
33const RDB_HASH_ZIPLIST_FLAG: u8 = 13;
34
35const RDB_INT_32_FLAG: u8 = 0b1100_0010;
36const RDB_INT_16_FLAG: u8 = 0b1100_0001;
37const RDB_INT_8_FLAG: u8 = 0b1100_0000;
38
39const RDB_VERSION: &[u8] = b"REDIS0005";
40const RDB_SELECT_DB: &[u8] = &[RDB_DB_SELECT_FLAG];
41const RDB_END_BUF: &[u8] = &[RDB_DB_END_FLAG];
42const RDB_NO_CHECKSUM: &[u8] = &[0, 0, 0, 0, 0, 0, 0, 0];
43const RDB_EXPIRE_MS_BUF: &[u8] = &[RDB_KV_EXPIRE_FLAG];
44
45const RDB_INT_32_BUF: &[u8] = &[RDB_INT_32_FLAG];
46const RDB_INT_16_BUF: &[u8] = &[RDB_INT_16_FLAG];
47const RDB_INT_8_BUF: &[u8] = &[RDB_INT_8_FLAG];
48
49pub fn rdb_save_in_background(server: &mut Server) -> Result<(), ()> {
50    if server.bg_save_in_progress {
51        return Err(());
52    }
53    match fork() {
54        Ok(ForkResult::Parent { child, .. }) => {
55            info!("Background saving started by pid {}", child);
56            server.bg_save_in_progress = true;
57            server.bg_save_child_pid = child.as_raw();
58            return Ok(());
59        }
60        Ok(ForkResult::Child) => {
61            let fd = server.fd.borrow().unwrap_listener().as_raw_fd();
62            let _ = nix::unistd::close(fd);
63            if let Ok(()) = rdb_save(server) {
64                exit(0);
65            } else {
66                exit(1);
67            }
68        }
69        Err(e) => {
70            warn!("Can't save in background: fork: {}", e.description());
71            return Err(());
72        }
73    }
74}
75
76pub fn rdb_kill_background_saving(server: &Server) {
77    let _ = kill(Pid::from_raw(server.bg_save_child_pid));
78}
79
80pub fn rdb_save(server: &Server) -> io::Result<()> {
81    let temp_file_name = format!("temp-{}.rdb", rand::thread_rng().gen::<usize>());
82    let file: File = OpenOptions::new()
83        .write(true)
84        .create(true)
85        .truncate(true)
86        .open(&temp_file_name)?;
87
88    let mut writer = BufWriter::new(file);
89
90    writer.write_all(RDB_VERSION)?;
91
92    for db in server.db.iter() {
93        if db.dict.len() > 0 {
94            writer.dump_db(db)?;
95        }
96    }
97
98    writer.write_all(RDB_END_BUF)?;
99    writer.write_all(RDB_NO_CHECKSUM)?;
100    writer.flush()?;
101
102    rename(&temp_file_name, &server.db_filename)?;
103    Ok(())
104}
105
106trait RdbWriter: io::Write {
107    fn dump_db(&mut self, db: &DB) -> io::Result<()> {
108        self.write_all(RDB_SELECT_DB)?;
109        self.dump_length(db.id)?;
110        for (k, v) in db.dict.iter() {
111            let exp = db.expires.find(k)
112                .map(|p| p.1);
113            self.dump_key_value(k, v, exp)?;
114        }
115        Ok(())
116    }
117
118    fn dump_length(&mut self, size: usize) -> io::Result<()> {
119        if size < 64 {
120            self.write_all(&[size as u8])?;
121        } else if size < 16384 {
122            let mut bytes: [u8; 2] = (size as u16).to_le_bytes();
123            bytes[0] |= 0b0100_0000;
124            self.write_all(&bytes)?;
125        } else if size < std::u32::MAX as usize {
126            let bytes: [u8; 4] = (size as u32).to_le_bytes();
127            let encoded: [u8; 5] = [
128                0b1000_0000, bytes[0], bytes[1], bytes[2], bytes[3]
129            ];
130            self.write_all(&encoded)?;
131        } else {
132            return Err(io::Error::new(
133                io::ErrorKind::Other,
134                "cannot be encoded as length",
135            ));
136        }
137        Ok(())
138    }
139
140    fn dump_key_value(
141        &mut self,
142        k: &RobjPtr,
143        v: &RobjPtr,
144        exp: Option<&SystemTime>,
145    ) -> io::Result<()> {
146        if let Some(t) = exp {
147            self.write_all(RDB_EXPIRE_MS_BUF)?;
148            self.dump_timestamp(t)?;
149        }
150
151        self.write_all(&[value_type_flag(v)])?;
152
153        self.dump_string(k)?;
154
155        self.dump_object(v)?;
156
157        Ok(())
158    }
159
160    fn dump_timestamp(&mut self, t: &SystemTime) -> io::Result<()> {
161        let unix_t = unix_timestamp(t);
162        self.write_all(&unix_t.to_le_bytes())?;
163        Ok(())
164    }
165
166    fn dump_object(&mut self, obj: &RobjPtr) -> io::Result<()> {
167        use RobjType::*;
168        use RobjEncoding::*;
169        match (obj.borrow().object_type(), obj.borrow().encoding()) {
170            (String, _) => self.dump_string(obj)?,
171            (List, LinkedList) => self.dump_list(obj)?,
172            (Set, Ht) => self.dump_set(obj)?,
173            (Zset, SkipList) => self.dump_zset(obj)?,
174            (Hash, Ht) => self.dump_hash(obj)?,
175            (Hash, ZipMap) => self.dump_zmap(obj)?,
176            (List, ZipList) => self.dump_ziplist(obj)?,
177            (Set, IntSet) => self.dump_intset(obj)?,
178            (Zset, ZipList) => self.dump_zset_ziplist(obj)?,
179            (Hash, ZipList) => self.dump_hash_ziplist(obj)?,
180            (_, _) => panic!("no such type-encoding pair"),
181        }
182        Ok(())
183    }
184
185    fn dump_string(&mut self, obj: &RobjPtr) -> io::Result<()> {
186        let obj_ref = obj.borrow();
187        if let RobjEncoding::Int = obj_ref.encoding() {
188            self.dump_integer(obj_ref.integer())?;
189            return Ok(());
190        }
191        self.dump_bytes(obj_ref.string())?;
192        Ok(())
193    }
194
195    fn dump_bytes(&mut self, s: &[u8]) -> io::Result<()> {
196        self.dump_length(s.len())?;
197        self.write_all(s)?;
198        Ok(())
199    }
200
201    fn dump_integer(&mut self, i: i64) -> io::Result<()> {
202        if i < std::i32::MIN as i64 || i > std::i32::MAX as i64 {
203            self.dump_bytes(i.to_string().as_bytes())?;
204        } else if i < std::i16::MIN as i64 || i > std::i16::MAX as i64 {
205            self.write_all(RDB_INT_32_BUF)?;
206            let bytes: [u8; 4] = (i as i32).to_le_bytes();
207            self.write_all(&bytes)?;
208        } else if i < std::i8::MIN as i64 || i > std::i8::MAX as i64 {
209            self.write_all(RDB_INT_16_BUF)?;
210            let bytes: [u8; 2] = (i as i16).to_le_bytes();
211            self.write_all(&bytes)?;
212        } else {
213            self.write_all(RDB_INT_8_BUF)?;
214            let bytes: [u8; 1] = (i as i8).to_le_bytes();
215            self.write_all(&bytes)?;
216        }
217        Ok(())
218    }
219
220    fn dump_list(&mut self, obj: &RobjPtr) -> io::Result<()> {
221        self.dump_linear(obj)
222    }
223
224    fn dump_set(&mut self, obj: &RobjPtr) -> io::Result<()> {
225        self.dump_linear(obj)
226    }
227
228    fn dump_linear(&mut self, obj: &RobjPtr) -> io::Result<()> {
229        let obj_ref = obj.borrow();
230        self.dump_length(obj_ref.linear_len())?;
231        for str_obj in obj_ref.linear_iter() {
232            self.dump_string(&str_obj)?;
233        }
234        Ok(())
235    }
236
237    fn dump_zset(&mut self, _obj: &RobjPtr) -> io::Result<()> {
238        unimplemented!()
239    }
240
241    fn dump_hash(&mut self, _obj: &RobjPtr) -> io::Result<()> {
242        unimplemented!()
243    }
244
245    fn dump_zmap(&mut self, _obj: &RobjPtr) -> io::Result<()> {
246        unimplemented!()
247    }
248
249    fn dump_ziplist(&mut self, obj: &RobjPtr) -> io::Result<()> {
250        self.dump_bytes(obj.borrow().raw_data())
251    }
252
253    fn dump_intset(&mut self, obj: &RobjPtr) -> io::Result<()> {
254        self.dump_bytes(obj.borrow().raw_data())
255    }
256
257    fn dump_zset_ziplist(&mut self, _obj: &RobjPtr) -> io::Result<()> {
258        unimplemented!()
259    }
260
261    fn dump_hash_ziplist(&mut self, _obj: &RobjPtr) -> io::Result<()> {
262        unimplemented!()
263    }
264}
265
266impl RdbWriter for BufWriter<File> {}
267
268fn value_type_flag(o: &RobjPtr) -> u8 {
269    use RobjEncoding::*;
270    use RobjType::*;
271
272    match (o.borrow().object_type(), o.borrow().encoding()) {
273        (String, _) => RDB_STRING_FLAG,
274        (List, LinkedList) => RDB_LIST_FLAG,
275        (Set, Ht) => RDB_SET_FLAG,
276        (Zset, SkipList) => RDB_ZSET_FLAG,
277        (Hash, Ht) => RDB_HASH_FLAG,
278        (Hash, ZipMap) => RDB_ZIPMAP_FLAG,
279        (List, ZipList) => RDB_ZIPLIST_FLAG,
280        (Set, IntSet) => RDB_INTSET_FLAG,
281        (Zset, ZipList) => RDB_ZSET_ZIPLIST_FLAG,
282        (Hash, ZipList) => RDB_HASH_ZIPLIST_FLAG,
283        (_, _) => panic!("no such type-encoding pair"),
284    }
285}
286
287
288pub fn rdb_load(server: &mut Server) -> io::Result<()> {
289    let file = OpenOptions::new()
290        .read(true)
291        .open(&server.db_filename)?;
292
293    let mut buf: [u8; 9] = [0; 9];
294    let mut reader = BufReader::new(file);
295
296    reader.read_exact(&mut buf[0..9])?;
297    check_magic_number(&buf[0..5])?;
298    let first_db_selector = reader.load_u8()?;
299    if let Err(_) = check_db_selector(first_db_selector) {
300        info!("Empty rdb file");
301        return Ok(());
302    }
303
304    loop {
305        let not_end = reader.load_db(server)?;
306        if !not_end {
307            break;
308        }
309    }
310
311    Ok(())
312}
313
314fn other_io_err(s: &str) -> io::Error {
315    io::Error::new(io::ErrorKind::Other, s)
316}
317
318fn check_magic_number(buf: &[u8]) -> io::Result<()> {
319    if buf != b"REDIS" {
320        return Err(other_io_err("Wrong magic number"));
321    }
322    Ok(())
323}
324
325fn check_db_selector(ch: u8) -> io::Result<()> {
326    if ch != RDB_DB_SELECT_FLAG {
327        return Err(other_io_err("Wrong db selector"));
328    }
329    Ok(())
330}
331
332fn check_db_idx(server: &Server, idx: usize) -> io::Result<()> {
333    if idx > server.db.len() {
334        return Err(other_io_err("Wrong db selector"));
335    }
336    if server.db[idx].dict.len() > 0 {
337        return Err(other_io_err("duplicate db selector"));
338    }
339    Ok(())
340}
341
342trait RdbReader: io::Read {
343    fn load_db(&mut self, server: &mut Server) -> io::Result<bool> {
344        let db_idx = self.load_length()?;
345        check_db_idx(server, db_idx)?;
346        let db = &mut server.db[db_idx];
347
348        loop {
349            let stat = self.load_key_value(db)?;
350            match stat {
351                LoadStatus::Ok => {}
352                LoadStatus::EndDB => return Ok(true),
353                LoadStatus::EndAll => return Ok(false),
354            }
355        }
356    }
357
358    fn load_length(&mut self) -> io::Result<usize> {
359        let len = self.load_length_or_integer()?;
360        match len {
361            LengthOrInteger::Len(l) => Ok(l),
362            _ => Err(other_io_err("require a length rather a integer string")),
363        }
364    }
365
366    fn load_u8(&mut self) -> io::Result<u8> {
367        let mut buf: [u8; 1] = [0; 1];
368        self.read_exact(&mut buf)?;
369        Ok(buf[0])
370    }
371
372    fn load_key_value(&mut self, db: &mut DB) -> io::Result<LoadStatus> {
373        let mut flag = self.load_u8()?;
374        let mut expire: Option<SystemTime> = None;
375
376        match flag {
377            RDB_DB_END_FLAG => return Ok(LoadStatus::EndAll),
378            RDB_DB_SELECT_FLAG => return Ok(LoadStatus::EndDB),
379            _ => {}
380        }
381
382        if flag == RDB_KV_EXPIRE_FLAG {
383            let t = self.load_time()?;
384            expire = Some(t);
385            flag = self.load_u8()?;
386        }
387
388        let key = self.load_string_object()?;
389        let value = self.load_object(flag)?;
390
391        if let Some(t) = expire {
392            let _ = db.set_expire(Rc::clone(&key), t);
393        }
394
395        db.dict.replace(key, value);
396
397        Ok(LoadStatus::Ok)
398    }
399
400    fn load_time(&mut self) -> io::Result<SystemTime> {
401        let mut buf: [u8; 8] = [0; 8];
402        self.read_exact(&mut buf)?;
403        let stamp: u64 = u64::from_le_bytes(buf);
404        Ok(to_system_time(stamp))
405    }
406
407    fn load_string_object(&mut self) -> io::Result<RobjPtr> {
408        let prefix = self.load_length_or_integer()?;
409        match prefix {
410            LengthOrInteger::Int(i) => {
411                Ok(Robj::create_int_object(i))
412            }
413            LengthOrInteger::Len(l) => {
414                let mut buf: Vec<u8> = vec![0; l];
415                self.load_n_bytes(&mut buf)?;
416                Ok(Robj::from_bytes(buf))
417            }
418        }
419    }
420
421    fn load_n_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
422        self.read_exact(buf)
423    }
424
425    fn load_length_or_integer(&mut self) -> io::Result<LengthOrInteger> {
426        let flag = self.load_u8()?;
427        match flag >> 6 {
428            0b0000 => {
429                return Ok(LengthOrInteger::Len(flag as usize));
430            }
431            0b0001 => {
432                let another = self.load_u8()?;
433                let buf = [flag & 0b0011_1111, another];
434                let len = u16::from_le_bytes(buf);
435                return Ok(LengthOrInteger::Len(len as usize));
436            }
437            0b0010 => {
438                let mut buf: [u8; 4] = [0; 4];
439                self.read_exact(&mut buf)?;
440                let len = u32::from_le_bytes(buf);
441                return Ok(LengthOrInteger::Len(len as usize));
442            }
443            0b0011 => {}
444            _ => unreachable!()
445        }
446
447        match flag & 0b0000_0011 {
448            0 => {
449                let i = self.load_u8()?;
450                let i = i8::from_le_bytes([i]);
451                return Ok(LengthOrInteger::Int(i as i64));
452            }
453            1 => {
454                let mut buf: [u8; 2] = [0; 2];
455                self.read_exact(&mut buf)?;
456                let i = i16::from_le_bytes(buf);
457                return Ok(LengthOrInteger::Int(i as i64));
458            }
459            2 => {
460                let mut buf: [u8; 4] = [0; 4];
461                self.read_exact(&mut buf)?;
462                let i = i32::from_le_bytes(buf);
463                return Ok(LengthOrInteger::Int(i as i64));
464            }
465            _ => {
466                return Err(other_io_err("Wrong length or integer prefix"));
467            }
468        }
469    }
470
471    fn load_object(&mut self, flag: u8) -> io::Result<RobjPtr> {
472        match flag {
473            RDB_STRING_FLAG => self.load_string_object(),
474            RDB_LIST_FLAG => self.load_list_object(),
475            RDB_SET_FLAG => self.load_set_object(),
476            RDB_ZSET_FLAG => self.load_zset_object(),
477            RDB_HASH_FLAG => self.load_hash_object(),
478            RDB_ZIPMAP_FLAG => self.load_zipmap_object(),
479            RDB_ZIPLIST_FLAG => self.load_zip_list_object(),
480            RDB_INTSET_FLAG => self.load_int_set_object(),
481            RDB_ZSET_ZIPLIST_FLAG => self.load_zset_ziplist_object(),
482            RDB_HASH_ZIPLIST_FLAG => self.load_hash_ziplist_object(),
483            _ => Err(other_io_err("No such value type"))
484        }
485    }
486
487    fn load_list_object(&mut self) -> io::Result<RobjPtr> {
488        let len = self.load_length()?;
489        let mut list: LinkedList<RobjPtr> = LinkedList::new();
490        for _ in 0..len {
491            let obj = self.load_string_object()?;
492            list.push_back(obj);
493        }
494        Ok(Robj::from_linked_list(list))
495    }
496
497    fn load_set_object(&mut self) -> io::Result<RobjPtr> {
498        let len = self.load_length()?;
499
500        let mut rng = rand::thread_rng();
501        let num: u64 = rng.gen();
502        let mut s: Dict<RobjPtr, ()> = Dict::new(hash::string_object_hash, num);
503
504        for _ in 0..len {
505            let obj = self.load_string_object()?;
506            let _ = s.add(obj, ());
507        }
508        Ok(Robj::from_set(s))
509    }
510
511    fn load_zset_object(&mut self) -> io::Result<RobjPtr> {
512        unimplemented!()
513    }
514
515    fn load_hash_object(&mut self) -> io::Result<RobjPtr> {
516        unimplemented!()
517    }
518
519    fn load_zipmap_object(&mut self) -> io::Result<RobjPtr> {
520        unimplemented!()
521    }
522
523    fn load_zip_list_object(&mut self) -> io::Result<RobjPtr> {
524        let len = self.load_length()?;
525        let mut buf: Vec<u8> = vec![0; len];
526        self.load_n_bytes(&mut buf)?;
527        Ok(Robj::zip_list_from_bytes(buf))
528    }
529
530    fn load_int_set_object(&mut self) -> io::Result<RobjPtr> {
531        let len = self.load_length()?;
532        let mut buf: Vec<u8> = vec![0; len];
533        self.load_n_bytes(&mut buf)?;
534        Ok(Robj::int_set_from_bytes(buf))
535    }
536
537    fn load_zset_ziplist_object(&mut self) -> io::Result<RobjPtr> {
538        unimplemented!()
539    }
540
541    fn load_hash_ziplist_object(&mut self) -> io::Result<RobjPtr> {
542        unimplemented!()
543    }
544}
545
546impl RdbReader for io::BufReader<File> {}
547
548enum LoadStatus {
549    Ok,
550    EndDB,
551    EndAll,
552}
553
554enum LengthOrInteger {
555    Int(i64),
556    Len(usize),
557}