use anyhow::Result;
use rmcp::handler::server::wrapper::Parameters;
use rmcp::schemars::{self, JsonSchema};
use rmcp::{ServiceExt, tool, tool_router};
use serde::{Deserialize, Serialize};
use sr_core::git::GitRepo;
#[derive(Debug, Clone)]
pub struct SrMcpServer;
#[derive(Deserialize, JsonSchema)]
pub struct DiffParams {
#[serde(default)]
pub staged: bool,
#[serde(default)]
pub files: Vec<String>,
#[serde(default)]
pub context: usize,
#[serde(default)]
pub name_only: bool,
}
#[derive(Serialize)]
struct DiffOutput {
files: Vec<FileDiff>,
total_additions: usize,
total_deletions: usize,
}
#[derive(Serialize)]
struct FileDiff {
path: String,
status: char,
additions: usize,
deletions: usize,
#[serde(skip_serializing_if = "Vec::is_empty")]
hunks: Vec<Hunk>,
}
#[derive(Serialize)]
struct Hunk {
old_start: usize,
old_lines: usize,
new_start: usize,
new_lines: usize,
changes: Vec<Change>,
}
#[derive(Serialize)]
struct Change {
kind: &'static str,
line: usize,
content: String,
}
fn parse_unified_diff(raw: &str) -> Vec<(String, Vec<Hunk>)> {
let mut files: Vec<(String, Vec<Hunk>)> = Vec::new();
let mut current_path: Option<String> = None;
let mut hunks: Vec<Hunk> = Vec::new();
let mut changes: Vec<Change> = Vec::new();
let mut hunk_header: Option<(usize, usize, usize, usize)> = None;
let mut old_cursor: usize = 0;
let mut new_cursor: usize = 0;
for line in raw.lines() {
if line.starts_with("diff --git ") {
if let Some((os, ol, ns, nl)) = hunk_header.take() {
hunks.push(Hunk {
old_start: os,
old_lines: ol,
new_start: ns,
new_lines: nl,
changes: std::mem::take(&mut changes),
});
}
if let Some(path) = current_path.take() {
files.push((path, std::mem::take(&mut hunks)));
}
if let Some(b_part) = line.split(" b/").last() {
current_path = Some(b_part.to_string());
}
continue;
}
if line.starts_with("@@ ") {
if let Some((os, ol, ns, nl)) = hunk_header.take() {
hunks.push(Hunk {
old_start: os,
old_lines: ol,
new_start: ns,
new_lines: nl,
changes: std::mem::take(&mut changes),
});
}
if let Some(header) = line.strip_prefix("@@ ") {
let parts: Vec<&str> = header.splitn(3, ' ').collect();
if parts.len() >= 2 {
let (os, ol) = parse_hunk_range(parts[0].trim_start_matches('-'));
let (ns, nl) = parse_hunk_range(parts[1].trim_start_matches('+'));
old_cursor = os;
new_cursor = ns;
hunk_header = Some((os, ol, ns, nl));
}
}
continue;
}
if line.starts_with("index ")
|| line.starts_with("--- ")
|| line.starts_with("+++ ")
|| line.starts_with("old mode")
|| line.starts_with("new mode")
|| line.starts_with("new file")
|| line.starts_with("deleted file")
|| line.starts_with("similarity")
|| line.starts_with("rename ")
|| line.starts_with("Binary ")
{
continue;
}
if hunk_header.is_some() {
if let Some(content) = line.strip_prefix('+') {
changes.push(Change {
kind: "add",
line: new_cursor,
content: content.to_string(),
});
new_cursor += 1;
} else if let Some(content) = line.strip_prefix('-') {
changes.push(Change {
kind: "delete",
line: old_cursor,
content: content.to_string(),
});
old_cursor += 1;
} else if let Some(content) = line.strip_prefix(' ') {
changes.push(Change {
kind: "context",
line: new_cursor,
content: content.to_string(),
});
old_cursor += 1;
new_cursor += 1;
}
}
}
if let Some((os, ol, ns, nl)) = hunk_header {
hunks.push(Hunk {
old_start: os,
old_lines: ol,
new_start: ns,
new_lines: nl,
changes,
});
}
if let Some(path) = current_path {
files.push((path, hunks));
}
files
}
fn parse_hunk_range(s: &str) -> (usize, usize) {
if let Some((start, count)) = s.split_once(',') {
(start.parse().unwrap_or(0), count.parse().unwrap_or(0))
} else {
(s.parse().unwrap_or(0), 1)
}
}
#[derive(Deserialize, JsonSchema)]
pub struct LogParams {
#[serde(default = "default_log_count")]
pub count: usize,
pub range: Option<String>,
}
fn default_log_count() -> usize {
10
}
#[derive(Deserialize, JsonSchema)]
pub struct StageParams {
pub files: Vec<String>,
}
#[derive(Deserialize, JsonSchema)]
pub struct CommitParams {
pub r#type: String,
pub scope: Option<String>,
pub description: String,
pub body: Option<String>,
pub footer: Option<String>,
#[serde(default)]
pub files: Vec<String>,
}
#[derive(Deserialize, JsonSchema)]
pub struct BranchParams {
pub name: Option<String>,
}
#[tool_router(server_handler)]
impl SrMcpServer {
#[tool(
name = "sr_status",
description = "Get repository status with file fingerprints. Call this first to see what changed."
)]
async fn status(&self) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
let status = match repo.status_porcelain() {
Ok(s) => s,
Err(e) => return format!("error: {e}"),
};
let statuses = match repo.file_statuses() {
Ok(s) => s,
Err(e) => return format!("error: {e}"),
};
let mut result = String::from("# Repository Status\n\n");
if status.trim().is_empty() {
result.push_str("No changes.\n");
return result;
}
for line in status.lines() {
if !line.is_empty() {
result.push_str(line);
result.push('\n');
}
}
result.push_str(&format!("\n{} file(s) changed\n", statuses.len()));
result
}
#[tool(
name = "sr_diff",
description = "Get structured diff: per-file stats + line-level changes as JSON. Use name_only for file list, then drill into specific files."
)]
async fn diff(&self, Parameters(params): Parameters<DiffParams>) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("{{\"error\":\"{e}\"}}"),
};
let stats = match repo.diff_numstat(params.staged, ¶ms.files) {
Ok(s) => s,
Err(e) => return format!("{{\"error\":\"{e}\"}}"),
};
if stats.is_empty() {
return "{\"files\":[],\"total_additions\":0,\"total_deletions\":0}".to_string();
}
let statuses = repo.file_statuses().unwrap_or_default();
let parsed_hunks = if params.name_only {
Vec::new()
} else {
match repo.diff_unified(params.staged, params.context, ¶ms.files) {
Ok(raw) => parse_unified_diff(&raw),
Err(_) => Vec::new(),
}
};
let hunk_map: std::collections::HashMap<&str, &Vec<Hunk>> =
parsed_hunks.iter().map(|(p, h)| (p.as_str(), h)).collect();
let mut total_add = 0;
let mut total_del = 0;
let mut file_diffs = Vec::new();
for (add, del, path) in &stats {
total_add += add;
total_del += del;
let status = statuses.get(path.as_str()).copied().unwrap_or('M');
let hunks = if params.name_only {
Vec::new()
} else if let Some(h) = hunk_map.get(path.as_str()) {
h.iter()
.map(|hunk| Hunk {
old_start: hunk.old_start,
old_lines: hunk.old_lines,
new_start: hunk.new_start,
new_lines: hunk.new_lines,
changes: hunk
.changes
.iter()
.map(|c| Change {
kind: c.kind,
line: c.line,
content: c.content.clone(),
})
.collect(),
})
.collect()
} else {
Vec::new()
};
file_diffs.push(FileDiff {
path: path.clone(),
status,
additions: *add,
deletions: *del,
hunks,
});
}
let output = DiffOutput {
files: file_diffs,
total_additions: total_add,
total_deletions: total_del,
};
serde_json::to_string(&output).unwrap_or_else(|e| format!("{{\"error\":\"{e}\"}}"))
}
#[tool(
name = "sr_log",
description = "Get commit log. Use range for PR commits or count for recent history."
)]
async fn log(&self, Parameters(params): Parameters<LogParams>) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
if let Some(range) = ¶ms.range {
match repo.log_range(range, None) {
Ok(log) => log,
Err(e) => format!("error: {e}"),
}
} else {
match repo.recent_commits(params.count) {
Ok(log) => log,
Err(e) => format!("error: {e}"),
}
}
}
#[tool(
name = "sr_stage",
description = "Stage files for commit. Use [\".\"] for all changes. Modifies the index."
)]
async fn stage(&self, Parameters(params): Parameters<StageParams>) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
if params.files.is_empty() {
return "error: no files specified".to_string();
}
let mut staged = Vec::new();
let mut failed = Vec::new();
for file in ¶ms.files {
if file == "." {
let s = std::process::Command::new("git")
.args(["-C", &repo.root().to_string_lossy()])
.args(["add", "-A"])
.status();
match s {
Ok(s) if s.success() => staged.push("all files".to_string()),
_ => failed.push("all files".to_string()),
}
} else {
match repo.stage_file(file) {
Ok(true) => staged.push(file.clone()),
_ => failed.push(file.clone()),
}
}
}
let mut result = String::new();
if !staged.is_empty() {
result.push_str(&format!("staged: {}\n", staged.join(", ")));
}
if !failed.is_empty() {
result.push_str(&format!("failed: {}\n", failed.join(", ")));
}
result
}
#[tool(
name = "sr_commit",
description = "Create a conventional commit. Stage files first with sr_stage."
)]
async fn commit(&self, Parameters(params): Parameters<CommitParams>) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
if !params.files.is_empty() {
for file in ¶ms.files {
let _ = repo.stage_file(file);
}
}
match repo.has_staged_changes() {
Ok(false) => return "error: no staged changes to commit".to_string(),
Err(e) => return format!("error: {e}"),
_ => {}
}
let header = match ¶ms.scope {
Some(scope) => format!("{}({}): {}", params.r#type, scope, params.description),
None => format!("{}: {}", params.r#type, params.description),
};
let mut message = header.clone();
if let Some(body) = ¶ms.body {
message.push_str("\n\n");
message.push_str(body);
}
if let Some(footer) = ¶ms.footer {
message.push_str("\n\n");
message.push_str(footer);
}
match repo.commit(&message) {
Ok(()) => {
let sha = repo.head_short().unwrap_or_else(|_| "???".to_string());
format!("{sha} {header}")
}
Err(e) => format!("error: {e}"),
}
}
#[tool(
name = "sr_branch",
description = "Get current branch or create a new one."
)]
async fn branch(&self, Parameters(params): Parameters<BranchParams>) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
match params.name {
None => match repo.current_branch() {
Ok(b) => b,
Err(e) => format!("error: {e}"),
},
Some(name) => {
let status = std::process::Command::new("git")
.args(["-C", &repo.root().to_string_lossy()])
.args(["checkout", "-b", &name])
.status();
match status {
Ok(s) if s.success() => format!("created and switched to branch: {name}"),
_ => format!("error: failed to create branch {name}"),
}
}
}
}
#[tool(
name = "sr_config",
description = "Read sr.yaml config (commit types, release settings, etc.)"
)]
async fn config(&self) -> String {
let repo = match GitRepo::discover() {
Ok(r) => r,
Err(e) => return format!("error: {e}"),
};
match sr_core::config::Config::find_config(repo.root().as_path()) {
Some((path, _)) => match sr_core::config::Config::load(&path) {
Ok(config) => {
serde_json::to_string_pretty(&config).unwrap_or_else(|e| format!("error: {e}"))
}
Err(e) => format!("error loading config: {e}"),
},
None => "no sr.yaml found (using defaults)".to_string(),
}
}
}
pub fn config() -> Result<()> {
let repo = GitRepo::discover()?;
let mcp_path = repo.root().join(".mcp.json");
let config = serde_json::json!({
"mcpServers": {
"sr": {
"command": "sr",
"args": ["mcp", "serve"]
}
}
});
let content = serde_json::to_string_pretty(&config)?;
std::fs::write(&mcp_path, &content)?;
println!("{}", mcp_path.display());
Ok(())
}
pub async fn run() -> Result<()> {
let server = SrMcpServer;
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
let service = server.serve((stdin, stdout)).await?;
service.waiting().await?;
Ok(())
}