use std::path::PathBuf;
use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::http::{Response, StatusCode, header};
use axum::response::IntoResponse;
use axum::routing::get;
use rust_embed::RustEmbed;
use tokio::sync::{RwLock, broadcast};
use tower_http::cors::CorsLayer;
use crate::graph::CodeGraph;
use super::{api, ws};
#[derive(Clone)]
pub struct AppState {
pub graph: Arc<RwLock<CodeGraph>>,
pub project_root: PathBuf,
pub ws_tx: broadcast::Sender<String>,
pub auth_token: String,
#[cfg(feature = "rag")]
pub vector_store: Arc<RwLock<Option<crate::rag::vector_store::VectorStore>>>,
#[cfg(feature = "rag")]
pub embedding_engine: Arc<Option<crate::rag::embedding::EmbeddingEngine>>,
#[cfg(feature = "rag")]
pub session_store: Arc<tokio::sync::Mutex<crate::rag::session::SessionStore>>,
#[cfg(feature = "rag")]
pub auth_state: Arc<RwLock<crate::rag::auth::AuthState>>,
#[cfg(feature = "rag")]
pub pkce_state: Arc<tokio::sync::Mutex<crate::web::api::auth::PkceState>>,
}
#[derive(RustEmbed)]
#[folder = "web/dist/"]
struct WebAssets;
const GRAPH_UPDATED_MSG: &str = r#"{"type":"graph_updated"}"#;
async fn auth_middleware(
axum::extract::State(state): axum::extract::State<AppState>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
if request.method() == axum::http::Method::OPTIONS {
return next.run(request).await;
}
let auth_header = request.headers().get(axum::http::header::AUTHORIZATION);
let expected = format!("Bearer {}", state.auth_token);
match auth_header.and_then(|v| v.to_str().ok()) {
Some(value) if value == expected => next.run(request).await,
_ => (StatusCode::UNAUTHORIZED, "Invalid or missing auth token").into_response(),
}
}
pub fn build_router(state: AppState, port: u16) -> Router {
let api_router = Router::new()
.route("/api/graph", get(api::graph::handler))
.route("/api/file", get(api::file::handler))
.route("/api/search", get(api::search::handler))
.route("/api/stats", get(api::stats::handler));
#[cfg(feature = "rag")]
let api_router = api_router
.route("/api/chat", axum::routing::post(api::chat::handler))
.route(
"/api/auth/status",
axum::routing::get(api::auth::status_handler),
)
.route(
"/api/auth/key",
axum::routing::post(api::auth::set_key_handler),
)
.route(
"/api/auth/provider",
axum::routing::post(api::auth::set_provider_handler),
)
.route(
"/api/auth/oauth/start",
axum::routing::get(api::auth::oauth_start_handler),
)
.route(
"/api/auth/oauth/callback",
axum::routing::get(api::auth::oauth_callback_handler),
)
.route(
"/api/ollama/models",
axum::routing::get(api::auth::ollama_models_handler),
);
let api_router = api_router.layer(axum::middleware::from_fn_with_state(
state.clone(),
auth_middleware,
));
let router = Router::new()
.merge(api_router)
.route("/ws", get(ws::handler))
.fallback(serve_asset);
let router = router.layer(axum::middleware::from_fn(security_headers));
let origin = format!("http://127.0.0.1:{port}");
let cors = CorsLayer::new()
.allow_origin(origin.parse::<axum::http::HeaderValue>().unwrap())
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
]);
router.layer(cors).with_state(state)
}
async fn security_headers(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
use axum::http::HeaderValue;
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"Content-Security-Policy",
HeaderValue::from_static(
"default-src 'self'; style-src 'self' 'unsafe-inline'; script-src 'self'",
),
);
headers.insert(
"X-Content-Type-Options",
HeaderValue::from_static("nosniff"),
);
headers.insert("X-Frame-Options", HeaderValue::from_static("DENY"));
headers.insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
response
}
fn generate_auth_token() -> String {
use std::io::Read;
let mut bytes = [0u8; 16];
std::fs::File::open("/dev/urandom")
.expect("failed to open /dev/urandom")
.read_exact(&mut bytes)
.expect("failed to read random bytes");
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
async fn serve_asset(uri: axum::http::Uri) -> impl IntoResponse {
let path = uri.path().trim_start_matches('/');
if let Some(content) = WebAssets::get(path) {
let mime = mime_guess::from_path(path).first_or_octet_stream();
let body = match content.data {
std::borrow::Cow::Borrowed(bytes) => Body::from(bytes),
std::borrow::Cow::Owned(vec) => Body::from(vec),
};
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime.as_ref())
.body(body)
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
})
} else {
if let Some(index) = WebAssets::get("index.html") {
let body = match index.data {
std::borrow::Cow::Borrowed(bytes) => Body::from(bytes),
std::borrow::Cow::Owned(vec) => Body::from(vec),
};
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(body)
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
})
} else {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not Found"))
.unwrap()
}
}
}
#[allow(unused_variables)]
pub async fn serve(root: PathBuf, port: u16, ollama: bool) -> anyhow::Result<()> {
eprintln!("Indexing {}...", root.display());
let mut graph = crate::build_graph(&root, false)?;
eprintln!(
"Indexed {} files, {} symbols.",
graph.file_count(),
graph.symbol_count()
);
graph.rebuild_bm25_index();
let (ws_tx, _ws_rx) = broadcast::channel::<String>(64);
#[cfg(feature = "rag")]
let (vector_store, embedding_engine, session_store, auth_state) = {
let cache_dir = root.join(".code-graph");
let vs = match crate::rag::vector_store::VectorStore::load(&cache_dir, 384) {
Ok(vs) => {
eprintln!("[rag] Loaded vector index: {} symbols", vs.len());
Some(vs)
}
Err(_) => {
eprintln!(
"[rag] No vector index found. Run 'code-graph index' with --features rag to build embeddings."
);
None
}
};
let vector_store = Arc::new(RwLock::new(vs));
let engine = match crate::rag::embedding::EmbeddingEngine::try_new() {
Ok(e) => {
eprintln!("[rag] Embedding engine initialized.");
Some(e)
}
Err(e) => {
eprintln!(
"[rag] Embedding engine unavailable (queries will use structural retrieval only): {}",
e
);
None
}
};
let embedding_engine = Arc::new(engine);
let session_store = Arc::new(tokio::sync::Mutex::new(
crate::rag::session::SessionStore::new(100),
));
let provider = if ollama {
crate::rag::auth::LlmProvider::Ollama {
host: "http://localhost:11434".to_string(),
model: "llama3.2".to_string(),
}
} else {
let api_key = crate::rag::auth::resolve_api_key().unwrap_or_default();
crate::rag::auth::LlmProvider::Claude { api_key }
};
let auth_state = Arc::new(RwLock::new(crate::rag::auth::AuthState { provider }));
(vector_store, embedding_engine, session_store, auth_state)
};
let auth_token = generate_auth_token();
let state = AppState {
graph: Arc::new(RwLock::new(graph)),
project_root: root.clone(),
ws_tx: ws_tx.clone(),
auth_token: auth_token.clone(),
#[cfg(feature = "rag")]
vector_store,
#[cfg(feature = "rag")]
embedding_engine,
#[cfg(feature = "rag")]
session_store,
#[cfg(feature = "rag")]
auth_state,
#[cfg(feature = "rag")]
pkce_state: Arc::new(tokio::sync::Mutex::new(
crate::web::api::auth::PkceState::new(),
)),
};
let watcher_graph = Arc::clone(&state.graph);
let watcher_root = root.clone();
let watcher_tx = ws_tx.clone();
#[cfg(feature = "rag")]
let watcher_vector_store = Arc::clone(&state.vector_store);
#[cfg(feature = "rag")]
let watcher_embedding_engine = Arc::clone(&state.embedding_engine);
match crate::watcher::start_watcher(&watcher_root) {
Ok((_handle, std_rx)) => {
let (bridge_tx, mut bridge_rx) =
tokio::sync::mpsc::channel::<crate::watcher::event::WatchEvent>(256);
tokio::task::spawn_blocking(move || {
while let Ok(event) = std_rx.recv() {
if bridge_tx.blocking_send(event).is_err() {
return; }
}
});
let _watcher_handle = _handle;
tokio::spawn(async move {
while let Some(event) = bridge_rx.recv().await {
#[cfg(feature = "rag")]
let event_file_path: Option<String> = match &event {
crate::watcher::event::WatchEvent::Modified(p) => {
Some(p.to_string_lossy().to_string())
}
crate::watcher::event::WatchEvent::Deleted(p) => {
Some(p.to_string_lossy().to_string())
}
_ => None,
};
{
let mut graph = watcher_graph.write().await;
crate::watcher::incremental::handle_file_event(
&mut graph,
&event,
&watcher_root,
);
}
#[cfg(feature = "rag")]
if let Some(file_path) = event_file_path {
let graph = watcher_graph.read().await;
let mut vs_guard = watcher_vector_store.write().await;
if let (Some(vs), Some(engine)) =
(vs_guard.as_mut(), watcher_embedding_engine.as_ref())
{
match crate::watcher::incremental::re_embed_file(
&graph, vs, engine, &file_path,
)
.await
{
Ok(count) => {
eprintln!(
"[watch] re-embedded {} symbols from {}",
count, file_path
);
}
Err(e) => {
eprintln!("[watch] re-embedding failed: {}", e);
}
}
}
}
let _ = watcher_tx.send(GRAPH_UPDATED_MSG.to_string());
}
});
}
Err(e) => {
eprintln!("[watcher] failed to start: {}", e);
}
}
let router = build_router(state, port);
let addr = format!("127.0.0.1:{port}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
println!("Serving on http://127.0.0.1:{port}");
println!("Auth token: {auth_token}");
axum::serve(listener, router).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
fn test_state() -> AppState {
let (ws_tx, _) = broadcast::channel::<String>(16);
AppState {
graph: Arc::new(RwLock::new(CodeGraph::new())),
project_root: PathBuf::from("/tmp/test-project"),
ws_tx,
auth_token: "test-token".to_string(),
#[cfg(feature = "rag")]
vector_store: Arc::new(RwLock::new(None)),
#[cfg(feature = "rag")]
embedding_engine: Arc::new(None),
#[cfg(feature = "rag")]
session_store: Arc::new(tokio::sync::Mutex::new(
crate::rag::session::SessionStore::new(10),
)),
#[cfg(feature = "rag")]
auth_state: Arc::new(RwLock::new(crate::rag::auth::AuthState {
provider: crate::rag::auth::LlmProvider::Claude {
api_key: String::new(),
},
})),
#[cfg(feature = "rag")]
pkce_state: Arc::new(tokio::sync::Mutex::new(
crate::web::api::auth::PkceState::new(),
)),
}
}
#[tokio::test]
async fn test_build_router_has_expected_routes() {
let state = test_state();
let app = build_router(state, 7070);
let routes = ["/api/graph", "/api/file", "/api/search", "/api/stats"];
for path in routes {
let req = Request::builder()
.uri(path)
.header("Authorization", "Bearer test-token")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"route {path} should exist (got 404)"
);
}
let req = Request::builder().uri("/ws").body(Body::empty()).unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"route /ws should exist (got 404)"
);
}
#[tokio::test]
async fn test_cors_allows_server_origin() {
let state = test_state();
let app = build_router(state, 7070);
let req = Request::builder()
.method("OPTIONS")
.uri("/api/graph")
.header("Origin", "http://127.0.0.1:7070")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let acl = resp
.headers()
.get("access-control-allow-origin")
.map(|v| v.to_str().unwrap().to_string());
assert_eq!(
acl.as_deref(),
Some("http://127.0.0.1:7070"),
"CORS should allow the server's own origin"
);
}
#[tokio::test]
async fn test_cors_origin_is_server_port_not_3000() {
let state = test_state();
let app = build_router(state, 7070);
let req = Request::builder()
.method("OPTIONS")
.uri("/api/graph")
.header("Origin", "http://127.0.0.1:3000")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let acl = resp
.headers()
.get("access-control-allow-origin")
.map(|v| v.to_str().unwrap().to_string());
assert_ne!(
acl.as_deref(),
Some("http://127.0.0.1:3000"),
"CORS must not allow the old hardcoded port 3000"
);
assert_eq!(
acl.as_deref(),
Some("http://127.0.0.1:7070"),
"CORS allowed origin should be the server's own port"
);
}
#[tokio::test]
async fn test_cors_custom_port() {
let state = test_state();
let app = build_router(state, 9999);
let req = Request::builder()
.method("OPTIONS")
.uri("/api/stats")
.header("Origin", "http://127.0.0.1:9999")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let acl = resp
.headers()
.get("access-control-allow-origin")
.map(|v| v.to_str().unwrap().to_string());
assert_eq!(
acl.as_deref(),
Some("http://127.0.0.1:9999"),
"CORS should reflect the port passed to build_router"
);
}
#[test]
fn test_no_mcp_references_in_web_module() {
let web_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/web");
let forbidden: Vec<String> = vec![
format!("r{}cp", "m"), format!("{}Server", "Mcp"), ["m", "c", "p"].join(""), ];
for entry in walkdir(&web_dir) {
let content = std::fs::read_to_string(&entry).unwrap_or_default();
let prod_content = if let Some(pos) = content.find("#[cfg(test)]") {
&content[..pos]
} else {
&content
};
for term in &forbidden {
assert!(
!prod_content.contains(term.as_str()),
"file {} contains forbidden MCP reference '{}'",
entry.display(),
term,
);
}
}
}
fn walkdir(dir: &std::path::Path) -> Vec<PathBuf> {
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
files.extend(walkdir(&path));
} else if path.extension().is_some_and(|e| e == "rs") {
files.push(path);
}
}
}
files
}
}