use std::process::Command;
use crate::{
Tool, execute_command,
serde_utils::{
deserialize_string, deserialize_string_vec, locking_mode_to_cli_flags,
output_verbosity_to_cli_flags,
},
};
use rmcp::ErrorData;
#[derive(Debug, ::serde::Deserialize, ::schemars::JsonSchema)]
pub struct CargoTestRequest {
#[serde(default, deserialize_with = "deserialize_string")]
toolchain: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
testname: Option<String>,
#[serde(default, deserialize_with = "deserialize_string_vec")]
test_args: Option<Vec<String>>,
#[serde(default)]
no_run: Option<bool>,
#[serde(default)]
no_fail_fast: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string_vec")]
package: Option<Vec<String>>,
#[serde(default)]
workspace: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string_vec")]
exclude: Option<Vec<String>>,
#[serde(default)]
lib: Option<bool>,
#[serde(default)]
bins: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
bin: Option<String>,
#[serde(default)]
examples: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
example: Option<String>,
#[serde(default)]
tests: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
test: Option<String>,
#[serde(default)]
benches: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
bench: Option<String>,
#[serde(default)]
all_targets: Option<bool>,
#[serde(default)]
doc: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string_vec")]
features: Option<Vec<String>>,
#[serde(default)]
all_features: Option<bool>,
#[serde(default)]
no_default_features: Option<bool>,
#[serde(default)]
jobs: Option<u32>,
#[serde(default)]
release: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
profile: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
target: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
target_dir: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
manifest_path: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
lockfile_path: Option<String>,
#[serde(default)]
ignore_rust_version: Option<bool>,
#[serde(default, deserialize_with = "deserialize_string")]
locking_mode: Option<String>,
#[serde(default, deserialize_with = "deserialize_string")]
output_verbosity: Option<String>,
}
impl CargoTestRequest {
pub fn build_cmd(&self) -> Result<Command, ErrorData> {
let mut cmd = Command::new("cargo");
if let Some(toolchain) = &self.toolchain {
cmd.arg(format!("+{toolchain}"));
}
cmd.arg("test");
if let Some(testname) = &self.testname {
cmd.arg(testname);
}
if self.no_run.unwrap_or(false) {
cmd.arg("--no-run");
}
if self.no_fail_fast.unwrap_or(false) {
cmd.arg("--no-fail-fast");
}
if let Some(packages) = &self.package {
for package in packages {
cmd.arg("--package").arg(package);
}
}
if self.workspace.unwrap_or(false) {
cmd.arg("--workspace");
}
if let Some(excludes) = &self.exclude {
for exclude in excludes {
cmd.arg("--exclude").arg(exclude);
}
}
if self.lib.unwrap_or(false) {
cmd.arg("--lib");
}
if self.bins.unwrap_or(false) {
cmd.arg("--bins");
}
if let Some(bin) = &self.bin {
cmd.arg("--bin").arg(bin);
}
if self.examples.unwrap_or(false) {
cmd.arg("--examples");
}
if let Some(example) = &self.example {
cmd.arg("--example").arg(example);
}
if self.tests.unwrap_or(false) {
cmd.arg("--tests");
}
if let Some(test) = &self.test {
cmd.arg("--test").arg(test);
}
if self.benches.unwrap_or(false) {
cmd.arg("--benches");
}
if let Some(bench) = &self.bench {
cmd.arg("--bench").arg(bench);
}
if self.all_targets.unwrap_or(false) {
cmd.arg("--all-targets");
}
if self.doc.unwrap_or(false) {
cmd.arg("--doc");
}
if let Some(features) = &self.features
&& !features.is_empty()
{
cmd.arg("--features").arg(features.join(","));
}
if self.all_features.unwrap_or(false) {
cmd.arg("--all-features");
}
if self.no_default_features.unwrap_or(false) {
cmd.arg("--no-default-features");
}
if let Some(jobs) = self.jobs {
cmd.arg("--jobs").arg(jobs.to_string());
}
if self.release.unwrap_or(false) {
cmd.arg("--release");
}
if let Some(profile) = &self.profile {
cmd.arg("--profile").arg(profile);
}
if let Some(target) = &self.target {
cmd.arg("--target").arg(target);
}
if let Some(target_dir) = &self.target_dir {
cmd.arg("--target-dir").arg(target_dir);
}
if let Some(manifest_path) = &self.manifest_path {
cmd.arg("--manifest-path").arg(manifest_path);
}
if let Some(lockfile_path) = &self.lockfile_path {
cmd.arg("--lockfile-path").arg(lockfile_path);
}
if self.ignore_rust_version.unwrap_or(false) {
cmd.arg("--ignore-rust-version");
}
let locking_flags = locking_mode_to_cli_flags(self.locking_mode.as_deref(), "locked")?;
for flag in locking_flags {
cmd.arg(flag);
}
let output_flags = output_verbosity_to_cli_flags(self.output_verbosity.as_deref())?;
cmd.args(output_flags);
if let Some(test_args) = &self.test_args {
cmd.arg("--");
for arg in test_args {
cmd.arg(arg);
}
}
Ok(cmd)
}
}
pub struct CargoTestRmcpTool;
impl Tool for CargoTestRmcpTool {
const NAME: &'static str = "cargo-test";
const TITLE: &'static str = "cargo test";
const DESCRIPTION: &'static str =
"Run `cargo test` to execute Rust tests in the current project.";
type RequestArgs = CargoTestRequest;
fn call_rmcp_tool(&self, request: Self::RequestArgs) -> Result<crate::Response, ErrorData> {
let cmd = request.build_cmd()?;
execute_command(cmd, Self::NAME).map(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_deserialize_with_missing_package_field() {
let input = json!({
"toolchain": null,
"workspace": true,
"all_features": true,
"no_default_features": false,
"release": false,
"all_targets": true
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool
.expect("Deserialization should succeed even if `package` is missing (it's Option)");
assert_eq!(tool.package, None);
assert_eq!(tool.workspace, Some(true));
assert_eq!(tool.all_features, Some(true));
assert_eq!(tool.all_targets, Some(true));
}
#[test]
fn test_deserialize_with_package_field_array() {
let input = json!({
"package": ["my_package", "another_package"],
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool.expect("Deserialization should succeed with package array");
assert_eq!(
tool.package.unwrap(),
["my_package".to_owned(), "another_package".to_owned()]
);
assert_eq!(tool.workspace, None);
assert_eq!(tool.all_features, None);
}
#[test]
fn test_deserialize_with_single_package_array() {
let input = json!({
"package": ["single_package"],
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool.expect("Deserialization should succeed with single-item package array");
assert_eq!(tool.package.unwrap(), ["single_package".to_owned()]);
}
#[test]
fn test_deserialize_with_single_package() {
let input = json!({
"package": "single_package",
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool.expect("Deserialization should succeed with single-item package array");
assert_eq!(tool.package.unwrap(), ["single_package".to_owned()]);
}
#[test]
fn test_deserialize_with_features_array() {
let input = json!({
"features": ["serde", "tokio"],
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool.expect("Deserialization should succeed with features array");
assert_eq!(
tool.features.unwrap(),
["serde".to_owned(), "tokio".to_owned()]
);
}
#[test]
fn test_deserialize_with_single_feature_string() {
let input = json!({
"features": "serde",
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool.expect("Deserialization should succeed with single feature string");
assert_eq!(tool.features.unwrap(), ["serde".to_owned()]);
}
#[test]
fn test_deserialize_with_features_string_array() {
let input = json!({
"features": "[\"serde\",\"tokio\"]",
});
let tool: Result<CargoTestRequest, _> = serde_json::from_value(input);
let tool = tool
.expect("Deserialization should succeed with features string that looks like array");
assert_eq!(tool.features.unwrap(), ["[\"serde\",\"tokio\"]".to_owned()]);
}
}