1use std::collections::BTreeMap;
17use std::fmt;
18use std::io::{self, Read, Write};
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use serde_json::value::RawValue;
24
25use crate::types::{
26 LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationRow, LeanWorkerDoctorReport,
27 LeanWorkerElabOptions, LeanWorkerElabResult, LeanWorkerKernelResult, LeanWorkerMetaResult,
28 LeanWorkerMetaTransparency, LeanWorkerProcessFileOutcome, LeanWorkerProcessModuleOutcome, LeanWorkerRendered,
29};
30
31pub const PROTOCOL_VERSION: u16 = 3;
34
35pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
40
41#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
43#[non_exhaustive]
44pub struct Frame {
45 pub version: u16,
47 pub message: Message,
49}
50
51impl Frame {
52 #[must_use]
54 pub fn new(message: Message) -> Self {
55 Self {
56 version: PROTOCOL_VERSION,
57 message,
58 }
59 }
60}
61
62#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
64#[serde(tag = "type", content = "body", rename_all = "snake_case")]
65#[non_exhaustive]
66pub enum Message {
67 Handshake {
70 worker_version: String,
72 protocol_version: u16,
74 },
75 Request(Request),
77 Response(Response),
79 Diagnostic(Diagnostic),
81 ProgressTick(ProgressTick),
83 DataRow(DataRow),
85 FatalExit(FatalExit),
88}
89
90#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
92#[serde(tag = "op", rename_all = "snake_case")]
93#[non_exhaustive]
94pub enum Request {
95 Health,
96 LoadFixtureCapability {
97 fixture_root: String,
98 },
99 CallFixtureMul {
100 fixture_root: String,
101 lhs: u64,
102 rhs: u64,
103 },
104 TriggerLeanPanic {
105 fixture_root: String,
106 },
107 OpenHostSession {
108 project_root: String,
109 package: String,
110 lib_name: String,
111 imports: Vec<String>,
112 },
113 Elaborate {
114 source: String,
115 options: LeanWorkerElabOptions,
116 },
117 KernelCheck {
118 source: String,
119 options: LeanWorkerElabOptions,
120 progress: bool,
121 },
122 DeclarationKinds {
123 names: Vec<String>,
124 progress: bool,
125 },
126 DeclarationNames {
127 names: Vec<String>,
128 progress: bool,
129 },
130 RunDataStream {
131 export: String,
132 request_json: String,
133 progress: bool,
134 },
135 CapabilityMetadata {
136 export: String,
137 request_json: String,
138 },
139 CapabilityDoctor {
140 export: String,
141 request_json: String,
142 },
143 JsonCommand {
144 export: String,
145 request_json: String,
146 },
147 InferType {
148 source: String,
149 options: LeanWorkerElabOptions,
150 },
151 Whnf {
152 source: String,
153 options: LeanWorkerElabOptions,
154 },
155 IsDefEq {
156 lhs: String,
157 rhs: String,
158 transparency: LeanWorkerMetaTransparency,
159 options: LeanWorkerElabOptions,
160 },
161 Describe {
162 name: String,
163 },
164 ListDeclarationsStrings {
165 filter: LeanWorkerDeclarationFilter,
166 progress: bool,
167 },
168 DescribeBulk {
169 names: Vec<String>,
170 progress: bool,
171 },
172 ProcessFile {
173 source: String,
174 options: LeanWorkerElabOptions,
175 },
176 ProcessModule {
177 source: String,
178 options: LeanWorkerElabOptions,
179 },
180 EmitTestRows {
183 streams: Vec<String>,
184 },
185 EmitTestRowsThenExit,
186 EmitTestRowsThenPanic,
187 Terminate,
188}
189
190#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
192#[serde(tag = "status", rename_all = "snake_case")]
193#[non_exhaustive]
194pub enum Response {
195 HealthOk,
196 CapabilityLoaded,
197 U64 {
198 value: u64,
199 },
200 HostSessionOpened,
201 Elaboration {
202 outcome: LeanWorkerElabResult,
203 },
204 KernelCheck {
205 outcome: LeanWorkerKernelResult,
206 },
207 Strings {
208 values: Vec<String>,
209 },
210 StreamComplete {
211 summary: StreamSummary,
212 },
213 StreamExportFailed {
214 status_byte: u8,
215 },
216 StreamCallbackFailed {
217 status_byte: u8,
218 description: String,
219 },
220 StreamRowMalformed {
221 message: String,
222 },
223 CapabilityMetadata {
224 metadata: LeanWorkerCapabilityMetadata,
225 },
226 CapabilityDoctor {
227 report: LeanWorkerDoctorReport,
228 },
229 CapabilityMetadataMalformed {
230 message: String,
231 },
232 CapabilityDoctorMalformed {
233 message: String,
234 },
235 JsonCommand {
236 response_json: String,
237 },
238 MetaExpr {
239 result: LeanWorkerMetaResult<LeanWorkerRendered>,
240 },
241 MetaBool {
242 result: LeanWorkerMetaResult<bool>,
243 },
244 Declaration {
245 row: Option<LeanWorkerDeclarationRow>,
246 },
247 DeclarationBulk {
248 rows: Vec<LeanWorkerDeclarationRow>,
249 },
250 ProcessFile {
251 outcome: LeanWorkerProcessFileOutcome,
252 },
253 ProcessModule {
254 outcome: LeanWorkerProcessModuleOutcome,
255 },
256 RowsComplete {
257 count: u64,
258 },
259 Terminating,
260 Error {
261 code: String,
262 message: String,
263 },
264}
265
266#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
268#[non_exhaustive]
269pub struct Diagnostic {
270 pub code: String,
272 pub message: String,
274}
275
276impl Diagnostic {
277 #[must_use]
279 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
280 Self {
281 code: code.into(),
282 message: message.into(),
283 }
284 }
285}
286
287#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
290#[non_exhaustive]
291pub struct ProgressTick {
292 pub phase: String,
294 pub current: u64,
296 pub total: Option<u64>,
298}
299
300impl ProgressTick {
301 #[must_use]
303 pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
304 Self {
305 phase: phase.into(),
306 current,
307 total,
308 }
309 }
310}
311
312#[derive(Clone, Debug, Deserialize, Serialize)]
319pub struct DataRow {
320 pub stream: String,
322 pub sequence: u64,
324 pub payload: Box<RawValue>,
326}
327
328impl PartialEq for DataRow {
329 fn eq(&self, other: &Self) -> bool {
330 self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
331 }
332}
333
334impl Eq for DataRow {}
335
336#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
338#[non_exhaustive]
339pub struct StreamSummary {
340 pub total_rows: u64,
342 pub per_stream_counts: BTreeMap<String, u64>,
344 pub elapsed_micros: u64,
346 pub metadata: Option<Value>,
348}
349
350impl StreamSummary {
351 #[must_use]
354 pub fn new(
355 total_rows: u64,
356 per_stream_counts: BTreeMap<String, u64>,
357 elapsed: Duration,
358 metadata: Option<Value>,
359 ) -> Self {
360 Self {
361 total_rows,
362 per_stream_counts,
363 elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
364 metadata,
365 }
366 }
367}
368
369#[derive(Debug, Default)]
372#[non_exhaustive]
373pub struct DataRowEmitter {
374 sequences: BTreeMap<String, u64>,
375 count: u64,
376}
377
378impl DataRowEmitter {
379 pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
382 let stream = stream.into();
383 let sequence = self.sequences.entry(stream.clone()).or_insert(0);
384 let row = DataRow {
385 stream,
386 sequence: *sequence,
387 payload,
388 };
389 *sequence = sequence.saturating_add(1);
390 self.count = self.count.saturating_add(1);
391 row
392 }
393
394 #[cfg(test)]
395 fn emit(
396 &mut self,
397 writer: &mut impl Write,
398 stream: impl Into<String>,
399 payload: &Value,
400 ) -> Result<(), ProtocolError> {
401 let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
402 write_frame(writer, Message::DataRow(row))
403 }
404
405 #[must_use]
407 pub fn count(&self) -> u64 {
408 self.count
409 }
410
411 #[must_use]
413 pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
414 self.sequences.clone()
415 }
416}
417
418#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
421#[non_exhaustive]
422pub struct FatalExit {
423 pub status: String,
425 pub stderr: String,
427}
428
429impl FatalExit {
430 #[must_use]
432 pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
433 Self {
434 status: status.into(),
435 stderr: stderr.into(),
436 }
437 }
438}
439
440#[derive(Debug)]
442#[non_exhaustive]
443pub enum ProtocolError {
444 Io(io::Error),
446 Json(serde_json::Error),
448 FrameTooLarge {
450 len: u32,
452 max: u32,
454 },
455 VersionMismatch {
457 expected: u16,
459 actual: u16,
461 },
462}
463
464impl ProtocolError {
465 #[must_use]
469 pub fn is_eof(&self) -> bool {
470 matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
471 }
472}
473
474impl fmt::Display for ProtocolError {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 match self {
477 Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
478 Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
479 Self::FrameTooLarge { len, max } => {
480 write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
481 }
482 Self::VersionMismatch { expected, actual } => {
483 write!(
484 f,
485 "worker protocol version mismatch: expected {expected}, received {actual}"
486 )
487 }
488 }
489 }
490}
491
492impl std::error::Error for ProtocolError {}
493
494impl From<io::Error> for ProtocolError {
495 fn from(value: io::Error) -> Self {
496 Self::Io(value)
497 }
498}
499
500impl From<serde_json::Error> for ProtocolError {
501 fn from(value: serde_json::Error) -> Self {
502 Self::Json(value)
503 }
504}
505
506pub fn write_frame(writer: &mut impl Write, message: Message) -> Result<(), ProtocolError> {
514 let bytes = serde_json::to_vec(&Frame::new(message))?;
515 let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
516 len: u32::MAX,
517 max: MAX_FRAME_BYTES,
518 })?;
519 if len > MAX_FRAME_BYTES {
520 return Err(ProtocolError::FrameTooLarge {
521 len,
522 max: MAX_FRAME_BYTES,
523 });
524 }
525 writer.write_all(&len.to_be_bytes())?;
526 writer.write_all(&bytes)?;
527 writer.flush()?;
528 Ok(())
529}
530
531pub fn read_frame(reader: &mut impl Read) -> Result<Frame, ProtocolError> {
541 let mut len_bytes = [0_u8; 4];
542 reader.read_exact(&mut len_bytes)?;
543 let len = u32::from_be_bytes(len_bytes);
544 if len > MAX_FRAME_BYTES {
545 return Err(ProtocolError::FrameTooLarge {
546 len,
547 max: MAX_FRAME_BYTES,
548 });
549 }
550 let mut bytes = vec![0_u8; len as usize];
551 reader.read_exact(&mut bytes)?;
552 let frame: Frame = serde_json::from_slice(&bytes)?;
553 if frame.version != PROTOCOL_VERSION {
554 return Err(ProtocolError::VersionMismatch {
555 expected: PROTOCOL_VERSION,
556 actual: frame.version,
557 });
558 }
559 Ok(frame)
560}
561
562#[cfg(test)]
563mod tests {
564 #![allow(clippy::expect_used, clippy::panic)]
565
566 use std::io::Cursor;
567
568 use serde_json::json;
569 use serde_json::value::RawValue;
570
571 use super::{DataRow, DataRowEmitter, MAX_FRAME_BYTES, Message, ProtocolError, Response, read_frame, write_frame};
572
573 fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
574 serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
575 }
576
577 #[test]
578 fn data_row_round_trips_through_length_delimited_frame() {
579 let row = DataRow {
580 stream: "rows".to_owned(),
581 sequence: 7,
582 payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
583 };
584 let mut bytes = Vec::new();
585 write_frame(&mut bytes, Message::DataRow(row.clone())).expect("data row writes");
586 let frame = read_frame(&mut Cursor::new(bytes)).expect("data row reads");
587 assert_eq!(frame.message, Message::DataRow(row));
588 }
589
590 #[test]
591 fn data_row_emitter_assigns_per_stream_sequences() {
592 let mut emitter = DataRowEmitter::default();
593 let mut bytes = Vec::new();
594 emitter
595 .emit(&mut bytes, "rows", &json!({ "i": 0 }))
596 .expect("first row writes");
597 emitter
598 .emit(&mut bytes, "warnings", &json!({ "i": 1 }))
599 .expect("second row writes");
600 emitter
601 .emit(&mut bytes, "rows", &json!({ "i": 2 }))
602 .expect("third row writes");
603 assert_eq!(emitter.count(), 3);
604
605 let mut cursor = Cursor::new(bytes);
606 let rows = [
607 read_frame(&mut cursor).expect("first row reads"),
608 read_frame(&mut cursor).expect("second row reads"),
609 read_frame(&mut cursor).expect("third row reads"),
610 ];
611 assert_eq!(
612 rows.map(|frame| frame.message),
613 [
614 Message::DataRow(DataRow {
615 stream: "rows".to_owned(),
616 sequence: 0,
617 payload: raw_json(&json!({ "i": 0 })),
618 }),
619 Message::DataRow(DataRow {
620 stream: "warnings".to_owned(),
621 sequence: 0,
622 payload: raw_json(&json!({ "i": 1 })),
623 }),
624 Message::DataRow(DataRow {
625 stream: "rows".to_owned(),
626 sequence: 1,
627 payload: raw_json(&json!({ "i": 2 })),
628 }),
629 ],
630 );
631 }
632
633 #[test]
634 fn oversized_data_row_is_rejected_before_write() {
635 let row = DataRow {
636 stream: "rows".to_owned(),
637 sequence: 0,
638 payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
639 };
640 let mut bytes = Vec::new();
641 let err = write_frame(&mut bytes, Message::DataRow(row)).expect_err("oversized frame is rejected");
642 match err {
643 ProtocolError::FrameTooLarge { len, max } => {
644 assert!(len > max);
645 assert_eq!(max, MAX_FRAME_BYTES);
646 }
647 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
648 panic!("expected FrameTooLarge, got {other:?}");
649 }
650 }
651 }
652
653 #[test]
654 fn oversized_data_row_is_rejected_before_read_allocation() {
655 let mut bytes = Vec::new();
656 bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
657 let err = read_frame(&mut Cursor::new(bytes)).expect_err("oversized frame is rejected");
658 match err {
659 ProtocolError::FrameTooLarge { len, max } => {
660 assert_eq!(len, MAX_FRAME_BYTES + 1);
661 assert_eq!(max, MAX_FRAME_BYTES);
662 }
663 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
664 panic!("expected FrameTooLarge, got {other:?}");
665 }
666 }
667 }
668
669 #[test]
670 fn malformed_frame_payload_is_protocol_error() {
671 let mut bytes = Vec::new();
672 bytes.extend_from_slice(&1_u32.to_be_bytes());
673 bytes.push(b'{');
674 let err = read_frame(&mut Cursor::new(bytes)).expect_err("malformed JSON is rejected");
675 match err {
676 ProtocolError::Json(_) => {}
677 other @ (ProtocolError::Io(_)
678 | ProtocolError::FrameTooLarge { .. }
679 | ProtocolError::VersionMismatch { .. }) => {
680 panic!("expected Json error, got {other:?}");
681 }
682 }
683 }
684
685 #[test]
686 fn rows_complete_response_round_trips() {
687 let mut bytes = Vec::new();
688 write_frame(&mut bytes, Message::Response(Response::RowsComplete { count: 2 })).expect("rows complete writes");
689 let frame = read_frame(&mut Cursor::new(bytes)).expect("rows complete reads");
690 assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
691 }
692}