Skip to main content

mcp_github/
server.rs

1use std::sync::Arc;
2
3use rmcp::handler::server::router::tool::ToolRouter;
4use rmcp::handler::server::wrapper::Parameters;
5use rmcp::model::*;
6use rmcp::{schemars, tool, tool_handler, tool_router, ServerHandler};
7use serde::Deserialize;
8
9use crate::error::McpGithubError;
10
11#[derive(Clone)]
12pub struct McpGithubServer {
13    github: Arc<octocrab::Octocrab>,
14    default_owner: Option<String>,
15    max_results: u32,
16    tool_router: ToolRouter<Self>,
17}
18
19// -- Tool parameter types --
20
21#[derive(Debug, Deserialize, schemars::JsonSchema)]
22pub struct OwnerParam {
23    #[schemars(description = "GitHub user or organization name")]
24    #[serde(default)]
25    pub owner: Option<String>,
26}
27
28#[derive(Debug, Deserialize, schemars::JsonSchema)]
29pub struct RepoParams {
30    #[schemars(description = "Repository owner (user or org)")]
31    #[serde(default)]
32    pub owner: Option<String>,
33
34    #[schemars(description = "Repository name")]
35    pub repo: String,
36}
37
38#[derive(Debug, Deserialize, schemars::JsonSchema)]
39pub struct ListIssuesParams {
40    #[schemars(description = "Repository owner (user or org)")]
41    #[serde(default)]
42    pub owner: Option<String>,
43
44    #[schemars(description = "Repository name")]
45    pub repo: String,
46
47    #[schemars(description = "Filter by state: open, closed, or all (default: open)")]
48    #[serde(default)]
49    pub state: Option<String>,
50
51    #[schemars(description = "Filter by comma-separated label names")]
52    #[serde(default)]
53    pub labels: Option<String>,
54
55    #[schemars(description = "Maximum number of results")]
56    #[serde(default)]
57    pub per_page: Option<u32>,
58}
59
60#[derive(Debug, Deserialize, schemars::JsonSchema)]
61pub struct IssueParams {
62    #[schemars(description = "Repository owner (user or org)")]
63    #[serde(default)]
64    pub owner: Option<String>,
65
66    #[schemars(description = "Repository name")]
67    pub repo: String,
68
69    #[schemars(description = "Issue number")]
70    pub issue_number: u64,
71}
72
73#[derive(Debug, Deserialize, schemars::JsonSchema)]
74pub struct ListPullsParams {
75    #[schemars(description = "Repository owner (user or org)")]
76    #[serde(default)]
77    pub owner: Option<String>,
78
79    #[schemars(description = "Repository name")]
80    pub repo: String,
81
82    #[schemars(description = "Filter by state: open, closed, or all (default: open)")]
83    #[serde(default)]
84    pub state: Option<String>,
85
86    #[schemars(description = "Maximum number of results")]
87    #[serde(default)]
88    pub per_page: Option<u32>,
89}
90
91#[derive(Debug, Deserialize, schemars::JsonSchema)]
92pub struct PullParams {
93    #[schemars(description = "Repository owner (user or org)")]
94    #[serde(default)]
95    pub owner: Option<String>,
96
97    #[schemars(description = "Repository name")]
98    pub repo: String,
99
100    #[schemars(description = "Pull request number")]
101    pub pr_number: u64,
102}
103
104#[derive(Debug, Deserialize, schemars::JsonSchema)]
105pub struct SearchCodeParams {
106    #[schemars(description = "Search query (GitHub code search syntax)")]
107    pub query: String,
108
109    #[schemars(description = "Scope search to this owner/org")]
110    #[serde(default)]
111    pub owner: Option<String>,
112
113    #[schemars(description = "Scope search to this repository")]
114    #[serde(default)]
115    pub repo: Option<String>,
116
117    #[schemars(description = "Maximum number of results")]
118    #[serde(default)]
119    pub per_page: Option<u32>,
120}
121
122#[derive(Debug, Deserialize, schemars::JsonSchema)]
123pub struct ActionsParams {
124    #[schemars(description = "Repository owner (user or org)")]
125    #[serde(default)]
126    pub owner: Option<String>,
127
128    #[schemars(description = "Repository name")]
129    pub repo: String,
130
131    #[schemars(description = "Filter by status: completed, in_progress, queued")]
132    #[serde(default)]
133    pub status: Option<String>,
134
135    #[schemars(description = "Maximum number of results")]
136    #[serde(default)]
137    pub per_page: Option<u32>,
138}
139
140#[derive(Debug, Deserialize, schemars::JsonSchema)]
141pub struct ListCommitsParams {
142    #[schemars(description = "Repository owner (user or org)")]
143    #[serde(default)]
144    pub owner: Option<String>,
145
146    #[schemars(description = "Repository name")]
147    pub repo: String,
148
149    #[schemars(description = "Branch or tag name (default: repo's default branch)")]
150    #[serde(default)]
151    pub sha: Option<String>,
152
153    #[schemars(description = "Filter commits by author (GitHub username or email)")]
154    #[serde(default)]
155    pub author: Option<String>,
156
157    #[schemars(description = "Maximum number of results")]
158    #[serde(default)]
159    pub per_page: Option<u32>,
160}
161
162#[derive(Debug, Deserialize, schemars::JsonSchema)]
163pub struct CommitRefParams {
164    #[schemars(description = "Repository owner (user or org)")]
165    #[serde(default)]
166    pub owner: Option<String>,
167
168    #[schemars(description = "Repository name")]
169    pub repo: String,
170
171    #[schemars(description = "Commit SHA, branch name, or tag")]
172    #[serde(rename = "ref")]
173    pub sha: String,
174}
175
176#[derive(Debug, Deserialize, schemars::JsonSchema)]
177pub struct RepoPageParams {
178    #[schemars(description = "Repository owner (user or org)")]
179    #[serde(default)]
180    pub owner: Option<String>,
181
182    #[schemars(description = "Repository name")]
183    pub repo: String,
184
185    #[schemars(description = "Maximum number of results")]
186    #[serde(default)]
187    pub per_page: Option<u32>,
188}
189
190#[derive(Debug, Deserialize, schemars::JsonSchema)]
191pub struct FileContentsParams {
192    #[schemars(description = "Repository owner (user or org)")]
193    #[serde(default)]
194    pub owner: Option<String>,
195
196    #[schemars(description = "Repository name")]
197    pub repo: String,
198
199    #[schemars(description = "File path within the repository")]
200    pub path: String,
201
202    #[schemars(description = "Git ref (branch, tag, or SHA). Defaults to the repo's default branch")]
203    #[serde(default, rename = "ref")]
204    pub git_ref: Option<String>,
205}
206
207impl McpGithubServer {
208    pub fn new(
209        github: octocrab::Octocrab,
210        default_owner: Option<String>,
211        max_results: u32,
212    ) -> Self {
213        Self {
214            github: Arc::new(github),
215            default_owner,
216            max_results,
217            tool_router: Self::tool_router(),
218        }
219    }
220
221    fn resolve_owner(&self, param: Option<&str>) -> Result<String, McpGithubError> {
222        param
223            .map(String::from)
224            .or_else(|| self.default_owner.clone())
225            .ok_or_else(|| {
226                McpGithubError::MissingParam(
227                    "owner is required (or set --owner default)".to_string(),
228                )
229            })
230    }
231
232    /// Cap per_page to 100 (GitHub API maximum) and safely cast to u8.
233    fn capped_per_page(&self, per_page: Option<u32>) -> u8 {
234        std::cmp::min(per_page.unwrap_or(self.max_results), 100) as u8
235    }
236
237    fn err(&self, e: McpGithubError) -> ErrorData {
238        e.to_mcp_error()
239    }
240}
241
242/// Format an issue/PR state as a lowercase string.
243fn format_state(state: &octocrab::models::IssueState) -> &'static str {
244    match state {
245        octocrab::models::IssueState::Open => "open",
246        octocrab::models::IssueState::Closed => "closed",
247        _ => "unknown",
248    }
249}
250
251/// Validate that a GitHub owner/repo name doesn't contain characters that
252/// could be used for URL injection in raw API routes.
253fn sanitize_github_name(name: &str, field: &str) -> Result<(), McpGithubError> {
254    if name.is_empty() {
255        return Err(McpGithubError::MissingParam(format!(
256            "{} must not be empty",
257            field
258        )));
259    }
260    for ch in ['/', '?', '#', '%', '\0', ' ', '\n', '\t'] {
261        if name.contains(ch) {
262            return Err(McpGithubError::MissingParam(format!(
263                "{} contains invalid character '{}'",
264                field, ch
265            )));
266        }
267    }
268    Ok(())
269}
270
271/// Validate a value for use in URL paths or query params. Unlike
272/// `sanitize_github_name`, this allows slashes (for branch names like
273/// `feature/foo` or file paths like `src/main.rs`).
274fn sanitize_url_value(value: &str, field: &str) -> Result<(), McpGithubError> {
275    if value.is_empty() {
276        return Err(McpGithubError::MissingParam(format!(
277            "{} must not be empty",
278            field
279        )));
280    }
281    for ch in ['?', '#', '&', '\0', '\n', '\r', '\t'] {
282        if value.contains(ch) {
283            return Err(McpGithubError::MissingParam(format!(
284                "{} contains invalid character",
285                field
286            )));
287        }
288    }
289    Ok(())
290}
291
292// -- MCP tool handlers (thin wrappers calling do_* methods) --
293
294#[tool_router]
295impl McpGithubServer {
296    #[tool(
297        name = "list_repos",
298        description = "List repositories for a user or organization"
299    )]
300    async fn list_repos(
301        &self,
302        Parameters(params): Parameters<OwnerParam>,
303    ) -> Result<CallToolResult, ErrorData> {
304        let owner = self
305            .resolve_owner(params.owner.as_deref())
306            .map_err(|e| self.err(e))?;
307
308        let per_page = self.capped_per_page(None);
309
310        let page = self
311            .github
312            .orgs(&owner)
313            .list_repos()
314            .per_page(per_page)
315            .send()
316            .await;
317
318        // If org fails, try as user
319        let repos = match page {
320            Ok(page) => page.items,
321            Err(_) => {
322                self.github
323                    .users(&owner)
324                    .repos()
325                    .per_page(per_page)
326                    .send()
327                    .await
328                    .map_err(|e| self.err(McpGithubError::GitHub(e)))?
329                    .items
330            }
331        };
332
333        let results: Vec<serde_json::Value> = repos
334            .iter()
335            .map(|r| {
336                serde_json::json!({
337                    "name": r.name,
338                    "full_name": r.full_name.as_deref().unwrap_or(""),
339                    "description": r.description.as_deref().unwrap_or(""),
340                    "language": r.language.as_ref().map(|l| l.to_string()).unwrap_or_default(),
341                    "stars": r.stargazers_count.unwrap_or(0),
342                    "forks": r.forks_count.unwrap_or(0),
343                    "private": r.private.unwrap_or(false),
344                })
345            })
346            .collect();
347
348        let text = serde_json::to_string_pretty(&serde_json::json!({
349            "owner": owner,
350            "repos": results,
351            "count": results.len(),
352        }))
353        .unwrap_or_else(|_| "{}".to_string());
354        Ok(CallToolResult::success(vec![Content::text(text)]))
355    }
356
357    #[tool(
358        name = "get_repo",
359        description = "Get repository info including description, stars, forks, language, and default branch"
360    )]
361    async fn get_repo(
362        &self,
363        Parameters(params): Parameters<RepoParams>,
364    ) -> Result<CallToolResult, ErrorData> {
365        let owner = self
366            .resolve_owner(params.owner.as_deref())
367            .map_err(|e| self.err(e))?;
368
369        let repo = self
370            .github
371            .repos(&owner, &params.repo)
372            .get()
373            .await
374            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
375
376        let text = serde_json::to_string_pretty(&serde_json::json!({
377            "name": repo.name,
378            "full_name": repo.full_name,
379            "description": repo.description,
380            "language": repo.language,
381            "default_branch": repo.default_branch,
382            "stars": repo.stargazers_count,
383            "forks": repo.forks_count,
384            "open_issues": repo.open_issues_count,
385            "private": repo.private,
386            "created_at": repo.created_at.map(|t| t.to_string()),
387            "updated_at": repo.updated_at.map(|t| t.to_string()),
388        }))
389        .unwrap_or_else(|_| "{}".to_string());
390        Ok(CallToolResult::success(vec![Content::text(text)]))
391    }
392
393    #[tool(
394        name = "list_issues",
395        description = "List issues in a repository, optionally filtered by state and labels"
396    )]
397    async fn list_issues(
398        &self,
399        Parameters(params): Parameters<ListIssuesParams>,
400    ) -> Result<CallToolResult, ErrorData> {
401        let owner = self
402            .resolve_owner(params.owner.as_deref())
403            .map_err(|e| self.err(e))?;
404
405        let per_page = self.capped_per_page(params.per_page);
406
407        let issue_handler = self.github.issues(&owner, &params.repo);
408        let mut request = issue_handler.list().per_page(per_page);
409
410        if let Some(ref state) = params.state {
411            request = match state.as_str() {
412                "open" => request.state(octocrab::params::State::Open),
413                "closed" => request.state(octocrab::params::State::Closed),
414                "all" => request.state(octocrab::params::State::All),
415                _ => request,
416            };
417        }
418
419        let label_list: Vec<String>;
420        if let Some(ref labels) = params.labels {
421            label_list = labels.split(',').map(|s| s.trim().to_string()).collect();
422            request = request.labels(&label_list);
423        }
424
425        let issues = request
426            .send()
427            .await
428            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
429
430        let results: Vec<serde_json::Value> = issues
431            .items
432            .iter()
433            .map(|i| {
434                let labels: Vec<String> = i.labels.iter().map(|l| l.name.clone()).collect();
435                serde_json::json!({
436                    "number": i.number,
437                    "title": i.title,
438                    "state": format_state(&i.state),
439                    "author": i.user.login,
440                    "labels": labels,
441                    "comments": i.comments,
442                    "created_at": i.created_at.to_string(),
443                })
444            })
445            .collect();
446
447        let text = serde_json::to_string_pretty(&serde_json::json!({
448            "repo": format!("{}/{}", owner, params.repo),
449            "issues": results,
450            "count": results.len(),
451        }))
452        .unwrap_or_else(|_| "{}".to_string());
453        Ok(CallToolResult::success(vec![Content::text(text)]))
454    }
455
456    #[tool(
457        name = "get_issue",
458        description = "Get issue details including body and comments"
459    )]
460    async fn get_issue(
461        &self,
462        Parameters(params): Parameters<IssueParams>,
463    ) -> Result<CallToolResult, ErrorData> {
464        let owner = self
465            .resolve_owner(params.owner.as_deref())
466            .map_err(|e| self.err(e))?;
467
468        let issue = self
469            .github
470            .issues(&owner, &params.repo)
471            .get(params.issue_number)
472            .await
473            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
474
475        // Fetch comments
476        let comments = self
477            .github
478            .issues(&owner, &params.repo)
479            .list_comments(params.issue_number)
480            .send()
481            .await
482            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
483
484        let comment_items: Vec<serde_json::Value> = comments
485            .items
486            .iter()
487            .map(|c| {
488                serde_json::json!({
489                    "author": c.user.login,
490                    "body": c.body.as_deref().unwrap_or(""),
491                    "created_at": c.created_at.to_string(),
492                })
493            })
494            .collect();
495
496        let labels: Vec<String> = issue.labels.iter().map(|l| l.name.clone()).collect();
497
498        let text = serde_json::to_string_pretty(&serde_json::json!({
499            "number": issue.number,
500            "title": issue.title,
501            "state": format_state(&issue.state),
502            "author": issue.user.login,
503            "labels": labels,
504            "body": issue.body.as_deref().unwrap_or(""),
505            "comments": comment_items,
506            "created_at": issue.created_at.to_string(),
507        }))
508        .unwrap_or_else(|_| "{}".to_string());
509        Ok(CallToolResult::success(vec![Content::text(text)]))
510    }
511
512    #[tool(
513        name = "list_pulls",
514        description = "List pull requests in a repository"
515    )]
516    async fn list_pulls(
517        &self,
518        Parameters(params): Parameters<ListPullsParams>,
519    ) -> Result<CallToolResult, ErrorData> {
520        let owner = self
521            .resolve_owner(params.owner.as_deref())
522            .map_err(|e| self.err(e))?;
523
524        let per_page = self.capped_per_page(params.per_page);
525
526        let pulls_handler = self.github.pulls(&owner, &params.repo);
527        let mut request = pulls_handler.list().per_page(per_page);
528
529        if let Some(ref state) = params.state {
530            request = match state.as_str() {
531                "open" => request.state(octocrab::params::State::Open),
532                "closed" => request.state(octocrab::params::State::Closed),
533                "all" => request.state(octocrab::params::State::All),
534                _ => request,
535            };
536        }
537
538        let pulls = request
539            .send()
540            .await
541            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
542
543        let results: Vec<serde_json::Value> = pulls
544            .items
545            .iter()
546            .map(|p| {
547                serde_json::json!({
548                    "number": p.number,
549                    "title": p.title.as_deref().unwrap_or(""),
550                    "state": p.state.as_ref().map(format_state).unwrap_or("unknown"),
551                    "author": p.user.as_ref().map(|u| u.login.as_str()).unwrap_or("unknown"),
552                    "head": p.head.ref_field,
553                    "base": p.base.ref_field,
554                    "draft": p.draft,
555                    "created_at": p.created_at.map(|t| t.to_string()),
556                })
557            })
558            .collect();
559
560        let text = serde_json::to_string_pretty(&serde_json::json!({
561            "repo": format!("{}/{}", owner, params.repo),
562            "pulls": results,
563            "count": results.len(),
564        }))
565        .unwrap_or_else(|_| "{}".to_string());
566        Ok(CallToolResult::success(vec![Content::text(text)]))
567    }
568
569    #[tool(
570        name = "get_pull",
571        description = "Get pull request details including review summary and changed files count"
572    )]
573    async fn get_pull(
574        &self,
575        Parameters(params): Parameters<PullParams>,
576    ) -> Result<CallToolResult, ErrorData> {
577        let owner = self
578            .resolve_owner(params.owner.as_deref())
579            .map_err(|e| self.err(e))?;
580
581        let pr = self
582            .github
583            .pulls(&owner, &params.repo)
584            .get(params.pr_number)
585            .await
586            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
587
588        let text = serde_json::to_string_pretty(&serde_json::json!({
589            "number": pr.number,
590            "title": pr.title.as_deref().unwrap_or(""),
591            "state": pr.state.as_ref().map(format_state).unwrap_or("unknown"),
592            "author": pr.user.as_ref().map(|u| u.login.as_str()).unwrap_or("unknown"),
593            "body": pr.body.as_deref().unwrap_or(""),
594            "head": pr.head.ref_field,
595            "base": pr.base.ref_field,
596            "draft": pr.draft,
597            "mergeable": pr.mergeable,
598            "additions": pr.additions,
599            "deletions": pr.deletions,
600            "changed_files": pr.changed_files,
601            "commits": pr.commits,
602            "created_at": pr.created_at.map(|t| t.to_string()),
603            "merged_at": pr.merged_at.map(|t| t.to_string()),
604        }))
605        .unwrap_or_else(|_| "{}".to_string());
606        Ok(CallToolResult::success(vec![Content::text(text)]))
607    }
608
609    #[tool(
610        name = "search_code",
611        description = "Search code across GitHub repositories using GitHub's code search syntax"
612    )]
613    async fn search_code(
614        &self,
615        Parameters(params): Parameters<SearchCodeParams>,
616    ) -> Result<CallToolResult, ErrorData> {
617        let mut query = params.query.clone();
618
619        // Scope to owner/repo if specified
620        if let Some(ref owner) = params.owner.as_ref().or(self.default_owner.as_ref()) {
621            if let Some(ref repo) = params.repo {
622                query = format!("{} repo:{}/{}", query, owner, repo);
623            } else {
624                query = format!("{} org:{}", query, owner);
625            }
626        }
627
628        let per_page = self.capped_per_page(params.per_page);
629
630        let results = self
631            .github
632            .search()
633            .code(&query)
634            .per_page(per_page)
635            .send()
636            .await
637            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
638
639        let items: Vec<serde_json::Value> = results
640            .items
641            .iter()
642            .map(|item| {
643                serde_json::json!({
644                    "name": item.name,
645                    "path": item.path,
646                    "repository": item.repository.full_name.as_deref().unwrap_or(""),
647                    "url": item.html_url,
648                })
649            })
650            .collect();
651
652        let text = serde_json::to_string_pretty(&serde_json::json!({
653            "query": params.query,
654            "results": items,
655            "count": items.len(),
656        }))
657        .unwrap_or_else(|_| "{}".to_string());
658        Ok(CallToolResult::success(vec![Content::text(text)]))
659    }
660
661    #[tool(
662        name = "list_actions_runs",
663        description = "List recent GitHub Actions workflow runs for a repository"
664    )]
665    async fn list_actions_runs(
666        &self,
667        Parameters(params): Parameters<ActionsParams>,
668    ) -> Result<CallToolResult, ErrorData> {
669        let owner = self
670            .resolve_owner(params.owner.as_deref())
671            .map_err(|e| self.err(e))?;
672
673        // Validate owner and repo to prevent URL injection in raw route
674        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
675        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
676
677        let per_page = self.capped_per_page(params.per_page);
678
679        let route = format!(
680            "/repos/{}/{}/actions/runs?per_page={}",
681            owner, params.repo, per_page
682        );
683
684        let response: serde_json::Value = self
685            .github
686            .get(route, None::<&()>)
687            .await
688            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
689
690        let runs = response
691            .get("workflow_runs")
692            .and_then(|r| r.as_array())
693            .map(|arr| {
694                arr.iter()
695                    .map(|run| {
696                        serde_json::json!({
697                            "id": run.get("id"),
698                            "name": run.get("name"),
699                            "status": run.get("status"),
700                            "conclusion": run.get("conclusion"),
701                            "branch": run.get("head_branch"),
702                            "event": run.get("event"),
703                            "created_at": run.get("created_at"),
704                        })
705                    })
706                    .collect::<Vec<_>>()
707            })
708            .unwrap_or_default();
709
710        let text = serde_json::to_string_pretty(&serde_json::json!({
711            "repo": format!("{}/{}", owner, params.repo),
712            "runs": runs,
713            "count": runs.len(),
714        }))
715        .unwrap_or_else(|_| "{}".to_string());
716        Ok(CallToolResult::success(vec![Content::text(text)]))
717    }
718
719    #[tool(
720        name = "list_commits",
721        description = "List commits on a branch or tag"
722    )]
723    async fn list_commits(
724        &self,
725        Parameters(params): Parameters<ListCommitsParams>,
726    ) -> Result<CallToolResult, ErrorData> {
727        let owner = self
728            .resolve_owner(params.owner.as_deref())
729            .map_err(|e| self.err(e))?;
730        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
731        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
732
733        let per_page = self.capped_per_page(params.per_page);
734        let mut route = format!(
735            "/repos/{}/{}/commits?per_page={}",
736            owner, params.repo, per_page
737        );
738        if let Some(ref sha) = params.sha {
739            sanitize_url_value(sha, "sha").map_err(|e| self.err(e))?;
740            route.push_str(&format!("&sha={}", sha));
741        }
742        if let Some(ref author) = params.author {
743            sanitize_url_value(author, "author").map_err(|e| self.err(e))?;
744            route.push_str(&format!("&author={}", author));
745        }
746
747        let response: Vec<serde_json::Value> = self
748            .github
749            .get(&route, None::<&()>)
750            .await
751            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
752
753        let commits: Vec<serde_json::Value> = response
754            .iter()
755            .map(|c| {
756                serde_json::json!({
757                    "sha": c.get("sha"),
758                    "message": c.pointer("/commit/message"),
759                    "author": c.pointer("/commit/author/name"),
760                    "author_login": c.pointer("/author/login"),
761                    "date": c.pointer("/commit/author/date"),
762                })
763            })
764            .collect();
765
766        let text = serde_json::to_string_pretty(&serde_json::json!({
767            "repo": format!("{}/{}", owner, params.repo),
768            "commits": commits,
769            "count": commits.len(),
770        }))
771        .unwrap_or_else(|_| "{}".to_string());
772        Ok(CallToolResult::success(vec![Content::text(text)]))
773    }
774
775    #[tool(
776        name = "get_commit",
777        description = "Get full commit details including changed files"
778    )]
779    async fn get_commit(
780        &self,
781        Parameters(params): Parameters<CommitRefParams>,
782    ) -> Result<CallToolResult, ErrorData> {
783        let owner = self
784            .resolve_owner(params.owner.as_deref())
785            .map_err(|e| self.err(e))?;
786        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
787        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
788        sanitize_url_value(&params.sha, "ref").map_err(|e| self.err(e))?;
789
790        let route = format!(
791            "/repos/{}/{}/commits/{}",
792            owner, params.repo, params.sha
793        );
794
795        let c: serde_json::Value = self
796            .github
797            .get(&route, None::<&()>)
798            .await
799            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
800
801        let files = c
802            .get("files")
803            .and_then(|f| f.as_array())
804            .map(|arr| {
805                arr.iter()
806                    .map(|f| {
807                        serde_json::json!({
808                            "filename": f.get("filename"),
809                            "status": f.get("status"),
810                            "additions": f.get("additions"),
811                            "deletions": f.get("deletions"),
812                            "changes": f.get("changes"),
813                        })
814                    })
815                    .collect::<Vec<_>>()
816            })
817            .unwrap_or_default();
818
819        let file_count = files.len();
820
821        let text = serde_json::to_string_pretty(&serde_json::json!({
822            "sha": c.get("sha"),
823            "message": c.pointer("/commit/message"),
824            "author": c.pointer("/commit/author/name"),
825            "author_login": c.pointer("/author/login"),
826            "date": c.pointer("/commit/author/date"),
827            "parents": c.get("parents").and_then(|p| p.as_array()).map(|arr| {
828                arr.iter().filter_map(|p| p.get("sha")).collect::<Vec<_>>()
829            }),
830            "stats": c.get("stats"),
831            "files": files,
832            "file_count": file_count,
833        }))
834        .unwrap_or_else(|_| "{}".to_string());
835        Ok(CallToolResult::success(vec![Content::text(text)]))
836    }
837
838    #[tool(
839        name = "list_branches",
840        description = "List branches in a repository"
841    )]
842    async fn list_branches(
843        &self,
844        Parameters(params): Parameters<RepoPageParams>,
845    ) -> Result<CallToolResult, ErrorData> {
846        let owner = self
847            .resolve_owner(params.owner.as_deref())
848            .map_err(|e| self.err(e))?;
849        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
850        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
851
852        let per_page = self.capped_per_page(params.per_page);
853        let route = format!(
854            "/repos/{}/{}/branches?per_page={}",
855            owner, params.repo, per_page
856        );
857
858        let response: Vec<serde_json::Value> = self
859            .github
860            .get(&route, None::<&()>)
861            .await
862            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
863
864        let branches: Vec<serde_json::Value> = response
865            .iter()
866            .map(|b| {
867                serde_json::json!({
868                    "name": b.get("name"),
869                    "sha": b.pointer("/commit/sha"),
870                    "protected": b.get("protected"),
871                })
872            })
873            .collect();
874
875        let text = serde_json::to_string_pretty(&serde_json::json!({
876            "repo": format!("{}/{}", owner, params.repo),
877            "branches": branches,
878            "count": branches.len(),
879        }))
880        .unwrap_or_else(|_| "{}".to_string());
881        Ok(CallToolResult::success(vec![Content::text(text)]))
882    }
883
884    #[tool(
885        name = "get_file_contents",
886        description = "Get file content from a repository at a specific ref"
887    )]
888    async fn get_file_contents(
889        &self,
890        Parameters(params): Parameters<FileContentsParams>,
891    ) -> Result<CallToolResult, ErrorData> {
892        let owner = self
893            .resolve_owner(params.owner.as_deref())
894            .map_err(|e| self.err(e))?;
895        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
896        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
897        sanitize_url_value(&params.path, "path").map_err(|e| self.err(e))?;
898
899        let mut route = format!(
900            "/repos/{}/{}/contents/{}",
901            owner, params.repo, params.path
902        );
903        if let Some(ref git_ref) = params.git_ref {
904            sanitize_url_value(git_ref, "ref").map_err(|e| self.err(e))?;
905            route.push_str(&format!("?ref={}", git_ref));
906        }
907
908        let response: serde_json::Value = self
909            .github
910            .get(&route, None::<&()>)
911            .await
912            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
913
914        // Decode base64 content (GitHub returns base64 with embedded newlines)
915        let content = response
916            .get("content")
917            .and_then(|c| c.as_str())
918            .map(|c| {
919                let cleaned: String = c.chars().filter(|ch| !ch.is_whitespace()).collect();
920                use base64::Engine;
921                base64::engine::general_purpose::STANDARD
922                    .decode(&cleaned)
923                    .ok()
924                    .and_then(|bytes| String::from_utf8(bytes).ok())
925                    .unwrap_or_else(|| "[binary content]".to_string())
926            })
927            .unwrap_or_default();
928
929        let text = serde_json::to_string_pretty(&serde_json::json!({
930            "path": response.get("path"),
931            "name": response.get("name"),
932            "size": response.get("size"),
933            "encoding": response.get("encoding"),
934            "content": content,
935            "sha": response.get("sha"),
936        }))
937        .unwrap_or_else(|_| "{}".to_string());
938        Ok(CallToolResult::success(vec![Content::text(text)]))
939    }
940
941    #[tool(
942        name = "list_releases",
943        description = "List releases for a repository"
944    )]
945    async fn list_releases(
946        &self,
947        Parameters(params): Parameters<RepoPageParams>,
948    ) -> Result<CallToolResult, ErrorData> {
949        let owner = self
950            .resolve_owner(params.owner.as_deref())
951            .map_err(|e| self.err(e))?;
952        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
953        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
954
955        let per_page = self.capped_per_page(params.per_page);
956        let route = format!(
957            "/repos/{}/{}/releases?per_page={}",
958            owner, params.repo, per_page
959        );
960
961        let response: Vec<serde_json::Value> = self
962            .github
963            .get(&route, None::<&()>)
964            .await
965            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
966
967        let releases: Vec<serde_json::Value> = response
968            .iter()
969            .map(|r| {
970                let assets = r
971                    .get("assets")
972                    .and_then(|a| a.as_array())
973                    .map(|a| a.len())
974                    .unwrap_or(0);
975                serde_json::json!({
976                    "tag": r.get("tag_name"),
977                    "name": r.get("name"),
978                    "author": r.pointer("/author/login"),
979                    "prerelease": r.get("prerelease"),
980                    "draft": r.get("draft"),
981                    "published_at": r.get("published_at"),
982                    "asset_count": assets,
983                })
984            })
985            .collect();
986
987        let text = serde_json::to_string_pretty(&serde_json::json!({
988            "repo": format!("{}/{}", owner, params.repo),
989            "releases": releases,
990            "count": releases.len(),
991        }))
992        .unwrap_or_else(|_| "{}".to_string());
993        Ok(CallToolResult::success(vec![Content::text(text)]))
994    }
995
996    #[tool(
997        name = "list_tags",
998        description = "List tags in a repository"
999    )]
1000    async fn list_tags(
1001        &self,
1002        Parameters(params): Parameters<RepoPageParams>,
1003    ) -> Result<CallToolResult, ErrorData> {
1004        let owner = self
1005            .resolve_owner(params.owner.as_deref())
1006            .map_err(|e| self.err(e))?;
1007        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
1008        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
1009
1010        let per_page = self.capped_per_page(params.per_page);
1011        let route = format!(
1012            "/repos/{}/{}/tags?per_page={}",
1013            owner, params.repo, per_page
1014        );
1015
1016        let response: Vec<serde_json::Value> = self
1017            .github
1018            .get(&route, None::<&()>)
1019            .await
1020            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
1021
1022        let tags: Vec<serde_json::Value> = response
1023            .iter()
1024            .map(|t| {
1025                serde_json::json!({
1026                    "name": t.get("name"),
1027                    "sha": t.pointer("/commit/sha"),
1028                })
1029            })
1030            .collect();
1031
1032        let text = serde_json::to_string_pretty(&serde_json::json!({
1033            "repo": format!("{}/{}", owner, params.repo),
1034            "tags": tags,
1035            "count": tags.len(),
1036        }))
1037        .unwrap_or_else(|_| "{}".to_string());
1038        Ok(CallToolResult::success(vec![Content::text(text)]))
1039    }
1040}
1041
1042#[tool_handler]
1043impl ServerHandler for McpGithubServer {
1044    fn get_info(&self) -> ServerInfo {
1045        ServerInfo {
1046            protocol_version: ProtocolVersion::V_2024_11_05,
1047            capabilities: ServerCapabilities::builder().enable_tools().build(),
1048            server_info: Implementation {
1049                name: "mcp-github".to_string(),
1050                version: env!("CARGO_PKG_VERSION").to_string(),
1051                ..Default::default()
1052            },
1053            instructions: Some(
1054                "GitHub server. Use list_repos to see repositories, get_repo for repo details, \
1055                 list_issues/get_issue for issues, list_pulls/get_pull for PRs, \
1056                 search_code to search code, list_actions_runs for CI/CD runs, \
1057                 list_commits/get_commit for commit history, list_branches for branches, \
1058                 get_file_contents to read files, list_releases for releases, \
1059                 and list_tags for tags."
1060                    .to_string(),
1061            ),
1062        }
1063    }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069
1070    fn make_server(default_owner: Option<String>, max_results: u32) -> McpGithubServer {
1071        let github = octocrab::Octocrab::default();
1072        McpGithubServer::new(github, default_owner, max_results)
1073    }
1074
1075    // Note: Octocrab::default() requires a Tokio runtime (tower::Buffer),
1076    // so these tests must be async even though they don't await anything.
1077
1078    #[tokio::test]
1079    async fn test_resolve_owner_with_param() {
1080        let server = make_server(None, 30);
1081        let result = server.resolve_owner(Some("my-org"));
1082        assert_eq!(result.unwrap(), "my-org");
1083    }
1084
1085    #[tokio::test]
1086    async fn test_resolve_owner_with_default() {
1087        let server = make_server(Some("default-org".to_string()), 30);
1088        let result = server.resolve_owner(None);
1089        assert_eq!(result.unwrap(), "default-org");
1090    }
1091
1092    #[tokio::test]
1093    async fn test_resolve_owner_param_overrides_default() {
1094        let server = make_server(Some("default-org".to_string()), 30);
1095        let result = server.resolve_owner(Some("explicit-org"));
1096        assert_eq!(result.unwrap(), "explicit-org");
1097    }
1098
1099    #[tokio::test]
1100    async fn test_resolve_owner_missing() {
1101        let server = make_server(None, 30);
1102        let result = server.resolve_owner(None);
1103        assert!(result.is_err());
1104    }
1105
1106    #[tokio::test]
1107    async fn test_capped_per_page_default() {
1108        let server = make_server(None, 30);
1109        assert_eq!(server.capped_per_page(None), 30);
1110    }
1111
1112    #[tokio::test]
1113    async fn test_capped_per_page_explicit() {
1114        let server = make_server(None, 30);
1115        assert_eq!(server.capped_per_page(Some(50)), 50);
1116    }
1117
1118    #[tokio::test]
1119    async fn test_capped_per_page_caps_at_100() {
1120        let server = make_server(None, 30);
1121        assert_eq!(server.capped_per_page(Some(200)), 100);
1122        assert_eq!(server.capped_per_page(Some(1000)), 100);
1123    }
1124
1125    #[tokio::test]
1126    async fn test_capped_per_page_max_results_capped() {
1127        // Even if max_results is set high, it should be capped at 100
1128        let server = make_server(None, 500);
1129        assert_eq!(server.capped_per_page(None), 100);
1130    }
1131
1132    #[test]
1133    fn test_sanitize_github_name_valid() {
1134        assert!(sanitize_github_name("my-org", "owner").is_ok());
1135        assert!(sanitize_github_name("user_name", "owner").is_ok());
1136        assert!(sanitize_github_name("repo.name", "repo").is_ok());
1137    }
1138
1139    #[test]
1140    fn test_sanitize_github_name_empty() {
1141        assert!(sanitize_github_name("", "owner").is_err());
1142    }
1143
1144    #[test]
1145    fn test_sanitize_github_name_slash() {
1146        assert!(sanitize_github_name("owner/repo", "owner").is_err());
1147        assert!(sanitize_github_name("../etc", "owner").is_err());
1148    }
1149
1150    #[test]
1151    fn test_sanitize_github_name_query() {
1152        assert!(sanitize_github_name("owner?evil=1", "owner").is_err());
1153        assert!(sanitize_github_name("repo#fragment", "repo").is_err());
1154    }
1155
1156    #[test]
1157    fn test_sanitize_github_name_whitespace() {
1158        assert!(sanitize_github_name("my repo", "repo").is_err());
1159        assert!(sanitize_github_name("my\nrepo", "repo").is_err());
1160    }
1161
1162    #[test]
1163    fn test_sanitize_url_value_valid() {
1164        assert!(sanitize_url_value("main", "sha").is_ok());
1165        assert!(sanitize_url_value("feature/foo", "sha").is_ok());
1166        assert!(sanitize_url_value("src/main.rs", "path").is_ok());
1167        assert!(sanitize_url_value("user@example.com", "author").is_ok());
1168    }
1169
1170    #[test]
1171    fn test_sanitize_url_value_empty() {
1172        assert!(sanitize_url_value("", "sha").is_err());
1173    }
1174
1175    #[test]
1176    fn test_sanitize_url_value_dangerous_chars() {
1177        assert!(sanitize_url_value("main?evil=1", "sha").is_err());
1178        assert!(sanitize_url_value("main#frag", "sha").is_err());
1179        assert!(sanitize_url_value("val&other=1", "author").is_err());
1180        assert!(sanitize_url_value("val\0x", "path").is_err());
1181        assert!(sanitize_url_value("val\nx", "path").is_err());
1182    }
1183
1184    #[test]
1185    fn test_sanitize_url_value_allows_slashes() {
1186        // Unlike sanitize_github_name, slashes are allowed for branch names and file paths
1187        assert!(sanitize_url_value("feature/my-branch", "sha").is_ok());
1188        assert!(sanitize_url_value("src/lib/utils.rs", "path").is_ok());
1189    }
1190}