1use std::io::{ErrorKind, Read, Write};
8
9use anyhow::{Result, bail};
10
11use crate::confirm::SearchOptions;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum Request {
15 Search {
17 opts: SearchOptions,
18 pattern: String,
19 },
20 Find {
23 needle: String,
24 after: Option<String>,
25 limit: u32,
26 },
27 Status,
29 Watch,
32 Shutdown,
34 CursorStore { blob: Vec<u8> },
36 CursorTake { token: String },
39}
40
41pub(crate) fn pack_opts(o: &SearchOptions) -> u8 {
42 (o.case_insensitive as u8)
43 | ((o.multi_line as u8) << 1)
44 | ((o.dot_matches_new_line as u8) << 2)
45 | ((o.word as u8) << 3)
46 | ((o.fixed_strings as u8) << 4)
47}
48
49pub(crate) fn unpack_opts(b: u8, before: u32, after: u32) -> SearchOptions {
50 SearchOptions {
51 case_insensitive: b & 1 != 0,
52 multi_line: b & 2 != 0,
53 dot_matches_new_line: b & 4 != 0,
54 word: b & 8 != 0,
55 fixed_strings: b & 16 != 0,
56 before_context: before as usize,
57 after_context: after as usize,
58 }
59}
60
61pub fn write_request(w: &mut impl Write, req: &Request) -> Result<()> {
62 let mut body = Vec::new();
63 match req {
64 Request::Search { opts, pattern } => {
65 body.push(b'S');
66 body.push(pack_opts(opts));
67 body.extend_from_slice(&(opts.before_context as u32).to_le_bytes());
68 body.extend_from_slice(&(opts.after_context as u32).to_le_bytes());
69 put_bytes(&mut body, pattern.as_bytes());
70 }
71 Request::Find {
72 needle,
73 after,
74 limit,
75 } => {
76 body.push(b'F');
77 body.extend_from_slice(&limit.to_le_bytes());
78 put_bytes(&mut body, needle.as_bytes());
79 put_bytes(&mut body, after.as_deref().unwrap_or("").as_bytes());
80 }
81 Request::Status => body.push(b'T'),
82 Request::Watch => body.push(b'W'),
83 Request::Shutdown => body.push(b'Q'),
84 Request::CursorStore { blob } => {
85 body.push(b'P');
86 put_bytes(&mut body, blob);
87 }
88 Request::CursorTake { token } => {
89 body.push(b'G');
90 put_bytes(&mut body, token.as_bytes());
91 }
92 }
93 write_frame(w, &body)
94}
95
96pub fn read_request(r: &mut impl Read) -> Result<Request> {
97 let body = read_frame(r)?;
98 let mut cur = &body[..];
99 let tag = take_u8(&mut cur)?;
100 Ok(match tag {
101 b'S' => {
102 let flags = take_u8(&mut cur)?;
103 let before = take_u32(&mut cur)?;
104 let after = take_u32(&mut cur)?;
105 let opts = unpack_opts(flags, before, after);
106 let pattern = String::from_utf8(take_bytes(&mut cur)?)?;
107 Request::Search { opts, pattern }
108 }
109 b'F' => {
110 let limit = take_u32(&mut cur)?;
111 let needle = String::from_utf8(take_bytes(&mut cur)?)?;
112 let after = String::from_utf8(take_bytes(&mut cur)?)?;
113 let after = (!after.is_empty()).then_some(after);
114 Request::Find {
115 needle,
116 after,
117 limit,
118 }
119 }
120 b'T' => Request::Status,
121 b'W' => Request::Watch,
122 b'Q' => Request::Shutdown,
123 b'P' => Request::CursorStore {
124 blob: take_bytes(&mut cur)?,
125 },
126 b'G' => Request::CursorTake {
127 token: String::from_utf8(take_bytes(&mut cur)?)?,
128 },
129 other => bail!("unknown request tag {other}"),
130 })
131}
132
133pub const FIND_HEADER_SENTINEL: u8 = 0x01;
139
140pub struct FindHeader {
141 pub total: usize,
142 pub start: usize,
143 pub returned: usize,
144 pub next_after: Option<String>,
145}
146
147pub fn format_find_header(
148 total: usize,
149 start: usize,
150 returned: usize,
151 next_after: Option<&str>,
152) -> String {
153 format!(
154 "{}{total}\t{start}\t{returned}\t{}\n",
155 FIND_HEADER_SENTINEL as char,
156 next_after.unwrap_or("")
157 )
158}
159
160pub fn parse_find_header(blob: &[u8]) -> (Option<FindHeader>, &[u8]) {
162 if blob.first() != Some(&FIND_HEADER_SENTINEL) {
163 return (None, blob);
164 }
165 let Some(nl) = blob.iter().position(|&b| b == b'\n') else {
166 return (None, blob);
167 };
168 let line = String::from_utf8_lossy(&blob[1..nl]);
169 let mut parts = line.splitn(4, '\t');
172 let total = parts.next().and_then(|s| s.parse().ok());
173 let start = parts.next().and_then(|s| s.parse().ok());
174 let returned = parts.next().and_then(|s| s.parse().ok());
175 match (total, start, returned) {
176 (Some(total), Some(start), Some(returned)) => {
177 let next_after = parts.next().filter(|s| !s.is_empty()).map(str::to_string);
178 (
179 Some(FindHeader {
180 total,
181 start,
182 returned,
183 next_after,
184 }),
185 &blob[nl + 1..],
186 )
187 }
188 _ => (None, blob),
189 }
190}
191
192pub fn write_data(w: &mut impl Write, data: &[u8]) -> Result<()> {
196 if !data.is_empty() {
197 write_frame(w, data)?;
198 }
199 Ok(())
200}
201
202pub fn end_stream(w: &mut impl Write) -> Result<()> {
203 w.write_all(&0u32.to_le_bytes())?;
204 w.flush()?;
205 Ok(())
206}
207
208pub fn read_stream(r: &mut impl Read, sink: &mut impl Write) -> Result<usize> {
210 let mut total = 0;
211 loop {
212 let n = read_len(r)?;
213 if n == 0 {
214 return Ok(total);
215 }
216 let mut body = vec![0u8; n];
217 r.read_exact(&mut body)?;
218 sink.write_all(&body)?;
219 total += n;
220 }
221}
222
223pub fn read_stream_to_vec(r: &mut impl Read) -> Result<Vec<u8>> {
225 let mut v = Vec::new();
226 read_stream(r, &mut v)?;
227 Ok(v)
228}
229
230pub fn read_watch_frame(r: &mut impl Read) -> Result<Option<Vec<u8>>> {
233 let mut len = [0u8; 4];
234 match r.read_exact(&mut len) {
235 Ok(()) => {}
236 Err(e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
237 Err(e) => return Err(e.into()),
238 }
239 let n = u32::from_le_bytes(len) as usize;
240 if n == 0 {
241 return Ok(None);
242 }
243 if n > MAX_FRAME {
244 bail!("frame length {n} exceeds maximum {MAX_FRAME}");
245 }
246 let mut body = vec![0u8; n];
247 r.read_exact(&mut body)?;
248 Ok(Some(body))
249}
250
251const MAX_FRAME: usize = 512 * 1024 * 1024;
254
255fn read_len(r: &mut impl Read) -> Result<usize> {
256 let mut len = [0u8; 4];
257 r.read_exact(&mut len)?;
258 let n = u32::from_le_bytes(len) as usize;
259 if n > MAX_FRAME {
260 bail!("frame length {n} exceeds maximum {MAX_FRAME}");
261 }
262 Ok(n)
263}
264
265fn write_frame(w: &mut impl Write, body: &[u8]) -> Result<()> {
266 w.write_all(&(body.len() as u32).to_le_bytes())?;
267 w.write_all(body)?;
268 w.flush()?;
269 Ok(())
270}
271
272fn read_frame(r: &mut impl Read) -> Result<Vec<u8>> {
273 let mut body = vec![0u8; read_len(r)?];
274 r.read_exact(&mut body)?;
275 Ok(body)
276}
277
278fn put_bytes(buf: &mut Vec<u8>, b: &[u8]) {
279 buf.extend_from_slice(&(b.len() as u32).to_le_bytes());
280 buf.extend_from_slice(b);
281}
282
283fn take_u8(cur: &mut &[u8]) -> Result<u8> {
284 let (&b, rest) = cur
285 .split_first()
286 .ok_or_else(|| anyhow::anyhow!("short frame"))?;
287 *cur = rest;
288 Ok(b)
289}
290
291fn take_u32(cur: &mut &[u8]) -> Result<u32> {
292 if cur.len() < 4 {
293 bail!("short frame");
294 }
295 let (head, rest) = cur.split_at(4);
296 *cur = rest;
297 Ok(u32::from_le_bytes(head.try_into().unwrap()))
298}
299
300fn take_bytes(cur: &mut &[u8]) -> Result<Vec<u8>> {
301 let n = take_u32(cur)? as usize;
302 if cur.len() < n {
303 bail!("short frame");
304 }
305 let (head, rest) = cur.split_at(n);
306 *cur = rest;
307 Ok(head.to_vec())
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 fn roundtrip(req: Request) {
315 let mut buf = Vec::new();
316 write_request(&mut buf, &req).unwrap();
317 let got = read_request(&mut &buf[..]).unwrap();
318 assert_eq!(req, got);
319 }
320
321 #[test]
322 fn request_roundtrips() {
323 roundtrip(Request::Search {
324 opts: SearchOptions {
325 case_insensitive: true,
326 ..Default::default()
327 },
328 pattern: "Foo|Bar".to_string(),
329 });
330 roundtrip(Request::Find {
331 needle: "config".into(),
332 after: None,
333 limit: 50,
334 });
335 roundtrip(Request::Find {
336 needle: "config".into(),
337 after: Some("src/config.rs".into()),
338 limit: 50,
339 });
340 roundtrip(Request::Status);
341 roundtrip(Request::Watch);
342 roundtrip(Request::Shutdown);
343 roundtrip(Request::CursorStore {
344 blob: vec![0, 1, 2, 255],
345 });
346 roundtrip(Request::CursorTake {
347 token: "0000abcd5".to_string(),
348 });
349 }
350
351 #[test]
352 fn find_header_roundtrips_and_tolerates_headerless() {
353 let blob = format!(
354 "{}src/a.rs\nsrc/b.rs\n",
355 format_find_header(1342, 200, 2, Some("src/b.rs"))
356 );
357 let (header, rest) = parse_find_header(blob.as_bytes());
358 let header = header.unwrap();
359 assert_eq!(header.total, 1342);
360 assert_eq!(header.start, 200);
361 assert_eq!(header.returned, 2);
362 assert_eq!(header.next_after.as_deref(), Some("src/b.rs"));
363 assert_eq!(rest, b"src/a.rs\nsrc/b.rs\n");
364
365 let (none, rest) = parse_find_header(b"src/a.rs\n");
367 assert!(none.is_none());
368 assert_eq!(rest, b"src/a.rs\n");
369 }
370
371 #[test]
372 fn response_stream_roundtrips() {
373 let mut buf = Vec::new();
374 write_data(&mut buf, b"path:1:hello\n").unwrap();
375 write_data(&mut buf, b"").unwrap(); write_data(&mut buf, b"path:2:world\n").unwrap();
377 end_stream(&mut buf).unwrap();
378 assert_eq!(
379 read_stream_to_vec(&mut &buf[..]).unwrap(),
380 b"path:1:hello\npath:2:world\n"
381 );
382 }
383}