1use kevy_resp::Argv;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct HandshakeReq {
26 pub from_offset: u64,
28 pub replica_id: String,
31}
32
33#[derive(Debug, PartialEq, Eq)]
35pub enum HandshakeError {
36 BadCommand,
38 WrongArity(usize),
40 BadFromKeyword,
42 BadOffset,
44 BadIdKeyword,
46 BadReplicaId,
48}
49
50impl std::fmt::Display for HandshakeError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 Self::BadCommand => write!(f, "expected REPLICATE command"),
54 Self::WrongArity(n) => write!(f, "REPLICATE expects 5 args, got {n}"),
55 Self::BadFromKeyword => write!(f, "expected 'FROM' keyword"),
56 Self::BadOffset => write!(f, "from-offset must be an unsigned decimal"),
57 Self::BadIdKeyword => write!(f, "expected 'ID' keyword"),
58 Self::BadReplicaId => write!(f, "replica id must be non-empty UTF-8"),
59 }
60 }
61}
62
63impl std::error::Error for HandshakeError {}
64
65pub fn parse_replicate_from(argv: &Argv) -> Result<HandshakeReq, HandshakeError> {
69 if argv.len() != 5 {
70 return Err(HandshakeError::WrongArity(argv.len()));
71 }
72 if !eq_ascii_ci(argv.get(0).unwrap(), b"REPLICATE") {
73 return Err(HandshakeError::BadCommand);
74 }
75 if !eq_ascii_ci(argv.get(1).unwrap(), b"FROM") {
76 return Err(HandshakeError::BadFromKeyword);
77 }
78 let from_offset =
79 parse_decimal_u64(argv.get(2).unwrap()).ok_or(HandshakeError::BadOffset)?;
80 if !eq_ascii_ci(argv.get(3).unwrap(), b"ID") {
81 return Err(HandshakeError::BadIdKeyword);
82 }
83 let id_bytes = argv.get(4).unwrap();
84 if id_bytes.is_empty() {
85 return Err(HandshakeError::BadReplicaId);
86 }
87 let replica_id =
88 std::str::from_utf8(id_bytes).map_err(|_| HandshakeError::BadReplicaId)?.to_string();
89 Ok(HandshakeReq {
90 from_offset,
91 replica_id,
92 })
93}
94
95pub fn encode_ack(current_offset: u64) -> Vec<u8> {
97 let mut out = Vec::with_capacity(8 + 20 + 2);
98 out.extend_from_slice(b"+ACK ");
99 push_u64(&mut out, current_offset);
100 out.extend_from_slice(b"\r\n");
101 out
102}
103
104fn eq_ascii_ci(a: &[u8], b: &[u8]) -> bool {
105 a.len() == b.len()
106 && a.iter()
107 .zip(b)
108 .all(|(x, y)| x.eq_ignore_ascii_case(y))
109}
110
111fn parse_decimal_u64(bytes: &[u8]) -> Option<u64> {
112 if bytes.is_empty() {
113 return None;
114 }
115 let mut n: u64 = 0;
116 for &b in bytes {
117 if !b.is_ascii_digit() {
118 return None;
119 }
120 n = n.checked_mul(10)?.checked_add(u64::from(b - b'0'))?;
121 }
122 Some(n)
123}
124
125fn push_u64(out: &mut Vec<u8>, n: u64) {
126 if n == 0 {
127 out.push(b'0');
128 return;
129 }
130 let mut tmp = [0u8; 20];
131 let mut i = tmp.len();
132 let mut v = n;
133 while v != 0 {
134 i -= 1;
135 tmp[i] = b'0' + (v % 10) as u8;
136 v /= 10;
137 }
138 out.extend_from_slice(&tmp[i..]);
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 fn argv(args: &[&[u8]]) -> Argv {
146 let mut a = Argv::default();
147 for arg in args {
148 a.push(arg);
149 }
150 a
151 }
152
153 #[test]
154 fn parses_fresh_replica_from_zero() {
155 let req = parse_replicate_from(&argv(&[
156 b"REPLICATE",
157 b"FROM",
158 b"0",
159 b"ID",
160 b"replica-a",
161 ]))
162 .unwrap();
163 assert_eq!(req.from_offset, 0);
164 assert_eq!(req.replica_id, "replica-a");
165 }
166
167 #[test]
168 fn parses_reconnect_with_large_offset() {
169 let req = parse_replicate_from(&argv(&[
170 b"REPLICATE",
171 b"FROM",
172 b"4294967296", b"ID",
174 b"node-7",
175 ]))
176 .unwrap();
177 assert_eq!(req.from_offset, 4_294_967_296);
178 assert_eq!(req.replica_id, "node-7");
179 }
180
181 #[test]
182 fn keywords_are_case_insensitive() {
183 let req = parse_replicate_from(&argv(&[
184 b"replicate", b"from", b"1", b"id", b"x",
185 ]))
186 .unwrap();
187 assert_eq!(req.from_offset, 1);
188 assert_eq!(req.replica_id, "x");
189 }
190
191 #[test]
192 fn wrong_arity_rejected_with_actual_count() {
193 let err = parse_replicate_from(&argv(&[b"REPLICATE", b"FROM", b"0"])).unwrap_err();
194 assert_eq!(err, HandshakeError::WrongArity(3));
195 }
196
197 #[test]
198 fn wrong_command_rejected() {
199 let err =
200 parse_replicate_from(&argv(&[b"SUBSCRIBE", b"FROM", b"0", b"ID", b"a"])).unwrap_err();
201 assert_eq!(err, HandshakeError::BadCommand);
202 }
203
204 #[test]
205 fn wrong_from_keyword_rejected() {
206 let err =
207 parse_replicate_from(&argv(&[b"REPLICATE", b"AT", b"0", b"ID", b"a"])).unwrap_err();
208 assert_eq!(err, HandshakeError::BadFromKeyword);
209 }
210
211 #[test]
212 fn wrong_id_keyword_rejected() {
213 let err = parse_replicate_from(&argv(&[
214 b"REPLICATE",
215 b"FROM",
216 b"0",
217 b"NAME",
218 b"a",
219 ]))
220 .unwrap_err();
221 assert_eq!(err, HandshakeError::BadIdKeyword);
222 }
223
224 #[test]
225 fn non_decimal_offset_rejected() {
226 let err =
227 parse_replicate_from(&argv(&[b"REPLICATE", b"FROM", b"NaN", b"ID", b"a"]))
228 .unwrap_err();
229 assert_eq!(err, HandshakeError::BadOffset);
230 }
231
232 #[test]
233 fn negative_offset_rejected_as_bad_offset() {
234 let err =
235 parse_replicate_from(&argv(&[b"REPLICATE", b"FROM", b"-1", b"ID", b"a"]))
236 .unwrap_err();
237 assert_eq!(err, HandshakeError::BadOffset);
238 }
239
240 #[test]
241 fn empty_replica_id_rejected() {
242 let err =
243 parse_replicate_from(&argv(&[b"REPLICATE", b"FROM", b"0", b"ID", b""]))
244 .unwrap_err();
245 assert_eq!(err, HandshakeError::BadReplicaId);
246 }
247
248 #[test]
249 fn non_utf8_replica_id_rejected() {
250 let err = parse_replicate_from(&argv(&[
251 b"REPLICATE",
252 b"FROM",
253 b"0",
254 b"ID",
255 &[0xFF, 0xFE, 0xFD], ]))
257 .unwrap_err();
258 assert_eq!(err, HandshakeError::BadReplicaId);
259 }
260
261 #[test]
262 fn ack_format_for_zero() {
263 assert_eq!(encode_ack(0), b"+ACK 0\r\n");
264 }
265
266 #[test]
267 fn ack_format_for_large_offset() {
268 assert_eq!(encode_ack(987_654_321), b"+ACK 987654321\r\n");
269 }
270}