basin2_lib/
nbt.rs

1use crate::result::*;
2use bytes::BytesMut;
3use linked_hash_map::LinkedHashMap;
4use crate::{ mcproto, McProtoBase };
5use crate::basin_err;
6
7enum_from_primitive! {
8#[derive(Clone, Copy, PartialEq, Debug)]
9#[repr(u8)]
10pub enum NbtType {
11    End,
12    Byte,
13    Short,
14    Int,
15    Long,
16    Float,
17    Double,
18    ByteArray,
19    Str,
20    List,
21    Compound,
22    IntArray,
23    LongArray,
24}
25}
26
27#[derive(PartialEq, Clone, Debug)]
28pub enum Nbt {
29    End,
30    Byte(i8),
31    Short(i16),
32    Int(i32),
33    Long(i64),
34    Float(f32),
35    Double(f64),
36    ByteArray(Vec<u8>),
37    Str(String),
38    List {
39        item_type: NbtType,
40        children: Vec<Nbt>,
41    },
42    Compound {
43        children: LinkedHashMap<String, Nbt>,
44    },
45    IntArray(Vec<i32>),
46    LongArray(Vec<i64>),
47}
48
49impl Nbt {
50    pub fn make_singleton_compound(key: String, value: Nbt) -> Nbt {
51        let mut children = LinkedHashMap::new();
52        children.insert(key, value);
53        Nbt::Compound { children }
54    }
55
56    pub fn make_compound(values: Vec<(String, Nbt)>) -> Nbt {
57        let mut children = LinkedHashMap::new();
58        for (key, value) in values {
59            children.insert(key, value);
60        }
61        Nbt::Compound { children }
62    }
63
64    pub fn nbt_type(&self) -> NbtType {
65        use Nbt::*;
66        match self {
67            End => NbtType::End,
68            Byte(..) => NbtType::Byte,
69            Short(..) => NbtType::Short,
70            Int(..) => NbtType::Int,
71            Long(..) => NbtType::Long,
72            Float(..) => NbtType::Float,
73            Double(..) => NbtType::Double,
74            ByteArray(..) => NbtType::ByteArray,
75            Str(..) => NbtType::Str,
76            List { .. } => NbtType::List,
77            Compound { .. } => NbtType::Compound,
78            IntArray(..) => NbtType::IntArray,
79            LongArray(..) => NbtType::LongArray,
80        }
81    }
82
83    pub fn parse(buf: &mut BytesMut) -> Result<Nbt> {
84        let direct_nbt = Nbt::parse_list(buf, NbtType::Compound)?;
85        match direct_nbt {
86            Nbt::Compound { children } if children.len() == 1 && children.contains_key("") => {
87                Ok(children[""].clone())
88            },
89            Nbt::Compound { children } if children.len() == 0 => {
90                Ok(Nbt::End)
91            },
92            _ => Ok(direct_nbt)
93        }
94    }
95
96    fn parse_item(buf: &mut BytesMut) -> Result<(Option<String>, Nbt)> {
97        let nbt_type: NbtType = if buf.len() == 0 {
98            NbtType::End
99        } else {
100            buf.get_mc_enum_u8()?
101        };
102        let name = match nbt_type {
103            NbtType::End => None,
104            _ => {
105                let name_length = buf.get_mc_u16()? as usize;
106                if buf.len() < name_length {
107                    return mcproto::invalid_data();
108                }
109                let bytes = buf.split_to(name_length).to_vec();
110                let name = &*String::from_utf8_lossy(&bytes);
111                Some(name.to_string())
112            }
113        };
114        Ok((name, Nbt::parse_list(buf, nbt_type)?))
115    }
116
117    fn parse_list(buf: &mut BytesMut, nbt_type: NbtType) -> Result<Nbt> {
118        use NbtType::*;
119        Ok(match nbt_type {
120            End => Nbt::End,
121            Byte => Nbt::Byte(buf.get_mc_i8()?),
122            Short => Nbt::Short(buf.get_mc_i16()?),
123            Int => Nbt::Int(buf.get_mc_i32()?),
124            Long => Nbt::Long(buf.get_mc_i64()?),
125            Float => Nbt::Float(buf.get_mc_f32()?),
126            Double => Nbt::Double(buf.get_mc_f64()?),
127            ByteArray => {
128                let length = buf.get_mc_i32()? as usize;
129                if length > buf.len() {
130                    return mcproto::invalid_data::<Nbt>();
131                }
132                Nbt::ByteArray(buf.split_to(length).to_vec())
133            }
134            Str => {
135                let string_length = buf.get_mc_u16()? as usize;
136                if buf.len() < string_length {
137                    return mcproto::invalid_data();
138                }
139                let bytes = buf.split_to(string_length).to_vec();
140                let string = &*String::from_utf8_lossy(&bytes);
141                Nbt::Str(string.to_string())
142            }
143            List => {
144                let item_type: NbtType = buf.get_mc_enum_u8()?;
145                let count = buf.get_mc_i32()? as usize;
146                let mut children: Vec<Nbt> = vec![];
147                for _ in 0..count {
148                    let item = Nbt::parse_list(buf, item_type)?;
149                    match item {
150                        Nbt::End => break, // should never happen
151                        _ => (),
152                    }
153                    children.push(item);
154                }
155                Nbt::List {
156                    item_type,
157                    children,
158                }
159            }
160            Compound => {
161                let mut children = LinkedHashMap::new();
162                loop {
163                    let (name, item) = Nbt::parse_item(buf)?;
164                    match item {
165                        Nbt::End => break,
166                        _ => (),
167                    }
168                    children.insert(
169                        name.expect("name not found and should have been")
170                            .to_string(),
171                        item,
172                    );
173                }
174                Nbt::Compound { children }
175            }
176            IntArray => {
177                let length = buf.get_mc_i32()? as usize;
178                if length * 4 > buf.len() {
179                    return mcproto::invalid_data();
180                }
181                Nbt::IntArray(buf.read_primitive_slice(length)?)
182            }
183            LongArray => {
184                let length = buf.get_mc_i32()? as usize;
185                if length * 8 > buf.len() {
186                    return mcproto::invalid_data();
187                }
188                Nbt::LongArray(buf.read_primitive_slice(length)?)
189            }
190        })
191    }
192
193    pub fn serialize(self, buf: &mut BytesMut) {
194        match self {
195            Nbt::Compound { .. } => (),
196            Nbt::End => {
197                buf.set_mc_u8(self.nbt_type() as u8);
198                return;
199            },
200            _ => panic!("attempted to serialize non-compound!"),
201        }
202        self.serialize_list(buf);
203    }
204
205    fn serialize_item(self, name: Option<&str>, buf: &mut BytesMut) {
206        buf.set_mc_u8(self.nbt_type() as u8);
207        match name {
208            Some(name) => {
209                let bytes = name.as_bytes();
210                buf.set_mc_i16(bytes.len() as i16);
211                buf.extend(bytes);
212            }
213            _ => (),
214        }
215        self.serialize_list(buf);
216    }
217
218    fn serialize_list(self, buf: &mut BytesMut) {
219        use Nbt::*;
220        match self {
221            End => {}
222            Byte(value) => {
223                buf.set_mc_i8(value);
224            }
225            Short(value) => {
226                buf.set_mc_i16(value);
227            }
228            Int(value) => {
229                buf.set_mc_i32(value);
230            }
231            Long(value) => {
232                buf.set_mc_i64(value);
233            }
234            Float(value) => {
235                buf.set_mc_f32(value);
236            }
237            Double(value) => {
238                buf.set_mc_f64(value);
239            }
240            ByteArray(value) => {
241                buf.set_mc_i32(value.len() as i32);
242                buf.extend(value);
243            }
244            Str(value) => {
245                let bytes = value.as_bytes();
246                buf.set_mc_i16(bytes.len() as i16);
247                buf.extend(bytes);
248            }
249            List {
250                item_type,
251                children,
252            } => {
253                buf.set_mc_u8(item_type as u8);
254                buf.set_mc_i32(children.len() as i32);
255                for child in children {
256                    child.serialize_list(buf);
257                }
258            }
259            Compound { children } => {
260                for (name, item) in children {
261                    item.serialize_item(Some(&*name), buf);
262                }
263                End.serialize_item(None, buf);
264            }
265            IntArray(value) => {
266                buf.set_mc_i32(value.len() as i32);
267                buf.write_primitive_slice(&value[..]);
268            }
269            LongArray(value) => {
270                buf.set_mc_i32(value.len() as i32);
271                buf.write_primitive_slice(&value[..]);
272            }
273        }
274    }
275
276    pub fn child(&self, key: &str) -> Result<&Nbt> {
277        match self {
278            Nbt::Compound { children } => {
279                children.get(key).map(|item| Ok(item)).unwrap_or(Err(basin_err!("could not find key {} in nbt", key)))
280            },
281            _ => Err(basin_err!("could not get key {} from non-compound nbt tag", key)),
282        }
283    }
284
285    pub fn unwrap_i8(&self) -> Result<i8> {
286        match self {
287            Nbt::Byte(value) => Ok(*value),
288            _ => Err(basin_err!("invalid nbt tag type, expected Byte got {:?}", self.nbt_type())),
289        }
290    }
291
292    pub fn unwrap_i16(&self) -> Result<i16> {
293        match self {
294            Nbt::Short(value) => Ok(*value),
295            _ => Err(basin_err!("invalid nbt tag type, expected Short got {:?}", self.nbt_type())),
296        }
297    }
298
299    pub fn unwrap_i32(&self) -> Result<i32> {
300        match self {
301            Nbt::Int(value) => Ok(*value),
302            _ => Err(basin_err!("invalid nbt tag type, expected Int got {:?}", self.nbt_type())),
303        }
304    }
305
306    pub fn unwrap_i64(&self) -> Result<i64> {
307        match self {
308            Nbt::Long(value) => Ok(*value),
309            _ => Err(basin_err!("invalid nbt tag type, expected Long got {:?}", self.nbt_type())),
310        }
311    }
312
313    pub fn unwrap_f32(&self) -> Result<f32> {
314        match self {
315            Nbt::Float(value) => Ok(*value),
316            _ => Err(basin_err!("invalid nbt tag type, expected Float got {:?}", self.nbt_type())),
317        }
318    }
319
320    pub fn unwrap_f64(&self) -> Result<f64> {
321        match self {
322            Nbt::Double(value) => Ok(*value),
323            _ => Err(basin_err!("invalid nbt tag type, expected Double got {:?}", self.nbt_type())),
324        }
325    }
326
327    pub fn unwrap_bytes(&self) -> Result<&[u8]> {
328        match self {
329            Nbt::ByteArray(value) => Ok(value),
330            _ => Err(basin_err!("invalid nbt tag type, expected ByteArray got {:?}", self.nbt_type())),
331        }
332    }
333
334    pub fn unwrap_str(&self) -> Result<&str> {
335        match self {
336            Nbt::Str(value) => Ok(value),
337            _ => Err(basin_err!("invalid nbt tag type, expected Str got {:?}", self.nbt_type())),
338        }
339    }
340
341    pub fn unwrap_compound(&self) -> Result<&LinkedHashMap<String, Nbt>> {
342        match self {
343            Nbt::Compound { children } => Ok(children),
344            _ => Err(basin_err!("invalid nbt tag type, expected Compound got {:?}", self.nbt_type())),
345        }
346    }
347
348    pub fn unwrap_list(&self) -> Result<&[Nbt]> {
349        match self {
350            Nbt::List { children, .. } => Ok(children),
351            _ => Err(basin_err!("invalid nbt tag type, expected List got {:?}", self.nbt_type())),
352        }
353    }
354
355    pub fn unwrap_ints(&self) -> Result<&[i32]> {
356        match self {
357            Nbt::IntArray(value) => Ok(value),
358            _ => Err(basin_err!("invalid nbt tag type, expected IntArray got {:?}", self.nbt_type())),
359        }
360    }
361
362    pub fn unwrap_longs(&self) -> Result<&[i64]> {
363        match self {
364            Nbt::LongArray(value) => Ok(value),
365            _ => Err(basin_err!("invalid nbt tag type, expected LongArray got {:?}", self.nbt_type())),
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    fn cycle(nbt: Nbt) -> Result<()> {
375        let mut buf = BytesMut::new();
376        nbt.clone().serialize(&mut buf);
377        // let original_buf = buf.clone();
378        let decoded = Nbt::parse(&mut buf)?;
379        assert_eq!(nbt, decoded);
380        Ok(())
381    }
382
383    #[test]
384    fn test_simple_compound() -> Result<()> {
385        cycle(Nbt::make_compound(vec![
386            ("byte".to_string(), Nbt::Byte(120)),
387            ("short".to_string(), Nbt::Short(12000)),
388            ("int".to_string(), Nbt::Int(43563456)),
389            ("long".to_string(), Nbt::Long(435643563456)),
390            ("float".to_string(), Nbt::Float(345345.345345)),
391            ("double".to_string(), Nbt::Double(34532.53456)),
392            (
393                "byte_array".to_string(),
394                Nbt::ByteArray(vec![0x0a, 0x0b, 0x0c]),
395            ),
396            ("str".to_string(), Nbt::Str("a string".to_string())),
397            (
398                "int_array".to_string(),
399                Nbt::IntArray(vec![2345, 23453245, 3452345, 324523]),
400            ),
401            (
402                "long_array".to_string(),
403                Nbt::LongArray(vec![
404                    0xffffffff,
405                    345643564356,
406                    43563456,
407                    456456456,
408                    456456456345,
409                    56345634563456,
410                ]),
411            ),
412        ]))
413    }
414
415    #[test]
416    fn test_nested_compound() -> Result<()> {
417        cycle(Nbt::make_compound(vec![
418            (
419                "nest1".to_string(),
420                Nbt::make_compound(vec![("int".to_string(), Nbt::Int(43563456))]),
421            ),
422            ("tail_int".to_string(), Nbt::Int(43563456)),
423        ]))
424    }
425
426    #[test]
427    fn test_nested_list_compound() -> Result<()> {
428        cycle(Nbt::make_compound(vec![(
429            "list1".to_string(),
430            Nbt::List {
431                item_type: NbtType::Compound,
432                children: vec![
433                    Nbt::make_compound(vec![("int1".to_string(), Nbt::Int(43563456))]),
434                    Nbt::make_compound(vec![("int2".to_string(), Nbt::Int(43563456))]),
435                ],
436            },
437        )]))
438    }
439}