use std::sync::Arc;
use axum::{
body::Bytes,
extract::{Path, Query, State},
response::Json,
};
use serde::{Deserialize, Serialize};
use crate::core::{
bow_embedding, cluster as run_cluster, extract_doc_comments, extract_kg_from_scip,
ClusterResult, NerExtractor, ScipIngestSummary,
};
use crate::embedder::{Embedder, EmbedderKind};
use crate::service::events::{fetch_chunks, AnalyzerAppState, AnalyzerEvent, ApiError};
use crate::types::{KgGraph, KgNode, RawEntity};
#[derive(Deserialize)]
pub struct GraphQueryParams {
pub language: Option<String>,
}
pub async fn graph_for_index(
State(state): State<Arc<AnalyzerAppState>>,
Path(id): Path<String>,
Query(params): Query<GraphQueryParams>,
) -> Result<Json<KgGraph>, ApiError> {
let chunks = fetch_chunks(&state, &id).await?;
let res = state.registry.analyze(&chunks);
let mut graph = res.graph;
if let Some(overlay) = state.scip_overlays.read().await.get(&id).cloned() {
graph.merge(overlay);
graph = crate::core::link(graph);
}
if let Some(lang) = params.language.as_deref() {
let keep_nodes: std::collections::HashSet<String> = graph
.nodes
.iter()
.filter(|n| n.language == lang)
.map(|n| n.id.clone())
.collect();
graph.nodes.retain(|n| keep_nodes.contains(&n.id));
graph
.edges
.retain(|e| keep_nodes.contains(&e.from) && keep_nodes.contains(&e.to));
}
Ok(Json(graph))
}
#[derive(Deserialize)]
pub struct EntitiesQueryParams {
pub kind: Option<String>,
pub language: Option<String>,
}
pub async fn entities_for_index(
State(state): State<Arc<AnalyzerAppState>>,
Path(id): Path<String>,
Query(params): Query<EntitiesQueryParams>,
) -> Result<Json<Vec<KgNode>>, ApiError> {
let chunks = fetch_chunks(&state, &id).await?;
let res = state.registry.analyze(&chunks);
let mut nodes = res.graph.nodes;
if let Some(lang) = params.language.as_deref() {
nodes.retain(|n| n.language == lang);
}
if let Some(kind) = params.kind.as_deref() {
nodes.retain(|n| format!("{:?}", n.kind) == kind);
}
nodes.sort_by(|a, b| {
format!("{:?}", a.kind)
.cmp(&format!("{:?}", b.kind))
.then_with(|| a.name.cmp(&b.name))
});
Ok(Json(nodes))
}
#[derive(Deserialize)]
pub struct ClusterQueryParams {
pub k: Option<usize>,
#[serde(default)]
pub method: Option<EmbedderKind>,
}
#[derive(Serialize)]
pub struct ClusterResponseItem {
pub id: usize,
pub label: String,
pub members: Vec<String>,
pub cohesion: f32,
pub size: usize,
}
#[derive(Serialize)]
pub struct ClusterResponse {
pub k: usize,
pub method: String,
pub dim: usize,
pub iterations: usize,
pub chunk_count: usize,
pub clusters: Vec<ClusterResponseItem>,
}
fn cluster_items_from(r: ClusterResult) -> Vec<ClusterResponseItem> {
r.clusters
.into_iter()
.map(|c| ClusterResponseItem {
id: c.id,
label: c.label,
size: c.members.len(),
members: c.members,
cohesion: c.cohesion,
})
.collect()
}
pub async fn clusters_for_index(
State(state): State<Arc<AnalyzerAppState>>,
Path(id): Path<String>,
Query(params): Query<ClusterQueryParams>,
) -> Result<Json<ClusterResponse>, ApiError> {
const BOW_DIM: usize = 256;
let k = params.k.unwrap_or(8).clamp(1, 50);
let method = params.method.clone().unwrap_or_default();
let chunks = fetch_chunks(&state, &id).await?;
if chunks.is_empty() {
return Ok(Json(ClusterResponse {
k,
method: method.as_str().to_string(),
dim: 0,
iterations: 0,
chunk_count: 0,
clusters: Vec::new(),
}));
}
let neural_embedder: Arc<dyn Embedder> = state.embedder.clone();
let effective_kind_initial: EmbedderKind = match method {
EmbedderKind::Neural => neural_embedder.kind(),
EmbedderKind::Bow => EmbedderKind::Bow,
};
let owned_texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
let embed_result: anyhow::Result<(Vec<Vec<f32>>, EmbedderKind, usize)> = match method {
EmbedderKind::Neural => {
let embedder_arc = Arc::clone(&neural_embedder);
let dim = embedder_arc.dim();
let texts_for_task = owned_texts.clone();
tokio::task::spawn_blocking(move || {
let refs: Vec<&str> = texts_for_task.iter().map(String::as_str).collect();
embedder_arc.embed_batch(&refs)
})
.await
.unwrap_or_else(|e| Err(anyhow::anyhow!("embed_batch task panicked: {e}")))
.map(|v| (v, EmbedderKind::Neural, dim))
}
EmbedderKind::Bow => {
let vecs: Vec<Vec<f32>> = owned_texts
.iter()
.map(|t| bow_embedding(t, BOW_DIM))
.collect();
Ok((vecs, EmbedderKind::Bow, BOW_DIM))
}
};
let (vecs, effective_kind, dim) = match embed_result {
Ok(triple) => triple,
Err(e) => {
tracing::warn!(
"embedder ({:?}) failed ({e:#}); falling back to BOW",
effective_kind_initial
);
let fallback: Vec<Vec<f32>> = owned_texts
.iter()
.map(|t| bow_embedding(t, BOW_DIM))
.collect();
(fallback, EmbedderKind::Bow, BOW_DIM)
}
};
let embeddings: Vec<(String, Vec<f32>)> = chunks
.iter()
.zip(vecs)
.map(|(c, v)| (c.id.clone(), v))
.collect();
let result = run_cluster(&embeddings, k, 100, 42);
let iterations = result.iterations;
Ok(Json(ClusterResponse {
k,
method: effective_kind.as_str().to_string(),
dim,
iterations,
chunk_count: chunks.len(),
clusters: cluster_items_from(result),
}))
}
#[derive(Deserialize)]
pub struct NerQueryParams {
pub top_k: Option<usize>,
}
pub async fn ner_for_index(
State(state): State<Arc<AnalyzerAppState>>,
Path(id): Path<String>,
Query(params): Query<NerQueryParams>,
) -> Result<Json<Vec<RawEntity>>, ApiError> {
let chunks = fetch_chunks(&state, &id).await?;
let top_k = params.top_k.unwrap_or(50);
let extractor = NerExtractor::try_load();
let mut entities: Vec<RawEntity> = Vec::new();
for chunk in &chunks {
let docs = extract_doc_comments(&chunk.content);
if docs.is_empty() {
continue;
}
entities.extend(extractor.extract(&docs, &chunk.file));
if entities.len() >= top_k {
break;
}
}
entities.truncate(top_k);
Ok(Json(entities))
}
#[derive(Serialize)]
pub struct ScipIngestResponse {
pub index_id: String,
#[serde(flatten)]
pub summary: ScipIngestSummary,
}
pub async fn ingest_scip(
State(state): State<Arc<AnalyzerAppState>>,
Path(id): Path<String>,
body: Bytes,
) -> Result<Json<ScipIngestResponse>, ApiError> {
let (graph, summary) = extract_kg_from_scip(&body).map_err(|e| {
tracing::warn!("SCIP ingest for {id} failed: {e:#}");
ApiError::bad_request(format!("invalid SCIP protobuf: {e:#}"))
})?;
let symbols_ingested = summary.kg_nodes;
state.scip_overlays.write().await.insert(id.clone(), graph);
state.emit(AnalyzerEvent::ScipIngested {
index_id: id.clone(),
symbols_ingested,
});
Ok(Json(ScipIngestResponse {
index_id: id,
summary,
}))
}