Skip to main content

mcp_git/
server.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use rmcp::handler::server::router::tool::ToolRouter;
5use rmcp::handler::server::wrapper::Parameters;
6use rmcp::model::*;
7use rmcp::{schemars, tool, tool_handler, tool_router, ServerHandler};
8use serde::Deserialize;
9
10use crate::error::McpGitError;
11
12#[derive(Clone)]
13pub struct RepoEntry {
14    pub name: String,
15    pub path: PathBuf,
16}
17
18#[derive(Clone)]
19pub struct McpGitServer {
20    repos: Arc<Vec<RepoEntry>>,
21    max_diff_lines: u32,
22    max_log_entries: u32,
23    tool_router: ToolRouter<Self>,
24}
25
26// -- Tool parameter types --
27
28#[derive(Debug, Deserialize, schemars::JsonSchema)]
29pub struct RepoParam {
30    #[schemars(description = "Repository name (optional if only one repo is connected)")]
31    #[serde(default)]
32    pub repo: Option<String>,
33}
34
35#[derive(Debug, Deserialize, schemars::JsonSchema)]
36pub struct LogParams {
37    #[schemars(description = "Repository name (optional if only one repo is connected)")]
38    #[serde(default)]
39    pub repo: Option<String>,
40
41    #[schemars(description = "Maximum number of commits to return")]
42    #[serde(default)]
43    pub max_count: Option<u32>,
44
45    #[schemars(description = "Branch or ref to show log for (default: HEAD)")]
46    #[serde(default)]
47    pub branch: Option<String>,
48
49    #[schemars(description = "Filter commits by author name or email")]
50    #[serde(default)]
51    pub author: Option<String>,
52}
53
54#[derive(Debug, Deserialize, schemars::JsonSchema)]
55pub struct DiffParams {
56    #[schemars(description = "Repository name (optional if only one repo is connected)")]
57    #[serde(default)]
58    pub repo: Option<String>,
59
60    #[schemars(description = "Starting ref (commit SHA, branch, or tag)")]
61    pub from_ref: String,
62
63    #[schemars(description = "Ending ref (commit SHA, branch, or tag). Default: HEAD")]
64    #[serde(default)]
65    pub to_ref: Option<String>,
66
67    #[schemars(description = "Filter diff to a specific file path")]
68    #[serde(default)]
69    pub path: Option<String>,
70}
71
72#[derive(Debug, Deserialize, schemars::JsonSchema)]
73pub struct CommitParams {
74    #[schemars(description = "Repository name (optional if only one repo is connected)")]
75    #[serde(default)]
76    pub repo: Option<String>,
77
78    #[schemars(description = "Commit SHA or ref to show")]
79    pub commit: String,
80}
81
82#[derive(Debug, Deserialize, schemars::JsonSchema)]
83pub struct SearchParams {
84    #[schemars(description = "Repository name (optional if only one repo is connected)")]
85    #[serde(default)]
86    pub repo: Option<String>,
87
88    #[schemars(description = "Search query to match against commit messages")]
89    pub query: String,
90
91    #[schemars(description = "Maximum number of results to return")]
92    #[serde(default)]
93    pub max_count: Option<u32>,
94}
95
96impl McpGitServer {
97    pub fn new(repos: Vec<RepoEntry>, max_diff_lines: u32, max_log_entries: u32) -> Self {
98        Self {
99            repos: Arc::new(repos),
100            max_diff_lines,
101            max_log_entries,
102            tool_router: Self::tool_router(),
103        }
104    }
105
106    fn resolve(&self, name: Option<&str>) -> Result<&RepoEntry, McpGitError> {
107        match name {
108            Some(n) => self
109                .repos
110                .iter()
111                .find(|r| r.name == n)
112                .ok_or_else(|| McpGitError::RepoNotFound(n.to_string())),
113            None if self.repos.len() == 1 => Ok(&self.repos[0]),
114            None => Err(McpGitError::AmbiguousRepo),
115        }
116    }
117
118    fn open_repo(&self, entry: &RepoEntry) -> Result<gix::Repository, McpGitError> {
119        gix::discover(&entry.path)
120            .map_err(|e| McpGitError::Git(format!("Cannot open repository '{}': {}", entry.name, e)))
121    }
122
123    fn err(&self, e: McpGitError) -> ErrorData {
124        e.to_mcp_error()
125    }
126}
127
128// -- Public methods for testability --
129
130impl McpGitServer {
131    pub fn do_list_repos(&self) -> Result<CallToolResult, ErrorData> {
132        let mut results = Vec::new();
133        for entry in self.repos.iter() {
134            let branch = match self.open_repo(entry) {
135                Ok(repo) => repo
136                    .head_name()
137                    .ok()
138                    .flatten()
139                    .map(|r| r.shorten().to_string())
140                    .unwrap_or_else(|| "detached".to_string()),
141                Err(_) => "unknown".to_string(),
142            };
143
144            results.push(serde_json::json!({
145                "name": entry.name,
146                "path": entry.path.display().to_string(),
147                "branch": branch,
148            }));
149        }
150
151        let text =
152            serde_json::to_string_pretty(&results).unwrap_or_else(|_| "[]".to_string());
153        Ok(CallToolResult::success(vec![Content::text(text)]))
154    }
155
156    pub fn do_log(&self, params: LogParams) -> Result<CallToolResult, ErrorData> {
157        let entry = self.resolve(params.repo.as_deref()).map_err(|e| self.err(e))?;
158        let repo = self.open_repo(entry).map_err(|e| self.err(e))?;
159
160        let max = params.max_count.unwrap_or(self.max_log_entries);
161        let rev_spec = params.branch.as_deref().unwrap_or("HEAD");
162
163        let commit_id = repo
164            .rev_parse_single(gix::bstr::BStr::new(rev_spec.as_bytes()))
165            .map_err(|e| self.err(McpGitError::InvalidRef(format!("{}: {}", rev_spec, e))))?
166            .detach();
167
168        let mut commits = Vec::new();
169        let walk = repo
170            .rev_walk([commit_id])
171            .all()
172            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
173
174        for info in walk {
175            if commits.len() >= max as usize {
176                break;
177            }
178            let info = info.map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
179            let commit = info
180                .object()
181                .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
182
183            let author = commit.author().map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
184            let author_name = author.name.to_string();
185            let author_email = author.email.to_string();
186            let message = commit.message_raw_sloppy().to_string();
187            let time = author.time.seconds;
188
189            // Apply author filter if specified
190            if let Some(ref filter) = params.author {
191                let filter_lower = filter.to_lowercase();
192                if !author_name.to_lowercase().contains(&filter_lower)
193                    && !author_email.to_lowercase().contains(&filter_lower)
194                {
195                    continue;
196                }
197            }
198
199            commits.push(serde_json::json!({
200                "sha": commit.id().to_string(),
201                "author": format!("{} <{}>", author_name, author_email),
202                "timestamp": time,
203                "message": message.trim(),
204            }));
205        }
206
207        let text = serde_json::to_string_pretty(&serde_json::json!({
208            "commits": commits,
209            "count": commits.len(),
210        }))
211        .unwrap_or_else(|_| "{}".to_string());
212        Ok(CallToolResult::success(vec![Content::text(text)]))
213    }
214
215    pub fn do_diff(&self, params: DiffParams) -> Result<CallToolResult, ErrorData> {
216        let entry = self.resolve(params.repo.as_deref()).map_err(|e| self.err(e))?;
217        let repo = self.open_repo(entry).map_err(|e| self.err(e))?;
218
219        let from = repo
220            .rev_parse_single(gix::bstr::BStr::new(params.from_ref.as_bytes()))
221            .map_err(|e| self.err(McpGitError::InvalidRef(format!("{}: {}", params.from_ref, e))))?;
222        let to_ref = params.to_ref.as_deref().unwrap_or("HEAD");
223        let to = repo
224            .rev_parse_single(gix::bstr::BStr::new(to_ref.as_bytes()))
225            .map_err(|e| self.err(McpGitError::InvalidRef(format!("{}: {}", to_ref, e))))?;
226
227        let from_commit = repo
228            .find_object(from)
229            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?
230            .try_into_commit()
231            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
232        let to_commit = repo
233            .find_object(to)
234            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?
235            .try_into_commit()
236            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
237
238        let from_tree = from_commit
239            .tree()
240            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
241        let to_tree = to_commit
242            .tree()
243            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
244
245        // Compute tree diff to find changed files
246        use gix::object::tree::diff::{Action as DiffAction, Change as DiffChange};
247        let mut changes = Vec::new();
248        let max_files = self.max_diff_lines as usize;
249
250        from_tree
251            .changes()
252            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?
253            .for_each_to_obtain_tree(&to_tree, |change: DiffChange<'_, '_, '_>| {
254                let path = change.location().to_string();
255
256                // Apply path filter if specified
257                if let Some(ref filter_path) = params.path {
258                    if !path.starts_with(filter_path.as_str()) {
259                        return Ok::<_, std::convert::Infallible>(DiffAction::Continue);
260                    }
261                }
262
263                let change_type = match &change {
264                    DiffChange::Addition { .. } => "added",
265                    DiffChange::Deletion { .. } => "deleted",
266                    DiffChange::Modification { .. } => "modified",
267                    DiffChange::Rewrite { copy: true, .. } => "copied",
268                    DiffChange::Rewrite { .. } => "renamed",
269                };
270
271                if changes.len() < max_files {
272                    changes.push(serde_json::json!({
273                        "path": path,
274                        "change": change_type,
275                    }));
276                }
277
278                Ok(DiffAction::Continue)
279            })
280            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
281
282        let text = serde_json::to_string_pretty(&serde_json::json!({
283            "from": params.from_ref,
284            "to": to_ref,
285            "from_sha": from_commit.id().to_string(),
286            "to_sha": to_commit.id().to_string(),
287            "files": changes,
288            "file_count": changes.len(),
289        }))
290        .unwrap_or_else(|_| "{}".to_string());
291        Ok(CallToolResult::success(vec![Content::text(text)]))
292    }
293
294    pub fn do_show_commit(&self, params: CommitParams) -> Result<CallToolResult, ErrorData> {
295        let entry = self.resolve(params.repo.as_deref()).map_err(|e| self.err(e))?;
296        let repo = self.open_repo(entry).map_err(|e| self.err(e))?;
297
298        let id = repo
299            .rev_parse_single(gix::bstr::BStr::new(params.commit.as_bytes()))
300            .map_err(|e| self.err(McpGitError::InvalidRef(format!("{}: {}", params.commit, e))))?;
301
302        let commit = repo
303            .find_object(id)
304            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?
305            .try_into_commit()
306            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
307
308        let author = commit.author().map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
309        let committer = commit.committer().map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
310        let message = commit.message_raw_sloppy().to_string();
311        let time = author.time.seconds;
312
313        let parent_ids: Vec<String> = commit
314            .parent_ids()
315            .map(|id| id.to_string())
316            .collect();
317
318        let text = serde_json::to_string_pretty(&serde_json::json!({
319            "sha": commit.id().to_string(),
320            "author": format!("{} <{}>", author.name, author.email),
321            "committer": format!("{} <{}>", committer.name, committer.email),
322            "timestamp": time,
323            "message": message.trim(),
324            "parents": parent_ids,
325        }))
326        .unwrap_or_else(|_| "{}".to_string());
327        Ok(CallToolResult::success(vec![Content::text(text)]))
328    }
329
330    pub fn do_list_branches(&self, params: RepoParam) -> Result<CallToolResult, ErrorData> {
331        let entry = self.resolve(params.repo.as_deref()).map_err(|e| self.err(e))?;
332        let repo = self.open_repo(entry).map_err(|e| self.err(e))?;
333
334        let head_name = repo
335            .head_name()
336            .ok()
337            .flatten()
338            .map(|r| r.shorten().to_string());
339
340        let platform = repo
341            .references()
342            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
343
344        let local = platform
345            .local_branches()
346            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
347
348        let mut branches = Vec::new();
349        for reference in local.flatten() {
350            let name = reference.name().shorten().to_string();
351            let is_current = head_name.as_deref() == Some(name.as_str());
352            branches.push(serde_json::json!({
353                "name": name,
354                "current": is_current,
355            }));
356        }
357
358        // Also list remote branches
359        let remote = platform
360            .remote_branches()
361            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
362
363        let mut remote_branches = Vec::new();
364        for reference in remote.flatten() {
365            let name = reference.name().shorten().to_string();
366            remote_branches.push(serde_json::json!({
367                "name": name,
368            }));
369        }
370
371        let text = serde_json::to_string_pretty(&serde_json::json!({
372            "local": branches,
373            "remote": remote_branches,
374        }))
375        .unwrap_or_else(|_| "{}".to_string());
376        Ok(CallToolResult::success(vec![Content::text(text)]))
377    }
378
379    pub fn do_search_commits(&self, params: SearchParams) -> Result<CallToolResult, ErrorData> {
380        let entry = self.resolve(params.repo.as_deref()).map_err(|e| self.err(e))?;
381        let repo = self.open_repo(entry).map_err(|e| self.err(e))?;
382        let max = params.max_count.unwrap_or(self.max_log_entries);
383
384        let head = repo
385            .head_id()
386            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
387
388        let walk = repo
389            .rev_walk([head.detach()])
390            .all()
391            .map_err(|e| self.err(McpGitError::Git(e.to_string())))?;
392
393        let query_lower = params.query.to_lowercase();
394        let mut matches = Vec::new();
395
396        for info in walk.flatten() {
397            let commit = match info.object() {
398                Ok(c) => c,
399                Err(_) => continue,
400            };
401
402            let message = commit.message_raw_sloppy().to_string();
403            if message.to_lowercase().contains(&query_lower) {
404                let author_str = match commit.author() {
405                    Ok(a) => format!("{} <{}>", a.name, a.email),
406                    Err(_) => "unknown".to_string(),
407                };
408                matches.push(serde_json::json!({
409                    "sha": commit.id().to_string(),
410                    "author": author_str,
411                    "message": message.trim(),
412                }));
413            }
414
415            if matches.len() >= max as usize {
416                break;
417            }
418        }
419
420        let text = serde_json::to_string_pretty(&serde_json::json!({
421            "query": params.query,
422            "matches": matches,
423            "count": matches.len(),
424        }))
425        .unwrap_or_else(|_| "{}".to_string());
426        Ok(CallToolResult::success(vec![Content::text(text)]))
427    }
428}
429
430// -- MCP tool handlers (thin wrappers) --
431
432#[tool_router]
433impl McpGitServer {
434    #[tool(
435        name = "list_repos",
436        description = "List all connected Git repositories with their paths and current branch"
437    )]
438    async fn list_repos(&self) -> Result<CallToolResult, ErrorData> {
439        self.do_list_repos()
440    }
441
442    #[tool(
443        name = "log",
444        description = "Show commit history for a repository. Returns commit SHA, author, date, and message."
445    )]
446    async fn log(
447        &self,
448        Parameters(params): Parameters<LogParams>,
449    ) -> Result<CallToolResult, ErrorData> {
450        self.do_log(params)
451    }
452
453    #[tool(
454        name = "diff",
455        description = "Show the diff between two refs (commits, branches, or tags)"
456    )]
457    async fn diff(
458        &self,
459        Parameters(params): Parameters<DiffParams>,
460    ) -> Result<CallToolResult, ErrorData> {
461        self.do_diff(params)
462    }
463
464    #[tool(
465        name = "show_commit",
466        description = "Show details of a specific commit including message, author, date, and files changed"
467    )]
468    async fn show_commit(
469        &self,
470        Parameters(params): Parameters<CommitParams>,
471    ) -> Result<CallToolResult, ErrorData> {
472        self.do_show_commit(params)
473    }
474
475    #[tool(
476        name = "list_branches",
477        description = "List all branches in the repository with current branch marked"
478    )]
479    async fn list_branches(
480        &self,
481        Parameters(params): Parameters<RepoParam>,
482    ) -> Result<CallToolResult, ErrorData> {
483        self.do_list_branches(params)
484    }
485
486    #[tool(
487        name = "search_commits",
488        description = "Search commit messages for a given query string"
489    )]
490    async fn search_commits(
491        &self,
492        Parameters(params): Parameters<SearchParams>,
493    ) -> Result<CallToolResult, ErrorData> {
494        self.do_search_commits(params)
495    }
496}
497
498#[tool_handler]
499impl ServerHandler for McpGitServer {
500    fn get_info(&self) -> ServerInfo {
501        ServerInfo {
502            protocol_version: ProtocolVersion::V_2024_11_05,
503            capabilities: ServerCapabilities::builder().enable_tools().build(),
504            server_info: Implementation {
505                name: "mcp-git".to_string(),
506                version: env!("CARGO_PKG_VERSION").to_string(),
507                ..Default::default()
508            },
509            instructions: Some(
510                "Git repository server. Use list_repos to see connected repositories, \
511                 log to view commit history, diff to compare refs, show_commit for commit details, \
512                 list_branches to see branches, and search_commits to search commit messages."
513                    .to_string(),
514            ),
515        }
516    }
517}