use rmcp::{
Error as McpError,
model::*, service::*,
transport::io,
};
use serde_json::{Map, Value, json};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use tokio_util::sync::CancellationToken;
fn create_schema_object(
properties: Vec<(&str, Value)>,
required: Vec<&str>,
) -> Arc<Map<String, Value>> {
let props_map: Map<String, Value> = properties
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect();
let req_vec: Vec<Value> = required
.into_iter()
.map(|s| Value::String(s.to_string()))
.collect();
let schema = json!({
"type": "object",
"properties": props_map,
"required": req_vec
});
let map = match schema {
Value::Object(map) => map,
_ => Map::new(),
};
Arc::new(map)
}
#[derive(Debug, Clone)]
struct ShellServer {
peer: Arc<Mutex<Option<Peer<RoleServer>>>>,
tools: Arc<HashMap<String, Tool>>,
}
impl ShellServer {
fn new() -> Self {
let mut tools = HashMap::new();
let shell_schema = create_schema_object(
vec![
(
"command",
json!({ "type": "string", "description": "The shell command to execute." }),
),
(
"workdir",
json!({ "type": "string", "description": "Optional working directory." }),
),
],
vec!["command"],
);
tools.insert(
"shell".to_string(),
Tool {
name: "shell".into(),
description: "Executes a shell command.".into(),
input_schema: shell_schema,
},
);
Self {
peer: Arc::new(Mutex::new(None)),
tools: Arc::new(tools),
}
}
async fn execute_shell_command(
command: &str,
workdir: Option<&str>,
) -> Result<(Vec<Annotated<RawContent>>, bool), McpError> {
let mut cmd_expr = duct::cmd!("/bin/sh", "-c", command); if let Some(dir) = workdir {
cmd_expr = cmd_expr.dir(dir);
}
let output_result = cmd_expr.stdout_capture().stderr_capture().unchecked().run();
let (content_vec, is_error) = match output_result {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let exit_code = output.status.code().unwrap_or(-1);
let result_text = format!(
"Exit Code: {}\n--- STDOUT ---\n{}\n--- STDERR ---\n{}",
exit_code, stdout, stderr
);
let raw_content = RawContent::Text(RawTextContent { text: result_text });
let annotated = Annotated {
raw: raw_content,
annotations: None,
};
(vec![annotated], !output.status.success())
}
Err(e) => {
let error_text = format!("Failed to execute command '{}': {}", command, e);
eprintln!("Execute Error: {}", error_text);
let raw_content = RawContent::Text(RawTextContent { text: error_text });
let annotated = Annotated {
raw: raw_content,
annotations: None,
};
(vec![annotated], true)
}
};
Ok((content_vec, is_error))
}
fn handle_shell_call(
&self,
params: CallToolRequestParam,
) -> Pin<Box<dyn Future<Output = Result<CallToolResult, McpError>> + Send + '_>> {
Box::pin(async move {
let args_map: Map<String, Value> = params
.arguments
.ok_or_else(|| McpError::invalid_params("Missing arguments", None))?;
let command = args_map
.get("command")
.and_then(Value::as_str)
.ok_or_else(|| McpError::invalid_params("Missing 'command' argument", None))?;
let workdir = args_map.get("workdir").and_then(Value::as_str);
let (content_vec, is_error): (Vec<Annotated<RawContent>>, bool) =
Self::execute_shell_command(command, workdir).await?;
Ok(CallToolResult {
content: content_vec,
is_error: Some(is_error),
})
})
}
}
impl Service<RoleServer> for ShellServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::LATEST,
capabilities: ServerCapabilities {
tools: Some(ToolsCapability {
list_changed: Some(true),
}),
..Default::default()
},
server_info: Implementation {
name: "volition-shell-server".into(),
version: env!("CARGO_PKG_VERSION").into(),
},
instructions: None,
}
}
fn get_peer(&self) -> Option<Peer<RoleServer>> {
self.peer.lock().unwrap().clone()
}
fn set_peer(&mut self, peer: Peer<RoleServer>) {
*self.peer.lock().unwrap() = Some(peer);
}
#[allow(refining_impl_trait)]
fn handle_request(
&self,
request: ClientRequest,
_context: RequestContext<RoleServer>,
) -> Pin<Box<dyn Future<Output = Result<ServerResult, McpError>> + Send + '_>> {
let self_clone = self.clone();
Box::pin(async move {
match request {
ClientRequest::InitializeRequest(_params) => {
eprintln!(
"Received InitializeRequest (handled in handle_request - should not happen with current rmcp)"
); let server_info = rmcp::Service::get_info(&self_clone);
Ok(ServerResult::InitializeResult(InitializeResult {
protocol_version: server_info.protocol_version, capabilities: server_info.capabilities,
server_info: server_info.server_info,
instructions: server_info.instructions,
}))
}
ClientRequest::ListToolsRequest(Request { .. }) => {
Ok(ServerResult::ListToolsResult(ListToolsResult {
tools: self_clone.tools.values().cloned().collect(),
next_cursor: None,
}))
}
ClientRequest::CallToolRequest(Request { params, .. }) => {
if params.name == "shell" {
self_clone
.handle_shell_call(params)
.await
.map(ServerResult::CallToolResult)
} else {
Err(McpError::method_not_found::<CallToolRequestMethod>())
}
}
_ => Err(McpError::method_not_found::<InitializeResultMethod>()),
}
})
}
#[allow(refining_impl_trait)]
fn handle_notification(
&self,
_notification: ClientNotification,
) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
Box::pin(async { Ok(()) }) }
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let server = ShellServer::new();
let transport = io::stdio();
let ct = CancellationToken::new();
eprintln!("Starting shell MCP server...");
if let Err(e) = server.serve_with_ct(transport, ct.clone()).await {
eprintln!("Server loop failed: {}", e);
}
ct.cancelled().await;
eprintln!("Shell MCP server stopped.");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
fn get_text_from_result(
result: Result<(Vec<Annotated<RawContent>>, bool), McpError>,
) -> (String, bool) {
match result {
Ok((content_vec, is_error)) => {
let text = content_vec
.first()
.map_or(String::new(), |annotated| match &annotated.raw {
RawContent::Text(t) => t.text.clone(),
_ => String::new(),
});
(text, is_error)
}
Err(e) => (format!("MCP Error: {}", e), true),
}
}
#[tokio::test]
async fn test_echo_absolute_path() -> Result<()> {
let command = "/bin/echo hello world";
let (output, is_error) =
get_text_from_result(ShellServer::execute_shell_command(command, None).await);
println!("Test Output ({}): {}", command, output); assert!(!is_error, "Command '{}' should succeed", command);
assert!(
output.contains("hello world"),
"Output should contain 'hello world'"
);
assert!(output.contains("Exit Code: 0"), "Exit code should be 0");
Ok(())
}
#[tokio::test]
async fn test_git_version_absolute_path() -> Result<()> {
let command = "/usr/bin/git --version";
if !std::path::Path::new("/usr/bin/git").exists() {
println!("Skipping test_git_version_absolute_path: /usr/bin/git not found.");
return Ok(());
}
let (output, is_error) =
get_text_from_result(ShellServer::execute_shell_command(command, None).await);
println!("Test Output ({}): {}", command, output); assert!(!is_error, "Command '{}' should succeed", command);
assert!(
output.contains("git version"),
"Output should contain 'git version'"
);
assert!(output.contains("Exit Code: 0"), "Exit code should be 0");
Ok(())
}
#[tokio::test]
async fn test_git_version_path() -> Result<()> {
let command = "git --version"; let path_check = duct::cmd!("which", "git").stdout_capture().run(); if path_check.is_err() || !path_check.unwrap().status.success() {
println!("Skipping test_git_version_path: 'git' not found in PATH.");
return Ok(());
}
let (output, is_error) =
get_text_from_result(ShellServer::execute_shell_command(command, None).await);
println!("Test Output ({}): {}", command, output); assert!(
!is_error,
"Command '{}' should succeed if git is in PATH",
command
);
assert!(
output.contains("git version"),
"Output should contain 'git version'"
);
assert!(output.contains("Exit Code: 0"), "Exit code should be 0");
Ok(())
}
#[tokio::test]
async fn test_command_not_found() -> Result<()> {
let command = "this_command_should_not_exist_qwertyuiop";
let (output, is_error) =
get_text_from_result(ShellServer::execute_shell_command(command, None).await);
println!("Test Output ({}): {}", command, output); assert!(is_error, "Execution should result in an error status");
assert!(
output.contains("not found") || output.contains("No such file"),
"Error message should indicate command not found by shell"
);
Ok(())
}
#[tokio::test]
async fn test_command_with_args() -> Result<()> {
let command = "/bin/ls -l"; let (output, is_error) =
get_text_from_result(ShellServer::execute_shell_command(command, None).await);
println!("Test Output ({}): {}", command, output); assert!(!is_error, "Command '{}' should succeed", command);
assert!(
output.contains("total"),
"Output should contain typical ls -l output"
); assert!(output.contains("Exit Code: 0"), "Exit code should be 0");
Ok(())
}
}