use std::collections::HashMap;
use lean_rs_worker_parent::{LeanWorkerJsonCommand, LeanWorkerSession};
use lean_semantic_search_contract::{
CAPABILITY_SCHEMA_VERSION, CommandResponse, DECLARATION_FEATURE_COMMAND_VERSION, DeclarationFeatureRow, Diagnostic,
DiagnosticSeverity, ModuleSpec, PROOF_GOAL_FEATURE_COMMAND_VERSION, ProofGoalFeatureRequest, ProofGoalFeatureRow,
SEMANTIC_FEATURE_VERSION,
};
use lean_semantic_search_retrieval::{Anchor, SemanticIndex, retrieve_across};
use serde::Serialize;
use crate::error::{Result, ServerError};
use crate::projections::SourceRange;
#[derive(Debug, Clone)]
pub(crate) struct SemanticProofSearchRequest {
pub(crate) goal: ProofGoalFeatureRequest,
pub(crate) candidate_modules: Vec<String>,
pub(crate) limit: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct SemanticProofCandidate {
pub(crate) name: String,
pub(crate) source: Option<SourceRange>,
pub(crate) score: i32,
pub(crate) evidence: Vec<String>,
}
#[derive(Debug, Clone)]
pub(crate) struct SemanticProofSearchResult {
pub(crate) candidates: Vec<SemanticProofCandidate>,
pub(crate) diagnostics: Vec<String>,
pub(crate) declaration_rows: usize,
pub(crate) goal_rows: usize,
}
type DeclarationResponse = CommandResponse<DeclarationFeatureRow>;
type ProofGoalResponse = CommandResponse<ProofGoalFeatureRow>;
#[derive(Debug, Serialize)]
struct DeclarationFeatureCommandRequest {
modules: Vec<ModuleSpec>,
#[serde(skip_serializing_if = "Vec::is_empty")]
declaration_ids: Vec<String>,
}
pub(crate) fn run_semantic_proof_search(
session: &mut LeanWorkerSession<'_>,
request: &SemanticProofSearchRequest,
) -> Result<SemanticProofSearchResult> {
let goal_command = LeanWorkerJsonCommand::<ProofGoalFeatureRequest, ProofGoalResponse>::new(
lean_semantic_search_capability::PROOF_GOAL_FEATURES_EXPORT,
);
let declaration_command = LeanWorkerJsonCommand::<DeclarationFeatureCommandRequest, DeclarationResponse>::new(
lean_semantic_search_capability::DECLARATION_FEATURES_EXPORT,
);
let goal_response = session
.run_json_command(&goal_command, &request.goal, None, None)
.map_err(crate::error::map_worker_err)?;
validate_response(
&goal_response,
lean_semantic_search_capability::PROOF_GOAL_FEATURES_COMMAND,
PROOF_GOAL_FEATURE_COMMAND_VERSION,
)?;
let declaration_request = DeclarationFeatureCommandRequest {
modules: request
.candidate_modules
.iter()
.map(|module| ModuleSpec {
module: module.clone(),
origin: Some("lean-host-mcp".to_owned()),
source_root: None,
})
.collect(),
declaration_ids: Vec::new(),
};
let declaration_response = session
.run_json_command(&declaration_command, &declaration_request, None, None)
.map_err(crate::error::map_worker_err)?;
validate_response(
&declaration_response,
lean_semantic_search_capability::DECLARATION_FEATURES_COMMAND,
DECLARATION_FEATURE_COMMAND_VERSION,
)?;
Ok(rank_semantic_rows(&goal_response, &declaration_response, request.limit))
}
fn validate_response<Row>(response: &CommandResponse<Row>, command: &str, command_version: &str) -> Result<()> {
if response.schema_version != CAPABILITY_SCHEMA_VERSION {
return Err(ServerError::Internal(format!(
"semantic search returned schema version {}, expected {}",
response.schema_version, CAPABILITY_SCHEMA_VERSION
)));
}
if response.command != command {
return Err(ServerError::Internal(format!(
"semantic search returned command {}, expected {}",
response.command, command
)));
}
if response.command_version != command_version {
return Err(ServerError::Internal(format!(
"semantic search returned command version {}, expected {}",
response.command_version, command_version
)));
}
if response.feature_version != SEMANTIC_FEATURE_VERSION {
return Err(ServerError::Internal(format!(
"semantic search returned feature version {}, expected {}",
response.feature_version, SEMANTIC_FEATURE_VERSION
)));
}
Ok(())
}
fn rank_semantic_rows(
goal_response: &ProofGoalResponse,
declaration_response: &DeclarationResponse,
limit: usize,
) -> SemanticProofSearchResult {
let mut diagnostics = Vec::new();
diagnostics.extend(diagnostic_strings(&goal_response.diagnostics));
diagnostics.extend(diagnostic_strings(&declaration_response.diagnostics));
let Some(goal) = goal_response.rows.first() else {
return SemanticProofSearchResult {
candidates: Vec::new(),
diagnostics,
declaration_rows: declaration_response.rows.len(),
goal_rows: 0,
};
};
if declaration_response.rows.is_empty() {
return SemanticProofSearchResult {
candidates: Vec::new(),
diagnostics,
declaration_rows: 0,
goal_rows: goal_response.rows.len(),
};
}
let source_by_id = declaration_response
.rows
.iter()
.map(|row| (row.declaration_id.clone(), row.source.map(source_range)))
.collect::<HashMap<_, _>>();
let index = SemanticIndex::from_declarations(&declaration_response.rows);
let anchor = Anchor::from_proof_goal(goal);
let retrieval = retrieve_across(&[&index], &anchor, limit);
diagnostics.extend(diagnostic_strings(&retrieval.diagnostics));
let candidates = retrieval
.candidates
.into_iter()
.map(|candidate| {
let rank = i32::try_from(candidate.rank).unwrap_or(i32::MAX);
let evidence = candidate
.explanations
.into_iter()
.map(|explanation| format!("semantic:{}:{}", explanation.family.label(), explanation.match_count))
.collect::<Vec<_>>();
SemanticProofCandidate {
source: source_by_id.get(&candidate.declaration_id).cloned().flatten(),
name: candidate.declaration_id,
score: 150_i32.saturating_sub(rank.saturating_mul(4)),
evidence,
}
})
.collect();
SemanticProofSearchResult {
candidates,
diagnostics,
declaration_rows: declaration_response.rows.len(),
goal_rows: goal_response.rows.len(),
}
}
fn diagnostic_strings(diagnostics: &[Diagnostic]) -> Vec<String> {
diagnostics
.iter()
.filter(|diagnostic| !matches!(diagnostic.severity, DiagnosticSeverity::Pass))
.map(|diagnostic| {
format!(
"{}:{}:{}",
severity_label(diagnostic.severity),
diagnostic.code,
diagnostic.message
)
})
.collect()
}
const fn severity_label(severity: DiagnosticSeverity) -> &'static str {
match severity {
DiagnosticSeverity::Pass => "pass",
DiagnosticSeverity::Warning => "warning",
DiagnosticSeverity::Error => "error",
}
}
fn source_range(span: lean_semantic_search_contract::SourceSpan) -> SourceRange {
SourceRange {
file: String::new(),
start_line: span.start.line,
start_column: span.start.column,
end_line: span.end.line,
end_column: span.end.column,
}
}
#[cfg(test)]
mod tests {
use lean_semantic_search_contract::{
CAPABILITY_SCHEMA_VERSION, CommandResponse, DECLARATION_FEATURE_COMMAND_VERSION, Diagnostic, Fingerprints,
OpaqueFeatureKey, PROOF_GOAL_FEATURE_COMMAND_VERSION, RoleFeature, SEMANTIC_FEATURE_VERSION, SourcePosition,
SourceSpan,
};
use super::{rank_semantic_rows, validate_response};
fn fingerprints(seed: &str) -> Fingerprints {
Fingerprints {
statement: OpaqueFeatureKey::new(format!("stmt:{seed}")),
safe_binder_permutation: OpaqueFeatureKey::new(format!("safe:{seed}")),
connective_shape: OpaqueFeatureKey::new(format!("conn:{seed}")),
conclusion_shape: OpaqueFeatureKey::new(format!("concl:{seed}")),
}
}
#[test]
fn validation_rejects_wrong_command_version() -> Result<(), String> {
let response = CommandResponse::<lean_semantic_search_contract::DeclarationFeatureRow> {
schema_version: CAPABILITY_SCHEMA_VERSION.to_owned(),
command: "declaration_features".to_owned(),
command_version: "old".to_owned(),
feature_version: SEMANTIC_FEATURE_VERSION.to_owned(),
rows: Vec::new(),
diagnostics: Vec::new(),
};
let err = match validate_response(&response, "declaration_features", DECLARATION_FEATURE_COMMAND_VERSION) {
Ok(()) => return Err("wrong command version must fail".to_owned()),
Err(err) => err,
};
assert!(err.to_string().contains("command version"));
Ok(())
}
#[test]
fn retrieval_mapping_hides_raw_keys() -> Result<(), String> {
let key = OpaqueFeatureKey::new("opaque-key-that-must-not-leak");
let goal = lean_semantic_search_contract::ProofGoalFeatureRow {
goal_id: "g".to_owned(),
feature_version: SEMANTIC_FEATURE_VERSION.to_owned(),
fingerprints: fingerprints("goal"),
role_features: vec![RoleFeature {
role: "conclusion_const".to_owned(),
key: key.clone(),
display: Some("Target.const".to_owned()),
}],
low_signal_markers: Vec::new(),
};
let row = lean_semantic_search_contract::DeclarationFeatureRow {
declaration_id: "Fixture.target_helper".to_owned(),
feature_version: SEMANTIC_FEATURE_VERSION.to_owned(),
fingerprints: fingerprints("row"),
role_features: vec![RoleFeature {
role: "conclusion_const".to_owned(),
key,
display: Some("Target.const".to_owned()),
}],
binder_count: 0,
low_signal_markers: Vec::new(),
source: Some(SourceSpan {
start: SourcePosition { line: 2, column: 3 },
end: SourcePosition { line: 2, column: 30 },
}),
};
let goal_response = CommandResponse {
schema_version: CAPABILITY_SCHEMA_VERSION.to_owned(),
command: "proof_goal_features".to_owned(),
command_version: PROOF_GOAL_FEATURE_COMMAND_VERSION.to_owned(),
feature_version: SEMANTIC_FEATURE_VERSION.to_owned(),
rows: vec![goal],
diagnostics: Vec::<Diagnostic>::new(),
};
let declaration_response = CommandResponse {
schema_version: CAPABILITY_SCHEMA_VERSION.to_owned(),
command: "declaration_features".to_owned(),
command_version: DECLARATION_FEATURE_COMMAND_VERSION.to_owned(),
feature_version: SEMANTIC_FEATURE_VERSION.to_owned(),
rows: vec![row],
diagnostics: Vec::<Diagnostic>::new(),
};
let result = rank_semantic_rows(&goal_response, &declaration_response, 5);
assert_eq!(result.candidates.len(), 1);
let Some(candidate) = result.candidates.first() else {
return Err("expected a semantic candidate".to_owned());
};
let evidence = candidate.evidence.join(",");
assert_eq!(candidate.name, "Fixture.target_helper");
assert!(evidence.contains("semantic:role_conclusion_const"));
assert!(!evidence.contains("opaque-key-that-must-not-leak"));
assert!(candidate.source.is_some());
Ok(())
}
}