kafka_wire_protocol/
strings.rs

1use std::io::{Error, ErrorKind, Read, Result, Write};
2use crate::readable_writable::{Readable, Writable};
3use crate::utils::{read_len_i16, write_len_i16};
4
5impl Readable for String {
6    fn read(#[allow(unused)] input: &mut impl Read) -> Result<Self> {
7        unimplemented!()
8    }
9
10    #[inline]
11    fn read_ext(input: &mut impl Read, field_name: &str, compact: bool) -> Result<Self> {
12        let len = read_len_i16(input, invalid_len_message(field_name), compact)?;
13        if len < 0 {
14            Err(Error::new(
15                ErrorKind::Other,
16                format!("non-nullable field {field_name} was serialized as null"),
17            ))
18        } else {
19            read_string(input, len)
20        }
21    }
22}
23
24impl Writable for String {
25    fn write(&self, #[allow(unused)] output: &mut impl Write) -> Result<()> {
26        unimplemented!()
27    }
28
29    #[inline]
30    fn write_ext(&self, output: &mut impl Write, field_name: &str, compact: bool) -> Result<()> {
31        let len = self.len();
32        if len > i16::MAX as usize {
33            Err(Error::new(ErrorKind::Other, invalid_len_message(field_name)(len as i64)))
34        } else {
35            write_len_i16(output, invalid_len_message(field_name), len as i16, compact)?;
36            output.write(self.as_bytes()).map(|_| ())
37        }
38    }
39}
40
41impl Readable for Option<String> {
42    fn read(#[allow(unused)] input: &mut impl Read) -> Result<Self> {
43        unimplemented!()
44    }
45
46    #[inline]
47    fn read_ext(input: &mut impl Read, field_name: &str, compact: bool) -> Result<Self> {
48        let len = read_len_i16(input, invalid_len_message(field_name), compact)?;
49        if len < 0 {
50            Ok(None)
51        } else {
52            read_string(input, len).map(Some)
53        }
54    }
55}
56
57impl Writable for Option<String> {
58    fn write(&self, #[allow(unused)] output: &mut impl Write) -> Result<()> {
59        unimplemented!()
60    }
61
62    #[inline]
63    fn write_ext(&self, output: &mut impl Write, field_name: &str, compact: bool) -> Result<()> {
64        if let Some(string) = self {
65            string.write_ext(output, field_name, compact)
66        } else {
67            write_len_i16(output, invalid_len_message(field_name), -1, compact)
68        }
69    }
70}
71
72#[inline]
73fn read_string(input: &mut impl Read, str_len: i16) -> Result<String> {
74    let mut buf = vec![0_u8; str_len as usize];
75    input.read_exact(&mut buf)?;
76    Ok(String::from_utf8_lossy(&buf).to_string())
77}
78
79#[inline]
80fn invalid_len_message(field_name: &str) -> impl FnOnce(i64) -> String {
81    let field_name_own = field_name.to_string();
82    move |len| {
83        format!("string field {field_name_own} had invalid length {len}")
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use std::io::{Cursor, Seek, SeekFrom};
90
91    use byteorder::{BigEndian, WriteBytesExt};
92    use proptest::prelude::*;
93    use rstest::rstest;
94    use varint_rs::VarintWriter;
95
96    use super::*;
97
98    #[rstest]
99    #[case(None, false)]
100    #[case(None, true)]
101    #[case(Some("".to_string()), false)]
102    #[case(Some("".to_string()), true)]
103    #[case(Some("aaa".to_string()), false)]
104    #[case(Some("aaa".to_string()), true)]
105    fn test_serde_nullable(#[case] original_data: Option<String>, #[case] compact: bool) {
106        check_serde_nullable(original_data, compact);
107    }
108
109    proptest! {
110        #[test]
111        fn test_prop_serde_nullable_non_compact(original_data: Option<String>) {
112            check_serde_nullable(original_data, false);
113        }
114
115        #[test]
116        fn test_prop_serde_nullable_compact(original_data: Option<String>) {
117            check_serde_nullable(original_data, true);
118        }
119    }
120
121    fn check_serde_nullable(original_data: Option<String>, compact: bool) {
122        let mut cur = Cursor::new(Vec::<u8>::new());
123        original_data.write_ext(&mut cur, "test", compact).unwrap();
124
125        cur.seek(SeekFrom::Start(0)).unwrap();
126        let read_data = Option::<String>::read_ext(&mut cur, "test", compact).unwrap();
127
128        assert_eq!(read_data, original_data);
129    }
130
131    #[rstest]
132    #[case("".to_string(), false)]
133    #[case("".to_string(), true)]
134    #[case("aaa".to_string(), false)]
135    #[case("aaa".to_string(), true)]
136    fn test_serde_non_nullable(#[case] original_data: String, #[case] compact: bool) {
137        check_serde_non_nullable(original_data, compact);
138    }
139
140    proptest! {
141        #[test]
142        fn test_prop_serde_non_nullable_non_compact(original_data: String) {
143            check_serde_non_nullable(original_data, false);
144        }
145
146        #[test]
147        fn test_prop_serde_non_nullable_compact(original_data: String) {
148            check_serde_non_nullable(original_data, true);
149        }
150    }
151
152    fn check_serde_non_nullable(original_data: String, compact: bool) {
153        let mut cur = Cursor::new(Vec::<u8>::new());
154        original_data.write_ext(&mut cur, "test", compact).unwrap();
155
156        cur.seek(SeekFrom::Start(0)).unwrap();
157        let read_data = String::read_ext(&mut cur, "test", compact).unwrap();
158
159        assert_eq!(read_data, original_data);
160    }
161
162    #[rstest]
163    #[case(false)]
164    #[case(true)]
165    fn test_write_long_string_non_nullable(#[case] compact: bool) {
166        let long_string = "a".repeat(i16::MAX as usize + 1);
167        let mut cur = Cursor::new(Vec::<u8>::new());
168        let error = long_string.write_ext(&mut cur, "test", compact)
169            .expect_err("must be error");
170        assert_eq!(error.to_string(), "string field test had invalid length 32768");
171    }
172
173    #[rstest]
174    #[case(false)]
175    #[case(true)]
176    fn test_write_long_string_nullable(#[case] compact: bool) {
177        let long_string = "a".repeat(i16::MAX as usize + 1);
178        let mut cur = Cursor::new(Vec::<u8>::new());
179        let error = Some(long_string).write_ext(&mut cur, "test", compact)
180            .expect_err("must be error");
181        assert_eq!(error.to_string(), "string field test had invalid length 32768");
182    }
183
184    #[test]
185    fn test_read_null_string_non_nullable_non_compact() {
186        let mut cur = Cursor::new(Vec::<u8>::new());
187        cur.write_i16::<BigEndian>(-1).unwrap();
188        cur.seek(SeekFrom::Start(0)).unwrap();
189        let error = String::read_ext(&mut cur, "test", false)
190            .expect_err("must be error");
191        assert_eq!(error.to_string(), "non-nullable field test was serialized as null");
192    }
193
194    #[test]
195    fn test_read_null_string_non_nullable_compact() {
196        let mut cur = Cursor::new(Vec::<u8>::new());
197        cur.write_u32_varint(0).unwrap();
198        cur.seek(SeekFrom::Start(0)).unwrap();
199        let error = String::read_ext(&mut cur, "test", true)
200            .expect_err("must be error");
201        assert_eq!(error.to_string(), "non-nullable field test was serialized as null");
202    }
203
204    #[test]
205    fn test_read_long_string_non_nullable_non_compact() {
206        // There's no point testing this, because we can't write i16 bigger than i16::MAX.
207    }
208
209    #[test]
210    fn test_read_long_string_non_nullable_compact() {
211        let mut cur = Cursor::new(Vec::<u8>::new());
212        cur.write_u32_varint(i16::MAX as u32 + 2).unwrap();
213        cur.seek(SeekFrom::Start(0)).unwrap();
214        let error = String::read_ext(&mut cur, "test", true)
215            .expect_err("must be error");
216        assert_eq!(error.to_string(), "string field test had invalid length 32768");
217    }
218
219    #[test]
220    fn test_read_long_string_nullable_non_compact() {
221        // There's no point testing this, because we can't write i16 bigger than i16::MAX.
222    }
223
224    #[test]
225    fn test_read_long_string_nullable_compact() {
226        let mut cur = Cursor::new(Vec::<u8>::new());
227        cur.write_u32_varint(i16::MAX as u32 + 2).unwrap();
228        cur.seek(SeekFrom::Start(0)).unwrap();
229        let error = Option::<String>::read_ext(&mut cur, "test", true)
230            .expect_err("must be error");
231        assert_eq!(error.to_string(), "string field test had invalid length 32768");
232    }
233}