messagepack_async/
tokio.rs

1use crate::{Ext, Float, Int, Value};
2
3use std::io::Result;
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5
6pub trait ReadFrom: Sized {
7    fn read_from<T: AsyncRead + Unpin>(
8        source: &mut T,
9    ) -> impl std::future::Future<Output = Result<Self>>;
10}
11
12pub trait WriteTo {
13    fn write_to<T: AsyncWrite + Unpin>(
14        &self,
15        sink: &mut T,
16    ) -> impl std::future::Future<Output = Result<()>>;
17}
18
19impl WriteTo for bool {
20    async fn write_to<T: AsyncWrite + Unpin>(&self, sink: &mut T) -> Result<()> {
21        let byte: u8 = match self {
22            false => 0xc2,
23            true => 0xc3,
24        };
25        sink.write_u8(byte).await
26    }
27}
28
29impl WriteTo for Int {
30    async fn write_to<T: AsyncWrite + Unpin>(&self, sink: &mut T) -> Result<()> {
31        match self {
32            Int::U8(i) => {
33                if *i >= 0b10000000 {
34                    sink.write_u8(0xcc).await?;
35                }
36                sink.write_u8(*i).await
37            }
38            Int::U16(i) => {
39                sink.write_u8(0xcd).await?;
40                sink.write_u16(*i).await
41            }
42            Int::U32(i) => {
43                sink.write_u8(0xce).await?;
44                sink.write_u32(*i).await
45            }
46            Int::U64(i) => {
47                sink.write_u8(0xcf).await?;
48                sink.write_u64(*i).await
49            }
50            Int::I8(i) => {
51                if *i > 0 || *i < -0b00100000 {
52                    sink.write_u8(0xd0).await?;
53                    sink.write_i8(*i).await
54                } else {
55                    let i: u8 = -*i as u8;
56                    sink.write_u8(0b11100000 | i).await
57                }
58            }
59            Int::I16(i) => {
60                sink.write_u8(0xd1).await?;
61                sink.write_i16(*i).await
62            }
63            Int::I32(i) => {
64                sink.write_u8(0xd2).await?;
65                sink.write_i32(*i).await
66            }
67            Int::I64(i) => {
68                sink.write_u8(0xd3).await?;
69                sink.write_i64(*i).await
70            }
71        }
72    }
73}
74
75impl WriteTo for Float {
76    async fn write_to<T: AsyncWrite + Unpin>(&self, sink: &mut T) -> Result<()> {
77        match self {
78            Float::F32(f) => {
79                sink.write_u8(0xca).await?;
80                sink.write_f32(*f).await?;
81            }
82            Float::F64(f) => {
83                sink.write_u8(0xcb).await?;
84                sink.write_f64(*f).await?;
85            }
86        }
87        Ok(())
88    }
89}
90
91impl WriteTo for Ext {
92    async fn write_to<T: AsyncWrite + Unpin>(&self, sink: &mut T) -> Result<()> {
93        let len = self.data.len();
94        match len {
95            1 => sink.write_u8(0xd4).await?,
96            2 => sink.write_u8(0xd5).await?,
97            4 => sink.write_u8(0xd6).await?,
98            8 => sink.write_u8(0xd7).await?,
99            16 => sink.write_u8(0xd8).await?,
100            i if i <= u8::MAX as usize => {
101                sink.write_u8(0xc7).await?;
102                sink.write_u8(i.try_into().unwrap()).await?;
103            }
104            i if i <= u16::MAX as usize => {
105                sink.write_u8(0xc8).await?;
106                sink.write_u16(i.try_into().unwrap()).await?;
107            }
108            i if i <= u32::MAX as usize => {
109                sink.write_u8(0xc9).await?;
110                sink.write_u32(i.try_into().unwrap()).await?;
111            }
112            _ => panic!(),
113        }
114        sink.write_u8(self.r#type).await?;
115        sink.write_all(self.data.as_slice()).await
116    }
117}
118
119impl ReadFrom for Value {
120    async fn read_from<T: AsyncRead + Unpin>(source: &mut T) -> Result<Self> {
121        let leading = source.read_u8().await?;
122        Ok(match leading {
123            // Nil
124            0xc0 => Value::Nil,
125            // Bools
126            0xc2 => Value::Bool(false),
127            0xc3 => Value::Bool(true),
128            // Ints
129            // Positive fixint
130            i if i < 0b10000000 => Value::Int(Int::U8(i)),
131            // Negative fixint
132            i if i >= 0b11100000 => Value::Int(Int::I8(-((i - 0b11100000) as i8))),
133            // Unsigned
134            0xcc => Value::Int(Int::U8(source.read_u8().await?)),
135            0xcd => Value::Int(Int::U16(source.read_u16().await?)),
136            0xce => Value::Int(Int::U32(source.read_u32().await?)),
137            0xcf => Value::Int(Int::U64(source.read_u64().await?)),
138            // Signed
139            0xd0 => Value::Int(Int::I8(source.read_i8().await?)),
140            0xd1 => Value::Int(Int::I16(source.read_i16().await?)),
141            0xd2 => Value::Int(Int::I32(source.read_i32().await?)),
142            0xd3 => Value::Int(Int::I64(source.read_i64().await?)),
143            // Floats
144            0xca => Value::Float(Float::F32(source.read_f32().await?)),
145            0xcb => Value::Float(Float::F64(source.read_f64().await?)),
146            // Strings
147            i if i & 0b11100000 == 0b10100000 => {
148                let len: usize = (i & 0b00011111).into();
149                let mut bytes: Vec<u8> = vec![0u8; len];
150                source.read_exact(&mut bytes).await?;
151                Value::Str(String::from_utf8(bytes).unwrap())
152            }
153            i @ (0xd9..=0xdb) => {
154                let len: usize = match i {
155                    0xd9 => source.read_u8().await?.into(),
156                    0xda => source.read_u16().await?.into(),
157                    0xdb => source.read_u32().await?.try_into().unwrap(),
158                    _ => panic!(),
159                };
160                let mut bytes: Vec<u8> = vec![0u8; len];
161                source.read_exact(&mut bytes).await?;
162                Value::Str(String::from_utf8(bytes).unwrap())
163            }
164            // Bin
165            i @ (0xc4..=0xc6) => {
166                let len: usize = match i {
167                    0xc4 => source.read_u8().await?.into(),
168                    0xc5 => source.read_u16().await?.into(),
169                    0xc6 => source.read_u32().await?.try_into().unwrap(),
170                    _ => panic!(),
171                };
172                let mut bytes: Vec<u8> = vec![0u8; len];
173                source.read_exact(&mut bytes).await?;
174                Value::Bin(bytes)
175            }
176            // Array
177            i if i & 0b11110000 == 0b10010000 => {
178                let len: usize = (i & 0b00001111).into();
179                let mut arr: Vec<Value> = vec![];
180                for _ in 0..len {
181                    arr.push(Box::pin(Value::read_from(&mut *source)).await?);
182                }
183                Value::Arr(arr)
184            }
185            i @ (0xdc | 0xdd) => {
186                let len: usize = match i {
187                    0xdc => source.read_u16().await?.into(),
188                    0xdd => source.read_u32().await?.try_into().unwrap(),
189                    _ => panic!(),
190                };
191
192                let mut arr: Vec<Value> = vec![];
193                for _ in 0..len {
194                    arr.push(Box::pin(Value::read_from(&mut *source)).await?);
195                }
196                Value::Arr(arr)
197            }
198            // Maps
199            i if (i & 0b11110000) == 0b10000000 => {
200                let len: usize = (i & 0b00001111).into();
201                let mut map = vec![];
202                for _ in 0..len {
203                    let key: Value = Box::pin(Value::read_from(&mut *source)).await?;
204                    let value: Value = Box::pin(Value::read_from(&mut *source)).await?;
205                    map.push((key, value));
206                }
207                Value::Map(map)
208            }
209            i @ (0xde | 0xdf) => {
210                let len: usize = match i {
211                    0xde => source.read_u16().await?.into(),
212                    0xdf => source.read_u32().await?.try_into().unwrap(),
213                    _ => panic!(),
214                };
215                let mut map = vec![];
216                for _ in 0..len {
217                    let key: Value = Box::pin(Value::read_from(&mut *source)).await?;
218                    let value: Value = Box::pin(Value::read_from(&mut *source)).await?;
219                    map.push((key, value));
220                }
221                Value::Map(map)
222            }
223            // Ext
224            i @ (0xd4 | 0xd5 | 0xd6 | 0xd7 | 0xd8 | 0xc7 | 0xc8 | 0xc9) => {
225                let len: usize = match i {
226                    0xd4 => 1,
227                    0xd5 => 2,
228                    0xd6 => 4,
229                    0xd7 => 8,
230                    0xd8 => 16,
231                    0xc7 => source.read_u8().await?.into(),
232                    0xc8 => source.read_u16().await?.into(),
233                    0xc9 => source.read_u32().await?.try_into().unwrap(),
234                    _ => panic!(),
235                };
236                let r#type = source.read_u8().await?;
237                let mut data = vec![0u8; len];
238                source.read_exact(&mut data).await?;
239                Value::Ext(Ext { r#type, data })
240            }
241            i => {
242                panic!("Whaaaa?: {i:x?}");
243            }
244        })
245    }
246}
247
248impl WriteTo for Value {
249    async fn write_to<T: AsyncWrite + Unpin>(&self, sink: &mut T) -> Result<()> {
250        match self {
251            Value::Nil => sink.write_u8(0xc0).await,
252            Value::Bool(b) => b.write_to(sink).await,
253            Value::Int(i) => i.write_to(sink).await,
254            Value::Float(f) => f.write_to(sink).await,
255            Value::Str(bytes) => {
256                let len = bytes.len();
257                if len < 0b100000 {
258                    let len: u8 = len.try_into().unwrap();
259                    sink.write_u8(0b10100000 | len).await?;
260                } else if let Ok(len) = TryInto::<u8>::try_into(len) {
261                    sink.write_u8(0xd9).await?;
262                    sink.write_u8(len).await?;
263                } else if let Ok(len) = TryInto::<u16>::try_into(len) {
264                    sink.write_u8(0xda).await?;
265                    sink.write_u16(len).await?;
266                } else if let Ok(len) = TryInto::<u32>::try_into(len) {
267                    sink.write_u8(0xdb).await?;
268                    sink.write_u32(len).await?;
269                } else {
270                    panic!()
271                }
272                sink.write_all(bytes.as_bytes()).await
273            }
274            Value::Bin(bytes) => {
275                let len = bytes.len();
276                if let Ok(len) = TryInto::<u8>::try_into(len) {
277                    sink.write_u8(0xc4).await?;
278                    sink.write_u8(len).await?;
279                } else if let Ok(len) = TryInto::<u16>::try_into(len) {
280                    sink.write_u8(0xc5).await?;
281                    sink.write_u16(len).await?;
282                } else if let Ok(len) = TryInto::<u32>::try_into(len) {
283                    sink.write_u8(0xc6).await?;
284                    sink.write_u32(len).await?;
285                }
286                sink.write_all(bytes).await
287            }
288            Value::Arr(arr) => {
289                let len = arr.len();
290                if len < 0b10000 {
291                    let len: u8 = len.try_into().unwrap();
292                    sink.write_u8(0b10010000 | len).await?;
293                } else if let Ok(len) = TryInto::<u16>::try_into(len) {
294                    sink.write_u8(0xdc).await?;
295                    sink.write_u16(len).await?;
296                } else if let Ok(len) = TryInto::<u32>::try_into(len) {
297                    sink.write_u8(0xdd).await?;
298                    sink.write_u32(len).await?;
299                } else {
300                    panic!();
301                }
302                for v in arr {
303                    Box::pin(v.write_to(&mut *sink)).await?;
304                }
305                Ok(())
306            }
307            Value::Map(map) => {
308                let len = map.len();
309                if len < 0b10000 {
310                    let len: u8 = len.try_into().unwrap();
311                    sink.write_u8(0b10000000 | len).await?;
312                } else if let Ok(len) = TryInto::<u16>::try_into(len) {
313                    sink.write_u8(0xde).await?;
314                    sink.write_u16(len).await?;
315                } else if let Ok(len) = TryInto::<u32>::try_into(len) {
316                    sink.write_u8(0xdf).await?;
317                    sink.write_u32(len).await?;
318                } else {
319                    panic!();
320                }
321                for (k, v) in map {
322                    Box::pin(k.write_to(&mut *sink)).await?;
323                    Box::pin(v.write_to(&mut *sink)).await?;
324                }
325                Ok(())
326            }
327            Value::Ext(e) => e.write_to(sink).await,
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    async fn assert_read_write(value: Value, bytes: &[u8]) {
336        // Write
337        let mut candidate_bytes: Vec<u8> = vec![];
338        value.write_to(&mut candidate_bytes).await.unwrap();
339        assert_eq!(bytes, candidate_bytes);
340        // Read
341        let mut cursor = bytes;
342        let candidate_value = Value::read_from(&mut cursor).await.unwrap();
343        assert_eq!(value, candidate_value)
344    }
345
346    #[tokio::test]
347    async fn nil_read_write() {
348        assert_read_write(Value::Nil, &[0xc0]).await;
349    }
350
351    #[tokio::test]
352    async fn bool_read_write() {
353        assert_read_write(Value::Bool(false), &[0xc2]).await;
354        assert_read_write(Value::Bool(true), &[0xc3]).await;
355    }
356
357    #[tokio::test]
358    async fn int_read_write() {
359        // Fixint (positive)
360        assert_read_write(Value::Int(Int::U8(0)), &[0]).await;
361        assert_read_write(Value::Int(Int::U8(5)), &[5]).await;
362        // Uint 8
363        assert_read_write(Value::Int(Int::U8(230)), &[0xcc, 230]).await;
364        // Uint 16
365        assert_read_write(Value::Int(Int::U16(256)), &[0xcd, 1, 0]).await;
366        // Uint 32
367        assert_read_write(Value::Int(Int::U32(65_536)), &[0xce, 0, 1, 0, 0]).await;
368        // Uint 64
369        assert_read_write(
370            Value::Int(Int::U64(4_294_967_296)),
371            &[0xcf, 0, 0, 0, 1, 0, 0, 0, 0],
372        )
373        .await;
374        // Fixint (negative)
375        assert_read_write(Value::Int(Int::I8(-6)), &[0b11100000 + 6]).await;
376        // Int 8
377        assert_read_write(Value::Int(Int::I8(-100)), &[0xd0, u8::MAX - 100 + 1]).await;
378        // Int 16
379        assert_read_write(
380            Value::Int(Int::I16(-100)),
381            &[0xd1, u8::MAX, u8::MAX - 100 + 1],
382        )
383        .await;
384        // Int 32
385        assert_read_write(
386            Value::Int(Int::I32(-100)),
387            &[0xd2, u8::MAX, u8::MAX, u8::MAX, u8::MAX - 100 + 1],
388        )
389        .await;
390        // Int 64
391        assert_read_write(
392            Value::Int(Int::I64(-100)),
393            &[
394                0xd3,
395                u8::MAX,
396                u8::MAX,
397                u8::MAX,
398                u8::MAX,
399                u8::MAX,
400                u8::MAX,
401                u8::MAX,
402                u8::MAX - 100 + 1,
403            ],
404        )
405        .await;
406    }
407
408    #[tokio::test]
409    async fn float_read_write() {
410        // F32
411        assert_read_write(
412            Value::Float(Float::F32(0.3)),
413            &[0xca, 0x3e, 0x99, 0x99, 0x9a],
414        )
415        .await;
416        // F64
417        assert_read_write(
418            Value::Float(Float::F64(0.3)),
419            &[0xcb, 0x3f, 0xd3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33],
420        )
421        .await;
422    }
423
424    #[tokio::test]
425    async fn str_bin_read_write() {
426        let bytes: Vec<u8> = b"Oh hello there, isn't this an interesting test".into();
427        let len: u8 = bytes.len().try_into().unwrap();
428        assert_read_write(
429            Value::Str(String::from_utf8(bytes.clone()).unwrap()),
430            &[[0xd9, len].as_slice(), bytes.as_slice()].concat(),
431        )
432        .await;
433        assert_read_write(
434            Value::Bin(bytes.clone()),
435            &[[0xc4, len].as_slice(), bytes.as_slice()].concat(),
436        )
437        .await;
438    }
439
440    #[tokio::test]
441    async fn arr_read_write() {
442        assert_read_write(
443            Value::Arr(vec![Value::Int(Int::U8(1)), Value::Int(Int::U8(2))]),
444            &[0b10010000 | 2, 1, 2],
445        )
446        .await;
447    }
448
449    #[tokio::test]
450    async fn map_read_write() {
451        assert_read_write(
452            Value::Map(vec![
453                (Value::Int(Int::U8(1)), Value::Int(Int::U8(10))),
454                (Value::Int(Int::U8(2)), Value::Int(Int::U8(20))),
455            ]),
456            &[0b10000000 | 2, 1, 10, 2, 20],
457        )
458        .await;
459        assert_read_write(Value::Map(vec![]), &[128]).await;
460    }
461
462    #[tokio::test]
463    async fn ext_read_write() {
464        assert_read_write(
465            Value::Ext(Ext {
466                r#type: 0xa8,
467                data: b"Hi there".into(),
468            }),
469            &[[0xd7, 0xa8].as_slice(), b"Hi there".as_slice()].concat(),
470        )
471        .await;
472    }
473
474    #[tokio::test]
475    async fn compound_read_write() {
476        assert_read_write(
477            Value::Arr(
478                [
479                    Value::Int(Int::U8(0)),
480                    Value::Int(Int::U32(0xf264e4f)),
481                    Value::Str("nvim_subscribe".into()),
482                    Value::Arr([Value::Str("rsnote_open_window".into())].into()),
483                ]
484                .into(),
485            ),
486            &[
487                148, 0, 206, 15, 38, 78, 79, 174, 110, 118, 105, 109, 95, 115, 117, 98, 115, 99,
488                114, 105, 98, 101, 145, 178, 114, 115, 110, 111, 116, 101, 95, 111, 112, 101, 110,
489                95, 119, 105, 110, 100, 111, 119,
490            ],
491        )
492        .await;
493        assert_read_write(
494            Value::Arr(
495                [
496                    Value::Int(Int::I8(0)),
497                    Value::Int(Int::U32(0x3f0c4a25)),
498                    Value::Str("nvim_buf_set_keymap".into()),
499                    Value::Arr(
500                        [
501                            Value::Ext(Ext {
502                                r#type: 0,
503                                data: [0xcd, 1, 0x2f].into(),
504                            }),
505                            Value::Str("n".into()),
506                            Value::Str("<ESC>".into()),
507                            Value::Str("<Cmd>w! .tasks<CR><Cmd>q<CR>".into()),
508                            Value::Map([].into()),
509                        ]
510                        .into(),
511                    ),
512                ]
513                .into(),
514            ),
515            &[
516                148, 224, 206, 63, 12, 74, 37, 179, 110, 118, 105, 109, 95, 98, 117, 102, 95, 115,
517                101, 116, 95, 107, 101, 121, 109, 97, 112, 149, 199, 3, 0, 205, 1, 47, 161, 110,
518                165, 60, 69, 83, 67, 62, 188, 60, 67, 109, 100, 62, 119, 33, 32, 46, 116, 97, 115,
519                107, 115, 60, 67, 82, 62, 60, 67, 109, 100, 62, 113, 60, 67, 82, 62, 128,
520            ],
521        )
522        .await;
523        assert_read_write(
524            Value::Arr(
525                [
526                    Value::Int(Int::U8(0)),
527                    Value::Int(Int::U32(0xc032e486)),
528                    Value::Str("nvim_exec_lua".into()),
529                    Value::Arr(
530                        [
531                            Value::Str("return vim.o.columns".into()),
532                            Value::Arr([].into()),
533                        ]
534                        .into(),
535                    ),
536                ]
537                .into(),
538            ),
539            &[
540                148, 0, 206, 192, 50, 228, 134, 173, 110, 118, 105, 109, 95, 101, 120, 101, 99, 95,
541                108, 117, 97, 146, 180, 114, 101, 116, 117, 114, 110, 32, 118, 105, 109, 46, 111,
542                46, 99, 111, 108, 117, 109, 110, 115, 144,
543            ],
544        )
545        .await;
546        assert_read_write(
547            Value::Arr(
548                [
549                    Value::Int(Int::U8(0)),
550                    Value::Int(Int::U32(0x49d10de0)),
551                    Value::Str("nvim_buf_set_lines".into()),
552                    Value::Arr(
553                        [
554                            Value::Ext(Ext {
555                                r#type: 0,
556                                data: [25].into(),
557                            }),
558                            Value::Int(Int::U8(0)),
559                            Value::Int(Int::U8(1)),
560                            Value::Bool(false),
561                            Value::Arr(
562                                [
563                                    Value::Str("Add the messagepack-rs module".into()),
564                                    Value::Str(
565                                        "Delete the
566 local code"
567                                            .into(),
568                                    ),
569                                    Value::Str("Refactor to use the messagepack library".into()),
570                                    Value::Str("".into()),
571                                ]
572                                .into(),
573                            ),
574                        ]
575                        .into(),
576                    ),
577                ]
578                .into(),
579            ),
580            &[
581                148, 0, 206, 73, 209, 13, 224, 178, 110, 118, 105, 109, 95, 98, 117, 102, 95, 115,
582                101, 116, 95, 108, 105, 110, 101, 115, 149, 212, 0, 25, 0, 1, 194, 148, 189, 65,
583                100, 100, 32, 116, 104, 101, 32, 109, 101, 115, 115, 97, 103, 101, 112, 97, 99,
584                107, 45, 114, 115, 32, 109, 111, 100, 117, 108, 101, 182, 68, 101, 108, 101, 116,
585                101, 32, 116, 104, 101, 10, 32, 108, 111, 99, 97, 108, 32, 99, 111, 100, 101, 217,
586                39, 82, 101, 102, 97, 99, 116, 111, 114, 32, 116, 111, 32, 117, 115, 101, 32, 116,
587                104, 101, 32, 109, 101, 115, 115, 97, 103, 101, 112, 97, 99, 107, 32, 108, 105, 98,
588                114, 97, 114, 121, 160,
589            ],
590        )
591        .await;
592    }
593}