Skip to main content

lean_rs_worker_protocol/
protocol.rs

1//! Length-delimited frame codec and message payload types for the
2//! parent ↔ child worker process boundary.
3//!
4//! ## Additive evolution
5//!
6//! Every public enum here is `#[non_exhaustive]` so the wire format can gain
7//! a new request, response, or message kind without forcing a semver-major
8//! bump on consumers. Most structs are also `#[non_exhaustive]` and expose
9//! `pub fn new(...)` constructors so the shapes can grow fields without
10//! breaking external builders. The exception is [`DataRow`], which is built
11//! so frequently with struct-literal syntax (tests, harnesses, fakes) that
12//! the ergonomic cost of `#[non_exhaustive]` outweighs the additive-evolution
13//! benefit; the wire schema for a data row is also already fixed by the
14//! stream contract.
15
16use std::collections::BTreeMap;
17use std::fmt;
18use std::io::{self, Read, Write};
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use serde_json::value::RawValue;
24
25use crate::types::{
26    LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationRow, LeanWorkerDoctorReport,
27    LeanWorkerElabOptions, LeanWorkerElabResult, LeanWorkerKernelResult, LeanWorkerMetaResult,
28    LeanWorkerMetaTransparency, LeanWorkerProcessFileOutcome, LeanWorkerProcessModuleOutcome, LeanWorkerRendered,
29};
30
31/// Wire protocol version negotiated between parent and child during the
32/// handshake frame. Bump only on a breaking wire change.
33pub const PROTOCOL_VERSION: u16 = 3;
34
35/// Hard ceiling on one frame's serialised JSON payload in bytes.
36///
37/// Both [`write_frame`] and [`read_frame`] reject frames over this limit so a
38/// runaway producer cannot make the peer allocate without bound.
39pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
40
41/// Versioned envelope around a single protocol [`Message`].
42#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
43#[non_exhaustive]
44pub struct Frame {
45    /// Protocol version the sender used. Receivers reject mismatches.
46    pub version: u16,
47    /// Inner message payload.
48    pub message: Message,
49}
50
51impl Frame {
52    /// Wrap `message` in a frame tagged with the current [`PROTOCOL_VERSION`].
53    #[must_use]
54    pub fn new(message: Message) -> Self {
55        Self {
56            version: PROTOCOL_VERSION,
57            message,
58        }
59    }
60}
61
62/// One protocol message exchanged over the worker boundary.
63#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
64#[serde(tag = "type", content = "body", rename_all = "snake_case")]
65#[non_exhaustive]
66pub enum Message {
67    /// Sent by the child immediately after spawn to advertise its version and
68    /// supported protocol revision.
69    Handshake {
70        /// `lean-rs-worker-child` package version.
71        worker_version: String,
72        /// Protocol version the child speaks. Parent rejects mismatches.
73        protocol_version: u16,
74    },
75    /// Parent → child request frame.
76    Request(Request),
77    /// Child → parent terminal response for one request.
78    Response(Response),
79    /// Child → parent intermediate diagnostic frame.
80    Diagnostic(Diagnostic),
81    /// Child → parent intermediate progress frame.
82    ProgressTick(ProgressTick),
83    /// Child → parent streaming data row frame.
84    DataRow(DataRow),
85    /// Child → parent fatal exit notification carrying the captured stderr
86    /// just before the child process tears down.
87    FatalExit(FatalExit),
88}
89
90/// Parent-issued worker request body.
91#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
92#[serde(tag = "op", rename_all = "snake_case")]
93#[non_exhaustive]
94pub enum Request {
95    Health,
96    LoadFixtureCapability {
97        fixture_root: String,
98    },
99    CallFixtureMul {
100        fixture_root: String,
101        lhs: u64,
102        rhs: u64,
103    },
104    TriggerLeanPanic {
105        fixture_root: String,
106    },
107    OpenHostSession {
108        project_root: String,
109        package: String,
110        lib_name: String,
111        imports: Vec<String>,
112    },
113    Elaborate {
114        source: String,
115        options: LeanWorkerElabOptions,
116    },
117    KernelCheck {
118        source: String,
119        options: LeanWorkerElabOptions,
120        progress: bool,
121    },
122    DeclarationKinds {
123        names: Vec<String>,
124        progress: bool,
125    },
126    DeclarationNames {
127        names: Vec<String>,
128        progress: bool,
129    },
130    RunDataStream {
131        export: String,
132        request_json: String,
133        progress: bool,
134    },
135    CapabilityMetadata {
136        export: String,
137        request_json: String,
138    },
139    CapabilityDoctor {
140        export: String,
141        request_json: String,
142    },
143    JsonCommand {
144        export: String,
145        request_json: String,
146    },
147    InferType {
148        source: String,
149        options: LeanWorkerElabOptions,
150    },
151    Whnf {
152        source: String,
153        options: LeanWorkerElabOptions,
154    },
155    IsDefEq {
156        lhs: String,
157        rhs: String,
158        transparency: LeanWorkerMetaTransparency,
159        options: LeanWorkerElabOptions,
160    },
161    Describe {
162        name: String,
163    },
164    ListDeclarationsStrings {
165        filter: LeanWorkerDeclarationFilter,
166        progress: bool,
167    },
168    DescribeBulk {
169        names: Vec<String>,
170        progress: bool,
171    },
172    ProcessFile {
173        source: String,
174        options: LeanWorkerElabOptions,
175    },
176    ProcessModule {
177        source: String,
178        options: LeanWorkerElabOptions,
179    },
180    // Private harness requests that exercise streaming frame behavior.
181    // Not part of the public row sink API.
182    EmitTestRows {
183        streams: Vec<String>,
184    },
185    EmitTestRowsThenExit,
186    EmitTestRowsThenPanic,
187    Terminate,
188}
189
190/// Child-issued terminal response body for one [`Request`].
191#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
192#[serde(tag = "status", rename_all = "snake_case")]
193#[non_exhaustive]
194pub enum Response {
195    HealthOk,
196    CapabilityLoaded,
197    U64 {
198        value: u64,
199    },
200    HostSessionOpened,
201    Elaboration {
202        outcome: LeanWorkerElabResult,
203    },
204    KernelCheck {
205        outcome: LeanWorkerKernelResult,
206    },
207    Strings {
208        values: Vec<String>,
209    },
210    StreamComplete {
211        summary: StreamSummary,
212    },
213    StreamExportFailed {
214        status_byte: u8,
215    },
216    StreamCallbackFailed {
217        status_byte: u8,
218        description: String,
219    },
220    StreamRowMalformed {
221        message: String,
222    },
223    CapabilityMetadata {
224        metadata: LeanWorkerCapabilityMetadata,
225    },
226    CapabilityDoctor {
227        report: LeanWorkerDoctorReport,
228    },
229    CapabilityMetadataMalformed {
230        message: String,
231    },
232    CapabilityDoctorMalformed {
233        message: String,
234    },
235    JsonCommand {
236        response_json: String,
237    },
238    MetaExpr {
239        result: LeanWorkerMetaResult<LeanWorkerRendered>,
240    },
241    MetaBool {
242        result: LeanWorkerMetaResult<bool>,
243    },
244    Declaration {
245        row: Option<LeanWorkerDeclarationRow>,
246    },
247    DeclarationBulk {
248        rows: Vec<LeanWorkerDeclarationRow>,
249    },
250    ProcessFile {
251        outcome: LeanWorkerProcessFileOutcome,
252    },
253    ProcessModule {
254        outcome: LeanWorkerProcessModuleOutcome,
255    },
256    RowsComplete {
257        count: u64,
258    },
259    Terminating,
260    Error {
261        code: String,
262        message: String,
263    },
264}
265
266/// Intermediate diagnostic frame emitted by the child during a request.
267#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
268#[non_exhaustive]
269pub struct Diagnostic {
270    /// Stable diagnostic code identifier.
271    pub code: String,
272    /// Bounded human-readable diagnostic message.
273    pub message: String,
274}
275
276impl Diagnostic {
277    /// Build a diagnostic frame payload.
278    #[must_use]
279    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
280        Self {
281            code: code.into(),
282            message: message.into(),
283        }
284    }
285}
286
287/// Intermediate progress frame emitted by the child during a long-running
288/// request.
289#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
290#[non_exhaustive]
291pub struct ProgressTick {
292    /// Phase name the child is reporting progress for.
293    pub phase: String,
294    /// Items completed so far in this phase.
295    pub current: u64,
296    /// Total expected items in this phase, if known.
297    pub total: Option<u64>,
298}
299
300impl ProgressTick {
301    /// Build a progress-tick frame payload.
302    #[must_use]
303    pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
304        Self {
305            phase: phase.into(),
306            current,
307            total,
308        }
309    }
310}
311
312/// One row in a streaming response.
313///
314/// Construction goes through [`DataRowEmitter::next`] in the child runtime;
315/// direct struct-literal construction is permitted in tests and harnesses.
316/// This struct intentionally stays exhaustive: see the module-level note on
317/// additive evolution.
318#[derive(Clone, Debug, Deserialize, Serialize)]
319pub struct DataRow {
320    /// Logical stream this row belongs to.
321    pub stream: String,
322    /// Per-stream monotonically increasing sequence number.
323    pub sequence: u64,
324    /// Opaque JSON payload (deserialised lazily by the parent).
325    pub payload: Box<RawValue>,
326}
327
328impl PartialEq for DataRow {
329    fn eq(&self, other: &Self) -> bool {
330        self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
331    }
332}
333
334impl Eq for DataRow {}
335
336/// Terminal stream-completion summary returned alongside a streaming response.
337#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
338#[non_exhaustive]
339pub struct StreamSummary {
340    /// Total rows emitted across all streams.
341    pub total_rows: u64,
342    /// Per-stream row counts at completion.
343    pub per_stream_counts: BTreeMap<String, u64>,
344    /// Child-side elapsed time in microseconds.
345    pub elapsed_micros: u64,
346    /// Optional downstream-defined terminal metadata.
347    pub metadata: Option<Value>,
348}
349
350impl StreamSummary {
351    /// Build a stream-completion summary, clamping the elapsed duration into
352    /// the `u64` micros field.
353    #[must_use]
354    pub fn new(
355        total_rows: u64,
356        per_stream_counts: BTreeMap<String, u64>,
357        elapsed: Duration,
358        metadata: Option<Value>,
359    ) -> Self {
360        Self {
361            total_rows,
362            per_stream_counts,
363            elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
364            metadata,
365        }
366    }
367}
368
369/// Stateful emitter that assigns per-stream sequence numbers and tracks the
370/// running row count for the terminal [`StreamSummary`].
371#[derive(Debug, Default)]
372#[non_exhaustive]
373pub struct DataRowEmitter {
374    sequences: BTreeMap<String, u64>,
375    count: u64,
376}
377
378impl DataRowEmitter {
379    /// Allocate the next [`DataRow`] for `stream`, advancing the per-stream
380    /// sequence and the overall count.
381    pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
382        let stream = stream.into();
383        let sequence = self.sequences.entry(stream.clone()).or_insert(0);
384        let row = DataRow {
385            stream,
386            sequence: *sequence,
387            payload,
388        };
389        *sequence = sequence.saturating_add(1);
390        self.count = self.count.saturating_add(1);
391        row
392    }
393
394    #[cfg(test)]
395    fn emit(
396        &mut self,
397        writer: &mut impl Write,
398        stream: impl Into<String>,
399        payload: &Value,
400    ) -> Result<(), ProtocolError> {
401        let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
402        write_frame(writer, Message::DataRow(row))
403    }
404
405    /// Total rows emitted across all streams.
406    #[must_use]
407    pub fn count(&self) -> u64 {
408        self.count
409    }
410
411    /// Snapshot of per-stream row counts.
412    #[must_use]
413    pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
414        self.sequences.clone()
415    }
416}
417
418/// Final frame the child writes before it tears down on an unrecoverable
419/// failure.
420#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
421#[non_exhaustive]
422pub struct FatalExit {
423    /// Stringified `ExitStatus` of the child process.
424    pub status: String,
425    /// Captured stderr tail at fatal-exit time.
426    pub stderr: String,
427}
428
429impl FatalExit {
430    /// Build a fatal-exit frame payload.
431    #[must_use]
432    pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
433        Self {
434            status: status.into(),
435            stderr: stderr.into(),
436        }
437    }
438}
439
440/// Failure modes the codec can produce while reading or writing a frame.
441#[derive(Debug)]
442#[non_exhaustive]
443pub enum ProtocolError {
444    /// Underlying I/O failure (pipe closed, partial read, etc.).
445    Io(io::Error),
446    /// JSON serialisation or deserialisation failure.
447    Json(serde_json::Error),
448    /// A frame body exceeded [`MAX_FRAME_BYTES`].
449    FrameTooLarge {
450        /// Observed frame size in bytes.
451        len: u32,
452        /// Maximum allowed frame size.
453        max: u32,
454    },
455    /// Peer's frame version did not match this binary's [`PROTOCOL_VERSION`].
456    VersionMismatch {
457        /// Version this binary expected.
458        expected: u16,
459        /// Version the peer used.
460        actual: u16,
461    },
462}
463
464impl ProtocolError {
465    /// Whether the underlying I/O error indicates the peer's pipe was closed
466    /// (`UnexpectedEof`). Used by callers to distinguish a clean fatal exit
467    /// from a true protocol failure.
468    #[must_use]
469    pub fn is_eof(&self) -> bool {
470        matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
471    }
472}
473
474impl fmt::Display for ProtocolError {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        match self {
477            Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
478            Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
479            Self::FrameTooLarge { len, max } => {
480                write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
481            }
482            Self::VersionMismatch { expected, actual } => {
483                write!(
484                    f,
485                    "worker protocol version mismatch: expected {expected}, received {actual}"
486                )
487            }
488        }
489    }
490}
491
492impl std::error::Error for ProtocolError {}
493
494impl From<io::Error> for ProtocolError {
495    fn from(value: io::Error) -> Self {
496        Self::Io(value)
497    }
498}
499
500impl From<serde_json::Error> for ProtocolError {
501    fn from(value: serde_json::Error) -> Self {
502        Self::Json(value)
503    }
504}
505
506/// Serialise `message` as a length-delimited JSON frame to `writer`.
507///
508/// # Errors
509///
510/// Returns [`ProtocolError::FrameTooLarge`] if the serialised body would
511/// exceed [`MAX_FRAME_BYTES`], or the underlying [`ProtocolError::Io`] /
512/// [`ProtocolError::Json`] for codec failures.
513pub fn write_frame(writer: &mut impl Write, message: Message) -> Result<(), ProtocolError> {
514    let bytes = serde_json::to_vec(&Frame::new(message))?;
515    let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
516        len: u32::MAX,
517        max: MAX_FRAME_BYTES,
518    })?;
519    if len > MAX_FRAME_BYTES {
520        return Err(ProtocolError::FrameTooLarge {
521            len,
522            max: MAX_FRAME_BYTES,
523        });
524    }
525    writer.write_all(&len.to_be_bytes())?;
526    writer.write_all(&bytes)?;
527    writer.flush()?;
528    Ok(())
529}
530
531/// Read one length-delimited JSON frame from `reader`.
532///
533/// # Errors
534///
535/// Returns [`ProtocolError::FrameTooLarge`] if the framed length exceeds
536/// [`MAX_FRAME_BYTES`] (rejected before allocation),
537/// [`ProtocolError::VersionMismatch`] if the peer's version does not match
538/// [`PROTOCOL_VERSION`], or the underlying [`ProtocolError::Io`] /
539/// [`ProtocolError::Json`] for codec failures.
540pub fn read_frame(reader: &mut impl Read) -> Result<Frame, ProtocolError> {
541    let mut len_bytes = [0_u8; 4];
542    reader.read_exact(&mut len_bytes)?;
543    let len = u32::from_be_bytes(len_bytes);
544    if len > MAX_FRAME_BYTES {
545        return Err(ProtocolError::FrameTooLarge {
546            len,
547            max: MAX_FRAME_BYTES,
548        });
549    }
550    let mut bytes = vec![0_u8; len as usize];
551    reader.read_exact(&mut bytes)?;
552    let frame: Frame = serde_json::from_slice(&bytes)?;
553    if frame.version != PROTOCOL_VERSION {
554        return Err(ProtocolError::VersionMismatch {
555            expected: PROTOCOL_VERSION,
556            actual: frame.version,
557        });
558    }
559    Ok(frame)
560}
561
562#[cfg(test)]
563mod tests {
564    #![allow(clippy::expect_used, clippy::panic)]
565
566    use std::io::Cursor;
567
568    use serde_json::json;
569    use serde_json::value::RawValue;
570
571    use super::{DataRow, DataRowEmitter, MAX_FRAME_BYTES, Message, ProtocolError, Response, read_frame, write_frame};
572
573    fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
574        serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
575    }
576
577    #[test]
578    fn data_row_round_trips_through_length_delimited_frame() {
579        let row = DataRow {
580            stream: "rows".to_owned(),
581            sequence: 7,
582            payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
583        };
584        let mut bytes = Vec::new();
585        write_frame(&mut bytes, Message::DataRow(row.clone())).expect("data row writes");
586        let frame = read_frame(&mut Cursor::new(bytes)).expect("data row reads");
587        assert_eq!(frame.message, Message::DataRow(row));
588    }
589
590    #[test]
591    fn data_row_emitter_assigns_per_stream_sequences() {
592        let mut emitter = DataRowEmitter::default();
593        let mut bytes = Vec::new();
594        emitter
595            .emit(&mut bytes, "rows", &json!({ "i": 0 }))
596            .expect("first row writes");
597        emitter
598            .emit(&mut bytes, "warnings", &json!({ "i": 1 }))
599            .expect("second row writes");
600        emitter
601            .emit(&mut bytes, "rows", &json!({ "i": 2 }))
602            .expect("third row writes");
603        assert_eq!(emitter.count(), 3);
604
605        let mut cursor = Cursor::new(bytes);
606        let rows = [
607            read_frame(&mut cursor).expect("first row reads"),
608            read_frame(&mut cursor).expect("second row reads"),
609            read_frame(&mut cursor).expect("third row reads"),
610        ];
611        assert_eq!(
612            rows.map(|frame| frame.message),
613            [
614                Message::DataRow(DataRow {
615                    stream: "rows".to_owned(),
616                    sequence: 0,
617                    payload: raw_json(&json!({ "i": 0 })),
618                }),
619                Message::DataRow(DataRow {
620                    stream: "warnings".to_owned(),
621                    sequence: 0,
622                    payload: raw_json(&json!({ "i": 1 })),
623                }),
624                Message::DataRow(DataRow {
625                    stream: "rows".to_owned(),
626                    sequence: 1,
627                    payload: raw_json(&json!({ "i": 2 })),
628                }),
629            ],
630        );
631    }
632
633    #[test]
634    fn oversized_data_row_is_rejected_before_write() {
635        let row = DataRow {
636            stream: "rows".to_owned(),
637            sequence: 0,
638            payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
639        };
640        let mut bytes = Vec::new();
641        let err = write_frame(&mut bytes, Message::DataRow(row)).expect_err("oversized frame is rejected");
642        match err {
643            ProtocolError::FrameTooLarge { len, max } => {
644                assert!(len > max);
645                assert_eq!(max, MAX_FRAME_BYTES);
646            }
647            other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
648                panic!("expected FrameTooLarge, got {other:?}");
649            }
650        }
651    }
652
653    #[test]
654    fn oversized_data_row_is_rejected_before_read_allocation() {
655        let mut bytes = Vec::new();
656        bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
657        let err = read_frame(&mut Cursor::new(bytes)).expect_err("oversized frame is rejected");
658        match err {
659            ProtocolError::FrameTooLarge { len, max } => {
660                assert_eq!(len, MAX_FRAME_BYTES + 1);
661                assert_eq!(max, MAX_FRAME_BYTES);
662            }
663            other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
664                panic!("expected FrameTooLarge, got {other:?}");
665            }
666        }
667    }
668
669    #[test]
670    fn malformed_frame_payload_is_protocol_error() {
671        let mut bytes = Vec::new();
672        bytes.extend_from_slice(&1_u32.to_be_bytes());
673        bytes.push(b'{');
674        let err = read_frame(&mut Cursor::new(bytes)).expect_err("malformed JSON is rejected");
675        match err {
676            ProtocolError::Json(_) => {}
677            other @ (ProtocolError::Io(_)
678            | ProtocolError::FrameTooLarge { .. }
679            | ProtocolError::VersionMismatch { .. }) => {
680                panic!("expected Json error, got {other:?}");
681            }
682        }
683    }
684
685    #[test]
686    fn rows_complete_response_round_trips() {
687        let mut bytes = Vec::new();
688        write_frame(&mut bytes, Message::Response(Response::RowsComplete { count: 2 })).expect("rows complete writes");
689        let frame = read_frame(&mut Cursor::new(bytes)).expect("rows complete reads");
690        assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
691    }
692}