#![allow(clippy::disallowed_methods)]
use crate::server::NotificationSink;
use crate::tools::subprocess::{run_apr_cancellable, spawn_streaming, CANCEL_GRACE_MS};
use crate::types::{InputSchema, JsonRpcNotification, ToolCallResult, ToolDefinition};
use std::sync::mpsc::Receiver;
pub const NAME: &str = "apr.run";
#[must_use]
pub fn run_tool_definition() -> ToolDefinition {
let input_schema: InputSchema = serde_json::from_str(crate::schemas::APR_RUN_SCHEMA).expect(
"FALSIFY-MCP-008: apr.run codegen constant must parse as InputSchema; \
regenerate by editing contracts/apr-mcp-tool-schemas-v1.yaml and rebuilding",
);
ToolDefinition {
name: NAME.to_string(),
description: crate::schemas::APR_RUN_DESCRIPTION.to_string(),
input_schema,
}
}
#[must_use]
pub fn call(args: &serde_json::Value, cancel_rx: &Receiver<()>) -> ToolCallResult {
call_with_sink(args, cancel_rx, None, None)
}
#[must_use]
pub fn call_with_sink(
args: &serde_json::Value,
cancel_rx: &Receiver<()>,
sink: Option<&NotificationSink>,
progress_token: Option<serde_json::Value>,
) -> ToolCallResult {
let Some(model_path) = args.get("model_path").and_then(|v| v.as_str()) else {
return ToolCallResult::error("Missing required argument: model_path");
};
let streaming = sink.is_some() && progress_token.is_some();
let mut owned: Vec<String> = vec!["run".to_string(), model_path.to_string()];
if streaming {
owned.push("--stream".to_string());
} else {
owned.push("--json".to_string());
}
if let Some(prompt) = args.get("prompt").and_then(|v| v.as_str()) {
if !prompt.is_empty() {
owned.push("--prompt".to_string());
owned.push(prompt.to_string());
}
}
if let Some(n) = args.get("max_tokens").and_then(serde_json::Value::as_u64) {
owned.push("--max-tokens".to_string());
owned.push(n.to_string());
}
if let Some(t) = args.get("temperature").and_then(serde_json::Value::as_f64) {
owned.push("--temperature".to_string());
owned.push(t.to_string());
}
if let Some(p) = args.get("top_p").and_then(serde_json::Value::as_f64) {
owned.push("--top-p".to_string());
owned.push(p.to_string());
}
let argv: Vec<&str> = owned.iter().map(String::as_str).collect();
match (streaming, sink, progress_token) {
(true, Some(sink), Some(token)) => stream_with_sink("apr", &argv, sink, &token),
_ => run_apr_cancellable(&argv, cancel_rx, CANCEL_GRACE_MS),
}
}
#[must_use]
pub fn stream_with_sink(
program: &str,
args: &[&str],
sink: &NotificationSink,
progress_token: &serde_json::Value,
) -> ToolCallResult {
spawn_streaming(program, args, |line| {
let trimmed = line.trim();
if trimmed.is_empty() {
return;
}
let payload = serde_json::from_str::<serde_json::Value>(trimmed)
.unwrap_or_else(|_| serde_json::Value::String(line.to_string()));
let notif = JsonRpcNotification::progress(progress_token.clone(), payload);
sink(notif);
})
}
pub fn dispatch(
args: &serde_json::Value,
cancel_rx: &Receiver<()>,
sink: Option<&NotificationSink>,
progress_token: Option<serde_json::Value>,
) -> ToolCallResult {
call_with_sink(args, cancel_rx, sink, progress_token)
}
crate::register_mcp_tool!(
name: NAME,
definition: run_tool_definition,
dispatch: dispatch,
);
#[cfg(test)]
#[allow(clippy::disallowed_methods)]
mod tests {
use super::*;
#[test]
fn definition_has_correct_name_and_required_field() {
let def = run_tool_definition();
assert_eq!(def.name, "apr.run");
assert_eq!(def.input_schema.schema_type, "object");
assert_eq!(def.input_schema.required, vec!["model_path".to_string()]);
for field in ["model_path", "prompt", "max_tokens", "temperature", "top_p"] {
assert!(
def.input_schema.properties.contains_key(field),
"property {field} present"
);
}
}
#[test]
fn missing_model_path_returns_error() {
let (_tx, rx) = std::sync::mpsc::channel::<()>();
let result = call(&serde_json::json!({}), &rx);
assert_eq!(result.is_error, Some(true));
assert!(result.content[0].text.contains("model_path"));
}
}