use axum::{extract::State, http::StatusCode, Json};
use serde::Deserialize;
use std::sync::Arc;
use crate::core::{classifier::QueryClassifier, indexer::SearchQuery, registry::IndexId};
use super::routing::{compute_context_weights, RoutingMode};
use super::state::SearchAppState;
#[derive(Deserialize)]
pub struct GlobalSearchRequest {
pub query: String,
#[serde(default = "default_global_top_k")]
pub top_k: usize,
#[serde(default)]
pub full_content: bool,
#[serde(default)]
pub indexes: Option<Vec<String>>,
#[serde(default)]
pub routing: Option<String>,
#[serde(default)]
pub routing_n: Option<usize>,
#[serde(default)]
pub routing_threshold: Option<f32>,
}
fn default_global_top_k() -> usize {
10
}
pub(super) async fn global_search_handler(
State(state): State<Arc<SearchAppState>>,
Json(req): Json<GlobalSearchRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
if req.query.trim().is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "query must not be empty" })),
));
}
use crate::core::search::rrf::{rrf_fuse, RRF_K};
let cold_indexes_skipped: usize = if let Some(requested) = req.indexes.as_ref() {
state
.cold_store
.count_matching(requested.iter().map(|s| s.as_str()))
} else {
state.cold_store.len()
};
let all_ids = state.registry.list();
let index_ids: Vec<IndexId> = if let Some(requested) = req.indexes.as_ref() {
let allow: std::collections::HashSet<&str> = requested.iter().map(|s| s.as_str()).collect();
all_ids
.into_iter()
.filter(|id| allow.contains(id.0.as_str()))
.collect()
} else {
all_ids
};
let total_indexes = index_ids.len();
if index_ids.is_empty() {
return Ok(Json(serde_json::json!({
"results": Vec::<crate::core::indexer::CodeChunk>::new(),
"indexes_searched": Vec::<String>::new(),
"total_indexes": 0_usize,
"cold_indexes_skipped": cold_indexes_skipped,
"latency_ms": 0_u64,
"intent": format!("{:?}", QueryClassifier::classify(&req.query)),
})));
}
let started = std::time::Instant::now();
let intent = QueryClassifier::classify(&req.query);
let routing_mode = RoutingMode::from_request(&req);
let weights = compute_context_weights(&state.registry, &index_ids, &req.query).await;
let (mut active_ids, mut weight_map) = routing_mode.apply(&index_ids, &weights);
let hierarchy = if req.indexes.is_none() {
let h = crate::core::search::hierarchy::IndexHierarchy::from_registry(
&state.registry,
&index_ids,
);
if matches!(routing_mode, RoutingMode::Threshold(_)) && !h.parent_of.is_empty() {
let inactive_ids: Vec<IndexId> = index_ids
.iter()
.filter(|id| !weight_map.contains_key(id))
.cloned()
.collect();
crate::core::search::hierarchy::apply_threshold_child_inclusion(
&inactive_ids,
&mut active_ids,
&mut weight_map,
&h,
);
}
h
} else {
crate::core::search::hierarchy::IndexHierarchy::default()
};
let routing_label = routing_mode.label().to_string();
let routing_decisions: Vec<serde_json::Value> = index_ids
.iter()
.map(|id| {
let w = weights.get(id).copied().unwrap_or(1.0);
let included = weight_map.contains_key(id);
serde_json::json!({
"index_id": id.0,
"cosine_similarity": w,
"included": included,
})
})
.collect();
let per_index_query = SearchQuery {
text: req.query.clone(),
top_k: req.top_k,
expand_graph: true,
compact: !req.full_content,
branch_files: None,
branch_boost: SearchQuery::default_branch_boost(),
branch: None,
mode: crate::core::indexer::SearchMode::default(),
exclude_archived: false,
stage: None,
refine_query: None,
};
let registry = state.registry.clone();
let futures = active_ids.into_iter().map(|id| {
let registry = registry.clone();
let query = per_index_query.clone();
async move {
let handle = registry.get(&id)?;
let indexer = handle.indexer.read().await;
match indexer.search(&query).await {
Ok(results) => Some((id, results)),
Err(e) => {
tracing::warn!("global search: index {} errored: {e}", id);
None
}
}
}
});
let per_index_results: Vec<(IndexId, Vec<crate::core::indexer::CodeChunk>)> =
futures::future::join_all(futures)
.await
.into_iter()
.flatten()
.collect();
let mut chunk_lookup: std::collections::HashMap<String, crate::core::indexer::CodeChunk> =
std::collections::HashMap::new();
let mut lanes: Vec<Vec<(String, f32)>> = Vec::with_capacity(per_index_results.len());
let mut indexes_searched: Vec<String> = Vec::with_capacity(per_index_results.len());
for (id, results) in per_index_results {
indexes_searched.push(id.0.clone());
let cosine_weight = weight_map.get(&id).copied().unwrap_or(1.0);
let weight = crate::core::search::hierarchy::effective_weight_for_index(
&id,
cosine_weight,
&hierarchy,
);
let mut lane: Vec<(String, f32)> = Vec::with_capacity(results.len());
for mut chunk in results {
let namespaced = format!("{}::{}", id.0, chunk.id);
chunk.index_id = Some(id.0.clone());
let weighted_score = chunk.score * weight;
lane.push((namespaced.clone(), weighted_score));
chunk_lookup.insert(namespaced, chunk);
}
lanes.push(lane);
}
let mut fused: Vec<(String, f32)> = Vec::new();
let oversample = req.top_k.saturating_mul(4).max(req.top_k).max(10);
for lane in lanes {
fused = rrf_fuse(&fused, &lane, 1.0, 1.0, RRF_K, oversample);
}
let (fused, hierarchy_dedup_count) = crate::core::search::hierarchy::dedup_nested_results(
fused,
&chunk_lookup,
&state.registry,
&hierarchy,
);
let mut fused = fused;
fused.truncate(req.top_k);
let results: Vec<crate::core::indexer::CodeChunk> = fused
.into_iter()
.filter_map(|(id, fused_score)| {
let mut chunk = chunk_lookup.remove(&id)?;
chunk.score = fused_score;
Some(chunk)
})
.collect();
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"indexes_searched": indexes_searched,
"total_indexes": total_indexes,
"cold_indexes_skipped": cold_indexes_skipped,
"latency_ms": latency_ms,
"intent": format!("{:?}", intent),
"routing": routing_label,
"routing_decisions": routing_decisions,
"hierarchy_dedup_count": hierarchy_dedup_count,
})))
}