1use std::collections::BTreeMap;
17use std::fmt;
18use std::io::{self, Read, Write};
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use serde_json::value::RawValue;
24
25use crate::types::{
26 LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationRow, LeanWorkerDoctorReport,
27 LeanWorkerElabOptions, LeanWorkerElabResult, LeanWorkerKernelResult, LeanWorkerMetaResult,
28 LeanWorkerMetaTransparency, LeanWorkerProcessFileOutcome, LeanWorkerProcessModuleOutcome, LeanWorkerRendered,
29};
30
31pub const PROTOCOL_VERSION: u16 = 5;
34
35pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
45
46pub const MIN_FRAME_BYTES: u32 = 64 * 1024;
49
50pub const MAX_FRAME_BYTES_HARD_CAP: u32 = 256 * 1024 * 1024;
53
54#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
56#[non_exhaustive]
57pub struct Frame {
58 pub version: u16,
60 pub message: Message,
62}
63
64impl Frame {
65 #[must_use]
67 pub fn new(message: Message) -> Self {
68 Self {
69 version: PROTOCOL_VERSION,
70 message,
71 }
72 }
73}
74
75#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
77#[serde(tag = "type", content = "body", rename_all = "snake_case")]
78#[non_exhaustive]
79pub enum Message {
80 Handshake {
83 worker_version: String,
85 protocol_version: u16,
87 },
88 ConfigureFrameLimit {
95 max_frame_bytes: u32,
97 },
98 Request(Request),
100 Response(Response),
102 Diagnostic(Diagnostic),
104 ProgressTick(ProgressTick),
106 DataRow(DataRow),
108 FatalExit(FatalExit),
111}
112
113#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
115#[serde(tag = "op", rename_all = "snake_case")]
116#[non_exhaustive]
117pub enum Request {
118 Health,
119 LoadFixtureCapability {
120 fixture_root: String,
121 },
122 CallFixtureMul {
123 fixture_root: String,
124 lhs: u64,
125 rhs: u64,
126 },
127 TriggerLeanPanic {
128 fixture_root: String,
129 },
130 OpenHostSession {
131 project_root: String,
132 mode: HostSessionMode,
133 imports: Vec<String>,
134 },
135 Elaborate {
136 source: String,
137 options: LeanWorkerElabOptions,
138 },
139 KernelCheck {
140 source: String,
141 options: LeanWorkerElabOptions,
142 progress: bool,
143 },
144 DeclarationKinds {
145 names: Vec<String>,
146 progress: bool,
147 },
148 DeclarationNames {
149 names: Vec<String>,
150 progress: bool,
151 },
152 RunDataStream {
153 export: String,
154 request_json: String,
155 progress: bool,
156 },
157 CapabilityMetadata {
158 export: String,
159 request_json: String,
160 },
161 CapabilityDoctor {
162 export: String,
163 request_json: String,
164 },
165 JsonCommand {
166 export: String,
167 request_json: String,
168 },
169 InferType {
170 source: String,
171 options: LeanWorkerElabOptions,
172 },
173 Whnf {
174 source: String,
175 options: LeanWorkerElabOptions,
176 },
177 IsDefEq {
178 lhs: String,
179 rhs: String,
180 transparency: LeanWorkerMetaTransparency,
181 options: LeanWorkerElabOptions,
182 },
183 Describe {
184 name: String,
185 },
186 ListDeclarationsStrings {
187 filter: LeanWorkerDeclarationFilter,
188 progress: bool,
189 },
190 DescribeBulk {
191 names: Vec<String>,
192 progress: bool,
193 },
194 ProcessFile {
195 source: String,
196 options: LeanWorkerElabOptions,
197 },
198 ProcessModule {
199 source: String,
200 options: LeanWorkerElabOptions,
201 },
202 EmitTestRows {
205 streams: Vec<String>,
206 },
207 EmitTestRowsThenExit,
208 EmitTestRowsThenPanic,
209 Terminate,
210}
211
212#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
214#[serde(tag = "kind", rename_all = "snake_case")]
215#[non_exhaustive]
216pub enum HostSessionMode {
217 Capability { package: String, lib_name: String },
219 ShimsOnly,
221}
222
223#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
225#[serde(tag = "status", rename_all = "snake_case")]
226#[non_exhaustive]
227pub enum Response {
228 HealthOk,
229 CapabilityLoaded,
230 U64 {
231 value: u64,
232 },
233 HostSessionOpened,
234 Elaboration {
235 outcome: LeanWorkerElabResult,
236 },
237 KernelCheck {
238 outcome: LeanWorkerKernelResult,
239 },
240 Strings {
241 values: Vec<String>,
242 },
243 StreamComplete {
244 summary: StreamSummary,
245 },
246 StreamExportFailed {
247 status_byte: u8,
248 },
249 StreamCallbackFailed {
250 status_byte: u8,
251 description: String,
252 },
253 StreamRowMalformed {
254 message: String,
255 },
256 CapabilityMetadata {
257 metadata: LeanWorkerCapabilityMetadata,
258 },
259 CapabilityDoctor {
260 report: LeanWorkerDoctorReport,
261 },
262 CapabilityMetadataMalformed {
263 message: String,
264 },
265 CapabilityDoctorMalformed {
266 message: String,
267 },
268 JsonCommand {
269 response_json: String,
270 },
271 MetaExpr {
272 result: LeanWorkerMetaResult<LeanWorkerRendered>,
273 },
274 MetaBool {
275 result: LeanWorkerMetaResult<bool>,
276 },
277 Declaration {
278 row: Option<LeanWorkerDeclarationRow>,
279 },
280 DeclarationBulk {
281 rows: Vec<LeanWorkerDeclarationRow>,
282 },
283 ProcessFile {
284 outcome: LeanWorkerProcessFileOutcome,
285 },
286 ProcessModule {
287 outcome: LeanWorkerProcessModuleOutcome,
288 },
289 RowsComplete {
290 count: u64,
291 },
292 Terminating,
293 Error {
294 code: String,
295 message: String,
296 },
297}
298
299#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
301#[non_exhaustive]
302pub struct Diagnostic {
303 pub code: String,
305 pub message: String,
307}
308
309impl Diagnostic {
310 #[must_use]
312 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
313 Self {
314 code: code.into(),
315 message: message.into(),
316 }
317 }
318}
319
320#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
323#[non_exhaustive]
324pub struct ProgressTick {
325 pub phase: String,
327 pub current: u64,
329 pub total: Option<u64>,
331}
332
333impl ProgressTick {
334 #[must_use]
336 pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
337 Self {
338 phase: phase.into(),
339 current,
340 total,
341 }
342 }
343}
344
345#[derive(Clone, Debug, Deserialize, Serialize)]
352pub struct DataRow {
353 pub stream: String,
355 pub sequence: u64,
357 pub payload: Box<RawValue>,
359}
360
361impl PartialEq for DataRow {
362 fn eq(&self, other: &Self) -> bool {
363 self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
364 }
365}
366
367impl Eq for DataRow {}
368
369#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
371#[non_exhaustive]
372pub struct StreamSummary {
373 pub total_rows: u64,
375 pub per_stream_counts: BTreeMap<String, u64>,
377 pub elapsed_micros: u64,
379 pub metadata: Option<Value>,
381}
382
383impl StreamSummary {
384 #[must_use]
387 pub fn new(
388 total_rows: u64,
389 per_stream_counts: BTreeMap<String, u64>,
390 elapsed: Duration,
391 metadata: Option<Value>,
392 ) -> Self {
393 Self {
394 total_rows,
395 per_stream_counts,
396 elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
397 metadata,
398 }
399 }
400}
401
402#[derive(Debug, Default)]
405#[non_exhaustive]
406pub struct DataRowEmitter {
407 sequences: BTreeMap<String, u64>,
408 count: u64,
409}
410
411impl DataRowEmitter {
412 pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
415 let stream = stream.into();
416 let sequence = self.sequences.entry(stream.clone()).or_insert(0);
417 let row = DataRow {
418 stream,
419 sequence: *sequence,
420 payload,
421 };
422 *sequence = sequence.saturating_add(1);
423 self.count = self.count.saturating_add(1);
424 row
425 }
426
427 #[cfg(test)]
428 fn emit(
429 &mut self,
430 writer: &mut impl Write,
431 stream: impl Into<String>,
432 payload: &Value,
433 ) -> Result<(), ProtocolError> {
434 let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
435 write_frame(writer, Message::DataRow(row), MAX_FRAME_BYTES)
436 }
437
438 #[must_use]
440 pub fn count(&self) -> u64 {
441 self.count
442 }
443
444 #[must_use]
446 pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
447 self.sequences.clone()
448 }
449}
450
451#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
454#[non_exhaustive]
455pub struct FatalExit {
456 pub status: String,
458 pub stderr: String,
460}
461
462impl FatalExit {
463 #[must_use]
465 pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
466 Self {
467 status: status.into(),
468 stderr: stderr.into(),
469 }
470 }
471}
472
473#[derive(Debug)]
475#[non_exhaustive]
476pub enum ProtocolError {
477 Io(io::Error),
479 Json(serde_json::Error),
481 FrameTooLarge {
483 len: u32,
485 max: u32,
487 },
488 VersionMismatch {
490 expected: u16,
492 actual: u16,
494 },
495}
496
497impl ProtocolError {
498 #[must_use]
502 pub fn is_eof(&self) -> bool {
503 matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
504 }
505}
506
507impl fmt::Display for ProtocolError {
508 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
509 match self {
510 Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
511 Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
512 Self::FrameTooLarge { len, max } => {
513 write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
514 }
515 Self::VersionMismatch { expected, actual } => {
516 write!(
517 f,
518 "worker protocol version mismatch: expected {expected}, received {actual}"
519 )
520 }
521 }
522 }
523}
524
525impl std::error::Error for ProtocolError {}
526
527impl From<io::Error> for ProtocolError {
528 fn from(value: io::Error) -> Self {
529 Self::Io(value)
530 }
531}
532
533impl From<serde_json::Error> for ProtocolError {
534 fn from(value: serde_json::Error) -> Self {
535 Self::Json(value)
536 }
537}
538
539pub fn write_frame(writer: &mut impl Write, message: Message, max_frame_bytes: u32) -> Result<(), ProtocolError> {
552 let bytes = serde_json::to_vec(&Frame::new(message))?;
553 let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
554 len: u32::MAX,
555 max: max_frame_bytes,
556 })?;
557 if len > max_frame_bytes {
558 return Err(ProtocolError::FrameTooLarge {
559 len,
560 max: max_frame_bytes,
561 });
562 }
563 writer.write_all(&len.to_be_bytes())?;
564 writer.write_all(&bytes)?;
565 writer.flush()?;
566 Ok(())
567}
568
569pub fn read_frame(reader: &mut impl Read, max_frame_bytes: u32) -> Result<Frame, ProtocolError> {
582 let mut len_bytes = [0_u8; 4];
583 reader.read_exact(&mut len_bytes)?;
584 let len = u32::from_be_bytes(len_bytes);
585 if len > max_frame_bytes {
586 return Err(ProtocolError::FrameTooLarge {
587 len,
588 max: max_frame_bytes,
589 });
590 }
591 let mut bytes = vec![0_u8; len as usize];
592 reader.read_exact(&mut bytes)?;
593 let frame: Frame = serde_json::from_slice(&bytes)?;
594 if frame.version != PROTOCOL_VERSION {
595 return Err(ProtocolError::VersionMismatch {
596 expected: PROTOCOL_VERSION,
597 actual: frame.version,
598 });
599 }
600 Ok(frame)
601}
602
603#[cfg(test)]
604mod tests {
605 #![allow(clippy::expect_used, clippy::panic)]
606
607 use std::io::Cursor;
608
609 use serde_json::json;
610 use serde_json::value::RawValue;
611
612 use super::{
613 DataRow, DataRowEmitter, MAX_FRAME_BYTES, MAX_FRAME_BYTES_HARD_CAP, MIN_FRAME_BYTES, Message, ProtocolError,
614 Response, read_frame, write_frame,
615 };
616
617 fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
618 serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
619 }
620
621 #[test]
622 fn data_row_round_trips_through_length_delimited_frame() {
623 let row = DataRow {
624 stream: "rows".to_owned(),
625 sequence: 7,
626 payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
627 };
628 let mut bytes = Vec::new();
629 write_frame(&mut bytes, Message::DataRow(row.clone()), MAX_FRAME_BYTES).expect("data row writes");
630 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("data row reads");
631 assert_eq!(frame.message, Message::DataRow(row));
632 }
633
634 #[test]
635 fn data_row_emitter_assigns_per_stream_sequences() {
636 let mut emitter = DataRowEmitter::default();
637 let mut bytes = Vec::new();
638 emitter
639 .emit(&mut bytes, "rows", &json!({ "i": 0 }))
640 .expect("first row writes");
641 emitter
642 .emit(&mut bytes, "warnings", &json!({ "i": 1 }))
643 .expect("second row writes");
644 emitter
645 .emit(&mut bytes, "rows", &json!({ "i": 2 }))
646 .expect("third row writes");
647 assert_eq!(emitter.count(), 3);
648
649 let mut cursor = Cursor::new(bytes);
650 let rows = [
651 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("first row reads"),
652 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("second row reads"),
653 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("third row reads"),
654 ];
655 assert_eq!(
656 rows.map(|frame| frame.message),
657 [
658 Message::DataRow(DataRow {
659 stream: "rows".to_owned(),
660 sequence: 0,
661 payload: raw_json(&json!({ "i": 0 })),
662 }),
663 Message::DataRow(DataRow {
664 stream: "warnings".to_owned(),
665 sequence: 0,
666 payload: raw_json(&json!({ "i": 1 })),
667 }),
668 Message::DataRow(DataRow {
669 stream: "rows".to_owned(),
670 sequence: 1,
671 payload: raw_json(&json!({ "i": 2 })),
672 }),
673 ],
674 );
675 }
676
677 #[test]
678 fn oversized_data_row_is_rejected_before_write() {
679 let row = DataRow {
680 stream: "rows".to_owned(),
681 sequence: 0,
682 payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
683 };
684 let mut bytes = Vec::new();
685 let err =
686 write_frame(&mut bytes, Message::DataRow(row), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
687 match err {
688 ProtocolError::FrameTooLarge { len, max } => {
689 assert!(len > max);
690 assert_eq!(max, MAX_FRAME_BYTES);
691 }
692 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
693 panic!("expected FrameTooLarge, got {other:?}");
694 }
695 }
696 }
697
698 #[test]
699 fn oversized_data_row_is_rejected_before_read_allocation() {
700 let mut bytes = Vec::new();
701 bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
702 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
703 match err {
704 ProtocolError::FrameTooLarge { len, max } => {
705 assert_eq!(len, MAX_FRAME_BYTES + 1);
706 assert_eq!(max, MAX_FRAME_BYTES);
707 }
708 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
709 panic!("expected FrameTooLarge, got {other:?}");
710 }
711 }
712 }
713
714 #[test]
715 fn larger_cap_accepts_frame_rejected_under_default() {
716 let raised = MAX_FRAME_BYTES.saturating_mul(8);
720 let row = DataRow {
721 stream: "rows".to_owned(),
722 sequence: 0,
723 payload: raw_json(&json!({ "blob": "x".repeat(2 * MAX_FRAME_BYTES as usize) })),
724 };
725 let mut buf = Vec::new();
726 write_frame(&mut buf, Message::DataRow(row.clone()), raised).expect("oversize-under-default frame writes");
727 let frame = read_frame(&mut Cursor::new(buf), raised).expect("oversize-under-default frame reads");
728 assert_eq!(frame.message, Message::DataRow(row));
729 }
730
731 #[test]
732 fn frame_cap_bounds_constants_are_consistent() {
733 const { assert!(MIN_FRAME_BYTES <= MAX_FRAME_BYTES) };
734 const { assert!(MAX_FRAME_BYTES <= MAX_FRAME_BYTES_HARD_CAP) };
735 }
736
737 #[test]
738 fn malformed_frame_payload_is_protocol_error() {
739 let mut bytes = Vec::new();
740 bytes.extend_from_slice(&1_u32.to_be_bytes());
741 bytes.push(b'{');
742 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("malformed JSON is rejected");
743 match err {
744 ProtocolError::Json(_) => {}
745 other @ (ProtocolError::Io(_)
746 | ProtocolError::FrameTooLarge { .. }
747 | ProtocolError::VersionMismatch { .. }) => {
748 panic!("expected Json error, got {other:?}");
749 }
750 }
751 }
752
753 #[test]
754 fn rows_complete_response_round_trips() {
755 let mut bytes = Vec::new();
756 write_frame(
757 &mut bytes,
758 Message::Response(Response::RowsComplete { count: 2 }),
759 MAX_FRAME_BYTES,
760 )
761 .expect("rows complete writes");
762 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("rows complete reads");
763 assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
764 }
765}