Skip to main content

rgx/
proto.rs

1//! The daemon wire protocol: length-prefixed frames over a Unix socket.
2//!
3//! Deliberately tiny and hand-encoded (no serde): a request is one tag byte plus fields; the
4//! response is a stream of non-empty data frames terminated by a zero-length frame, so results
5//! (rendered `path:line:text`, a file list, or status text) flow without buffering huge sets.
6
7use std::io::{ErrorKind, Read, Write};
8
9use anyhow::{Result, bail};
10
11use crate::confirm::SearchOptions;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum Request {
15    /// Content search: render `path:line:text` for `pattern`.
16    Search {
17        opts: SearchOptions,
18        pattern: String,
19    },
20    /// File/dir name lookup (fd/find-style). `after` resumes a keyset page: only paths strictly
21    /// greater than it are returned (empty/None = from the start).
22    Find {
23        needle: String,
24        after: Option<String>,
25        limit: u32,
26    },
27    /// Index health summary.
28    Status,
29    /// Subscribe to live status: the daemon streams a fresh status frame on each change (and a
30    /// periodic heartbeat) until the client disconnects.
31    Watch,
32    /// Ask the daemon to exit.
33    Shutdown,
34    /// Park a pagination cursor blob; the daemon replies with a short opaque token.
35    CursorStore { blob: Vec<u8> },
36    /// Redeem a pagination token; the daemon replies with the blob, or an empty frame if it has
37    /// expired or was already used.
38    CursorTake { token: String },
39}
40
41pub(crate) fn pack_opts(o: &SearchOptions) -> u8 {
42    (o.case_insensitive as u8)
43        | ((o.multi_line as u8) << 1)
44        | ((o.dot_matches_new_line as u8) << 2)
45        | ((o.word as u8) << 3)
46        | ((o.fixed_strings as u8) << 4)
47}
48
49pub(crate) fn unpack_opts(b: u8, before: u32, after: u32) -> SearchOptions {
50    SearchOptions {
51        case_insensitive: b & 1 != 0,
52        multi_line: b & 2 != 0,
53        dot_matches_new_line: b & 4 != 0,
54        word: b & 8 != 0,
55        fixed_strings: b & 16 != 0,
56        before_context: before as usize,
57        after_context: after as usize,
58    }
59}
60
61pub fn write_request(w: &mut impl Write, req: &Request) -> Result<()> {
62    let mut body = Vec::new();
63    match req {
64        Request::Search { opts, pattern } => {
65            body.push(b'S');
66            body.push(pack_opts(opts));
67            body.extend_from_slice(&(opts.before_context as u32).to_le_bytes());
68            body.extend_from_slice(&(opts.after_context as u32).to_le_bytes());
69            put_bytes(&mut body, pattern.as_bytes());
70        }
71        Request::Find {
72            needle,
73            after,
74            limit,
75        } => {
76            body.push(b'F');
77            body.extend_from_slice(&limit.to_le_bytes());
78            put_bytes(&mut body, needle.as_bytes());
79            put_bytes(&mut body, after.as_deref().unwrap_or("").as_bytes());
80        }
81        Request::Status => body.push(b'T'),
82        Request::Watch => body.push(b'W'),
83        Request::Shutdown => body.push(b'Q'),
84        Request::CursorStore { blob } => {
85            body.push(b'P');
86            put_bytes(&mut body, blob);
87        }
88        Request::CursorTake { token } => {
89            body.push(b'G');
90            put_bytes(&mut body, token.as_bytes());
91        }
92    }
93    write_frame(w, &body)
94}
95
96pub fn read_request(r: &mut impl Read) -> Result<Request> {
97    let body = read_frame(r)?;
98    let mut cur = &body[..];
99    let tag = take_u8(&mut cur)?;
100    Ok(match tag {
101        b'S' => {
102            let flags = take_u8(&mut cur)?;
103            let before = take_u32(&mut cur)?;
104            let after = take_u32(&mut cur)?;
105            let opts = unpack_opts(flags, before, after);
106            let pattern = String::from_utf8(take_bytes(&mut cur)?)?;
107            Request::Search { opts, pattern }
108        }
109        b'F' => {
110            let limit = take_u32(&mut cur)?;
111            let needle = String::from_utf8(take_bytes(&mut cur)?)?;
112            let after = String::from_utf8(take_bytes(&mut cur)?)?;
113            let after = (!after.is_empty()).then_some(after);
114            Request::Find {
115                needle,
116                after,
117                limit,
118            }
119        }
120        b'T' => Request::Status,
121        b'W' => Request::Watch,
122        b'Q' => Request::Shutdown,
123        b'P' => Request::CursorStore {
124            blob: take_bytes(&mut cur)?,
125        },
126        b'G' => Request::CursorTake {
127            token: String::from_utf8(take_bytes(&mut cur)?)?,
128        },
129        other => bail!("unknown request tag {other}"),
130    })
131}
132
133/// `Find` responses optionally lead with a one-line header
134/// `\x01<total>\t<start>\t<returned>\t<next_after>\n` so the client can report the true total (not
135/// just the truncated page), the keyset offset (`start` = items skipped before this page, for an
136/// honest "X-Y of N" range), and resume via `next_after`. The `0x01` sentinel can't begin a real path
137/// line, and older/headerless blobs parse as all-paths.
138pub const FIND_HEADER_SENTINEL: u8 = 0x01;
139
140pub struct FindHeader {
141    pub total: usize,
142    pub start: usize,
143    pub returned: usize,
144    pub next_after: Option<String>,
145}
146
147pub fn format_find_header(
148    total: usize,
149    start: usize,
150    returned: usize,
151    next_after: Option<&str>,
152) -> String {
153    format!(
154        "{}{total}\t{start}\t{returned}\t{}\n",
155        FIND_HEADER_SENTINEL as char,
156        next_after.unwrap_or("")
157    )
158}
159
160/// Split a `Find` response blob into its optional header and the remaining path lines.
161pub fn parse_find_header(blob: &[u8]) -> (Option<FindHeader>, &[u8]) {
162    if blob.first() != Some(&FIND_HEADER_SENTINEL) {
163        return (None, blob);
164    }
165    let Some(nl) = blob.iter().position(|&b| b == b'\n') else {
166        return (None, blob);
167    };
168    let line = String::from_utf8_lossy(&blob[1..nl]);
169    // `splitn(4)` keeps next_after (a file path) intact even if it contains a tab — the three numeric
170    // fields are tab-delimited, and everything after the third tab is the path verbatim.
171    let mut parts = line.splitn(4, '\t');
172    let total = parts.next().and_then(|s| s.parse().ok());
173    let start = parts.next().and_then(|s| s.parse().ok());
174    let returned = parts.next().and_then(|s| s.parse().ok());
175    match (total, start, returned) {
176        (Some(total), Some(start), Some(returned)) => {
177            let next_after = parts.next().filter(|s| !s.is_empty()).map(str::to_string);
178            (
179                Some(FindHeader {
180                    total,
181                    start,
182                    returned,
183                    next_after,
184                }),
185                &blob[nl + 1..],
186            )
187        }
188        _ => (None, blob),
189    }
190}
191
192/// Responses are a stream of non-empty data frames terminated by a zero-length frame, so the daemon
193/// can emit results as it finds them and the client writes them straight to stdout (no buffering of
194/// huge result sets on either side).
195pub fn write_data(w: &mut impl Write, data: &[u8]) -> Result<()> {
196    if !data.is_empty() {
197        write_frame(w, data)?;
198    }
199    Ok(())
200}
201
202pub fn end_stream(w: &mut impl Write) -> Result<()> {
203    w.write_all(&0u32.to_le_bytes())?;
204    w.flush()?;
205    Ok(())
206}
207
208/// Read a response stream, writing each chunk to `sink`; returns total bytes written.
209pub fn read_stream(r: &mut impl Read, sink: &mut impl Write) -> Result<usize> {
210    let mut total = 0;
211    loop {
212        let n = read_len(r)?;
213        if n == 0 {
214            return Ok(total);
215        }
216        let mut body = vec![0u8; n];
217        r.read_exact(&mut body)?;
218        sink.write_all(&body)?;
219        total += n;
220    }
221}
222
223/// Convenience: collect a whole response stream into a `Vec` (for small responses like status/find).
224pub fn read_stream_to_vec(r: &mut impl Read) -> Result<Vec<u8>> {
225    let mut v = Vec::new();
226    read_stream(r, &mut v)?;
227    Ok(v)
228}
229
230/// Read one frame from an open-ended stream (e.g. `Watch`), returning `None` when the stream ends
231/// (zero-length terminator or the daemon closing the connection).
232pub fn read_watch_frame(r: &mut impl Read) -> Result<Option<Vec<u8>>> {
233    let mut len = [0u8; 4];
234    match r.read_exact(&mut len) {
235        Ok(()) => {}
236        Err(e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
237        Err(e) => return Err(e.into()),
238    }
239    let n = u32::from_le_bytes(len) as usize;
240    if n == 0 {
241        return Ok(None);
242    }
243    if n > MAX_FRAME {
244        bail!("frame length {n} exceeds maximum {MAX_FRAME}");
245    }
246    let mut body = vec![0u8; n];
247    r.read_exact(&mut body)?;
248    Ok(Some(body))
249}
250
251/// Upper bound on a single frame, so a bogus/desynced length prefix can't trigger a multi-GB
252/// allocation. Generous (search results stream in many small frames; requests are tiny).
253const MAX_FRAME: usize = 512 * 1024 * 1024;
254
255fn read_len(r: &mut impl Read) -> Result<usize> {
256    let mut len = [0u8; 4];
257    r.read_exact(&mut len)?;
258    let n = u32::from_le_bytes(len) as usize;
259    if n > MAX_FRAME {
260        bail!("frame length {n} exceeds maximum {MAX_FRAME}");
261    }
262    Ok(n)
263}
264
265fn write_frame(w: &mut impl Write, body: &[u8]) -> Result<()> {
266    w.write_all(&(body.len() as u32).to_le_bytes())?;
267    w.write_all(body)?;
268    w.flush()?;
269    Ok(())
270}
271
272fn read_frame(r: &mut impl Read) -> Result<Vec<u8>> {
273    let mut body = vec![0u8; read_len(r)?];
274    r.read_exact(&mut body)?;
275    Ok(body)
276}
277
278fn put_bytes(buf: &mut Vec<u8>, b: &[u8]) {
279    buf.extend_from_slice(&(b.len() as u32).to_le_bytes());
280    buf.extend_from_slice(b);
281}
282
283fn take_u8(cur: &mut &[u8]) -> Result<u8> {
284    let (&b, rest) = cur
285        .split_first()
286        .ok_or_else(|| anyhow::anyhow!("short frame"))?;
287    *cur = rest;
288    Ok(b)
289}
290
291fn take_u32(cur: &mut &[u8]) -> Result<u32> {
292    if cur.len() < 4 {
293        bail!("short frame");
294    }
295    let (head, rest) = cur.split_at(4);
296    *cur = rest;
297    Ok(u32::from_le_bytes(head.try_into().unwrap()))
298}
299
300fn take_bytes(cur: &mut &[u8]) -> Result<Vec<u8>> {
301    let n = take_u32(cur)? as usize;
302    if cur.len() < n {
303        bail!("short frame");
304    }
305    let (head, rest) = cur.split_at(n);
306    *cur = rest;
307    Ok(head.to_vec())
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    fn roundtrip(req: Request) {
315        let mut buf = Vec::new();
316        write_request(&mut buf, &req).unwrap();
317        let got = read_request(&mut &buf[..]).unwrap();
318        assert_eq!(req, got);
319    }
320
321    #[test]
322    fn request_roundtrips() {
323        roundtrip(Request::Search {
324            opts: SearchOptions {
325                case_insensitive: true,
326                ..Default::default()
327            },
328            pattern: "Foo|Bar".to_string(),
329        });
330        roundtrip(Request::Find {
331            needle: "config".into(),
332            after: None,
333            limit: 50,
334        });
335        roundtrip(Request::Find {
336            needle: "config".into(),
337            after: Some("src/config.rs".into()),
338            limit: 50,
339        });
340        roundtrip(Request::Status);
341        roundtrip(Request::Watch);
342        roundtrip(Request::Shutdown);
343        roundtrip(Request::CursorStore {
344            blob: vec![0, 1, 2, 255],
345        });
346        roundtrip(Request::CursorTake {
347            token: "0000abcd5".to_string(),
348        });
349    }
350
351    #[test]
352    fn find_header_roundtrips_and_tolerates_headerless() {
353        let blob = format!(
354            "{}src/a.rs\nsrc/b.rs\n",
355            format_find_header(1342, 200, 2, Some("src/b.rs"))
356        );
357        let (header, rest) = parse_find_header(blob.as_bytes());
358        let header = header.unwrap();
359        assert_eq!(header.total, 1342);
360        assert_eq!(header.start, 200);
361        assert_eq!(header.returned, 2);
362        assert_eq!(header.next_after.as_deref(), Some("src/b.rs"));
363        assert_eq!(rest, b"src/a.rs\nsrc/b.rs\n");
364
365        // A headerless blob (no sentinel) parses as all paths.
366        let (none, rest) = parse_find_header(b"src/a.rs\n");
367        assert!(none.is_none());
368        assert_eq!(rest, b"src/a.rs\n");
369    }
370
371    #[test]
372    fn response_stream_roundtrips() {
373        let mut buf = Vec::new();
374        write_data(&mut buf, b"path:1:hello\n").unwrap();
375        write_data(&mut buf, b"").unwrap(); // empty chunk is a no-op, not a terminator
376        write_data(&mut buf, b"path:2:world\n").unwrap();
377        end_stream(&mut buf).unwrap();
378        assert_eq!(
379            read_stream_to_vec(&mut &buf[..]).unwrap(),
380            b"path:1:hello\npath:2:world\n"
381        );
382    }
383}