1use std::sync::Arc;
13
14use rmcp::handler::server::router::tool::ToolRouter;
15use rmcp::handler::server::wrapper::Parameters;
16use rmcp::model::{CallToolResult, Content, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo};
17use rmcp::{ErrorData as McpError, ServerHandler, tool, tool_handler, tool_router};
18
19use crate::broker::ProjectBroker;
20use crate::envelope::Response;
21use crate::error::ServerError;
22use crate::tools::{self, ResponseCarrier, ToolConfig, ToolContext};
23
24#[derive(Debug, Clone)]
29pub struct LeanHostService {
30 ctx: ToolContext,
31 #[allow(dead_code)]
34 tool_router: ToolRouter<Self>,
35}
36
37impl LeanHostService {
38 pub fn new(broker: Arc<ProjectBroker>, config: ToolConfig) -> Self {
39 let ctx = ToolContext { broker, config };
40 Self {
41 ctx,
42 tool_router: Self::tool_router(),
43 }
44 }
45}
46
47#[tool_router]
48impl LeanHostService {
49 #[tool(description = "Inspect one Lean declaration by name.")]
50 async fn inspect_declaration(
51 &self,
52 Parameters(req): Parameters<tools::declaration::InspectDeclarationRequest>,
53 ) -> std::result::Result<CallToolResult, McpError> {
54 tracing::debug!(tool = "inspect_declaration", "tool call");
55 self.respond(tools::declaration::inspect_declaration(&self.ctx, req).await)
56 }
57
58 #[tool(description = "Return ranked declarations for the next proof step.")]
59 async fn search_for_proof(
60 &self,
61 Parameters(req): Parameters<tools::proof_search::SearchForProofRequest>,
62 ) -> std::result::Result<CallToolResult, McpError> {
63 tracing::debug!(tool = "search_for_proof", "tool call");
64 self.respond(tools::proof_search::search_for_proof(&self.ctx, req).await)
65 }
66
67 #[tool(description = "Try proof snippets in memory. Never writes files.")]
68 async fn try_proof_step(
69 &self,
70 Parameters(req): Parameters<tools::proof_action::TryProofStepRequest>,
71 ) -> std::result::Result<CallToolResult, McpError> {
72 tracing::debug!(tool = "try_proof_step", "tool call");
73 self.respond(tools::proof_action::try_proof_step(&self.ctx, req).await)
74 }
75
76 #[tool(description = "Verify one declaration in memory. Never writes files.")]
77 async fn verify_declaration(
78 &self,
79 Parameters(req): Parameters<tools::proof_action::VerifyDeclarationRequest>,
80 ) -> std::result::Result<CallToolResult, McpError> {
81 tracing::debug!(tool = "verify_declaration", "tool call");
82 self.respond(tools::proof_action::verify_declaration(&self.ctx, req).await)
83 }
84
85 #[tool(description = "Proof context for a declaration proof position.")]
86 async fn proof_state(
87 &self,
88 Parameters(req): Parameters<tools::position::ProofStateRequest>,
89 ) -> std::result::Result<CallToolResult, McpError> {
90 tracing::debug!(tool = "proof_state", "tool call");
91 self.respond(tools::position::proof_state(&self.ctx, req).await)
92 }
93
94 #[tool(description = "Find references to a fully-qualified Lean name.")]
95 async fn find_references(
96 &self,
97 Parameters(req): Parameters<tools::position::FindReferencesRequest>,
98 ) -> std::result::Result<CallToolResult, McpError> {
99 tracing::debug!(tool = "find_references", "tool call");
100 self.respond(tools::position::find_references(&self.ctx, req).await)
101 }
102}
103
104#[tool_handler]
105impl ServerHandler for LeanHostService {
106 fn get_info(&self) -> ServerInfo {
107 let mut info = ServerInfo::default();
111 info.protocol_version = ProtocolVersion::LATEST;
112 info.capabilities = ServerCapabilities::builder().enable_tools().build();
113 info.server_info = Implementation::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
114 .with_website_url("https://github.com/jcreinhold/lean-host-mcp");
115 info.instructions = Some(
116 "MCP server hosting Lean 4 in-process via lean-rs. \
117 Tools expose a bounded proof-agent workflow: proof context, \
118 proof retrieval, declaration inspection, non-mutating proof \
119 attempts and verification, and semantic reference lookup."
120 .to_owned(),
121 );
122 info
123 }
124}
125
126impl LeanHostService {
127 fn respond<T>(&self, result: crate::error::Result<Response<T>>) -> std::result::Result<CallToolResult, McpError>
133 where
134 T: serde::Serialize + schemars::JsonSchema,
135 {
136 let response = match result {
137 Ok(response) => response,
138 Err(ServerError::WorkerUnavailable(info)) => {
139 Response::runtime_unavailable(info.failure(), info.freshness(), info.runtime.clone())
140 }
141 Err(err) => return Err(McpError::from(err)),
142 };
143 Ok(self.finalize(response))
144 }
145
146 fn finalize<T>(&self, mut response: Response<T>) -> CallToolResult
151 where
152 T: serde::Serialize + schemars::JsonSchema,
153 {
154 response.drain_advisories();
155 if !self.ctx.config.verbosity.is_full() {
156 response.drop_telemetry();
157 }
158 carry(&response, self.ctx.config.carrier)
159 }
160}
161
162fn carry<T>(response: &Response<T>, carrier: ResponseCarrier) -> CallToolResult
167where
168 T: serde::Serialize + schemars::JsonSchema,
169{
170 let value = match serde_json::to_value(response) {
171 Ok(value) => value,
172 Err(err) => {
173 return CallToolResult::error(vec![Content::text(format!("failed to serialize response: {err}"))]);
174 }
175 };
176 match carrier {
177 ResponseCarrier::Text => CallToolResult::success(vec![Content::text(value.to_string())]),
178 ResponseCarrier::Both => CallToolResult::structured(value),
179 ResponseCarrier::Structured => {
180 let mut result = CallToolResult::structured(value);
181 result.content.clear();
182 result
183 }
184 }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used, clippy::unreachable)]
189mod tests {
190 use super::*;
191 use crate::envelope::{RuntimeFacts, RuntimeRestartEvent};
192 use crate::error::{ServerError, WorkerUnavailable};
193
194 #[test]
195 fn worker_unavailable_is_a_structured_tool_response() {
196 let runtime = RuntimeFacts {
197 worker_generation: 7,
198 retry_count: 1,
199 call_restart: Some(RuntimeRestartEvent {
200 cause: "child_exit".to_owned(),
201 reason: "worker_death".to_owned(),
202 worker_generation: 7,
203 planned: false,
204 rss_kib: Some(42),
205 limit_kib: Some(100),
206 }),
207 worker_lanes: 1,
208 ..RuntimeFacts::default()
209 };
210 let error = ServerError::worker_unavailable(WorkerUnavailable {
211 retryable: true,
212 worker_restarted: true,
213 project_root: "/tmp/project".to_owned(),
214 project_hash: "hash".to_owned(),
215 imports: vec!["Init".to_owned()],
216 session_id: "session".to_owned(),
217 lean_toolchain: "leanprover/lean4:v4.30.0".to_owned(),
218 worker_generation: 7,
219 reason: "worker_death".to_owned(),
220 restart_cause: Some("child_exit".to_owned()),
221 rss_kib: Some(42),
222 limit_kib: Some(100),
223 retry_after_millis: None,
224 restarts_in_window: Some(1),
225 window_millis: Some(60_000),
226 runtime,
227 toolchain_advisories: vec!["worker for v4.30.0 has no runtime smoke record".to_owned()],
230 });
231 let ServerError::WorkerUnavailable(info) = error else {
232 unreachable!("constructed a WorkerUnavailable error")
233 };
234 let mut response =
238 Response::<serde_json::Value>::runtime_unavailable(info.failure(), info.freshness(), info.runtime.clone());
239 response.drain_advisories();
240
241 let json = serde_json::to_value(&response).unwrap();
242 assert_eq!(
243 json.pointer("/status").and_then(serde_json::Value::as_str),
244 Some("runtime_unavailable")
245 );
246 assert!(json.pointer("/result").is_none_or(serde_json::Value::is_null));
247 assert_eq!(
248 json.pointer("/runtime_error/retryable")
249 .and_then(serde_json::Value::as_bool),
250 Some(true)
251 );
252 assert_eq!(
253 json.pointer("/runtime_error/restart_cause")
254 .and_then(serde_json::Value::as_str),
255 Some("child_exit")
256 );
257 assert_eq!(
258 json.pointer("/telemetry/runtime/retry_count")
259 .and_then(serde_json::Value::as_u64),
260 Some(1)
261 );
262 assert_eq!(
265 json.pointer("/warnings/0").and_then(serde_json::Value::as_str),
266 Some("worker for v4.30.0 has no runtime smoke record")
267 );
268 }
269}