1use std::collections::BTreeMap;
17use std::fmt;
18use std::io::{self, Read, Write};
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use serde_json::value::RawValue;
24
25use crate::types::{
26 LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationRow, LeanWorkerDoctorReport,
27 LeanWorkerElabOptions, LeanWorkerElabResult, LeanWorkerKernelResult, LeanWorkerMetaResult,
28 LeanWorkerMetaTransparency, LeanWorkerProcessFileOutcome, LeanWorkerProcessModuleOutcome, LeanWorkerRendered,
29};
30
31pub const PROTOCOL_VERSION: u16 = 4;
34
35pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
45
46pub const MIN_FRAME_BYTES: u32 = 64 * 1024;
49
50pub const MAX_FRAME_BYTES_HARD_CAP: u32 = 256 * 1024 * 1024;
53
54#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
56#[non_exhaustive]
57pub struct Frame {
58 pub version: u16,
60 pub message: Message,
62}
63
64impl Frame {
65 #[must_use]
67 pub fn new(message: Message) -> Self {
68 Self {
69 version: PROTOCOL_VERSION,
70 message,
71 }
72 }
73}
74
75#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
77#[serde(tag = "type", content = "body", rename_all = "snake_case")]
78#[non_exhaustive]
79pub enum Message {
80 Handshake {
83 worker_version: String,
85 protocol_version: u16,
87 },
88 ConfigureFrameLimit {
95 max_frame_bytes: u32,
97 },
98 Request(Request),
100 Response(Response),
102 Diagnostic(Diagnostic),
104 ProgressTick(ProgressTick),
106 DataRow(DataRow),
108 FatalExit(FatalExit),
111}
112
113#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
115#[serde(tag = "op", rename_all = "snake_case")]
116#[non_exhaustive]
117pub enum Request {
118 Health,
119 LoadFixtureCapability {
120 fixture_root: String,
121 },
122 CallFixtureMul {
123 fixture_root: String,
124 lhs: u64,
125 rhs: u64,
126 },
127 TriggerLeanPanic {
128 fixture_root: String,
129 },
130 OpenHostSession {
131 project_root: String,
132 package: String,
133 lib_name: String,
134 imports: Vec<String>,
135 },
136 Elaborate {
137 source: String,
138 options: LeanWorkerElabOptions,
139 },
140 KernelCheck {
141 source: String,
142 options: LeanWorkerElabOptions,
143 progress: bool,
144 },
145 DeclarationKinds {
146 names: Vec<String>,
147 progress: bool,
148 },
149 DeclarationNames {
150 names: Vec<String>,
151 progress: bool,
152 },
153 RunDataStream {
154 export: String,
155 request_json: String,
156 progress: bool,
157 },
158 CapabilityMetadata {
159 export: String,
160 request_json: String,
161 },
162 CapabilityDoctor {
163 export: String,
164 request_json: String,
165 },
166 JsonCommand {
167 export: String,
168 request_json: String,
169 },
170 InferType {
171 source: String,
172 options: LeanWorkerElabOptions,
173 },
174 Whnf {
175 source: String,
176 options: LeanWorkerElabOptions,
177 },
178 IsDefEq {
179 lhs: String,
180 rhs: String,
181 transparency: LeanWorkerMetaTransparency,
182 options: LeanWorkerElabOptions,
183 },
184 Describe {
185 name: String,
186 },
187 ListDeclarationsStrings {
188 filter: LeanWorkerDeclarationFilter,
189 progress: bool,
190 },
191 DescribeBulk {
192 names: Vec<String>,
193 progress: bool,
194 },
195 ProcessFile {
196 source: String,
197 options: LeanWorkerElabOptions,
198 },
199 ProcessModule {
200 source: String,
201 options: LeanWorkerElabOptions,
202 },
203 EmitTestRows {
206 streams: Vec<String>,
207 },
208 EmitTestRowsThenExit,
209 EmitTestRowsThenPanic,
210 Terminate,
211}
212
213#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
215#[serde(tag = "status", rename_all = "snake_case")]
216#[non_exhaustive]
217pub enum Response {
218 HealthOk,
219 CapabilityLoaded,
220 U64 {
221 value: u64,
222 },
223 HostSessionOpened,
224 Elaboration {
225 outcome: LeanWorkerElabResult,
226 },
227 KernelCheck {
228 outcome: LeanWorkerKernelResult,
229 },
230 Strings {
231 values: Vec<String>,
232 },
233 StreamComplete {
234 summary: StreamSummary,
235 },
236 StreamExportFailed {
237 status_byte: u8,
238 },
239 StreamCallbackFailed {
240 status_byte: u8,
241 description: String,
242 },
243 StreamRowMalformed {
244 message: String,
245 },
246 CapabilityMetadata {
247 metadata: LeanWorkerCapabilityMetadata,
248 },
249 CapabilityDoctor {
250 report: LeanWorkerDoctorReport,
251 },
252 CapabilityMetadataMalformed {
253 message: String,
254 },
255 CapabilityDoctorMalformed {
256 message: String,
257 },
258 JsonCommand {
259 response_json: String,
260 },
261 MetaExpr {
262 result: LeanWorkerMetaResult<LeanWorkerRendered>,
263 },
264 MetaBool {
265 result: LeanWorkerMetaResult<bool>,
266 },
267 Declaration {
268 row: Option<LeanWorkerDeclarationRow>,
269 },
270 DeclarationBulk {
271 rows: Vec<LeanWorkerDeclarationRow>,
272 },
273 ProcessFile {
274 outcome: LeanWorkerProcessFileOutcome,
275 },
276 ProcessModule {
277 outcome: LeanWorkerProcessModuleOutcome,
278 },
279 RowsComplete {
280 count: u64,
281 },
282 Terminating,
283 Error {
284 code: String,
285 message: String,
286 },
287}
288
289#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
291#[non_exhaustive]
292pub struct Diagnostic {
293 pub code: String,
295 pub message: String,
297}
298
299impl Diagnostic {
300 #[must_use]
302 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
303 Self {
304 code: code.into(),
305 message: message.into(),
306 }
307 }
308}
309
310#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
313#[non_exhaustive]
314pub struct ProgressTick {
315 pub phase: String,
317 pub current: u64,
319 pub total: Option<u64>,
321}
322
323impl ProgressTick {
324 #[must_use]
326 pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
327 Self {
328 phase: phase.into(),
329 current,
330 total,
331 }
332 }
333}
334
335#[derive(Clone, Debug, Deserialize, Serialize)]
342pub struct DataRow {
343 pub stream: String,
345 pub sequence: u64,
347 pub payload: Box<RawValue>,
349}
350
351impl PartialEq for DataRow {
352 fn eq(&self, other: &Self) -> bool {
353 self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
354 }
355}
356
357impl Eq for DataRow {}
358
359#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
361#[non_exhaustive]
362pub struct StreamSummary {
363 pub total_rows: u64,
365 pub per_stream_counts: BTreeMap<String, u64>,
367 pub elapsed_micros: u64,
369 pub metadata: Option<Value>,
371}
372
373impl StreamSummary {
374 #[must_use]
377 pub fn new(
378 total_rows: u64,
379 per_stream_counts: BTreeMap<String, u64>,
380 elapsed: Duration,
381 metadata: Option<Value>,
382 ) -> Self {
383 Self {
384 total_rows,
385 per_stream_counts,
386 elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
387 metadata,
388 }
389 }
390}
391
392#[derive(Debug, Default)]
395#[non_exhaustive]
396pub struct DataRowEmitter {
397 sequences: BTreeMap<String, u64>,
398 count: u64,
399}
400
401impl DataRowEmitter {
402 pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
405 let stream = stream.into();
406 let sequence = self.sequences.entry(stream.clone()).or_insert(0);
407 let row = DataRow {
408 stream,
409 sequence: *sequence,
410 payload,
411 };
412 *sequence = sequence.saturating_add(1);
413 self.count = self.count.saturating_add(1);
414 row
415 }
416
417 #[cfg(test)]
418 fn emit(
419 &mut self,
420 writer: &mut impl Write,
421 stream: impl Into<String>,
422 payload: &Value,
423 ) -> Result<(), ProtocolError> {
424 let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
425 write_frame(writer, Message::DataRow(row), MAX_FRAME_BYTES)
426 }
427
428 #[must_use]
430 pub fn count(&self) -> u64 {
431 self.count
432 }
433
434 #[must_use]
436 pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
437 self.sequences.clone()
438 }
439}
440
441#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
444#[non_exhaustive]
445pub struct FatalExit {
446 pub status: String,
448 pub stderr: String,
450}
451
452impl FatalExit {
453 #[must_use]
455 pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
456 Self {
457 status: status.into(),
458 stderr: stderr.into(),
459 }
460 }
461}
462
463#[derive(Debug)]
465#[non_exhaustive]
466pub enum ProtocolError {
467 Io(io::Error),
469 Json(serde_json::Error),
471 FrameTooLarge {
473 len: u32,
475 max: u32,
477 },
478 VersionMismatch {
480 expected: u16,
482 actual: u16,
484 },
485}
486
487impl ProtocolError {
488 #[must_use]
492 pub fn is_eof(&self) -> bool {
493 matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
494 }
495}
496
497impl fmt::Display for ProtocolError {
498 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
499 match self {
500 Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
501 Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
502 Self::FrameTooLarge { len, max } => {
503 write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
504 }
505 Self::VersionMismatch { expected, actual } => {
506 write!(
507 f,
508 "worker protocol version mismatch: expected {expected}, received {actual}"
509 )
510 }
511 }
512 }
513}
514
515impl std::error::Error for ProtocolError {}
516
517impl From<io::Error> for ProtocolError {
518 fn from(value: io::Error) -> Self {
519 Self::Io(value)
520 }
521}
522
523impl From<serde_json::Error> for ProtocolError {
524 fn from(value: serde_json::Error) -> Self {
525 Self::Json(value)
526 }
527}
528
529pub fn write_frame(writer: &mut impl Write, message: Message, max_frame_bytes: u32) -> Result<(), ProtocolError> {
542 let bytes = serde_json::to_vec(&Frame::new(message))?;
543 let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
544 len: u32::MAX,
545 max: max_frame_bytes,
546 })?;
547 if len > max_frame_bytes {
548 return Err(ProtocolError::FrameTooLarge {
549 len,
550 max: max_frame_bytes,
551 });
552 }
553 writer.write_all(&len.to_be_bytes())?;
554 writer.write_all(&bytes)?;
555 writer.flush()?;
556 Ok(())
557}
558
559pub fn read_frame(reader: &mut impl Read, max_frame_bytes: u32) -> Result<Frame, ProtocolError> {
572 let mut len_bytes = [0_u8; 4];
573 reader.read_exact(&mut len_bytes)?;
574 let len = u32::from_be_bytes(len_bytes);
575 if len > max_frame_bytes {
576 return Err(ProtocolError::FrameTooLarge {
577 len,
578 max: max_frame_bytes,
579 });
580 }
581 let mut bytes = vec![0_u8; len as usize];
582 reader.read_exact(&mut bytes)?;
583 let frame: Frame = serde_json::from_slice(&bytes)?;
584 if frame.version != PROTOCOL_VERSION {
585 return Err(ProtocolError::VersionMismatch {
586 expected: PROTOCOL_VERSION,
587 actual: frame.version,
588 });
589 }
590 Ok(frame)
591}
592
593#[cfg(test)]
594mod tests {
595 #![allow(clippy::expect_used, clippy::panic)]
596
597 use std::io::Cursor;
598
599 use serde_json::json;
600 use serde_json::value::RawValue;
601
602 use super::{
603 DataRow, DataRowEmitter, MAX_FRAME_BYTES, MAX_FRAME_BYTES_HARD_CAP, MIN_FRAME_BYTES, Message, ProtocolError,
604 Response, read_frame, write_frame,
605 };
606
607 fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
608 serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
609 }
610
611 #[test]
612 fn data_row_round_trips_through_length_delimited_frame() {
613 let row = DataRow {
614 stream: "rows".to_owned(),
615 sequence: 7,
616 payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
617 };
618 let mut bytes = Vec::new();
619 write_frame(&mut bytes, Message::DataRow(row.clone()), MAX_FRAME_BYTES).expect("data row writes");
620 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("data row reads");
621 assert_eq!(frame.message, Message::DataRow(row));
622 }
623
624 #[test]
625 fn data_row_emitter_assigns_per_stream_sequences() {
626 let mut emitter = DataRowEmitter::default();
627 let mut bytes = Vec::new();
628 emitter
629 .emit(&mut bytes, "rows", &json!({ "i": 0 }))
630 .expect("first row writes");
631 emitter
632 .emit(&mut bytes, "warnings", &json!({ "i": 1 }))
633 .expect("second row writes");
634 emitter
635 .emit(&mut bytes, "rows", &json!({ "i": 2 }))
636 .expect("third row writes");
637 assert_eq!(emitter.count(), 3);
638
639 let mut cursor = Cursor::new(bytes);
640 let rows = [
641 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("first row reads"),
642 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("second row reads"),
643 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("third row reads"),
644 ];
645 assert_eq!(
646 rows.map(|frame| frame.message),
647 [
648 Message::DataRow(DataRow {
649 stream: "rows".to_owned(),
650 sequence: 0,
651 payload: raw_json(&json!({ "i": 0 })),
652 }),
653 Message::DataRow(DataRow {
654 stream: "warnings".to_owned(),
655 sequence: 0,
656 payload: raw_json(&json!({ "i": 1 })),
657 }),
658 Message::DataRow(DataRow {
659 stream: "rows".to_owned(),
660 sequence: 1,
661 payload: raw_json(&json!({ "i": 2 })),
662 }),
663 ],
664 );
665 }
666
667 #[test]
668 fn oversized_data_row_is_rejected_before_write() {
669 let row = DataRow {
670 stream: "rows".to_owned(),
671 sequence: 0,
672 payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
673 };
674 let mut bytes = Vec::new();
675 let err =
676 write_frame(&mut bytes, Message::DataRow(row), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
677 match err {
678 ProtocolError::FrameTooLarge { len, max } => {
679 assert!(len > max);
680 assert_eq!(max, MAX_FRAME_BYTES);
681 }
682 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
683 panic!("expected FrameTooLarge, got {other:?}");
684 }
685 }
686 }
687
688 #[test]
689 fn oversized_data_row_is_rejected_before_read_allocation() {
690 let mut bytes = Vec::new();
691 bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
692 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
693 match err {
694 ProtocolError::FrameTooLarge { len, max } => {
695 assert_eq!(len, MAX_FRAME_BYTES + 1);
696 assert_eq!(max, MAX_FRAME_BYTES);
697 }
698 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
699 panic!("expected FrameTooLarge, got {other:?}");
700 }
701 }
702 }
703
704 #[test]
705 fn larger_cap_accepts_frame_rejected_under_default() {
706 let raised = MAX_FRAME_BYTES.saturating_mul(8);
710 let row = DataRow {
711 stream: "rows".to_owned(),
712 sequence: 0,
713 payload: raw_json(&json!({ "blob": "x".repeat(2 * MAX_FRAME_BYTES as usize) })),
714 };
715 let mut buf = Vec::new();
716 write_frame(&mut buf, Message::DataRow(row.clone()), raised).expect("oversize-under-default frame writes");
717 let frame = read_frame(&mut Cursor::new(buf), raised).expect("oversize-under-default frame reads");
718 assert_eq!(frame.message, Message::DataRow(row));
719 }
720
721 #[test]
722 fn frame_cap_bounds_constants_are_consistent() {
723 const { assert!(MIN_FRAME_BYTES <= MAX_FRAME_BYTES) };
724 const { assert!(MAX_FRAME_BYTES <= MAX_FRAME_BYTES_HARD_CAP) };
725 }
726
727 #[test]
728 fn malformed_frame_payload_is_protocol_error() {
729 let mut bytes = Vec::new();
730 bytes.extend_from_slice(&1_u32.to_be_bytes());
731 bytes.push(b'{');
732 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("malformed JSON is rejected");
733 match err {
734 ProtocolError::Json(_) => {}
735 other @ (ProtocolError::Io(_)
736 | ProtocolError::FrameTooLarge { .. }
737 | ProtocolError::VersionMismatch { .. }) => {
738 panic!("expected Json error, got {other:?}");
739 }
740 }
741 }
742
743 #[test]
744 fn rows_complete_response_round_trips() {
745 let mut bytes = Vec::new();
746 write_frame(
747 &mut bytes,
748 Message::Response(Response::RowsComplete { count: 2 }),
749 MAX_FRAME_BYTES,
750 )
751 .expect("rows complete writes");
752 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("rows complete reads");
753 assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
754 }
755}