Skip to main content

mcp_github/
server.rs

1use std::sync::Arc;
2
3use rmcp::handler::server::router::tool::ToolRouter;
4use rmcp::handler::server::wrapper::Parameters;
5use rmcp::model::*;
6use rmcp::{schemars, tool, tool_handler, tool_router, ServerHandler};
7use serde::Deserialize;
8
9use crate::error::McpGithubError;
10
11#[derive(Clone)]
12pub struct McpGithubServer {
13    github: Arc<octocrab::Octocrab>,
14    default_owner: Option<String>,
15    max_results: u32,
16    tool_router: ToolRouter<Self>,
17}
18
19// -- Tool parameter types --
20
21#[derive(Debug, Deserialize, schemars::JsonSchema)]
22pub struct OwnerParam {
23    #[schemars(description = "GitHub user or organization name")]
24    #[serde(default)]
25    pub owner: Option<String>,
26}
27
28#[derive(Debug, Deserialize, schemars::JsonSchema)]
29pub struct RepoParams {
30    #[schemars(description = "Repository owner (user or org)")]
31    #[serde(default)]
32    pub owner: Option<String>,
33
34    #[schemars(description = "Repository name")]
35    pub repo: String,
36}
37
38#[derive(Debug, Deserialize, schemars::JsonSchema)]
39pub struct ListIssuesParams {
40    #[schemars(description = "Repository owner (user or org)")]
41    #[serde(default)]
42    pub owner: Option<String>,
43
44    #[schemars(description = "Repository name")]
45    pub repo: String,
46
47    #[schemars(description = "Filter by state: open, closed, or all (default: open)")]
48    #[serde(default)]
49    pub state: Option<String>,
50
51    #[schemars(description = "Filter by comma-separated label names")]
52    #[serde(default)]
53    pub labels: Option<String>,
54
55    #[schemars(description = "Maximum number of results")]
56    #[serde(default)]
57    pub per_page: Option<u32>,
58}
59
60#[derive(Debug, Deserialize, schemars::JsonSchema)]
61pub struct IssueParams {
62    #[schemars(description = "Repository owner (user or org)")]
63    #[serde(default)]
64    pub owner: Option<String>,
65
66    #[schemars(description = "Repository name")]
67    pub repo: String,
68
69    #[schemars(description = "Issue number")]
70    pub issue_number: u64,
71}
72
73#[derive(Debug, Deserialize, schemars::JsonSchema)]
74pub struct ListPullsParams {
75    #[schemars(description = "Repository owner (user or org)")]
76    #[serde(default)]
77    pub owner: Option<String>,
78
79    #[schemars(description = "Repository name")]
80    pub repo: String,
81
82    #[schemars(description = "Filter by state: open, closed, or all (default: open)")]
83    #[serde(default)]
84    pub state: Option<String>,
85
86    #[schemars(description = "Maximum number of results")]
87    #[serde(default)]
88    pub per_page: Option<u32>,
89}
90
91#[derive(Debug, Deserialize, schemars::JsonSchema)]
92pub struct PullParams {
93    #[schemars(description = "Repository owner (user or org)")]
94    #[serde(default)]
95    pub owner: Option<String>,
96
97    #[schemars(description = "Repository name")]
98    pub repo: String,
99
100    #[schemars(description = "Pull request number")]
101    pub pr_number: u64,
102}
103
104#[derive(Debug, Deserialize, schemars::JsonSchema)]
105pub struct SearchCodeParams {
106    #[schemars(description = "Search query (GitHub code search syntax)")]
107    pub query: String,
108
109    #[schemars(description = "Scope search to this owner/org")]
110    #[serde(default)]
111    pub owner: Option<String>,
112
113    #[schemars(description = "Scope search to this repository")]
114    #[serde(default)]
115    pub repo: Option<String>,
116
117    #[schemars(description = "Maximum number of results")]
118    #[serde(default)]
119    pub per_page: Option<u32>,
120}
121
122#[derive(Debug, Deserialize, schemars::JsonSchema)]
123pub struct ActionsParams {
124    #[schemars(description = "Repository owner (user or org)")]
125    #[serde(default)]
126    pub owner: Option<String>,
127
128    #[schemars(description = "Repository name")]
129    pub repo: String,
130
131    #[schemars(description = "Filter by status: completed, in_progress, queued")]
132    #[serde(default)]
133    pub status: Option<String>,
134
135    #[schemars(description = "Maximum number of results")]
136    #[serde(default)]
137    pub per_page: Option<u32>,
138}
139
140impl McpGithubServer {
141    pub fn new(
142        github: octocrab::Octocrab,
143        default_owner: Option<String>,
144        max_results: u32,
145    ) -> Self {
146        Self {
147            github: Arc::new(github),
148            default_owner,
149            max_results,
150            tool_router: Self::tool_router(),
151        }
152    }
153
154    fn resolve_owner(&self, param: Option<&str>) -> Result<String, McpGithubError> {
155        param
156            .map(String::from)
157            .or_else(|| self.default_owner.clone())
158            .ok_or_else(|| {
159                McpGithubError::MissingParam(
160                    "owner is required (or set --owner default)".to_string(),
161                )
162            })
163    }
164
165    /// Cap per_page to 100 (GitHub API maximum) and safely cast to u8.
166    fn capped_per_page(&self, per_page: Option<u32>) -> u8 {
167        std::cmp::min(per_page.unwrap_or(self.max_results), 100) as u8
168    }
169
170    fn err(&self, e: McpGithubError) -> ErrorData {
171        e.to_mcp_error()
172    }
173}
174
175/// Format an issue/PR state as a lowercase string.
176fn format_state(state: &octocrab::models::IssueState) -> &'static str {
177    match state {
178        octocrab::models::IssueState::Open => "open",
179        octocrab::models::IssueState::Closed => "closed",
180        _ => "unknown",
181    }
182}
183
184/// Validate that a GitHub owner/repo name doesn't contain characters that
185/// could be used for URL injection in raw API routes.
186fn sanitize_github_name(name: &str, field: &str) -> Result<(), McpGithubError> {
187    if name.is_empty() {
188        return Err(McpGithubError::MissingParam(format!(
189            "{} must not be empty",
190            field
191        )));
192    }
193    for ch in ['/', '?', '#', '%', '\0', ' ', '\n', '\t'] {
194        if name.contains(ch) {
195            return Err(McpGithubError::MissingParam(format!(
196                "{} contains invalid character '{}'",
197                field, ch
198            )));
199        }
200    }
201    Ok(())
202}
203
204// -- MCP tool handlers (thin wrappers calling do_* methods) --
205
206#[tool_router]
207impl McpGithubServer {
208    #[tool(
209        name = "list_repos",
210        description = "List repositories for a user or organization"
211    )]
212    async fn list_repos(
213        &self,
214        Parameters(params): Parameters<OwnerParam>,
215    ) -> Result<CallToolResult, ErrorData> {
216        let owner = self
217            .resolve_owner(params.owner.as_deref())
218            .map_err(|e| self.err(e))?;
219
220        let per_page = self.capped_per_page(None);
221
222        let page = self
223            .github
224            .orgs(&owner)
225            .list_repos()
226            .per_page(per_page)
227            .send()
228            .await;
229
230        // If org fails, try as user
231        let repos = match page {
232            Ok(page) => page.items,
233            Err(_) => {
234                self.github
235                    .users(&owner)
236                    .repos()
237                    .per_page(per_page)
238                    .send()
239                    .await
240                    .map_err(|e| self.err(McpGithubError::GitHub(e)))?
241                    .items
242            }
243        };
244
245        let results: Vec<serde_json::Value> = repos
246            .iter()
247            .map(|r| {
248                serde_json::json!({
249                    "name": r.name,
250                    "full_name": r.full_name.as_deref().unwrap_or(""),
251                    "description": r.description.as_deref().unwrap_or(""),
252                    "language": r.language.as_ref().map(|l| l.to_string()).unwrap_or_default(),
253                    "stars": r.stargazers_count.unwrap_or(0),
254                    "forks": r.forks_count.unwrap_or(0),
255                    "private": r.private.unwrap_or(false),
256                })
257            })
258            .collect();
259
260        let text = serde_json::to_string_pretty(&serde_json::json!({
261            "owner": owner,
262            "repos": results,
263            "count": results.len(),
264        }))
265        .unwrap_or_else(|_| "{}".to_string());
266        Ok(CallToolResult::success(vec![Content::text(text)]))
267    }
268
269    #[tool(
270        name = "get_repo",
271        description = "Get repository info including description, stars, forks, language, and default branch"
272    )]
273    async fn get_repo(
274        &self,
275        Parameters(params): Parameters<RepoParams>,
276    ) -> Result<CallToolResult, ErrorData> {
277        let owner = self
278            .resolve_owner(params.owner.as_deref())
279            .map_err(|e| self.err(e))?;
280
281        let repo = self
282            .github
283            .repos(&owner, &params.repo)
284            .get()
285            .await
286            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
287
288        let text = serde_json::to_string_pretty(&serde_json::json!({
289            "name": repo.name,
290            "full_name": repo.full_name,
291            "description": repo.description,
292            "language": repo.language,
293            "default_branch": repo.default_branch,
294            "stars": repo.stargazers_count,
295            "forks": repo.forks_count,
296            "open_issues": repo.open_issues_count,
297            "private": repo.private,
298            "created_at": repo.created_at.map(|t| t.to_string()),
299            "updated_at": repo.updated_at.map(|t| t.to_string()),
300        }))
301        .unwrap_or_else(|_| "{}".to_string());
302        Ok(CallToolResult::success(vec![Content::text(text)]))
303    }
304
305    #[tool(
306        name = "list_issues",
307        description = "List issues in a repository, optionally filtered by state and labels"
308    )]
309    async fn list_issues(
310        &self,
311        Parameters(params): Parameters<ListIssuesParams>,
312    ) -> Result<CallToolResult, ErrorData> {
313        let owner = self
314            .resolve_owner(params.owner.as_deref())
315            .map_err(|e| self.err(e))?;
316
317        let per_page = self.capped_per_page(params.per_page);
318
319        let issue_handler = self.github.issues(&owner, &params.repo);
320        let mut request = issue_handler.list().per_page(per_page);
321
322        if let Some(ref state) = params.state {
323            request = match state.as_str() {
324                "open" => request.state(octocrab::params::State::Open),
325                "closed" => request.state(octocrab::params::State::Closed),
326                "all" => request.state(octocrab::params::State::All),
327                _ => request,
328            };
329        }
330
331        let label_list: Vec<String>;
332        if let Some(ref labels) = params.labels {
333            label_list = labels.split(',').map(|s| s.trim().to_string()).collect();
334            request = request.labels(&label_list);
335        }
336
337        let issues = request
338            .send()
339            .await
340            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
341
342        let results: Vec<serde_json::Value> = issues
343            .items
344            .iter()
345            .map(|i| {
346                let labels: Vec<String> = i.labels.iter().map(|l| l.name.clone()).collect();
347                serde_json::json!({
348                    "number": i.number,
349                    "title": i.title,
350                    "state": format_state(&i.state),
351                    "author": i.user.login,
352                    "labels": labels,
353                    "comments": i.comments,
354                    "created_at": i.created_at.to_string(),
355                })
356            })
357            .collect();
358
359        let text = serde_json::to_string_pretty(&serde_json::json!({
360            "repo": format!("{}/{}", owner, params.repo),
361            "issues": results,
362            "count": results.len(),
363        }))
364        .unwrap_or_else(|_| "{}".to_string());
365        Ok(CallToolResult::success(vec![Content::text(text)]))
366    }
367
368    #[tool(
369        name = "get_issue",
370        description = "Get issue details including body and comments"
371    )]
372    async fn get_issue(
373        &self,
374        Parameters(params): Parameters<IssueParams>,
375    ) -> Result<CallToolResult, ErrorData> {
376        let owner = self
377            .resolve_owner(params.owner.as_deref())
378            .map_err(|e| self.err(e))?;
379
380        let issue = self
381            .github
382            .issues(&owner, &params.repo)
383            .get(params.issue_number)
384            .await
385            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
386
387        // Fetch comments
388        let comments = self
389            .github
390            .issues(&owner, &params.repo)
391            .list_comments(params.issue_number)
392            .send()
393            .await
394            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
395
396        let comment_items: Vec<serde_json::Value> = comments
397            .items
398            .iter()
399            .map(|c| {
400                serde_json::json!({
401                    "author": c.user.login,
402                    "body": c.body.as_deref().unwrap_or(""),
403                    "created_at": c.created_at.to_string(),
404                })
405            })
406            .collect();
407
408        let labels: Vec<String> = issue.labels.iter().map(|l| l.name.clone()).collect();
409
410        let text = serde_json::to_string_pretty(&serde_json::json!({
411            "number": issue.number,
412            "title": issue.title,
413            "state": format_state(&issue.state),
414            "author": issue.user.login,
415            "labels": labels,
416            "body": issue.body.as_deref().unwrap_or(""),
417            "comments": comment_items,
418            "created_at": issue.created_at.to_string(),
419        }))
420        .unwrap_or_else(|_| "{}".to_string());
421        Ok(CallToolResult::success(vec![Content::text(text)]))
422    }
423
424    #[tool(
425        name = "list_pulls",
426        description = "List pull requests in a repository"
427    )]
428    async fn list_pulls(
429        &self,
430        Parameters(params): Parameters<ListPullsParams>,
431    ) -> Result<CallToolResult, ErrorData> {
432        let owner = self
433            .resolve_owner(params.owner.as_deref())
434            .map_err(|e| self.err(e))?;
435
436        let per_page = self.capped_per_page(params.per_page);
437
438        let pulls_handler = self.github.pulls(&owner, &params.repo);
439        let mut request = pulls_handler.list().per_page(per_page);
440
441        if let Some(ref state) = params.state {
442            request = match state.as_str() {
443                "open" => request.state(octocrab::params::State::Open),
444                "closed" => request.state(octocrab::params::State::Closed),
445                "all" => request.state(octocrab::params::State::All),
446                _ => request,
447            };
448        }
449
450        let pulls = request
451            .send()
452            .await
453            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
454
455        let results: Vec<serde_json::Value> = pulls
456            .items
457            .iter()
458            .map(|p| {
459                serde_json::json!({
460                    "number": p.number,
461                    "title": p.title.as_deref().unwrap_or(""),
462                    "state": p.state.as_ref().map(format_state).unwrap_or("unknown"),
463                    "author": p.user.as_ref().map(|u| u.login.as_str()).unwrap_or("unknown"),
464                    "head": p.head.ref_field,
465                    "base": p.base.ref_field,
466                    "draft": p.draft,
467                    "created_at": p.created_at.map(|t| t.to_string()),
468                })
469            })
470            .collect();
471
472        let text = serde_json::to_string_pretty(&serde_json::json!({
473            "repo": format!("{}/{}", owner, params.repo),
474            "pulls": results,
475            "count": results.len(),
476        }))
477        .unwrap_or_else(|_| "{}".to_string());
478        Ok(CallToolResult::success(vec![Content::text(text)]))
479    }
480
481    #[tool(
482        name = "get_pull",
483        description = "Get pull request details including review summary and changed files count"
484    )]
485    async fn get_pull(
486        &self,
487        Parameters(params): Parameters<PullParams>,
488    ) -> Result<CallToolResult, ErrorData> {
489        let owner = self
490            .resolve_owner(params.owner.as_deref())
491            .map_err(|e| self.err(e))?;
492
493        let pr = self
494            .github
495            .pulls(&owner, &params.repo)
496            .get(params.pr_number)
497            .await
498            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
499
500        let text = serde_json::to_string_pretty(&serde_json::json!({
501            "number": pr.number,
502            "title": pr.title.as_deref().unwrap_or(""),
503            "state": pr.state.as_ref().map(format_state).unwrap_or("unknown"),
504            "author": pr.user.as_ref().map(|u| u.login.as_str()).unwrap_or("unknown"),
505            "body": pr.body.as_deref().unwrap_or(""),
506            "head": pr.head.ref_field,
507            "base": pr.base.ref_field,
508            "draft": pr.draft,
509            "mergeable": pr.mergeable,
510            "additions": pr.additions,
511            "deletions": pr.deletions,
512            "changed_files": pr.changed_files,
513            "commits": pr.commits,
514            "created_at": pr.created_at.map(|t| t.to_string()),
515            "merged_at": pr.merged_at.map(|t| t.to_string()),
516        }))
517        .unwrap_or_else(|_| "{}".to_string());
518        Ok(CallToolResult::success(vec![Content::text(text)]))
519    }
520
521    #[tool(
522        name = "search_code",
523        description = "Search code across GitHub repositories using GitHub's code search syntax"
524    )]
525    async fn search_code(
526        &self,
527        Parameters(params): Parameters<SearchCodeParams>,
528    ) -> Result<CallToolResult, ErrorData> {
529        let mut query = params.query.clone();
530
531        // Scope to owner/repo if specified
532        if let Some(ref owner) = params.owner.as_ref().or(self.default_owner.as_ref()) {
533            if let Some(ref repo) = params.repo {
534                query = format!("{} repo:{}/{}", query, owner, repo);
535            } else {
536                query = format!("{} org:{}", query, owner);
537            }
538        }
539
540        let per_page = self.capped_per_page(params.per_page);
541
542        let results = self
543            .github
544            .search()
545            .code(&query)
546            .per_page(per_page)
547            .send()
548            .await
549            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
550
551        let items: Vec<serde_json::Value> = results
552            .items
553            .iter()
554            .map(|item| {
555                serde_json::json!({
556                    "name": item.name,
557                    "path": item.path,
558                    "repository": item.repository.full_name.as_deref().unwrap_or(""),
559                    "url": item.html_url,
560                })
561            })
562            .collect();
563
564        let text = serde_json::to_string_pretty(&serde_json::json!({
565            "query": params.query,
566            "results": items,
567            "count": items.len(),
568        }))
569        .unwrap_or_else(|_| "{}".to_string());
570        Ok(CallToolResult::success(vec![Content::text(text)]))
571    }
572
573    #[tool(
574        name = "list_actions_runs",
575        description = "List recent GitHub Actions workflow runs for a repository"
576    )]
577    async fn list_actions_runs(
578        &self,
579        Parameters(params): Parameters<ActionsParams>,
580    ) -> Result<CallToolResult, ErrorData> {
581        let owner = self
582            .resolve_owner(params.owner.as_deref())
583            .map_err(|e| self.err(e))?;
584
585        // Validate owner and repo to prevent URL injection in raw route
586        sanitize_github_name(&owner, "owner").map_err(|e| self.err(e))?;
587        sanitize_github_name(&params.repo, "repo").map_err(|e| self.err(e))?;
588
589        let per_page = self.capped_per_page(params.per_page);
590
591        let route = format!(
592            "/repos/{}/{}/actions/runs?per_page={}",
593            owner, params.repo, per_page
594        );
595
596        let response: serde_json::Value = self
597            .github
598            .get(route, None::<&()>)
599            .await
600            .map_err(|e| self.err(McpGithubError::GitHub(e)))?;
601
602        let runs = response
603            .get("workflow_runs")
604            .and_then(|r| r.as_array())
605            .map(|arr| {
606                arr.iter()
607                    .map(|run| {
608                        serde_json::json!({
609                            "id": run.get("id"),
610                            "name": run.get("name"),
611                            "status": run.get("status"),
612                            "conclusion": run.get("conclusion"),
613                            "branch": run.get("head_branch"),
614                            "event": run.get("event"),
615                            "created_at": run.get("created_at"),
616                        })
617                    })
618                    .collect::<Vec<_>>()
619            })
620            .unwrap_or_default();
621
622        let text = serde_json::to_string_pretty(&serde_json::json!({
623            "repo": format!("{}/{}", owner, params.repo),
624            "runs": runs,
625            "count": runs.len(),
626        }))
627        .unwrap_or_else(|_| "{}".to_string());
628        Ok(CallToolResult::success(vec![Content::text(text)]))
629    }
630}
631
632#[tool_handler]
633impl ServerHandler for McpGithubServer {
634    fn get_info(&self) -> ServerInfo {
635        ServerInfo {
636            protocol_version: ProtocolVersion::V_2024_11_05,
637            capabilities: ServerCapabilities::builder().enable_tools().build(),
638            server_info: Implementation {
639                name: "mcp-github".to_string(),
640                version: env!("CARGO_PKG_VERSION").to_string(),
641                ..Default::default()
642            },
643            instructions: Some(
644                "GitHub server. Use list_repos to see repositories, get_repo for repo details, \
645                 list_issues and get_issue for issues, list_pulls and get_pull for PRs, \
646                 search_code to search code, and list_actions_runs for CI/CD runs."
647                    .to_string(),
648            ),
649        }
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    fn make_server(default_owner: Option<String>, max_results: u32) -> McpGithubServer {
658        let github = octocrab::Octocrab::default();
659        McpGithubServer::new(github, default_owner, max_results)
660    }
661
662    // Note: Octocrab::default() requires a Tokio runtime (tower::Buffer),
663    // so these tests must be async even though they don't await anything.
664
665    #[tokio::test]
666    async fn test_resolve_owner_with_param() {
667        let server = make_server(None, 30);
668        let result = server.resolve_owner(Some("my-org"));
669        assert_eq!(result.unwrap(), "my-org");
670    }
671
672    #[tokio::test]
673    async fn test_resolve_owner_with_default() {
674        let server = make_server(Some("default-org".to_string()), 30);
675        let result = server.resolve_owner(None);
676        assert_eq!(result.unwrap(), "default-org");
677    }
678
679    #[tokio::test]
680    async fn test_resolve_owner_param_overrides_default() {
681        let server = make_server(Some("default-org".to_string()), 30);
682        let result = server.resolve_owner(Some("explicit-org"));
683        assert_eq!(result.unwrap(), "explicit-org");
684    }
685
686    #[tokio::test]
687    async fn test_resolve_owner_missing() {
688        let server = make_server(None, 30);
689        let result = server.resolve_owner(None);
690        assert!(result.is_err());
691    }
692
693    #[tokio::test]
694    async fn test_capped_per_page_default() {
695        let server = make_server(None, 30);
696        assert_eq!(server.capped_per_page(None), 30);
697    }
698
699    #[tokio::test]
700    async fn test_capped_per_page_explicit() {
701        let server = make_server(None, 30);
702        assert_eq!(server.capped_per_page(Some(50)), 50);
703    }
704
705    #[tokio::test]
706    async fn test_capped_per_page_caps_at_100() {
707        let server = make_server(None, 30);
708        assert_eq!(server.capped_per_page(Some(200)), 100);
709        assert_eq!(server.capped_per_page(Some(1000)), 100);
710    }
711
712    #[tokio::test]
713    async fn test_capped_per_page_max_results_capped() {
714        // Even if max_results is set high, it should be capped at 100
715        let server = make_server(None, 500);
716        assert_eq!(server.capped_per_page(None), 100);
717    }
718
719    #[test]
720    fn test_sanitize_github_name_valid() {
721        assert!(sanitize_github_name("my-org", "owner").is_ok());
722        assert!(sanitize_github_name("user_name", "owner").is_ok());
723        assert!(sanitize_github_name("repo.name", "repo").is_ok());
724    }
725
726    #[test]
727    fn test_sanitize_github_name_empty() {
728        assert!(sanitize_github_name("", "owner").is_err());
729    }
730
731    #[test]
732    fn test_sanitize_github_name_slash() {
733        assert!(sanitize_github_name("owner/repo", "owner").is_err());
734        assert!(sanitize_github_name("../etc", "owner").is_err());
735    }
736
737    #[test]
738    fn test_sanitize_github_name_query() {
739        assert!(sanitize_github_name("owner?evil=1", "owner").is_err());
740        assert!(sanitize_github_name("repo#fragment", "repo").is_err());
741    }
742
743    #[test]
744    fn test_sanitize_github_name_whitespace() {
745        assert!(sanitize_github_name("my repo", "repo").is_err());
746        assert!(sanitize_github_name("my\nrepo", "repo").is_err());
747    }
748}