use std::collections::BTreeMap;
use std::fmt;
use std::io::{self, Read, Write};
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::value::RawValue;
use crate::types::{
LeanWorkerCapabilityMetadata, LeanWorkerDeclarationFilter, LeanWorkerDeclarationInspectionRequest,
LeanWorkerDeclarationInspectionResult, LeanWorkerDeclarationRow, LeanWorkerDeclarationSearch,
LeanWorkerDeclarationSearchResult, LeanWorkerDeclarationType, LeanWorkerDeclarationVerificationRequest,
LeanWorkerDeclarationVerificationResult, LeanWorkerDoctorReport, LeanWorkerElabOptions, LeanWorkerElabResult,
LeanWorkerKernelResult, LeanWorkerMetaResult, LeanWorkerMetaTransparency, LeanWorkerModuleQuery,
LeanWorkerModuleQueryBatchOutcome, LeanWorkerModuleQueryOutcome, LeanWorkerModuleQuerySelector,
LeanWorkerOutputBudgets, LeanWorkerProofAttemptRequest, LeanWorkerProofAttemptResult, LeanWorkerRendered,
};
pub const PROTOCOL_VERSION: u16 = 8;
pub const MAX_FRAME_BYTES: u32 = 1024 * 1024;
pub const MIN_FRAME_BYTES: u32 = 64 * 1024;
pub const MAX_FRAME_BYTES_HARD_CAP: u32 = 256 * 1024 * 1024;
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[non_exhaustive]
pub struct Frame {
pub version: u16,
pub message: Message,
}
impl Frame {
#[must_use]
pub fn new(message: Message) -> Self {
Self {
version: PROTOCOL_VERSION,
message,
}
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "type", content = "body", rename_all = "snake_case")]
#[non_exhaustive]
pub enum Message {
Handshake {
worker_version: String,
protocol_version: u16,
},
ConfigureFrameLimit {
max_frame_bytes: u32,
},
Request(Request),
Response(Response),
Diagnostic(Diagnostic),
ProgressTick(ProgressTick),
DataRow(DataRow),
FatalExit(FatalExit),
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "op", rename_all = "snake_case")]
#[non_exhaustive]
pub enum Request {
Health,
LoadFixtureCapability {
manifest_path: String,
},
CallFixtureMul {
manifest_path: String,
lhs: u64,
rhs: u64,
},
TriggerLeanPanic {
manifest_path: String,
},
OpenHostSession {
project_root: String,
mode: HostSessionMode,
imports: Vec<String>,
},
Elaborate {
source: String,
options: LeanWorkerElabOptions,
},
KernelCheck {
source: String,
options: LeanWorkerElabOptions,
progress: bool,
},
DeclarationKinds {
names: Vec<String>,
progress: bool,
},
DeclarationNames {
names: Vec<String>,
progress: bool,
},
RunDataStream {
export: String,
request_json: String,
progress: bool,
},
CapabilityMetadata {
export: String,
request_json: String,
},
CapabilityDoctor {
export: String,
request_json: String,
},
JsonCommand {
export: String,
request_json: String,
},
InferType {
source: String,
options: LeanWorkerElabOptions,
},
Whnf {
source: String,
options: LeanWorkerElabOptions,
},
IsDefEq {
lhs: String,
rhs: String,
transparency: LeanWorkerMetaTransparency,
options: LeanWorkerElabOptions,
},
Describe {
name: String,
},
SearchDeclarations {
search: LeanWorkerDeclarationSearch,
},
DeclarationType {
name: String,
max_bytes: usize,
},
InspectDeclaration {
request: LeanWorkerDeclarationInspectionRequest,
},
AttemptProof {
request: LeanWorkerProofAttemptRequest,
options: LeanWorkerElabOptions,
progress: bool,
},
VerifyDeclaration {
request: LeanWorkerDeclarationVerificationRequest,
options: LeanWorkerElabOptions,
progress: bool,
},
ListDeclarationsStrings {
filter: LeanWorkerDeclarationFilter,
progress: bool,
},
DescribeBulk {
names: Vec<String>,
progress: bool,
},
ProcessModuleQuery {
source: String,
query: LeanWorkerModuleQuery,
options: LeanWorkerElabOptions,
},
ProcessModuleQueryBatch {
source: String,
selectors: Vec<LeanWorkerModuleQuerySelector>,
budgets: LeanWorkerOutputBudgets,
options: LeanWorkerElabOptions,
},
ClearModuleSnapshotCache,
EmitTestRows {
streams: Vec<String>,
},
EmitTestRowsThenExit,
EmitTestRowsThenPanic,
Terminate,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
#[non_exhaustive]
pub enum HostSessionMode {
Capability {
package: String,
lib_name: String,
manifest_path: Option<String>,
},
ShimsOnly,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "status", rename_all = "snake_case")]
#[non_exhaustive]
pub enum Response {
HealthOk,
CapabilityLoaded,
U64 {
value: u64,
},
HostSessionOpened,
Elaboration {
outcome: LeanWorkerElabResult,
},
KernelCheck {
outcome: LeanWorkerKernelResult,
},
Strings {
values: Vec<String>,
},
StreamComplete {
summary: StreamSummary,
},
StreamExportFailed {
status_byte: u8,
},
StreamCallbackFailed {
status_byte: u8,
description: String,
},
StreamRowMalformed {
message: String,
},
CapabilityMetadata {
metadata: LeanWorkerCapabilityMetadata,
},
CapabilityDoctor {
report: LeanWorkerDoctorReport,
},
CapabilityMetadataMalformed {
message: String,
},
CapabilityDoctorMalformed {
message: String,
},
JsonCommand {
response_json: String,
},
MetaExpr {
result: LeanWorkerMetaResult<LeanWorkerRendered>,
},
MetaBool {
result: LeanWorkerMetaResult<bool>,
},
Declaration {
row: Option<LeanWorkerDeclarationRow>,
},
DeclarationSearch {
result: LeanWorkerDeclarationSearchResult,
},
DeclarationType {
row: Option<LeanWorkerDeclarationType>,
},
DeclarationInspection {
result: LeanWorkerDeclarationInspectionResult,
},
ProofAttempt {
result: LeanWorkerProofAttemptResult,
},
DeclarationVerification {
result: LeanWorkerDeclarationVerificationResult,
},
DeclarationBulk {
rows: Vec<LeanWorkerDeclarationRow>,
},
ProcessModuleQuery {
outcome: LeanWorkerModuleQueryOutcome,
},
ProcessModuleQueryBatch {
outcome: LeanWorkerModuleQueryBatchOutcome,
},
ModuleSnapshotCacheCleared {
result: crate::types::LeanWorkerModuleSnapshotCacheClearResult,
},
RowsComplete {
count: u64,
},
Terminating,
Error {
code: String,
message: String,
},
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[non_exhaustive]
pub struct Diagnostic {
pub code: String,
pub message: String,
}
impl Diagnostic {
#[must_use]
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
code: code.into(),
message: message.into(),
}
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[non_exhaustive]
pub struct ProgressTick {
pub phase: String,
pub current: u64,
pub total: Option<u64>,
}
impl ProgressTick {
#[must_use]
pub fn new(phase: impl Into<String>, current: u64, total: Option<u64>) -> Self {
Self {
phase: phase.into(),
current,
total,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct DataRow {
pub stream: String,
pub sequence: u64,
pub payload: Box<RawValue>,
}
impl PartialEq for DataRow {
fn eq(&self, other: &Self) -> bool {
self.stream == other.stream && self.sequence == other.sequence && self.payload.get() == other.payload.get()
}
}
impl Eq for DataRow {}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[non_exhaustive]
pub struct StreamSummary {
pub total_rows: u64,
pub per_stream_counts: BTreeMap<String, u64>,
pub elapsed_micros: u64,
pub metadata: Option<Value>,
}
impl StreamSummary {
#[must_use]
pub fn new(
total_rows: u64,
per_stream_counts: BTreeMap<String, u64>,
elapsed: Duration,
metadata: Option<Value>,
) -> Self {
Self {
total_rows,
per_stream_counts,
elapsed_micros: elapsed.as_micros().try_into().unwrap_or(u64::MAX),
metadata,
}
}
}
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct DataRowEmitter {
sequences: BTreeMap<String, u64>,
count: u64,
}
impl DataRowEmitter {
pub fn next(&mut self, stream: impl Into<String>, payload: Box<RawValue>) -> DataRow {
let stream = stream.into();
let sequence = self.sequences.entry(stream.clone()).or_insert(0);
let row = DataRow {
stream,
sequence: *sequence,
payload,
};
*sequence = sequence.saturating_add(1);
self.count = self.count.saturating_add(1);
row
}
#[cfg(test)]
fn emit(
&mut self,
writer: &mut impl Write,
stream: impl Into<String>,
payload: &Value,
) -> Result<(), ProtocolError> {
let row = self.next(stream, serde_json::value::to_raw_value(payload)?);
write_frame(writer, Message::DataRow(row), MAX_FRAME_BYTES)
}
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[must_use]
pub fn per_stream_counts(&self) -> BTreeMap<String, u64> {
self.sequences.clone()
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[non_exhaustive]
pub struct FatalExit {
pub status: String,
pub stderr: String,
}
impl FatalExit {
#[must_use]
pub fn new(status: impl Into<String>, stderr: impl Into<String>) -> Self {
Self {
status: status.into(),
stderr: stderr.into(),
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ProtocolError {
Io(io::Error),
Json(serde_json::Error),
FrameTooLarge {
len: u32,
max: u32,
},
VersionMismatch {
expected: u16,
actual: u16,
},
}
impl ProtocolError {
#[must_use]
pub fn is_eof(&self) -> bool {
matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
}
}
impl fmt::Display for ProtocolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(err) => write!(f, "worker protocol I/O failed: {err}"),
Self::Json(err) => write!(f, "worker protocol JSON decode failed: {err}"),
Self::FrameTooLarge { len, max } => {
write!(f, "worker protocol frame too large: {len} bytes exceeds {max}")
}
Self::VersionMismatch { expected, actual } => {
write!(
f,
"worker protocol version mismatch: expected {expected}, received {actual}"
)
}
}
}
}
impl std::error::Error for ProtocolError {}
impl From<io::Error> for ProtocolError {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
impl From<serde_json::Error> for ProtocolError {
fn from(value: serde_json::Error) -> Self {
Self::Json(value)
}
}
pub fn write_frame(writer: &mut impl Write, message: Message, max_frame_bytes: u32) -> Result<(), ProtocolError> {
let bytes = serde_json::to_vec(&Frame::new(message))?;
let len = u32::try_from(bytes.len()).map_err(|_| ProtocolError::FrameTooLarge {
len: u32::MAX,
max: max_frame_bytes,
})?;
if len > max_frame_bytes {
return Err(ProtocolError::FrameTooLarge {
len,
max: max_frame_bytes,
});
}
writer.write_all(&len.to_be_bytes())?;
writer.write_all(&bytes)?;
writer.flush()?;
Ok(())
}
pub fn read_frame(reader: &mut impl Read, max_frame_bytes: u32) -> Result<Frame, ProtocolError> {
let mut len_bytes = [0_u8; 4];
reader.read_exact(&mut len_bytes)?;
let len = u32::from_be_bytes(len_bytes);
if len > max_frame_bytes {
return Err(ProtocolError::FrameTooLarge {
len,
max: max_frame_bytes,
});
}
let mut bytes = vec![0_u8; len as usize];
reader.read_exact(&mut bytes)?;
let frame: Frame = serde_json::from_slice(&bytes)?;
if frame.version != PROTOCOL_VERSION {
return Err(ProtocolError::VersionMismatch {
expected: PROTOCOL_VERSION,
actual: frame.version,
});
}
Ok(frame)
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::panic)]
use std::io::Cursor;
use serde_json::json;
use serde_json::value::RawValue;
use super::{
DataRow, DataRowEmitter, MAX_FRAME_BYTES, MAX_FRAME_BYTES_HARD_CAP, MIN_FRAME_BYTES, Message, ProtocolError,
Request, Response, read_frame, write_frame,
};
use crate::types::{
LeanWorkerDeclarationFilter, LeanWorkerDeclarationFlags, LeanWorkerDeclarationInspection,
LeanWorkerDeclarationInspectionFields, LeanWorkerDeclarationInspectionRequest,
LeanWorkerDeclarationInspectionResult, LeanWorkerDeclarationNameMatch, LeanWorkerDeclarationProofSearchFacts,
LeanWorkerDeclarationSearch, LeanWorkerDeclarationSearchBias, LeanWorkerDeclarationSearchFacts,
LeanWorkerDeclarationSearchPruning, LeanWorkerDeclarationSearchResult, LeanWorkerDeclarationSearchRow,
LeanWorkerDeclarationSearchScope, LeanWorkerDeclarationSearchTimings, LeanWorkerDeclarationTargetInfo,
LeanWorkerDeclarationVerificationFacts, LeanWorkerDeclarationVerificationRequest,
LeanWorkerDeclarationVerificationResult, LeanWorkerDeclarationVerificationStatus,
LeanWorkerDeclarationVerificationTarget, LeanWorkerElabFailure, LeanWorkerElabOptions,
LeanWorkerModuleCacheStatus, LeanWorkerModuleQuery, LeanWorkerModuleQueryBatchEnvelope,
LeanWorkerModuleQueryBatchItem, LeanWorkerModuleQueryBatchOutcome, LeanWorkerModuleQueryBatchResult,
LeanWorkerModuleQueryCacheFacts, LeanWorkerModuleQueryOutcome, LeanWorkerModuleQueryResult,
LeanWorkerModuleQuerySelector, LeanWorkerModuleQueryTimings, LeanWorkerModuleSourceSpan,
LeanWorkerOutputBudgets, LeanWorkerProofAttemptEnvelope, LeanWorkerProofAttemptRequest,
LeanWorkerProofAttemptResult, LeanWorkerProofAttemptRow, LeanWorkerProofAttemptStatus,
LeanWorkerProofCandidate, LeanWorkerProofEditTarget, LeanWorkerProofPositionSelector,
LeanWorkerProofPositionSummary, LeanWorkerProofStateResult, LeanWorkerRenderedInfo, LeanWorkerRendering,
LeanWorkerSorryPolicy, LeanWorkerSourceRange, LeanWorkerTypeAtResult,
};
fn raw_json(value: &serde_json::Value) -> Box<RawValue> {
serde_json::value::to_raw_value(value).expect("test JSON converts to raw value")
}
fn assert_frame_round_trips(message: &Message) {
let mut bytes = Vec::new();
write_frame(&mut bytes, message.clone(), MAX_FRAME_BYTES).expect("frame writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("frame reads");
assert_eq!(&frame.message, message);
}
fn declaration_target_info_fixture(declaration_name: &str) -> LeanWorkerDeclarationTargetInfo {
let span = LeanWorkerModuleSourceSpan {
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 10,
};
let short_name = declaration_name.rsplit('.').next().unwrap_or(declaration_name);
LeanWorkerDeclarationTargetInfo {
short_name: short_name.to_owned(),
declaration_name: declaration_name.to_owned(),
namespace_name: declaration_name
.strip_suffix(&format!(".{short_name}"))
.unwrap_or("")
.to_owned(),
declaration_kind: "theorem".to_owned(),
declaration_span: span.clone(),
name_span: span.clone(),
body_span: span,
}
}
fn verification_facts_fixture(
candidates: Vec<LeanWorkerDeclarationTargetInfo>,
axioms_available: bool,
) -> LeanWorkerDeclarationVerificationFacts {
LeanWorkerDeclarationVerificationFacts {
target: None,
diagnostics: LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
},
unresolved_goals: Vec::new(),
contains_sorry: false,
contains_admit: false,
contains_sorry_ax: false,
axioms: Vec::new(),
axioms_truncated: false,
output_truncated: false,
candidates,
axioms_available,
}
}
#[test]
fn data_row_round_trips_through_length_delimited_frame() {
let row = DataRow {
stream: "rows".to_owned(),
sequence: 7,
payload: raw_json(&json!({ "name": "Nat.add", "score": 3 })),
};
let mut bytes = Vec::new();
write_frame(&mut bytes, Message::DataRow(row.clone()), MAX_FRAME_BYTES).expect("data row writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("data row reads");
assert_eq!(frame.message, Message::DataRow(row));
}
#[test]
fn data_row_emitter_assigns_per_stream_sequences() {
let mut emitter = DataRowEmitter::default();
let mut bytes = Vec::new();
emitter
.emit(&mut bytes, "rows", &json!({ "i": 0 }))
.expect("first row writes");
emitter
.emit(&mut bytes, "warnings", &json!({ "i": 1 }))
.expect("second row writes");
emitter
.emit(&mut bytes, "rows", &json!({ "i": 2 }))
.expect("third row writes");
assert_eq!(emitter.count(), 3);
let mut cursor = Cursor::new(bytes);
let rows = [
read_frame(&mut cursor, MAX_FRAME_BYTES).expect("first row reads"),
read_frame(&mut cursor, MAX_FRAME_BYTES).expect("second row reads"),
read_frame(&mut cursor, MAX_FRAME_BYTES).expect("third row reads"),
];
assert_eq!(
rows.map(|frame| frame.message),
[
Message::DataRow(DataRow {
stream: "rows".to_owned(),
sequence: 0,
payload: raw_json(&json!({ "i": 0 })),
}),
Message::DataRow(DataRow {
stream: "warnings".to_owned(),
sequence: 0,
payload: raw_json(&json!({ "i": 1 })),
}),
Message::DataRow(DataRow {
stream: "rows".to_owned(),
sequence: 1,
payload: raw_json(&json!({ "i": 2 })),
}),
],
);
}
#[test]
fn oversized_data_row_is_rejected_before_write() {
let row = DataRow {
stream: "rows".to_owned(),
sequence: 0,
payload: raw_json(&json!({ "blob": "x".repeat(MAX_FRAME_BYTES as usize) })),
};
let mut bytes = Vec::new();
let err =
write_frame(&mut bytes, Message::DataRow(row), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
match err {
ProtocolError::FrameTooLarge { len, max } => {
assert!(len > max);
assert_eq!(max, MAX_FRAME_BYTES);
}
other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
panic!("expected FrameTooLarge, got {other:?}");
}
}
}
#[test]
fn oversized_data_row_is_rejected_before_read_allocation() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&(MAX_FRAME_BYTES.saturating_add(1)).to_be_bytes());
let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("oversized frame is rejected");
match err {
ProtocolError::FrameTooLarge { len, max } => {
assert_eq!(len, MAX_FRAME_BYTES + 1);
assert_eq!(max, MAX_FRAME_BYTES);
}
other @ (ProtocolError::Io(_) | ProtocolError::Json(_) | ProtocolError::VersionMismatch { .. }) => {
panic!("expected FrameTooLarge, got {other:?}");
}
}
}
#[test]
fn larger_cap_accepts_frame_rejected_under_default() {
let raised = MAX_FRAME_BYTES.saturating_mul(8);
let row = DataRow {
stream: "rows".to_owned(),
sequence: 0,
payload: raw_json(&json!({ "blob": "x".repeat(2 * MAX_FRAME_BYTES as usize) })),
};
let mut buf = Vec::new();
write_frame(&mut buf, Message::DataRow(row.clone()), raised).expect("oversize-under-default frame writes");
let frame = read_frame(&mut Cursor::new(buf), raised).expect("oversize-under-default frame reads");
assert_eq!(frame.message, Message::DataRow(row));
}
#[test]
fn frame_cap_bounds_constants_are_consistent() {
const { assert!(MIN_FRAME_BYTES <= MAX_FRAME_BYTES) };
const { assert!(MAX_FRAME_BYTES <= MAX_FRAME_BYTES_HARD_CAP) };
}
#[test]
fn malformed_frame_payload_is_protocol_error() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&1_u32.to_be_bytes());
bytes.push(b'{');
let err = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect_err("malformed JSON is rejected");
match err {
ProtocolError::Json(_) => {}
other @ (ProtocolError::Io(_)
| ProtocolError::FrameTooLarge { .. }
| ProtocolError::VersionMismatch { .. }) => {
panic!("expected Json error, got {other:?}");
}
}
}
#[test]
fn rows_complete_response_round_trips() {
let mut bytes = Vec::new();
write_frame(
&mut bytes,
Message::Response(Response::RowsComplete { count: 2 }),
MAX_FRAME_BYTES,
)
.expect("rows complete writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("rows complete reads");
assert_eq!(frame.message, Message::Response(Response::RowsComplete { count: 2 }));
}
#[test]
fn declaration_search_request_and_response_round_trip() {
let request = Message::Request(Request::SearchDeclarations {
search: LeanWorkerDeclarationSearch {
name_fragment: Some("map".to_owned()),
name_match: LeanWorkerDeclarationNameMatch::Suffix,
kind: Some("theorem".to_owned()),
required_constants: vec!["List.map".to_owned()],
conclusion_head: Some("Eq".to_owned()),
scope_biases: vec![LeanWorkerDeclarationSearchBias {
scope: LeanWorkerDeclarationSearchScope::Namespace,
prefix: "List".to_owned(),
strict: false,
weight: 7,
}],
limit: 3,
filter: LeanWorkerDeclarationFilter {
include_private: false,
include_generated: false,
include_internal: false,
},
include_source: false,
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("declaration search request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("declaration search request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::DeclarationSearch {
result: LeanWorkerDeclarationSearchResult {
declarations: vec![LeanWorkerDeclarationSearchRow {
name: "List.map_map".to_owned(),
kind: "theorem".to_owned(),
module: Some("Init.Data.List.Lemmas".to_owned()),
source: None,
match_reason: "name,kind,required_constants,conclusion_head".to_owned(),
score: 127,
rank: 1,
flags: LeanWorkerDeclarationFlags::default(),
}],
truncated: true,
facts: LeanWorkerDeclarationSearchFacts {
declarations_scanned: 100,
after_name_filter: 10,
after_kind_filter: 8,
after_required_constants_filter: 4,
after_conclusion_filter: 2,
after_scope_filter: 2,
source_lookups: 0,
broad_pruning: vec![LeanWorkerDeclarationSearchPruning {
stage: "limit".to_owned(),
reason: "broad_search_limit".to_owned(),
count: 1,
}],
truncated: true,
timings: LeanWorkerDeclarationSearchTimings {
scan_micros: 1000,
rank_micros: 50,
source_micros: 0,
},
},
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("declaration search response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("declaration search response reads");
assert_eq!(frame.message, response);
}
#[test]
fn declaration_inspection_request_and_response_round_trip() {
let request = Message::Request(Request::InspectDeclaration {
request: LeanWorkerDeclarationInspectionRequest {
name: "List.map_map".to_owned(),
fields: LeanWorkerDeclarationInspectionFields {
source: true,
statement: true,
docstring: true,
attributes: true,
flags: true,
rendering: LeanWorkerRendering::Pretty,
},
budgets: LeanWorkerOutputBudgets {
per_field_bytes: 128,
total_bytes: 512,
},
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("declaration inspection request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("declaration inspection request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::DeclarationInspection {
result: LeanWorkerDeclarationInspectionResult::Found {
declaration: Box::new(LeanWorkerDeclarationInspection {
name: "List.map_map".to_owned(),
kind: "theorem".to_owned(),
module: Some("Init.Data.List.Lemmas".to_owned()),
source: Some(LeanWorkerSourceRange {
file: "Init/Data/List/Lemmas.lean".to_owned(),
start_line: 1,
start_column: 1,
end_line: 1,
end_column: 10,
}),
statement: Some(LeanWorkerRenderedInfo {
value: "forall ...".to_owned(),
truncated: true,
}),
docstring: Some(LeanWorkerRenderedInfo {
value: "doc".to_owned(),
truncated: false,
}),
attributes: vec!["simp".to_owned(), "rw".to_owned()],
proof_search: LeanWorkerDeclarationProofSearchFacts {
is_simp: true,
is_rw_candidate: true,
is_instance: false,
is_class: false,
class_name: None,
},
flags: LeanWorkerDeclarationFlags::default(),
statement_rendering: Some(LeanWorkerRendering::Pretty),
}),
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("declaration inspection response writes");
let frame =
read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("declaration inspection response reads");
assert_eq!(frame.message, response);
let not_found = Message::Response(Response::DeclarationInspection {
result: LeanWorkerDeclarationInspectionResult::NotFound {
name: "Missing.name".to_owned(),
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, not_found.clone(), MAX_FRAME_BYTES)
.expect("declaration inspection not-found response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES)
.expect("declaration inspection not-found response reads");
assert_eq!(frame.message, not_found);
let unsupported = Message::Response(Response::DeclarationInspection {
result: LeanWorkerDeclarationInspectionResult::Unsupported,
});
let mut bytes = Vec::new();
write_frame(&mut bytes, unsupported.clone(), MAX_FRAME_BYTES)
.expect("declaration inspection unsupported response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES)
.expect("declaration inspection unsupported response reads");
assert_eq!(frame.message, unsupported);
}
#[test]
fn module_query_request_and_response_round_trip() {
let request = Message::Request(Request::ProcessModuleQuery {
source: "def x := 1\n#check x\n".to_owned(),
query: LeanWorkerModuleQuery::TypeAt { line: 2, column: 8 },
options: LeanWorkerElabOptions::default(),
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("module query request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::ProcessModuleQuery {
outcome: LeanWorkerModuleQueryOutcome::Ok {
imports: Vec::new(),
result: LeanWorkerModuleQueryResult::TypeAt(LeanWorkerTypeAtResult::Term {
span: LeanWorkerModuleSourceSpan {
start_line: 2,
start_column: 8,
end_line: 2,
end_column: 9,
},
expr: LeanWorkerRenderedInfo {
value: "x".to_owned(),
truncated: false,
},
type_str: LeanWorkerRenderedInfo {
value: "Nat".to_owned(),
truncated: false,
},
expected_type: None,
}),
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("module query response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query response reads");
assert_eq!(frame.message, response);
let unsupported = LeanWorkerModuleQueryOutcome::Unsupported;
let json = serde_json::to_value(&unsupported).expect("unsupported serializes");
assert_eq!(json, json!({ "status": "unsupported" }));
let diagnostics = LeanWorkerModuleQueryResult::Diagnostics(LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
});
let json = serde_json::to_value(&diagnostics).expect("diagnostics serializes");
assert_eq!(
json,
json!({
"result": "diagnostics",
"body": {
"diagnostics": [],
"truncated": false
}
})
);
}
#[test]
fn module_query_batch_request_and_response_round_trip() {
let request = Message::Request(Request::ProcessModuleQueryBatch {
source: "theorem t : True := by\n trivial\n".to_owned(),
selectors: vec![
LeanWorkerModuleQuerySelector::Diagnostics {
id: "diagnostics".to_owned(),
},
LeanWorkerModuleQuerySelector::ProofState {
id: "state".to_owned(),
line: 2,
column: 4,
},
],
budgets: LeanWorkerOutputBudgets::default(),
options: LeanWorkerElabOptions::default(),
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("module query batch request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query batch request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::ProcessModuleQueryBatch {
outcome: LeanWorkerModuleQueryBatchOutcome::Ok {
imports: Vec::new(),
result: LeanWorkerModuleQueryBatchEnvelope {
items: vec![LeanWorkerModuleQueryBatchItem::Ok {
id: "diagnostics".to_owned(),
result: Box::new(LeanWorkerModuleQueryBatchResult::Diagnostics(LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
})),
}],
total_truncated: false,
},
facts: LeanWorkerModuleQueryCacheFacts {
cache_status: LeanWorkerModuleCacheStatus::Miss,
timings: LeanWorkerModuleQueryTimings::zero(),
output_bytes: 0,
cache_entry_count: Some(1),
cache_approx_bytes: Some(1024),
},
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("module query batch response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("module query batch response reads");
assert_eq!(frame.message, response);
}
#[test]
fn proof_attempt_request_and_response_round_trip() {
let span = LeanWorkerModuleSourceSpan {
start_line: 1,
start_column: 22,
end_line: 2,
end_column: 7,
};
let request = Message::Request(Request::AttemptProof {
request: LeanWorkerProofAttemptRequest {
source: "theorem t : True := by\n trivial\n".to_owned(),
edit: LeanWorkerProofEditTarget::Declaration {
name: "t".to_owned(),
position: LeanWorkerProofPositionSelector::Default,
},
candidates: vec![LeanWorkerProofCandidate {
id: "rfl".to_owned(),
text: "by trivial".to_owned(),
}],
budgets: LeanWorkerOutputBudgets::default(),
},
options: LeanWorkerElabOptions::default(),
progress: true,
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("proof attempt request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("proof attempt request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::ProofAttempt {
result: LeanWorkerProofAttemptResult::Ok {
imports: Vec::new(),
result: LeanWorkerProofAttemptEnvelope {
candidates: vec![LeanWorkerProofAttemptRow {
id: "rfl".to_owned(),
status: LeanWorkerProofAttemptStatus::Closed,
candidate_text: LeanWorkerRenderedInfo {
value: "rfl".to_owned(),
truncated: false,
},
diagnostics: LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
},
downstream_diagnostics: LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
},
goals: Vec::new(),
declaration: Some(LeanWorkerDeclarationTargetInfo {
short_name: "t".to_owned(),
declaration_name: "t".to_owned(),
namespace_name: String::new(),
declaration_kind: "theorem".to_owned(),
declaration_span: span.clone(),
name_span: span.clone(),
body_span: span,
}),
proof_position: Some(LeanWorkerProofPositionSummary {
index: 0,
tactic: LeanWorkerRenderedInfo {
value: "trivial".to_owned(),
truncated: false,
},
}),
output_truncated: false,
}],
candidate_limit: 8,
candidates_truncated: false,
},
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("proof attempt response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("proof attempt response reads");
assert_eq!(frame.message, response);
}
#[test]
fn proof_position_selector_tags_are_stable_and_round_trip() {
let cases = [
(
LeanWorkerProofPositionSelector::Default,
serde_json::json!({"kind": "default"}),
),
(
LeanWorkerProofPositionSelector::Index { index: 3 },
serde_json::json!({"kind": "index", "index": 3}),
),
(
LeanWorkerProofPositionSelector::AfterText {
text: "intro x".to_owned(),
occurrence: Some(1),
},
serde_json::json!({"kind": "after_text", "text": "intro x", "occurrence": 1}),
),
(
LeanWorkerProofPositionSelector::Entry,
serde_json::json!({"kind": "entry"}),
),
];
for (selector, expected) in cases {
let value = serde_json::to_value(&selector).expect("selector serializes");
assert_eq!(value, expected, "selector tag must be stable: {selector:?}");
let parsed: LeanWorkerProofPositionSelector =
serde_json::from_value(expected).expect("selector deserializes");
assert_eq!(parsed, selector, "selector must round-trip through JSON");
}
}
#[test]
fn declaration_verification_request_and_response_round_trip() {
let request = Message::Request(Request::VerifyDeclaration {
request: LeanWorkerDeclarationVerificationRequest {
source: "theorem t : True := by\n trivial\n".to_owned(),
target: LeanWorkerDeclarationVerificationTarget::Name { name: "t".to_owned() },
sorry_policy: LeanWorkerSorryPolicy::Deny,
report_axioms: true,
budgets: LeanWorkerOutputBudgets::default(),
},
options: LeanWorkerElabOptions::default(),
progress: false,
});
let mut bytes = Vec::new();
write_frame(&mut bytes, request.clone(), MAX_FRAME_BYTES).expect("verification request writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("verification request reads");
assert_eq!(frame.message, request);
let response = Message::Response(Response::DeclarationVerification {
result: LeanWorkerDeclarationVerificationResult::Ok {
verification_status: LeanWorkerDeclarationVerificationStatus::Accepted,
facts: Box::new(LeanWorkerDeclarationVerificationFacts {
target: None,
diagnostics: LeanWorkerElabFailure {
diagnostics: Vec::new(),
truncated: false,
},
unresolved_goals: Vec::new(),
contains_sorry: false,
contains_admit: false,
contains_sorry_ax: false,
axioms: vec!["propext".to_owned(), "Classical.choice".to_owned()],
axioms_truncated: false,
output_truncated: false,
candidates: Vec::new(),
axioms_available: true,
}),
imports: Vec::new(),
},
});
let mut bytes = Vec::new();
write_frame(&mut bytes, response.clone(), MAX_FRAME_BYTES).expect("verification response writes");
let frame = read_frame(&mut Cursor::new(bytes), MAX_FRAME_BYTES).expect("verification response reads");
assert_eq!(frame.message, response);
}
#[test]
fn verification_needs_build_and_ambiguous_round_trip() {
let needs_build = Message::Response(Response::DeclarationVerification {
result: LeanWorkerDeclarationVerificationResult::MissingImports {
verification_status: LeanWorkerDeclarationVerificationStatus::NeedsBuild,
facts: Box::new(verification_facts_fixture(Vec::new(), false)),
imports: vec!["Mathlib.Tactic".to_owned()],
missing: vec!["Mathlib.Unbuilt.Dep".to_owned()],
},
});
assert_frame_round_trips(&needs_build);
let ambiguous = Message::Response(Response::DeclarationVerification {
result: LeanWorkerDeclarationVerificationResult::Ok {
verification_status: LeanWorkerDeclarationVerificationStatus::Ambiguous,
facts: Box::new(verification_facts_fixture(
vec![
declaration_target_info_fixture("A.foo"),
declaration_target_info_fixture("B.foo"),
],
false,
)),
imports: Vec::new(),
},
});
assert_frame_round_trips(&ambiguous);
}
#[test]
fn proof_state_ambiguous_and_needs_build_round_trip() {
let ambiguous = Message::Response(Response::ProcessModuleQueryBatch {
outcome: LeanWorkerModuleQueryBatchOutcome::Ok {
result: LeanWorkerModuleQueryBatchEnvelope {
items: vec![LeanWorkerModuleQueryBatchItem::Ok {
id: "state".to_owned(),
result: Box::new(LeanWorkerModuleQueryBatchResult::ProofState(
LeanWorkerProofStateResult::Ambiguous {
candidates: vec![
declaration_target_info_fixture("A.foo"),
declaration_target_info_fixture("B.foo"),
],
},
)),
}],
total_truncated: false,
},
imports: Vec::new(),
facts: LeanWorkerModuleQueryCacheFacts::uncached(0),
},
});
assert_frame_round_trips(&ambiguous);
let needs_build = Message::Response(Response::ProcessModuleQueryBatch {
outcome: LeanWorkerModuleQueryBatchOutcome::Ok {
result: LeanWorkerModuleQueryBatchEnvelope {
items: vec![LeanWorkerModuleQueryBatchItem::Ok {
id: "state".to_owned(),
result: Box::new(LeanWorkerModuleQueryBatchResult::ProofState(
LeanWorkerProofStateResult::NeedsBuild {
missing: vec!["Mathlib.Unbuilt.Dep".to_owned()],
},
)),
}],
total_truncated: false,
},
imports: Vec::new(),
facts: LeanWorkerModuleQueryCacheFacts::uncached(0),
},
});
assert_frame_round_trips(&needs_build);
}
#[test]
fn inspection_fields_default_rendering_is_pretty() {
let json = serde_json::json!({
"source": true,
"statement": true,
"docstring": false,
"attributes": false,
"flags": false,
});
let fields: LeanWorkerDeclarationInspectionFields =
serde_json::from_value(json).expect("fields without rendering deserialize");
assert_eq!(fields.rendering, LeanWorkerRendering::Pretty);
}
}