use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::Deserialize;
use std::sync::Arc;
use crate::core::registry::IndexId;
use super::search::GlobalSearchRequest;
use super::state::SearchAppState;
#[derive(Debug, Clone, Copy)]
pub(super) enum RoutingMode {
All,
TopN(usize),
Threshold(f32),
}
impl RoutingMode {
pub(super) const DEFAULT_TOP_N: usize = 3;
const DEFAULT_THRESHOLD: f32 = 0.3;
pub(super) fn from_request(req: &GlobalSearchRequest) -> Self {
match req.routing.as_deref() {
Some("top_n") => Self::TopN(req.routing_n.unwrap_or(Self::DEFAULT_TOP_N).max(1)),
Some("threshold") => {
Self::Threshold(req.routing_threshold.unwrap_or(Self::DEFAULT_THRESHOLD))
}
_ => Self::All,
}
}
pub(super) fn label(self) -> &'static str {
match self {
Self::All => "all",
Self::TopN(_) => "top_n",
Self::Threshold(_) => "threshold",
}
}
pub(super) fn apply(
self,
index_ids: &[IndexId],
weights: &std::collections::HashMap<IndexId, f32>,
) -> (Vec<IndexId>, std::collections::HashMap<IndexId, f32>) {
match self {
Self::All => {
let active: Vec<IndexId> = index_ids.to_vec();
let map: std::collections::HashMap<IndexId, f32> = index_ids
.iter()
.map(|id| (id.clone(), weights.get(id).copied().unwrap_or(1.0)))
.collect();
(active, map)
}
Self::TopN(n) => {
let mut ranked: Vec<(&IndexId, f32)> = index_ids
.iter()
.map(|id| (id, weights.get(id).copied().unwrap_or(1.0)))
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let active: Vec<IndexId> =
ranked.iter().take(n).map(|(id, _)| (*id).clone()).collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
Self::Threshold(t) => {
let active: Vec<IndexId> = index_ids
.iter()
.filter(|id| weights.get(id).copied().unwrap_or(1.0) >= t)
.cloned()
.collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
}
}
}
pub(super) async fn compute_context_weights(
registry: &crate::core::registry::IndexRegistry,
index_ids: &[IndexId],
query: &str,
) -> std::collections::HashMap<IndexId, f32> {
use crate::core::mmr::cosine_similarity;
let mut query_embedding: Option<Vec<f32>> = None;
for id in index_ids {
let Some(handle) = registry.get(id) else {
continue;
};
let indexer = handle.indexer.read().await;
match indexer.embed_text(query).await {
Ok(Some(vec)) => {
query_embedding = Some(vec);
break;
}
Ok(None) => continue,
Err(e) => {
tracing::debug!("context_routing: embed_text failed on {}: {e}", id.0);
continue;
}
}
}
let mut out = std::collections::HashMap::with_capacity(index_ids.len());
let Some(q) = query_embedding else {
for id in index_ids {
out.insert(id.clone(), 1.0);
}
return out;
};
for id in index_ids {
let Some(handle) = registry.get(id) else {
out.insert(id.clone(), 1.0);
continue;
};
let ctx_guard = handle.context_embedding.read().await;
let weight = match ctx_guard.as_ref() {
Some(ctx) if ctx.len() == q.len() => cosine_similarity(&q, ctx).max(0.0),
_ => 1.0,
};
out.insert(id.clone(), weight);
}
out
}
#[derive(Deserialize)]
pub struct SearchSimilarRequest {
pub file: String,
#[serde(default)]
pub function: Option<String>,
#[serde(default = "default_similar_top_k")]
pub top_k: usize,
}
fn default_similar_top_k() -> usize {
10
}
pub(super) async fn search_similar_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<SearchSimilarRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let started = std::time::Instant::now();
let indexer = handle.indexer.read().await;
let chunk_id = indexer
.find_chunk_id(&req.file, req.function.as_deref())
.await
.ok_or(StatusCode::NOT_FOUND)?;
let embedding = if let Some(cached) = indexer.get_embedding(&chunk_id) {
cached
} else {
let content = indexer
.chunk_content_by_id(&chunk_id)
.await
.ok_or(StatusCode::NOT_FOUND)?;
indexer
.embed_text(&content)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)? };
let results = indexer
.similar_by_embedding(&embedding, req.top_k, Some(&chunk_id))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"seed_chunk_id": chunk_id,
"latency_ms": latency_ms,
})))
}