gitai/server/tools/
mod.rs

1//! MCP tools module
2//!
3//! This module contains the implementation of the MCP tools
4//! that expose Pilot functionality to MCP clients.
5
6pub mod changelog;
7pub mod codereview;
8pub mod commit;
9pub mod pr;
10pub mod releasenotes;
11pub mod utils;
12
13use crate::config::Config as PilotConfig;
14use crate::debug;
15use crate::git::GitRepo;
16use crate::server::tools::utils::PilotTool;
17
18use rmcp::ErrorData as Error;
19use rmcp::RoleServer;
20use rmcp::model::{
21    CallToolRequestParam, CallToolResult, ListToolsResult, PaginatedRequestParam,
22    ServerCapabilities, Tool,
23};
24use rmcp::service::NotificationContext;
25use rmcp::service::RequestContext;
26use rmcp::{ServerHandler, model::ServerInfo};
27
28use serde_json::{Map, Value};
29use std::future::Future;
30use std::path::PathBuf;
31use std::sync::Arc;
32use std::sync::Mutex;
33
34// Re-export all tools for easy importing
35pub use self::changelog::ChangelogTool;
36pub use self::codereview::CodeReviewTool;
37pub use self::commit::CommitTool;
38pub use self::pr::PrTool;
39pub use self::releasenotes::ReleaseNotesTool;
40
41// Define tools for Pilot toolbox
42#[derive(Debug)]
43pub enum PilotTools {
44    ReleaseNotesTool(ReleaseNotesTool),
45    ChangelogTool(ChangelogTool),
46    CommitTool(CommitTool),
47    CodeReviewTool(CodeReviewTool),
48    PrTool(PrTool),
49}
50
51impl PilotTools {
52    /// Get all tools available
53    pub fn get_tools() -> Vec<Tool> {
54        vec![
55            ReleaseNotesTool::get_tool_definition(),
56            ChangelogTool::get_tool_definition(),
57            CommitTool::get_tool_definition(),
58            CodeReviewTool::get_tool_definition(),
59            PrTool::get_tool_definition(),
60        ]
61    }
62
63    /// Try to convert a parameter map into a `PilotTools` enum
64    pub fn try_from(params: Map<String, Value>) -> Result<Self, Error> {
65        // Check the tool name and convert to the appropriate variant
66        let tool_name = params
67            .get("name")
68            .and_then(|v| v.as_str())
69            .ok_or_else(|| Error::invalid_params("Tool name not specified", None))?;
70
71        match tool_name {
72            "gitai_release_notes" => {
73                // Convert params to ReleaseNotesTool
74                let tool: ReleaseNotesTool = serde_json::from_value(Value::Object(params))
75                    .map_err(|e| Error::invalid_params(format!("Invalid parameters: {e}"), None))?;
76                Ok(PilotTools::ReleaseNotesTool(tool))
77            }
78            "gitai_changelog" => {
79                // Convert params to ChangelogTool
80                let tool: ChangelogTool = serde_json::from_value(Value::Object(params))
81                    .map_err(|e| Error::invalid_params(format!("Invalid parameters: {e}"), None))?;
82                Ok(PilotTools::ChangelogTool(tool))
83            }
84            "gitai_commit" => {
85                // Convert params to CommitTool
86                let tool: CommitTool = serde_json::from_value(Value::Object(params))
87                    .map_err(|e| Error::invalid_params(format!("Invalid parameters: {e}"), None))?;
88                Ok(PilotTools::CommitTool(tool))
89            }
90            "gitai_review" => {
91                // Convert params to CodeReviewTool
92                let tool: CodeReviewTool = serde_json::from_value(Value::Object(params))
93                    .map_err(|e| Error::invalid_params(format!("Invalid parameters: {e}"), None))?;
94                Ok(PilotTools::CodeReviewTool(tool))
95            }
96            "gitai_pr" => {
97                // Convert params to PrTool
98                let tool: PrTool = serde_json::from_value(Value::Object(params))
99                    .map_err(|e| Error::invalid_params(format!("Invalid parameters: {e}"), None))?;
100                Ok(PilotTools::PrTool(tool))
101            }
102            _ => Err(Error::invalid_params(
103                format!("Unknown tool: {tool_name}"),
104                None,
105            )),
106        }
107    }
108}
109
110pub fn handle_tool_error(e: &anyhow::Error) -> Error {
111    Error::invalid_params(format!("Tool execution failed: {e}"), None)
112}
113
114/// The main handler for Pilot, providing all MCP tools
115#[derive(Clone)]
116pub struct PilotHandler {
117    /// Git repository instance
118    pub git_repo: Arc<GitRepo>,
119    /// Pilot configuration
120    pub config: PilotConfig,
121    /// Workspace roots registered by the client
122    pub workspace_roots: Arc<Mutex<Vec<PathBuf>>>,
123}
124
125impl PilotHandler {
126    /// Create new handler with the provided dependencies
127    pub fn new(git_repo: Arc<GitRepo>, config: PilotConfig) -> Self {
128        Self {
129            git_repo,
130            config,
131            workspace_roots: Arc::new(Mutex::new(Vec::new())),
132        }
133    }
134
135    /// Get the current workspace root, if available
136    pub fn get_workspace_root(&self) -> Option<PathBuf> {
137        let roots = self
138            .workspace_roots
139            .lock()
140            .expect("Failed to lock workspace roots mutex");
141        // Use the first workspace root if available
142        roots.first().cloned()
143    }
144}
145
146impl ServerHandler for PilotHandler {
147    fn get_info(&self) -> ServerInfo {
148        ServerInfo {
149            instructions: Some("`gitai` is an AI-powered Git workflow assistant. You can use it to generate commit messages, review code, create changelogs and release notes.".to_string()),
150            capabilities: ServerCapabilities::builder()
151                .enable_tools()
152                .build(),
153            ..Default::default()
154        }
155    }
156
157    // Handle notification when client workspace roots change
158    fn on_roots_list_changed(
159        &self,
160        _context: NotificationContext<RoleServer>,
161    ) -> impl Future<Output = ()> + Send + '_ {
162        debug!("Client workspace roots changed");
163        async move {
164            // Access and update workspace roots
165            let roots = self
166                .workspace_roots
167                .lock()
168                .expect("Failed to lock workspace roots mutex");
169
170            // If we have a workspace root, log it
171            if let Some(root) = roots.first() {
172                debug!("Primary workspace root: {}", root.display());
173            } else {
174                debug!("No workspace roots provided by client");
175            }
176
177            // If this is a development log, print more information
178            if roots.len() > 1 {
179                for (i, root) in roots.iter().skip(1).enumerate() {
180                    debug!("Additional workspace root {}: {}", i + 1, root.display());
181                }
182            }
183        }
184    }
185
186    async fn list_tools(
187        &self,
188        _: Option<PaginatedRequestParam>,
189        _: RequestContext<RoleServer>,
190    ) -> Result<ListToolsResult, Error> {
191        // Use our custom method to get all tools
192        let tools = PilotTools::get_tools();
193
194        Ok(ListToolsResult {
195            next_cursor: None,
196            tools,
197        })
198    }
199
200    async fn call_tool(
201        &self,
202        request: CallToolRequestParam,
203        _: RequestContext<RoleServer>,
204    ) -> Result<CallToolResult, Error> {
205        // Get the arguments as a Map
206        let args = match &request.arguments {
207            Some(args) => args.clone(),
208            None => {
209                return Err(Error::invalid_params(
210                    String::from("Missing arguments"),
211                    None,
212                ));
213            }
214        };
215
216        // Add the tool name to the parameters
217        let mut params = args.clone();
218        params.insert("name".to_string(), Value::String(request.name.to_string()));
219
220        let tool_params = PilotTools::try_from(params)?;
221
222        // Make a clone of the repository path before executing the tool
223        // This prevents git2 objects from crossing async boundaries
224        let git_repo_path = self.git_repo.repo_path().clone();
225
226        // Clone config to avoid sharing it across async boundaries
227        let config = self.config.clone();
228
229        // Create a new git repo instance - handle any errors here before async code
230        let git_repo = match GitRepo::new(&git_repo_path) {
231            Ok(repo) => Arc::new(repo),
232            Err(e) => return Err(handle_tool_error(&e)),
233        };
234
235        // Use the PilotTool trait to execute any tool without matching on specific types
236        match tool_params {
237            PilotTools::ReleaseNotesTool(tool) => tool
238                .execute(git_repo.clone(), config.clone())
239                .await
240                .map_err(|e| handle_tool_error(&e)),
241            PilotTools::ChangelogTool(tool) => tool
242                .execute(git_repo.clone(), config.clone())
243                .await
244                .map_err(|e| handle_tool_error(&e)),
245            PilotTools::CommitTool(tool) => tool
246                .execute(git_repo.clone(), config.clone())
247                .await
248                .map_err(|e| handle_tool_error(&e)),
249            PilotTools::CodeReviewTool(tool) => tool
250                .execute(git_repo.clone(), config)
251                .await
252                .map_err(|e| handle_tool_error(&e)),
253            PilotTools::PrTool(tool) => tool
254                .execute(git_repo, config)
255                .await
256                .map_err(|e| handle_tool_error(&e)),
257        }
258    }
259}