#![allow(clippy::disallowed_methods)]
use crate::server::NotificationSink;
use crate::tools::subprocess::{run_apr, spawn_streaming};
use crate::types::{InputSchema, JsonRpcNotification, ToolCallResult, ToolDefinition};
pub const NAME: &str = "apr.finetune";
#[must_use]
pub fn finetune_tool_definition() -> ToolDefinition {
let input_schema: InputSchema = serde_json::from_str(crate::schemas::APR_FINETUNE_SCHEMA)
.expect(
"FALSIFY-MCP-008: apr.finetune codegen constant must parse as InputSchema; \
regenerate by editing contracts/apr-mcp-tool-schemas-v1.yaml and rebuilding",
);
ToolDefinition {
name: NAME.to_string(),
description: crate::schemas::APR_FINETUNE_DESCRIPTION.to_string(),
input_schema,
}
}
#[must_use]
pub fn call(args: &serde_json::Value) -> ToolCallResult {
call_with_sink(args, None, None)
}
#[must_use]
pub fn call_with_sink(
args: &serde_json::Value,
sink: Option<&NotificationSink>,
progress_token: Option<serde_json::Value>,
) -> ToolCallResult {
let Some(base_model) = args.get("base_model").and_then(|v| v.as_str()) else {
return ToolCallResult::error("Missing required argument: base_model");
};
let mut owned: Vec<String> = vec![
"finetune".to_string(),
base_model.to_string(),
"--json".to_string(),
];
if let Some(dataset) = args.get("dataset").and_then(|v| v.as_str()) {
if !dataset.is_empty() {
owned.push("--data".to_string());
owned.push(dataset.to_string());
}
}
if let Some(rank) = args.get("lora_rank").and_then(serde_json::Value::as_u64) {
owned.push("--rank".to_string());
owned.push(rank.to_string());
}
if let Some(epochs) = args.get("epochs").and_then(serde_json::Value::as_u64) {
owned.push("--epochs".to_string());
owned.push(epochs.to_string());
}
if let Some(method) = args.get("method").and_then(|v| v.as_str()) {
if !method.is_empty() {
owned.push("--method".to_string());
owned.push(method.to_string());
}
}
if let Some(output) = args.get("output").and_then(|v| v.as_str()) {
if !output.is_empty() {
owned.push("--output".to_string());
owned.push(output.to_string());
}
}
let argv: Vec<&str> = owned.iter().map(String::as_str).collect();
match (sink, progress_token) {
(Some(sink), Some(token)) => stream_with_sink("apr", &argv, sink, &token),
_ => run_apr(&argv),
}
}
#[must_use]
pub fn stream_with_sink(
program: &str,
args: &[&str],
sink: &NotificationSink,
progress_token: &serde_json::Value,
) -> ToolCallResult {
spawn_streaming(program, args, |line| {
let trimmed = line.trim();
if trimmed.is_empty() {
return;
}
let payload = serde_json::from_str::<serde_json::Value>(trimmed)
.unwrap_or_else(|_| serde_json::Value::String(line.to_string()));
let notif = JsonRpcNotification::progress(progress_token.clone(), payload);
sink(notif);
})
}
#[cfg(test)]
#[allow(clippy::disallowed_methods)] mod tests {
use super::*;
#[test]
fn finetune_tool_definition_shape() {
let def = finetune_tool_definition();
assert_eq!(def.name, "apr.finetune");
assert_eq!(def.input_schema.schema_type, "object");
assert_eq!(def.input_schema.required, vec!["base_model".to_string()]);
for field in [
"base_model",
"dataset",
"lora_rank",
"epochs",
"method",
"output",
] {
assert!(
def.input_schema.properties.contains_key(field),
"property {field} present"
);
}
}
#[test]
fn finetune_missing_base_model_is_error() {
let result = call(&serde_json::json!({}));
assert_eq!(result.is_error, Some(true));
assert!(
result.content[0].text.contains("base_model"),
"error message must mention base_model, got: {}",
result.content[0].text
);
}
#[test]
fn finetune_nonstring_base_model_is_error() {
let result = call(&serde_json::json!({ "base_model": 42 }));
assert_eq!(result.is_error, Some(true));
assert!(result.content[0].text.contains("base_model"));
}
#[test]
fn stream_with_sink_emits_one_notification_per_line() {
use std::sync::{Arc, Mutex};
let captured: Arc<Mutex<Vec<JsonRpcNotification>>> = Arc::new(Mutex::new(Vec::new()));
let captured_clone = Arc::clone(&captured);
let sink: NotificationSink = Box::new(move |n| {
captured_clone
.lock()
.expect("sink mutex not poisoned")
.push(n);
});
let token = serde_json::json!("progress-token-xyz");
let result = stream_with_sink(
"printf",
&[r#"{"step":1}\n{"step":2}\nplain-line\n"#],
&sink,
&token,
);
assert!(result.is_error.is_none(), "printf should succeed");
let notifs = captured.lock().expect("mutex").clone();
assert_eq!(
notifs.len(),
3,
"one notification per non-empty stdout line"
);
for n in ¬ifs {
assert_eq!(n.method, "notifications/progress");
assert_eq!(n.params["progressToken"], "progress-token-xyz");
}
assert_eq!(notifs[0].params["message"]["step"], 1);
assert_eq!(notifs[1].params["message"]["step"], 2);
assert_eq!(notifs[2].params["message"], "plain-line");
}
#[test]
fn call_with_sink_none_sink_is_synchronous() {
let result = call_with_sink(
&serde_json::json!({ "base_model": "/nonexistent/model.apr" }),
None,
None,
);
assert!(!result.content.is_empty());
}
}