1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use hyper_util::{
6 rt::{TokioExecutor, TokioIo},
7 server::conn::auto::Builder,
8 service::TowerToHyperService,
9};
10use nu_protocol::{ShellError, engine::EngineState};
11use rmcp::{
12 ServiceExt,
13 transport::{
14 stdio,
15 streamable_http_server::{
16 StreamableHttpServerConfig, StreamableHttpService,
17 session::local::{LocalSessionManager, SessionConfig},
18 },
19 },
20};
21use server::NushellMcpServer;
22use tokio::runtime::Runtime;
23use tokio::sync::RwLock;
24use tokio_util::sync::CancellationToken;
25use tracing_subscriber::EnvFilter;
26
27mod evaluation;
28mod history;
29mod server;
30
31#[derive(Debug, Clone, Default)]
33pub enum McpTransport {
34 #[default]
36 Stdio,
37 Http {
39 port: u16,
41 },
42}
43
44pub fn initialize_mcp_server(
45 mut engine_state: EngineState,
46 transport: McpTransport,
47) -> Result<(), ShellError> {
48 tracing_subscriber::fmt()
49 .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
50 .with_writer(std::io::stderr)
51 .with_ansi(false)
52 .init();
53
54 #[cfg(unix)]
63 {
64 let _ = nix::unistd::setsid();
66 }
67
68 engine_state.is_mcp = true;
71
72 tracing::info!(?transport, "Starting MCP server");
73 let runtime = Runtime::new().map_err(|e| ShellError::GenericError {
74 error: format!("Could not instantiate tokio: {e}"),
75 msg: "".into(),
76 span: None,
77 help: None,
78 inner: vec![],
79 })?;
80
81 runtime.block_on(async {
82 let result = match transport {
83 McpTransport::Stdio => run_stdio_server(engine_state).await,
84 McpTransport::Http { port } => run_http_server(engine_state, port).await,
85 };
86 if let Err(e) = result {
87 tracing::error!("Error running MCP server: {:?}", e);
88 }
89 });
90 Ok(())
91}
92
93async fn run_stdio_server(engine_state: EngineState) -> Result<(), Box<dyn std::error::Error>> {
94 NushellMcpServer::new(engine_state)
95 .serve(stdio())
96 .await
97 .inspect_err(|e| {
98 tracing::error!("serving error: {:?}", e);
99 })?
100 .waiting()
101 .await?;
102 Ok(())
103}
104
105const SESSION_KEEP_ALIVE: Duration = Duration::from_secs(30 * 60);
107
108const SESSION_CHANNEL_CAPACITY: usize = 16;
110
111const SSE_KEEP_ALIVE: Duration = Duration::from_secs(15);
113
114const SSE_RETRY: Duration = Duration::from_secs(3);
116
117async fn run_http_server(
118 engine_state: EngineState,
119 port: u16,
120) -> Result<(), Box<dyn std::error::Error>> {
121 let engine_state = Arc::new(engine_state);
122
123 let cancellation_token = CancellationToken::new();
125
126 let session_manager = Arc::new(LocalSessionManager {
127 sessions: RwLock::new(HashMap::new()),
128 session_config: SessionConfig {
129 channel_capacity: SESSION_CHANNEL_CAPACITY,
130 keep_alive: Some(SESSION_KEEP_ALIVE),
131 },
132 });
133
134 let service = TowerToHyperService::new(StreamableHttpService::new(
135 {
136 let engine_state = engine_state.clone();
137 move || Ok(NushellMcpServer::new((*engine_state).clone()))
138 },
139 session_manager,
140 StreamableHttpServerConfig {
141 sse_keep_alive: Some(SSE_KEEP_ALIVE),
142 sse_retry: Some(SSE_RETRY),
143 stateful_mode: true,
144 cancellation_token: cancellation_token.clone(),
145 },
146 ));
147
148 let addr = format!("0.0.0.0:{port}");
149 let listener = tokio::net::TcpListener::bind(&addr).await?;
150 tracing::info!("MCP HTTP server listening on http://{addr}");
151 eprintln!("MCP HTTP server listening on http://{addr}");
152
153 loop {
154 let io = tokio::select! {
155 _ = tokio::signal::ctrl_c() => {
156 tracing::info!("Received Ctrl-C, shutting down...");
157 cancellation_token.cancel();
158 break;
159 }
160 accept = listener.accept() => {
161 TokioIo::new(accept?.0)
162 }
163 };
164 let service = service.clone();
165 tokio::spawn(async move {
166 let _ = Builder::new(TokioExecutor::new())
167 .serve_connection(io, service)
168 .await;
169 });
170 }
171 Ok(())
172}