gitai/server/
serve.rs

1//! MCP server implementation
2//!
3//! This module contains the implementation of the MCP server
4//! that allows `gitai` to be used directly from compatible tools.
5
6use crate::config::Config as GitAIConfig;
7use crate::debug;
8use crate::git::GitRepo;
9use crate::server::config::{MCPServerConfig, MCPTransportType};
10use crate::server::tools::PilotHandler;
11
12use anyhow::{Context, Result};
13use rmcp::ServiceExt;
14use rmcp::transport::sse_server::SseServer;
15use std::net::SocketAddr;
16use std::sync::Arc;
17use tokio::io::{stdin, stdout};
18
19/// Serve the MCP server with the provided configuration
20pub async fn serve(config: MCPServerConfig) -> Result<()> {
21    // Configure logging based on transport type and dev mode
22    if config.dev_mode {
23        // In dev mode, set up appropriate logging
24        let log_path = format!("gitai-mcp-{}.log", std::process::id());
25        if let Err(e) = crate::logger::set_log_file(&log_path) {
26            // For non-stdio transports, we can print this error
27            if config.transport != MCPTransportType::StdIO {
28                eprintln!("Failed to set up log file: {e}");
29            }
30            // Continue without file logging
31        }
32
33        // For stdio transport, we must NEVER log to stdout
34        if config.transport == MCPTransportType::StdIO {
35            crate::logger::set_log_to_stdout(false);
36        } else {
37            crate::logger::set_log_to_stdout(true);
38        }
39
40        crate::logger::enable_logging();
41    }
42
43    debug!("Starting MCP server with config: {:?}", config);
44
45    // Display configuration info if not using stdio transport
46    if config.transport != MCPTransportType::StdIO {
47        use crate::ui;
48        ui::print_info(&format!(
49            "Starting gitai MCP server with {:?} transport",
50            config.transport
51        ));
52        if let Some(port) = config.port {
53            ui::print_info(&format!("Port: {port}"));
54        }
55        if let Some(addr) = &config.listen_address {
56            ui::print_info(&format!("Listening on: {addr}"));
57        }
58        ui::print_info(&format!(
59            "Development mode: {}",
60            if config.dev_mode {
61                "Enabled"
62            } else {
63                "Disabled"
64            }
65        ));
66    }
67
68    // Initialize GitRepo for use with tools
69    let git_repo = Arc::new(GitRepo::new_from_url(None)?);
70    debug!(
71        "Initialized Git repository at: {}",
72        git_repo.repo_path().display()
73    );
74
75    let pilot_config = GitAIConfig::load()?;
76    debug!("Loaded gitai configuration");
77
78    // Create the handler with necessary dependencies
79    let handler = PilotHandler::new(git_repo, pilot_config);
80
81    // Start the appropriate transport
82    match config.transport {
83        MCPTransportType::StdIO => serve_stdio(handler, config.dev_mode).await,
84        MCPTransportType::SSE => {
85            // Get socket address for the server
86            let socket_addr = get_socket_addr(&config)?;
87            serve_sse(handler, socket_addr).await
88        }
89    }
90}
91
92/// Start the MCP server using `StdIO` transport
93async fn serve_stdio(handler: PilotHandler, _dev_mode: bool) -> Result<()> {
94    debug!("Starting MCP server with StdIO transport");
95
96    let transport = (stdin(), stdout());
97
98    let server = handler.serve(transport).await?;
99
100    // Wait for the server to finish
101    debug!("MCP server initialized, waiting for completion");
102    let quit_reason = server.waiting().await?;
103    debug!("MCP server finished: {:?}", quit_reason);
104
105    Ok(())
106}
107
108/// Start the MCP server using SSE transport
109async fn serve_sse(handler: PilotHandler, socket_addr: SocketAddr) -> Result<()> {
110    debug!("Starting MCP server with SSE transport on {}", socket_addr);
111
112    // Create and start the SSE server
113    let server = SseServer::serve(socket_addr).await?;
114
115    // Set up the service with our handler
116    let control = server.with_service(move || {
117        // Return a clone of the handler directly as it implements ServerHandler
118        handler.clone()
119    });
120
121    // Wait for Ctrl+C signal
122    debug!("SSE server initialized, waiting for interrupt signal");
123    tokio::signal::ctrl_c()
124        .await
125        .context("Failed to listen for ctrl+c signal")?;
126
127    // Cancel the server gracefully
128    debug!("Interrupt signal received, shutting down SSE server");
129    control.cancel();
130
131    Ok(())
132}
133
134/// Helper function to get a socket address from the configuration
135fn get_socket_addr(config: &MCPServerConfig) -> Result<SocketAddr> {
136    // Get listen address, or use default
137    let listen_address = config.listen_address.as_deref().unwrap_or("127.0.0.1");
138    let port = config.port.context("Port is required for SSE transport")?;
139
140    // Parse the socket address
141    let socket_addr: SocketAddr = format!("{listen_address}:{port}")
142        .parse()
143        .context("Failed to parse socket address")?;
144
145    Ok(socket_addr)
146}