Skip to main content

aprender_mcp/tools/
run.rs

1//! `apr.run` — M2 tool. Synchronous inference via subprocess wrapper.
2//!
3//! Wraps `apr run <model> --json [--prompt X] [--max-tokens N] [--temperature T] [--top-p P]`.
4//!
5//! M3 (FALSIFY-MCP-006) adds cancellation: the call accepts a cancel receiver
6//! and forwards it to [`run_apr_cancellable`], which SIGTERMs the spawned
7//! subprocess on signal and SIGKILLs after the grace window.
8//!
9//! M3 (FALSIFY-MCP-PROGRESS-002) adds streaming: when the originating
10//! `tools/call` carries `params._meta.progressToken`, [`call_with_sink`]
11//! invokes `apr run ... --stream` (NDJSON: one `event=token` line per
12//! decoded token, then one `event=final` blob) and forwards each line as a
13//! `notifications/progress` message tagged with the caller's token. When the
14//! sink is absent (no progressToken), we fall back to the original
15//! cancellable sync path so existing clients see no behaviour change.
16
17#![allow(clippy::disallowed_methods)] // serde_json::json! macro expands to .unwrap() internally
18
19use crate::server::NotificationSink;
20use crate::tools::subprocess::{run_apr_cancellable, spawn_streaming, CANCEL_GRACE_MS};
21use crate::types::{InputSchema, JsonRpcNotification, ToolCallResult, ToolDefinition};
22use std::sync::mpsc::Receiver;
23
24/// Tool name registered with MCP clients.
25pub const NAME: &str = "apr.run";
26
27/// Return the MCP tool definition for `apr.run`.
28///
29/// FALSIFY-MCP-008: the `inputSchema` is parsed from the build-time codegen
30/// constant `crate::schemas::APR_RUN_SCHEMA`, which `build.rs` emits from
31/// `contracts/apr-mcp-tool-schemas-v1.yaml`. The contract is the single
32/// source of truth — the live `tools/list` response and the YAML must agree
33/// byte-for-byte after JSON canonicalization (asserted by
34/// `tests/falsify_mcp_008.rs`).
35#[must_use]
36pub fn run_tool_definition() -> ToolDefinition {
37    let input_schema: InputSchema = serde_json::from_str(crate::schemas::APR_RUN_SCHEMA).expect(
38        "FALSIFY-MCP-008: apr.run codegen constant must parse as InputSchema; \
39             regenerate by editing contracts/apr-mcp-tool-schemas-v1.yaml and rebuilding",
40    );
41    ToolDefinition {
42        name: NAME.to_string(),
43        description: crate::schemas::APR_RUN_DESCRIPTION.to_string(),
44        input_schema,
45    }
46}
47
48/// Execute `apr.run` by spawning `apr run <model> --json [...flags]`.
49///
50/// `cancel_rx` is signalled by the MCP dispatcher when a matching
51/// `notifications/cancelled` arrives on the same request id (FALSIFY-MCP-006).
52/// Pass a never-firing channel for tests or direct non-MCP callers.
53///
54/// Back-compat entry point used by callers that don't opt into progress
55/// streaming. Equivalent to `call_with_sink(args, cancel_rx, None, None)` but
56/// preserves the cancellable code path for the no-stream case.
57#[must_use]
58pub fn call(args: &serde_json::Value, cancel_rx: &Receiver<()>) -> ToolCallResult {
59    call_with_sink(args, cancel_rx, None, None)
60}
61
62/// Execute `apr.run` with optional `notifications/progress` streaming.
63///
64/// FALSIFY-MCP-PROGRESS-002: when both `sink` and `progress_token` are
65/// `Some`, the subprocess is spawned with `apr run ... --stream` so each
66/// decoded token (NDJSON `event=token` line) and the terminal `event=final`
67/// blob is forwarded as a `notifications/progress` message tagged with the
68/// caller's token. When either argument is `None` (no progressToken on the
69/// originating `tools/call`) we fall back to the synchronous
70/// [`run_apr_cancellable`] path so existing clients see identical behaviour.
71///
72/// Note: the streaming path does NOT honour `cancel_rx` today (the MCP
73/// `apr.finetune` streaming path made the same trade-off in #887). Wiring
74/// SIGTERM into [`spawn_streaming`] is tracked separately — clients that
75/// require both streaming AND cancellation should not yet supply a
76/// progressToken on `apr.run`. The non-streaming path remains fully
77/// cancellable.
78#[must_use]
79pub fn call_with_sink(
80    args: &serde_json::Value,
81    cancel_rx: &Receiver<()>,
82    sink: Option<&NotificationSink>,
83    progress_token: Option<serde_json::Value>,
84) -> ToolCallResult {
85    let Some(model_path) = args.get("model_path").and_then(|v| v.as_str()) else {
86        return ToolCallResult::error("Missing required argument: model_path");
87    };
88
89    let streaming = sink.is_some() && progress_token.is_some();
90
91    let mut owned: Vec<String> = vec!["run".to_string(), model_path.to_string()];
92    // --stream emits NDJSON (one event per line); the legacy --json path
93    // emits a single pretty-printed blob. Pick whichever matches the
94    // intended consumer.
95    if streaming {
96        owned.push("--stream".to_string());
97    } else {
98        owned.push("--json".to_string());
99    }
100
101    if let Some(prompt) = args.get("prompt").and_then(|v| v.as_str()) {
102        if !prompt.is_empty() {
103            owned.push("--prompt".to_string());
104            owned.push(prompt.to_string());
105        }
106    }
107    if let Some(n) = args.get("max_tokens").and_then(serde_json::Value::as_u64) {
108        owned.push("--max-tokens".to_string());
109        owned.push(n.to_string());
110    }
111    if let Some(t) = args.get("temperature").and_then(serde_json::Value::as_f64) {
112        owned.push("--temperature".to_string());
113        owned.push(t.to_string());
114    }
115    if let Some(p) = args.get("top_p").and_then(serde_json::Value::as_f64) {
116        owned.push("--top-p".to_string());
117        owned.push(p.to_string());
118    }
119
120    let argv: Vec<&str> = owned.iter().map(String::as_str).collect();
121
122    match (streaming, sink, progress_token) {
123        (true, Some(sink), Some(token)) => stream_with_sink("apr", &argv, sink, &token),
124        _ => run_apr_cancellable(&argv, cancel_rx, CANCEL_GRACE_MS),
125    }
126}
127
128/// Test-visible: stream `program args...` and forward each stdout line as a
129/// `notifications/progress` notification through `sink`, tagged with
130/// `progress_token`. Each stdout line is JSON-parsed if possible (the
131/// `apr run --stream` NDJSON contract guarantees JSON) so downstream MCP
132/// clients receive structured `message.event = "token"` / `"final"` events;
133/// non-JSON lines fall back to a bare string. The returned `ToolCallResult`
134/// is the aggregated stdout (same shape as `run_apr_cancellable`'s success
135/// body) so non-streaming consumers get the full payload too.
136#[must_use]
137pub fn stream_with_sink(
138    program: &str,
139    args: &[&str],
140    sink: &NotificationSink,
141    progress_token: &serde_json::Value,
142) -> ToolCallResult {
143    spawn_streaming(program, args, |line| {
144        let trimmed = line.trim();
145        if trimmed.is_empty() {
146            return;
147        }
148        let payload = serde_json::from_str::<serde_json::Value>(trimmed)
149            .unwrap_or_else(|_| serde_json::Value::String(line.to_string()));
150        let notif = JsonRpcNotification::progress(progress_token.clone(), payload);
151        sink(notif);
152    })
153}
154
155/// HELIX-IDEA-002 — unified-signature shim for the inventory dispatcher.
156/// `apr.run` honours both `cancel_rx` and the optional notification sink.
157pub fn dispatch(
158    args: &serde_json::Value,
159    cancel_rx: &Receiver<()>,
160    sink: Option<&NotificationSink>,
161    progress_token: Option<serde_json::Value>,
162) -> ToolCallResult {
163    call_with_sink(args, cancel_rx, sink, progress_token)
164}
165
166crate::register_mcp_tool!(
167    name: NAME,
168    definition: run_tool_definition,
169    dispatch: dispatch,
170);
171
172#[cfg(test)]
173#[allow(clippy::disallowed_methods)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn definition_has_correct_name_and_required_field() {
179        let def = run_tool_definition();
180        assert_eq!(def.name, "apr.run");
181        assert_eq!(def.input_schema.schema_type, "object");
182        assert_eq!(def.input_schema.required, vec!["model_path".to_string()]);
183        for field in ["model_path", "prompt", "max_tokens", "temperature", "top_p"] {
184            assert!(
185                def.input_schema.properties.contains_key(field),
186                "property {field} present"
187            );
188        }
189    }
190
191    #[test]
192    fn missing_model_path_returns_error() {
193        let (_tx, rx) = std::sync::mpsc::channel::<()>();
194        let result = call(&serde_json::json!({}), &rx);
195        assert_eq!(result.is_error, Some(true));
196        assert!(result.content[0].text.contains("model_path"));
197    }
198}