use aptu_coder::{
CodeAnalyzer,
logging::McpLoggingLayer,
metrics::{MetricEvent, MetricsSender, MetricsWriter},
};
use rmcp::serve_server;
use rmcp::transport::stdio;
use rmcp::transport::streamable_http_server::session::never::NeverSessionManager;
use rmcp::transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService};
use rustls::crypto::aws_lc_rs;
use std::sync::{Arc, Mutex};
use tokio::sync::Mutex as TokioMutex;
use tokio_util::sync::CancellationToken;
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
mod otel;
async fn run_http(analyzer: CodeAnalyzer, port: u16) -> Result<(), Box<dyn std::error::Error>> {
let ct = CancellationToken::new();
let ct_signal = ct.clone();
tokio::spawn(async move {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm =
signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
tokio::select! {
_ = tokio::signal::ctrl_c() => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
ct_signal.cancel();
});
let config = StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true)
.with_sse_keep_alive(None)
.with_sse_retry(None)
.with_cancellation_token(ct.child_token());
let service: StreamableHttpService<CodeAnalyzer, NeverSessionManager> =
StreamableHttpService::new(
move || Ok(analyzer.clone()),
Arc::new(NeverSessionManager::default()),
config,
);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await?;
eprintln!("Listening on http://127.0.0.1:{port}/mcp");
axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled().await })
.await?;
Ok(())
}
fn parse_port(s: &str) -> Result<u16, String> {
match s.parse::<u16>() {
Ok(0) => Err("must be a non-zero u16 value".to_string()),
Ok(p) => Ok(p),
Err(_) => Err(format!("{s:?} is not a valid u16 value")),
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
aws_lc_rs::default_provider()
.install_default()
.expect("failed to install rustls CryptoProvider");
let mut port: Option<u16> = None;
let mut args = std::env::args();
while let Some(arg) = args.next() {
match arg.as_str() {
"--version" => {
println!("{}", env!("CARGO_PKG_VERSION"));
return Ok(());
}
"--port" => match args.next() {
Some(val) => match parse_port(&val) {
Ok(p) => port = Some(p),
Err(msg) => {
eprintln!("error: --port {msg}");
std::process::exit(1);
}
},
None => {
eprintln!("error: --port requires a value");
std::process::exit(1);
}
},
_ => {}
}
}
if port.is_none()
&& let Ok(val) = std::env::var("APTU_CODER_PORT")
{
match parse_port(&val) {
Ok(p) => port = Some(p),
Err(msg) => {
eprintln!("error: APTU_CODER_PORT {msg}");
std::process::exit(1);
}
}
}
let otel_provider = otel::init_otel();
let log_provider = otel::init_log_appender();
let meter_provider = otel::init_meter();
if let Err(e) = aptu_coder::metrics::migrate_legacy_metrics_dir() {
tracing::warn!("Failed to migrate legacy metrics directory: {e}");
}
let peer = Arc::new(TokioMutex::new(None));
let log_level_filter = Arc::new(Mutex::new(LevelFilter::WARN));
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel();
let mcp_logging_layer = McpLoggingLayer::new(event_tx, log_level_filter.clone());
use opentelemetry::trace::TracerProvider as _;
let otel_trace_layer = otel_provider
.as_ref()
.map(|p| tracing_opentelemetry::layer().with_tracer(p.tracer("aptu-coder")));
let otel_log_layer = log_provider
.as_ref()
.map(opentelemetry_appender_tracing::layer::OpenTelemetryTracingBridge::new);
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(mcp_logging_layer)
.with(otel_trace_layer)
.with(otel_log_layer)
.init();
let (metrics_tx, metrics_rx) = tokio::sync::mpsc::unbounded_channel::<MetricEvent>();
tokio::spawn(MetricsWriter::new(metrics_rx, None).run());
let analyzer = CodeAnalyzer::new(peer, log_level_filter, event_rx, MetricsSender(metrics_tx));
if let Some(p) = port {
run_http(analyzer, p).await?;
} else {
let (stdin, stdout) = stdio();
let service = serve_server(analyzer, (stdin, stdout)).await?;
service.waiting().await?;
}
if let Some(provider) = otel_provider
&& let Err(e) = provider.shutdown()
{
tracing::warn!("Failed to shutdown OpenTelemetry trace provider: {e}");
}
if let Some(log_prov) = log_provider
&& let Err(e) = log_prov.shutdown()
{
tracing::warn!("Failed to shutdown OpenTelemetry log provider: {e}");
}
if let Some(meter_prov) = meter_provider
&& let Err(e) = meter_prov.shutdown()
{
tracing::warn!("Failed to shutdown OpenTelemetry meter provider: {e}");
}
Ok(())
}