1use std::collections::BTreeMap;
17use std::fmt;
18use std::io::{self, Read, Write};
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use serde_json::value::RawValue;
24
25use crate::types::{
26 LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationRow, LeanWorkerDoctorReport,
27 LeanWorkerElabOptions, LeanWorkerElabResult, LeanWorkerKernelResult, LeanWorkerMetaResult,
28 LeanWorkerMetaTransparency, LeanWorkerModuleQuery, LeanWorkerModuleQueryOutcome, LeanWorkerRendered,
29};
30
31pub const PROTOCOL_VERSION: u16 = 6;
34
35pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
45
46pub const MIN_FRAME_BYTES: u32 = 64 * 1024;
49
50pub const MAX_FRAME_BYTES_HARD_CAP: u32 = 256 * 1024 * 1024;
53
54#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
56#[non_exhaustive]
57pub struct Frame {
58 pub version: u16,
60 pub message: Message,
62}
63
64impl Frame {
65 #[must_use]
67 pub fn new(message: Message) -> Self {
68 Self {
69 version: PROTOCOL_VERSION,
70 message,
71 }
72 }
73}
74
75#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
77#[serde(tag = "type", content = "body", rename_all = "snake_case")]
78#[non_exhaustive]
79pub enum Message {
80 Handshake {
83 worker_version: String,
85 protocol_version: u16,
87 },
88 ConfigureFrameLimit {
95 max_frame_bytes: u32,
97 },
98 Request(Request),
100 Response(Response),
102 Diagnostic(Diagnostic),
104 ProgressTick(ProgressTick),
106 DataRow(DataRow),
108 FatalExit(FatalExit),
111}
112
113#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
115#[serde(tag = "op", rename_all = "snake_case")]
116#[non_exhaustive]
117pub enum Request {
118 Health,
119 LoadFixtureCapability {
120 fixture_root: String,
121 },
122 CallFixtureMul {
123 fixture_root: String,
124 lhs: u64,
125 rhs: u64,
126 },
127 TriggerLeanPanic {
128 fixture_root: String,
129 },
130 OpenHostSession {
131 project_root: String,
132 mode: HostSessionMode,
133 imports: Vec<String>,
134 },
135 Elaborate {
136 source: String,
137 options: LeanWorkerElabOptions,
138 },
139 KernelCheck {
140 source: String,
141 options: LeanWorkerElabOptions,
142 progress: bool,
143 },
144 DeclarationKinds {
145 names: Vec<String>,
146 progress: bool,
147 },
148 DeclarationNames {
149 names: Vec<String>,
150 progress: bool,
151 },
152 RunDataStream {
153 export: String,
154 request_json: String,
155 progress: bool,
156 },
157 CapabilityMetadata {
158 export: String,
159 request_json: String,
160 },
161 CapabilityDoctor {
162 export: String,
163 request_json: String,
164 },
165 JsonCommand {
166 export: String,
167 request_json: String,
168 },
169 InferType {
170 source: String,
171 options: LeanWorkerElabOptions,
172 },
173 Whnf {
174 source: String,
175 options: LeanWorkerElabOptions,
176 },
177 IsDefEq {
178 lhs: String,
179 rhs: String,
180 transparency: LeanWorkerMetaTransparency,
181 options: LeanWorkerElabOptions,
182 },
183 Describe {
184 name: String,
185 },
186 ListDeclarationsStrings {
187 filter: LeanWorkerDeclarationFilter,
188 progress: bool,
189 },
190 DescribeBulk {
191 names: Vec<String>,
192 progress: bool,
193 },
194 ProcessModuleQuery {
195 source: String,
196 query: LeanWorkerModuleQuery,
197 options: LeanWorkerElabOptions,
198 },
199 EmitTestRows {
202 streams: Vec<String>,
203 },
204 EmitTestRowsThenExit,
205 EmitTestRowsThenPanic,
206 Terminate,
207}
208
209#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
211#[serde(tag = "kind", rename_all = "snake_case")]
212#[non_exhaustive]
213pub enum HostSessionMode {
214 Capability { package: String, lib_name: String },
216 ShimsOnly,
218}
219
220#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
222#[serde(tag = "status", rename_all = "snake_case")]
223#[non_exhaustive]
224pub enum Response {
225 HealthOk,
226 CapabilityLoaded,
227 U64 {
228 value: u64,
229 },
230 HostSessionOpened,
231 Elaboration {
232 outcome: LeanWorkerElabResult,
233 },
234 KernelCheck {
235 outcome: LeanWorkerKernelResult,
236 },
237 Strings {
238 values: Vec<String>,
239 },
240 StreamComplete {
241 summary: StreamSummary,
242 },
243 StreamExportFailed {
244 status_byte: u8,
245 },
246 StreamCallbackFailed {
247 status_byte: u8,
248 description: String,
249 },
250 StreamRowMalformed {
251 message: String,
252 },
253 CapabilityMetadata {
254 metadata: LeanWorkerCapabilityMetadata,
255 },
256 CapabilityDoctor {
257 report: LeanWorkerDoctorReport,
258 },
259 CapabilityMetadataMalformed {
260 message: String,
261 },
262 CapabilityDoctorMalformed {
263 message: String,
264 },
265 JsonCommand {
266 response_json: String,
267 },
268 MetaExpr {
269 result: LeanWorkerMetaResult<LeanWorkerRendered>,
270 },
271 MetaBool {
272 result: LeanWorkerMetaResult<bool>,
273 },
274 Declaration {
275 row: Option<LeanWorkerDeclarationRow>,
276 },
277 DeclarationBulk {
278 rows: Vec<LeanWorkerDeclarationRow>,
279 },
280 ProcessModuleQuery {
281 outcome: LeanWorkerModuleQueryOutcome,
282 },
283 RowsComplete {
284 count: u64,
285 },
286 Terminating,
287 Error {
288 code: String,
289 message: String,
290 },
291}
292
293#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
295#[non_exhaustive]
296pub struct Diagnostic {
297 pub code: String,
299 pub message: String,
301}
302
303impl Diagnostic {
304 #[must_use]
306 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
307 Self {
308 code: code.into(),
309 message: message.into(),
310 }
311 }
312}
313
314#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
317#[non_exhaustive]
318pub struct ProgressTick {
319 pub phase: String,
321 pub current: u64,
323 pub total: Option<u64>,
325}
326
327impl ProgressTick {
328 #[must_use]
330 pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
331 Self {
332 phase: phase.into(),
333 current,
334 total,
335 }
336 }
337}
338
339#[derive(Clone, Debug, Deserialize, Serialize)]
346pub struct DataRow {
347 pub stream: String,
349 pub sequence: u64,
351 pub payload: Box<RawValue>,
353}
354
355impl PartialEq for DataRow {
356 fn eq(&self, other: &Self) -> bool {
357 self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
358 }
359}
360
361impl Eq for DataRow {}
362
363#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
365#[non_exhaustive]
366pub struct StreamSummary {
367 pub total_rows: u64,
369 pub per_stream_counts: BTreeMap<String, u64>,
371 pub elapsed_micros: u64,
373 pub metadata: Option<Value>,
375}
376
377impl StreamSummary {
378 #[must_use]
381 pub fn new(
382 total_rows: u64,
383 per_stream_counts: BTreeMap<String, u64>,
384 elapsed: Duration,
385 metadata: Option<Value>,
386 ) -> Self {
387 Self {
388 total_rows,
389 per_stream_counts,
390 elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
391 metadata,
392 }
393 }
394}
395
396#[derive(Debug, Default)]
399#[non_exhaustive]
400pub struct DataRowEmitter {
401 sequences: BTreeMap<String, u64>,
402 count: u64,
403}
404
405impl DataRowEmitter {
406 pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
409 let stream = stream.into();
410 let sequence = self.sequences.entry(stream.clone()).or_insert(0);
411 let row = DataRow {
412 stream,
413 sequence: *sequence,
414 payload,
415 };
416 *sequence = sequence.saturating_add(1);
417 self.count = self.count.saturating_add(1);
418 row
419 }
420
421 #[cfg(test)]
422 fn emit(
423 &mut self,
424 writer: &mut impl Write,
425 stream: impl Into<String>,
426 payload: &Value,
427 ) -> Result<(), ProtocolError> {
428 let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
429 write_frame(writer, Message::DataRow(row), MAX_FRAME_BYTES)
430 }
431
432 #[must_use]
434 pub fn count(&self) -> u64 {
435 self.count
436 }
437
438 #[must_use]
440 pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
441 self.sequences.clone()
442 }
443}
444
445#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
448#[non_exhaustive]
449pub struct FatalExit {
450 pub status: String,
452 pub stderr: String,
454}
455
456impl FatalExit {
457 #[must_use]
459 pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
460 Self {
461 status: status.into(),
462 stderr: stderr.into(),
463 }
464 }
465}
466
467#[derive(Debug)]
469#[non_exhaustive]
470pub enum ProtocolError {
471 Io(io::Error),
473 Json(serde_json::Error),
475 FrameTooLarge {
477 len: u32,
479 max: u32,
481 },
482 VersionMismatch {
484 expected: u16,
486 actual: u16,
488 },
489}
490
491impl ProtocolError {
492 #[must_use]
496 pub fn is_eof(&self) -> bool {
497 matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
498 }
499}
500
501impl fmt::Display for ProtocolError {
502 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503 match self {
504 Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
505 Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
506 Self::FrameTooLarge { len, max } => {
507 write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
508 }
509 Self::VersionMismatch { expected, actual } => {
510 write!(
511 f,
512 "worker protocol version mismatch: expected {expected}, received {actual}"
513 )
514 }
515 }
516 }
517}
518
519impl std::error::Error for ProtocolError {}
520
521impl From<io::Error> for ProtocolError {
522 fn from(value: io::Error) -> Self {
523 Self::Io(value)
524 }
525}
526
527impl From<serde_json::Error> for ProtocolError {
528 fn from(value: serde_json::Error) -> Self {
529 Self::Json(value)
530 }
531}
532
533pub fn write_frame(writer: &mut impl Write, message: Message, max_frame_bytes: u32) -> Result<(), ProtocolError> {
546 let bytes = serde_json::to_vec(&Frame::new(message))?;
547 let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
548 len: u32::MAX,
549 max: max_frame_bytes,
550 })?;
551 if len > max_frame_bytes {
552 return Err(ProtocolError::FrameTooLarge {
553 len,
554 max: max_frame_bytes,
555 });
556 }
557 writer.write_all(&len.to_be_bytes())?;
558 writer.write_all(&bytes)?;
559 writer.flush()?;
560 Ok(())
561}
562
563pub fn read_frame(reader: &mut impl Read, max_frame_bytes: u32) -> Result<Frame, ProtocolError> {
576 let mut len_bytes = [0_u8; 4];
577 reader.read_exact(&mut len_bytes)?;
578 let len = u32::from_be_bytes(len_bytes);
579 if len > max_frame_bytes {
580 return Err(ProtocolError::FrameTooLarge {
581 len,
582 max: max_frame_bytes,
583 });
584 }
585 let mut bytes = vec![0_u8; len as usize];
586 reader.read_exact(&mut bytes)?;
587 let frame: Frame = serde_json::from_slice(&bytes)?;
588 if frame.version != PROTOCOL_VERSION {
589 return Err(ProtocolError::VersionMismatch {
590 expected: PROTOCOL_VERSION,
591 actual: frame.version,
592 });
593 }
594 Ok(frame)
595}
596
597#[cfg(test)]
598mod tests {
599 #![allow(clippy::expect_used, clippy::panic)]
600
601 use std::io::Cursor;
602
603 use serde_json::json;
604 use serde_json::value::RawValue;
605
606 use super::{
607 DataRow, DataRowEmitter, MAX_FRAME_BYTES, MAX_FRAME_BYTES_HARD_CAP, MIN_FRAME_BYTES, Message, ProtocolError,
608 Request, Response, read_frame, write_frame,
609 };
610 use crate::types::{
611 LeanWorkerElabFailure, LeanWorkerElabOptions, LeanWorkerModuleQuery, LeanWorkerModuleQueryOutcome,
612 LeanWorkerModuleQueryResult, LeanWorkerModuleSourceSpan, LeanWorkerRenderedInfo, LeanWorkerTypeAtResult,
613 };
614
615 fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
616 serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
617 }
618
619 #[test]
620 fn data_row_round_trips_through_length_delimited_frame() {
621 let row = DataRow {
622 stream: "rows".to_owned(),
623 sequence: 7,
624 payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
625 };
626 let mut bytes = Vec::new();
627 write_frame(&mut bytes, Message::DataRow(row.clone()), MAX_FRAME_BYTES).expect("data row writes");
628 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("data row reads");
629 assert_eq!(frame.message, Message::DataRow(row));
630 }
631
632 #[test]
633 fn data_row_emitter_assigns_per_stream_sequences() {
634 let mut emitter = DataRowEmitter::default();
635 let mut bytes = Vec::new();
636 emitter
637 .emit(&mut bytes, "rows", &json!({ "i": 0 }))
638 .expect("first row writes");
639 emitter
640 .emit(&mut bytes, "warnings", &json!({ "i": 1 }))
641 .expect("second row writes");
642 emitter
643 .emit(&mut bytes, "rows", &json!({ "i": 2 }))
644 .expect("third row writes");
645 assert_eq!(emitter.count(), 3);
646
647 let mut cursor = Cursor::new(bytes);
648 let rows = [
649 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("first row reads"),
650 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("second row reads"),
651 read_frame(&mut cursor, MAX_FRAME_BYTES).expect("third row reads"),
652 ];
653 assert_eq!(
654 rows.map(|frame| frame.message),
655 [
656 Message::DataRow(DataRow {
657 stream: "rows".to_owned(),
658 sequence: 0,
659 payload: raw_json(&json!({ "i": 0 })),
660 }),
661 Message::DataRow(DataRow {
662 stream: "warnings".to_owned(),
663 sequence: 0,
664 payload: raw_json(&json!({ "i": 1 })),
665 }),
666 Message::DataRow(DataRow {
667 stream: "rows".to_owned(),
668 sequence: 1,
669 payload: raw_json(&json!({ "i": 2 })),
670 }),
671 ],
672 );
673 }
674
675 #[test]
676 fn oversized_data_row_is_rejected_before_write() {
677 let row = DataRow {
678 stream: "rows".to_owned(),
679 sequence: 0,
680 payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
681 };
682 let mut bytes = Vec::new();
683 let err =
684 write_frame(&mut bytes, Message::DataRow(row), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
685 match err {
686 ProtocolError::FrameTooLarge { len, max } => {
687 assert!(len > max);
688 assert_eq!(max, MAX_FRAME_BYTES);
689 }
690 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
691 panic!("expected FrameTooLarge, got {other:?}");
692 }
693 }
694 }
695
696 #[test]
697 fn oversized_data_row_is_rejected_before_read_allocation() {
698 let mut bytes = Vec::new();
699 bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
700 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
701 match err {
702 ProtocolError::FrameTooLarge { len, max } => {
703 assert_eq!(len, MAX_FRAME_BYTES + 1);
704 assert_eq!(max, MAX_FRAME_BYTES);
705 }
706 other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
707 panic!("expected FrameTooLarge, got {other:?}");
708 }
709 }
710 }
711
712 #[test]
713 fn larger_cap_accepts_frame_rejected_under_default() {
714 let raised = MAX_FRAME_BYTES.saturating_mul(8);
718 let row = DataRow {
719 stream: "rows".to_owned(),
720 sequence: 0,
721 payload: raw_json(&json!({ "blob": "x".repeat(2 * MAX_FRAME_BYTES as usize) })),
722 };
723 let mut buf = Vec::new();
724 write_frame(&mut buf, Message::DataRow(row.clone()), raised).expect("oversize-under-default frame writes");
725 let frame = read_frame(&mut Cursor::new(buf), raised).expect("oversize-under-default frame reads");
726 assert_eq!(frame.message, Message::DataRow(row));
727 }
728
729 #[test]
730 fn frame_cap_bounds_constants_are_consistent() {
731 const { assert!(MIN_FRAME_BYTES <= MAX_FRAME_BYTES) };
732 const { assert!(MAX_FRAME_BYTES <= MAX_FRAME_BYTES_HARD_CAP) };
733 }
734
735 #[test]
736 fn malformed_frame_payload_is_protocol_error() {
737 let mut bytes = Vec::new();
738 bytes.extend_from_slice(&1_u32.to_be_bytes());
739 bytes.push(b'{');
740 let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("malformed JSON is rejected");
741 match err {
742 ProtocolError::Json(_) => {}
743 other @ (ProtocolError::Io(_)
744 | ProtocolError::FrameTooLarge { .. }
745 | ProtocolError::VersionMismatch { .. }) => {
746 panic!("expected Json error, got {other:?}");
747 }
748 }
749 }
750
751 #[test]
752 fn rows_complete_response_round_trips() {
753 let mut bytes = Vec::new();
754 write_frame(
755 &mut bytes,
756 Message::Response(Response::RowsComplete { count: 2 }),
757 MAX_FRAME_BYTES,
758 )
759 .expect("rows complete writes");
760 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("rows complete reads");
761 assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
762 }
763
764 #[test]
765 fn module_query_request_and_response_round_trip() {
766 let request = Message::Request(Request::ProcessModuleQuery {
767 source: "def x := 1\n#check x\n".to_owned(),
768 query: LeanWorkerModuleQuery::TypeAt { line: 2, column: 8 },
769 options: LeanWorkerElabOptions::default(),
770 });
771 let mut bytes = Vec::new();
772 write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("module query request writes");
773 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query request reads");
774 assert_eq!(frame.message, request);
775
776 let response = Message::Response(Response::ProcessModuleQuery {
777 outcome: LeanWorkerModuleQueryOutcome::Ok {
778 imports: Vec::new(),
779 result: LeanWorkerModuleQueryResult::TypeAt(LeanWorkerTypeAtResult::Term {
780 span: LeanWorkerModuleSourceSpan {
781 start_line: 2,
782 start_column: 8,
783 end_line: 2,
784 end_column: 9,
785 },
786 expr: LeanWorkerRenderedInfo {
787 value: "x".to_owned(),
788 truncated: false,
789 },
790 type_str: LeanWorkerRenderedInfo {
791 value: "Nat".to_owned(),
792 truncated: false,
793 },
794 expected_type: None,
795 }),
796 },
797 });
798 let mut bytes = Vec::new();
799 write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("module query response writes");
800 let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query response reads");
801 assert_eq!(frame.message, response);
802
803 let unsupported = LeanWorkerModuleQueryOutcome::Unsupported;
804 let json = serde_json::to_value(&unsupported).expect("unsupported serializes");
805 assert_eq!(json, json!({ "status": "unsupported" }));
806
807 let diagnostics = LeanWorkerModuleQueryResult::Diagnostics(LeanWorkerElabFailure {
808 diagnostics: Vec::new(),
809 truncated: false,
810 });
811 let json = serde_json::to_value(&diagnostics).expect("diagnostics serializes");
812 assert_eq!(
813 json,
814 json!({
815 "result": "diagnostics",
816 "body": {
817 "diagnostics": [],
818 "truncated": false
819 }
820 })
821 );
822 }
823}