use crate::core::{
classifier::QueryClassifier,
embed::Embedder,
indexer::SearchQuery,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use axum::{
body::Body,
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Json, Redirect, Response},
routing::{delete, get, post},
Router,
};
use dashmap::DashMap;
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{broadcast, watch, OnceCell, RwLock};
use tokio_stream::wrappers::BroadcastStream;
use trusty_common::{ChatProvider, LocalModelConfig};
use crate::service::reindex::{spawn_reindex_with_cleanup, ReindexProgress, ReindexStatus};
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DaemonEvent {
StatusChanged {
indexes: u64,
total_chunks: u64,
uptime_secs: u64,
version: String,
},
IndexRegistered { id: String },
IndexRemoved { id: String },
}
#[derive(Clone)]
pub struct SearchAppState {
pub registry: IndexRegistry,
pub reindex_progress: Arc<DashMap<IndexId, Arc<ReindexProgress>>>,
pub last_reindex_aborted_at: Arc<DashMap<IndexId, std::time::Instant>>,
pub embedder: Option<Arc<dyn Embedder>>,
pub embedder_slot: Arc<RwLock<Option<Arc<dyn Embedder>>>>,
pub embedder_ready: watch::Receiver<bool>,
pub embedder_ready_tx: Arc<watch::Sender<bool>>,
pub embedder_error: Arc<RwLock<Option<String>>>,
pub daemon_port: Option<u16>,
pub openrouter_enabled: bool,
pub started_at: Instant,
pub local_model: LocalModelConfig,
pub openrouter_model: String,
pub openrouter_api_key: String,
pub chat_provider: Arc<OnceCell<Option<Arc<dyn ChatProvider>>>>,
pub events: Arc<broadcast::Sender<DaemonEvent>>,
}
impl SearchAppState {
pub fn new(registry: IndexRegistry) -> Self {
let openrouter_api_key = std::env::var("OPENROUTER_API_KEY").unwrap_or_default();
let (events_tx, _) = broadcast::channel::<DaemonEvent>(128);
let (ready_tx, ready_rx) = watch::channel(false);
Self {
registry,
reindex_progress: Arc::new(DashMap::new()),
last_reindex_aborted_at: Arc::new(DashMap::new()),
embedder: None,
embedder_slot: Arc::new(RwLock::new(None)),
embedder_ready: ready_rx,
embedder_ready_tx: Arc::new(ready_tx),
embedder_error: Arc::new(RwLock::new(None)),
daemon_port: None,
openrouter_enabled: !openrouter_api_key.is_empty(),
started_at: Instant::now(),
local_model: LocalModelConfig::default(),
openrouter_model: "anthropic/claude-haiku-4.5".to_string(),
openrouter_api_key,
chat_provider: Arc::new(OnceCell::new()),
events: Arc::new(events_tx),
}
}
pub fn emit(&self, event: DaemonEvent) {
let _ = self.events.send(event);
}
pub fn with_local_model(mut self, cfg: LocalModelConfig) -> Self {
self.local_model = cfg;
self
}
pub fn with_openrouter_model(mut self, model: impl Into<String>) -> Self {
self.openrouter_model = model.into();
self
}
pub fn with_openrouter_api_key(mut self, api_key: impl Into<String>) -> Self {
let api_key_str = api_key.into();
self.openrouter_enabled = !api_key_str.is_empty();
self.openrouter_api_key = api_key_str;
self
}
pub async fn chat_provider(&self) -> Option<Arc<dyn ChatProvider>> {
self.chat_provider
.get_or_init(|| async {
if self.local_model.enabled {
if let Some(mut p) =
trusty_common::auto_detect_local_provider(&self.local_model.base_url).await
{
p.model = self.local_model.model.clone();
return Some(Arc::new(p) as Arc<dyn ChatProvider>);
}
}
if !self.openrouter_api_key.is_empty() {
return Some(Arc::new(trusty_common::OpenRouterProvider::new(
self.openrouter_api_key.clone(),
self.openrouter_model.clone(),
)) as Arc<dyn ChatProvider>);
}
None
})
.await
.clone()
}
pub fn with_daemon_port(mut self, port: u16) -> Self {
self.daemon_port = Some(port);
self
}
pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
self.embedder = Some(Arc::clone(&embedder));
if let Ok(mut slot) = self.embedder_slot.try_write() {
*slot = Some(embedder);
}
let _ = self.embedder_ready_tx.send(true);
self
}
pub async fn install_embedder(&self, embedder: Arc<dyn Embedder>) {
let mut slot = self.embedder_slot.write().await;
*slot = Some(embedder);
drop(slot);
{
let mut err = self.embedder_error.write().await;
*err = None;
}
let _ = self.embedder_ready_tx.send(true);
}
pub async fn install_embedder_error(&self, message: impl Into<String>) {
let msg = message.into();
tracing::error!("embedder init failed: {msg}");
let mut err = self.embedder_error.write().await;
*err = Some(msg);
}
pub fn current_embedder_error(&self) -> Option<String> {
self.embedder_error.try_read().ok().and_then(|g| g.clone())
}
pub async fn current_embedder(&self) -> Option<Arc<dyn Embedder>> {
let slot = self.embedder_slot.read().await;
slot.clone()
}
pub fn is_embedder_ready(&self) -> bool {
*self.embedder_ready.borrow()
}
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
version: &'static str,
indexes: usize,
uptime_secs: u64,
embedder: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
embedder_error: Option<String>,
}
#[derive(Serialize)]
struct IndexListResponse {
indexes: Vec<String>,
}
#[derive(Deserialize)]
pub struct CreateIndexRequest {
pub id: String,
pub root_path: std::path::PathBuf,
#[serde(default)]
pub include_paths: Option<Vec<String>>,
#[serde(default)]
pub exclude_globs: Option<Vec<String>>,
#[serde(default)]
pub extensions: Option<Vec<String>>,
#[serde(default)]
pub domain_terms: Option<Vec<String>>,
#[serde(default)]
pub path_filter: Option<Vec<String>>,
}
#[derive(Deserialize)]
pub struct IndexFileRequest {
pub path: String,
pub content: String,
}
#[derive(Deserialize)]
pub struct RemoveFileRequest {
pub path: String,
}
pub fn build_router(state: SearchAppState) -> Router {
use crate::service::ui::{
chat_handler, list_chat_providers, ui_asset_handler, ui_index_handler,
};
let state_arc = Arc::new(state);
spawn_status_ticker(Arc::clone(&state_arc));
let router = Router::new()
.route("/", get(|| async { Redirect::permanent("/ui/") }))
.route("/health", get(health_handler))
.route("/status/stream", get(status_stream_handler))
.route(
"/indexes",
get(list_indexes_handler).post(create_index_handler),
)
.route("/indexes/{id}", delete(delete_index_handler))
.route("/ui", get(|| async { Redirect::permanent("/ui/") }))
.route("/ui/", get(ui_index_handler))
.route("/ui/{*path}", get(ui_asset_handler))
.route("/chat", post(chat_handler))
.route("/api/chat/providers", get(list_chat_providers))
.route("/search", post(global_search_handler))
.route("/indexes/{id}/search", post(search_handler))
.route("/indexes/{id}/search_similar", post(search_similar_handler))
.route("/indexes/{id}/status", get(index_status_handler))
.route("/indexes/{id}/graph", get(graph_handler))
.route("/indexes/{id}/index-file", post(index_file_handler))
.route("/indexes/{id}/remove-file", post(remove_file_handler))
.route("/indexes/{id}/reindex", post(reindex_handler))
.route("/indexes/{id}/reindex/stream", get(reindex_stream_handler))
.route("/indexes/{id}/chunks", get(get_index_chunks_handler))
.route(
"/config",
get(get_config_handler).patch(patch_config_handler),
)
.with_state(Arc::clone(&state_arc));
trusty_common::server::with_standard_middleware(router)
}
fn spawn_status_ticker(state: Arc<SearchAppState>) {
let weak = Arc::downgrade(&state);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(2));
interval.tick().await;
loop {
interval.tick().await;
let Some(state) = weak.upgrade() else {
break;
};
let (indexes, total_chunks) = collect_status_counts(&state).await;
state.emit(DaemonEvent::StatusChanged {
indexes: indexes as u64,
total_chunks: total_chunks as u64,
uptime_secs: state.started_at.elapsed().as_secs(),
version: env!("CARGO_PKG_VERSION").to_string(),
});
}
});
}
async fn health_handler(State(state): State<Arc<SearchAppState>>) -> Json<HealthResponse> {
let embedder_error = state.current_embedder_error();
let embedder_status = if state.is_embedder_ready() {
"ready"
} else if state.embedder.is_some()
|| state
.embedder_slot
.try_read()
.map(|g| g.is_some())
.unwrap_or(false)
{
"ready"
} else if embedder_error.is_some() {
"error"
} else {
"initializing"
};
Json(HealthResponse {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
indexes: state.registry.list().len(),
uptime_secs: state.started_at.elapsed().as_secs(),
embedder: embedder_status,
embedder_error,
})
}
#[derive(Debug, Deserialize, Default)]
struct PatchConfigRequest {
#[serde(default, deserialize_with = "deserialize_optional_option_u64")]
memory_limit_mb: Option<Option<u64>>,
#[serde(default, deserialize_with = "deserialize_optional_option_u64")]
index_memory_limit_mb: Option<Option<u64>>,
}
#[derive(Debug, Serialize)]
struct ConfigResponse {
memory_limit_mb: Option<u64>,
index_memory_limit_mb: Option<u64>,
}
fn deserialize_optional_option_u64<'de, D>(deserializer: D) -> Result<Option<Option<u64>>, D::Error>
where
D: serde::Deserializer<'de>,
{
let v = Option::<u64>::deserialize(deserializer)?;
Ok(Some(v))
}
async fn get_config_handler(State(_state): State<Arc<SearchAppState>>) -> Json<ConfigResponse> {
use crate::core::memguard::{index_memory_limit_mb, memory_limit_mb};
Json(ConfigResponse {
memory_limit_mb: memory_limit_mb(),
index_memory_limit_mb: index_memory_limit_mb(),
})
}
async fn patch_config_handler(
State(_state): State<Arc<SearchAppState>>,
Json(req): Json<PatchConfigRequest>,
) -> Json<ConfigResponse> {
use crate::core::memguard::{
index_memory_limit_mb, memory_limit_mb, set_index_memory_limit_mb, set_memory_limit_mb,
};
let fmt = |v: Option<u64>| match v {
Some(mb) => mb.to_string(),
None => "unlimited".to_string(),
};
if let Some(new) = req.memory_limit_mb {
let before = memory_limit_mb();
set_memory_limit_mb(new);
let after = memory_limit_mb();
tracing::info!(
"config updated: memory_limit_mb {} → {}",
fmt(before),
fmt(after)
);
}
if let Some(new) = req.index_memory_limit_mb {
let before = index_memory_limit_mb();
set_index_memory_limit_mb(new);
let after = index_memory_limit_mb();
tracing::info!(
"config updated: index_memory_limit_mb {} → {}",
fmt(before),
fmt(after)
);
}
Json(ConfigResponse {
memory_limit_mb: memory_limit_mb(),
index_memory_limit_mb: index_memory_limit_mb(),
})
}
async fn collect_status_counts(state: &SearchAppState) -> (usize, usize) {
let ids = state.registry.list();
let indexes_count = ids.len();
let mut total_chunks: usize = 0;
for id in ids {
if let Some(handle) = state.registry.get(&id) {
let indexer = handle.indexer.read().await;
total_chunks = total_chunks.saturating_add(indexer.chunk_count());
}
}
(indexes_count, total_chunks)
}
async fn status_stream_handler(State(state): State<Arc<SearchAppState>>) -> impl IntoResponse {
let rx = state.events.subscribe();
let initial = stream::once(async {
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(
"data: {\"type\":\"connected\"}\n\n",
))
});
let events = BroadcastStream::new(rx).map(|res| {
let frame = match res {
Ok(event) => match serde_json::to_string(&event) {
Ok(json) => format!("data: {json}\n\n"),
Err(e) => format!("data: {{\"type\":\"error\",\"message\":\"{e}\"}}\n\n"),
},
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")
}
};
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(frame))
});
let stream = initial.chain(events);
Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(Body::from_stream(stream))
.expect("valid SSE response")
}
async fn list_indexes_handler(State(state): State<Arc<SearchAppState>>) -> Json<IndexListResponse> {
Json(IndexListResponse {
indexes: state.registry.list().into_iter().map(|id| id.0).collect(),
})
}
async fn create_index_handler(
State(state): State<Arc<SearchAppState>>,
Json(req): Json<CreateIndexRequest>,
) -> Response {
let id = IndexId::new(req.id.clone());
if state.registry.get(&id).is_some() {
return Json(serde_json::json!({
"id": req.id,
"created": false,
"reason": "already exists",
}))
.into_response();
}
let Some(embedder) = state.current_embedder().await else {
if let Some(err) = state.current_embedder_error() {
return embedder_error_response(&err);
}
return embedder_initializing_response();
};
let mut indexer = crate::service::persistence_loader::build_indexer_with_persisted_state(
&req.id,
req.root_path.clone(),
&embedder,
)
.await;
let include_paths: Vec<std::path::PathBuf> = req
.include_paths
.clone()
.unwrap_or_default()
.into_iter()
.filter(|p| !p.trim().is_empty() && p.trim() != ".")
.map(|p| req.root_path.join(p.trim()))
.collect();
let exclude_globs: Vec<String> = req.exclude_globs.clone().unwrap_or_default();
let extensions: Vec<String> = req
.extensions
.clone()
.unwrap_or_default()
.into_iter()
.map(|e| e.trim_start_matches('.').to_string())
.filter(|e| !e.is_empty())
.collect();
let domain_terms: Vec<String> = req.domain_terms.clone().unwrap_or_default();
let path_filter: Vec<String> = req
.path_filter
.clone()
.unwrap_or_default()
.into_iter()
.filter(|p| !p.trim().is_empty())
.collect();
indexer.set_domain_terms(domain_terms.clone());
if let Err(e) = crate::service::persistence::upsert_index_registry_entry(
crate::service::persistence::PersistedIndex {
id: req.id.clone(),
root_path: req.root_path.clone(),
include_paths: req.include_paths.clone().unwrap_or_default(),
exclude_globs: exclude_globs.clone(),
extensions: extensions.clone(),
domain_terms: domain_terms.clone(),
path_filter: path_filter.clone(),
},
) {
tracing::warn!("could not persist index registry for {}: {e}", req.id);
}
let handle = IndexHandle {
id: id.clone(),
indexer: Arc::new(tokio::sync::RwLock::new(indexer)),
root_path: req.root_path,
include_paths,
exclude_globs,
extensions,
domain_terms,
path_filter,
context_embedding: Arc::new(tokio::sync::RwLock::new(None)),
context_summary: Arc::new(tokio::sync::RwLock::new(None)),
};
state.registry.register(handle);
state.emit(DaemonEvent::IndexRegistered { id: req.id.clone() });
Json(serde_json::json!({ "id": req.id, "created": true })).into_response()
}
fn embedder_initializing_response() -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "embedder initializing, retry in a few seconds"
})),
)
.into_response()
}
fn embedder_error_response(message: &str) -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": format!("embedder init failed: {message}"),
})),
)
.into_response()
}
async fn delete_index_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Json<serde_json::Value> {
let index_id = IndexId::new(id.clone());
let removed = state.registry.unregister(&index_id);
state.reindex_progress.remove(&index_id);
if removed {
if let Err(e) = crate::service::persistence::remove_index_registry_entry(&id) {
tracing::warn!("could not remove '{id}' from indexes.toml: {e}");
}
if let Err(e) = crate::service::persistence::remove_index_data_dir(&id) {
tracing::warn!("could not remove on-disk data for '{id}': {e}");
}
state.emit(DaemonEvent::IndexRemoved { id: id.clone() });
}
Json(serde_json::json!({ "id": id, "removed": removed }))
}
async fn search_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(query): Json<SearchQuery>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let intent = QueryClassifier::classify_with_domain(&query.text, &handle.domain_terms);
let started = std::time::Instant::now();
let indexer = handle.indexer.read().await;
let results = indexer
.search(&query)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let latency_ms = started.elapsed().as_millis() as u64;
tracing::info!(
index_id = %index_id,
intent = %format!("{intent:?}"),
latency_ms = latency_ms,
results = results.len(),
query = %&query.text[..query.text.len().min(80)],
"search"
);
Ok(Json(serde_json::json!({
"results": results,
"intent": format!("{:?}", intent),
"latency_ms": latency_ms,
})))
}
#[derive(Deserialize)]
pub struct GlobalSearchRequest {
pub query: String,
#[serde(default = "default_global_top_k")]
pub top_k: usize,
#[serde(default)]
pub full_content: bool,
#[serde(default)]
pub indexes: Option<Vec<String>>,
#[serde(default)]
pub routing: Option<String>,
#[serde(default)]
pub routing_n: Option<usize>,
#[serde(default)]
pub routing_threshold: Option<f32>,
}
fn default_global_top_k() -> usize {
10
}
async fn global_search_handler(
State(state): State<Arc<SearchAppState>>,
Json(req): Json<GlobalSearchRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
use crate::core::search::rrf::{rrf_fuse, RRF_K};
let all_ids = state.registry.list();
let index_ids: Vec<IndexId> = if let Some(requested) = req.indexes.as_ref() {
let allow: std::collections::HashSet<&str> = requested.iter().map(|s| s.as_str()).collect();
all_ids
.into_iter()
.filter(|id| allow.contains(id.0.as_str()))
.collect()
} else {
all_ids
};
let total_indexes = index_ids.len();
if index_ids.is_empty() {
return Ok(Json(serde_json::json!({
"results": Vec::<crate::core::indexer::CodeChunk>::new(),
"indexes_searched": Vec::<String>::new(),
"total_indexes": 0_usize,
"latency_ms": 0_u64,
"intent": format!("{:?}", QueryClassifier::classify(&req.query)),
})));
}
let started = std::time::Instant::now();
let intent = QueryClassifier::classify(&req.query);
let routing_mode = RoutingMode::from_request(&req);
let weights = compute_context_weights(&state.registry, &index_ids, &req.query).await;
let (active_ids, weight_map) = routing_mode.apply(&index_ids, &weights);
let routing_label = routing_mode.label().to_string();
let routing_decisions: Vec<serde_json::Value> = index_ids
.iter()
.map(|id| {
let w = weights.get(id).copied().unwrap_or(1.0);
let included = weight_map.contains_key(id);
serde_json::json!({
"index_id": id.0,
"cosine_similarity": w,
"included": included,
})
})
.collect();
let per_index_query = SearchQuery {
text: req.query.clone(),
top_k: req.top_k,
expand_graph: true,
compact: !req.full_content,
branch_files: None,
branch_boost: SearchQuery::default_branch_boost(),
branch: None,
};
let registry = state.registry.clone();
let futures = active_ids.into_iter().map(|id| {
let registry = registry.clone();
let query = per_index_query.clone();
async move {
let handle = registry.get(&id)?;
let indexer = handle.indexer.read().await;
match indexer.search(&query).await {
Ok(results) => Some((id, results)),
Err(e) => {
tracing::warn!("global search: index {} errored: {e}", id);
None
}
}
}
});
let per_index_results: Vec<(IndexId, Vec<crate::core::indexer::CodeChunk>)> =
futures::future::join_all(futures)
.await
.into_iter()
.flatten()
.collect();
let mut chunk_lookup: std::collections::HashMap<String, crate::core::indexer::CodeChunk> =
std::collections::HashMap::new();
let mut lanes: Vec<Vec<(String, f32)>> = Vec::with_capacity(per_index_results.len());
let mut indexes_searched: Vec<String> = Vec::with_capacity(per_index_results.len());
for (id, results) in per_index_results {
indexes_searched.push(id.0.clone());
let weight = weight_map.get(&id).copied().unwrap_or(1.0);
let mut lane: Vec<(String, f32)> = Vec::with_capacity(results.len());
for mut chunk in results {
let namespaced = format!("{}::{}", id.0, chunk.id);
chunk.index_id = Some(id.0.clone());
let weighted_score = chunk.score * weight;
lane.push((namespaced.clone(), weighted_score));
chunk_lookup.insert(namespaced, chunk);
}
lanes.push(lane);
}
let mut fused: Vec<(String, f32)> = Vec::new();
let oversample = req.top_k.saturating_mul(4).max(req.top_k).max(10);
for lane in lanes {
fused = rrf_fuse(&fused, &lane, 1.0, 1.0, RRF_K, oversample);
}
fused.truncate(req.top_k);
let results: Vec<crate::core::indexer::CodeChunk> = fused
.into_iter()
.filter_map(|(id, fused_score)| {
let mut chunk = chunk_lookup.remove(&id)?;
chunk.score = fused_score;
Some(chunk)
})
.collect();
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"indexes_searched": indexes_searched,
"total_indexes": total_indexes,
"latency_ms": latency_ms,
"intent": format!("{:?}", intent),
"routing": routing_label,
"routing_decisions": routing_decisions,
})))
}
#[derive(Debug, Clone, Copy)]
enum RoutingMode {
All,
TopN(usize),
Threshold(f32),
}
impl RoutingMode {
const DEFAULT_TOP_N: usize = 3;
const DEFAULT_THRESHOLD: f32 = 0.3;
fn from_request(req: &GlobalSearchRequest) -> Self {
match req.routing.as_deref() {
Some("top_n") => Self::TopN(req.routing_n.unwrap_or(Self::DEFAULT_TOP_N).max(1)),
Some("threshold") => {
Self::Threshold(req.routing_threshold.unwrap_or(Self::DEFAULT_THRESHOLD))
}
_ => Self::All,
}
}
fn label(self) -> &'static str {
match self {
Self::All => "all",
Self::TopN(_) => "top_n",
Self::Threshold(_) => "threshold",
}
}
fn apply(
self,
index_ids: &[IndexId],
weights: &std::collections::HashMap<IndexId, f32>,
) -> (Vec<IndexId>, std::collections::HashMap<IndexId, f32>) {
match self {
Self::All => {
let active: Vec<IndexId> = index_ids.to_vec();
let map: std::collections::HashMap<IndexId, f32> = index_ids
.iter()
.map(|id| (id.clone(), weights.get(id).copied().unwrap_or(1.0)))
.collect();
(active, map)
}
Self::TopN(n) => {
let mut ranked: Vec<(&IndexId, f32)> = index_ids
.iter()
.map(|id| (id, weights.get(id).copied().unwrap_or(1.0)))
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let active: Vec<IndexId> =
ranked.iter().take(n).map(|(id, _)| (*id).clone()).collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
Self::Threshold(t) => {
let active: Vec<IndexId> = index_ids
.iter()
.filter(|id| weights.get(id).copied().unwrap_or(1.0) >= t)
.cloned()
.collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
}
}
}
async fn compute_context_weights(
registry: &crate::core::registry::IndexRegistry,
index_ids: &[IndexId],
query: &str,
) -> std::collections::HashMap<IndexId, f32> {
use crate::core::mmr::cosine_similarity;
let mut query_embedding: Option<Vec<f32>> = None;
for id in index_ids {
let Some(handle) = registry.get(id) else {
continue;
};
let indexer = handle.indexer.read().await;
match indexer.embed_text(query).await {
Ok(Some(vec)) => {
query_embedding = Some(vec);
break;
}
Ok(None) => continue,
Err(e) => {
tracing::debug!("context_routing: embed_text failed on {}: {e}", id.0);
continue;
}
}
}
let mut out = std::collections::HashMap::with_capacity(index_ids.len());
let Some(q) = query_embedding else {
for id in index_ids {
out.insert(id.clone(), 1.0);
}
return out;
};
for id in index_ids {
let Some(handle) = registry.get(id) else {
out.insert(id.clone(), 1.0);
continue;
};
let ctx_guard = handle.context_embedding.read().await;
let weight = match ctx_guard.as_ref() {
Some(ctx) if ctx.len() == q.len() => cosine_similarity(&q, ctx).max(0.0),
_ => 1.0,
};
out.insert(id.clone(), weight);
}
out
}
#[derive(Deserialize)]
pub struct SearchSimilarRequest {
pub file: String,
#[serde(default)]
pub function: Option<String>,
#[serde(default = "default_similar_top_k")]
pub top_k: usize,
}
fn default_similar_top_k() -> usize {
10
}
async fn search_similar_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<SearchSimilarRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let started = std::time::Instant::now();
let indexer = handle.indexer.read().await;
let chunk_id = indexer
.find_chunk_id(&req.file, req.function.as_deref())
.await
.ok_or(StatusCode::NOT_FOUND)?;
let embedding = indexer
.get_embedding(&chunk_id)
.ok_or(StatusCode::NOT_FOUND)?;
let results = indexer
.similar_by_embedding(&embedding, req.top_k, Some(&chunk_id))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"seed_chunk_id": chunk_id,
"latency_ms": latency_ms,
})))
}
async fn index_status_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
let path_filter = if handle.path_filter.is_empty() {
serde_json::Value::Null
} else {
serde_json::Value::Array(
handle
.path_filter
.iter()
.map(|s| serde_json::Value::String(s.clone()))
.collect(),
)
};
let has_context_embedding = handle.context_embedding.read().await.is_some();
let context_summary = handle
.context_summary
.read()
.await
.clone()
.map(serde_json::Value::String)
.unwrap_or(serde_json::Value::Null);
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"root_path": handle.root_path,
"chunk_count": indexer.chunk_count(),
"path_filter": path_filter,
"has_context_embedding": has_context_embedding,
"context_summary": context_summary,
})))
}
#[derive(Debug, Default, serde::Deserialize)]
struct GraphQueryParams {
types: Option<String>,
edge_types: Option<String>,
min_weight: Option<f32>,
}
fn parse_filter_set(raw: Option<&str>) -> Option<std::collections::HashSet<String>> {
let raw = raw?;
let set: std::collections::HashSet<String> = raw
.split(',')
.map(|s| s.trim().to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect();
if set.is_empty() {
None
} else {
Some(set)
}
}
fn node_type_for_symbol(symbol: &str) -> &'static str {
let looks_like_path = symbol.contains('/')
&& std::path::Path::new(symbol)
.extension()
.is_some_and(|e| !e.is_empty());
if looks_like_path {
"File"
} else {
"Symbol"
}
}
async fn graph_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Query(params): Query<GraphQueryParams>,
) -> Result<Response, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let graph = {
let indexer = handle.indexer.read().await;
indexer.snapshot_symbol_graph().await
};
let type_filter = parse_filter_set(params.types.as_deref());
let edge_filter = parse_filter_set(params.edge_types.as_deref());
let min_weight = params.min_weight.unwrap_or(f32::MIN);
let mut kept_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut nodes: Vec<serde_json::Value> = Vec::new();
for (symbol, chunk_id, file) in graph.all_nodes() {
let node_type = node_type_for_symbol(&symbol);
if let Some(ref filter) = type_filter {
if !filter.contains(&node_type.to_ascii_lowercase()) {
continue;
}
}
kept_symbols.insert(symbol.clone());
nodes.push(serde_json::json!({
"id": chunk_id,
"type": node_type,
"label": symbol,
"metadata": { "file": file, "symbol": symbol },
}));
}
let mut edges: Vec<serde_json::Value> = Vec::new();
for (source, target, kind) in graph.all_edges() {
if type_filter.is_some()
&& (!kept_symbols.contains(&source) || !kept_symbols.contains(&target))
{
continue;
}
let kind_name = format!("{kind:?}");
if let Some(ref filter) = edge_filter {
if !filter.contains(&kind_name.to_ascii_lowercase()) {
continue;
}
}
let weight = kind.score_multiplier();
if weight < min_weight {
continue;
}
edges.push(serde_json::json!({
"source": source,
"target": target,
"type": kind_name,
"weight": weight,
}));
}
let body = serde_json::json!({
"nodes": nodes,
"edges": edges,
"stats": {
"node_count": graph.node_count(),
"edge_count": graph.edge_count(),
},
"generated_at": chrono::Utc::now().to_rfc3339(),
});
let mut response = Json(body).into_response();
response.headers_mut().insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("max-age=3600"),
);
Ok(response)
}
async fn index_file_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<IndexFileRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
indexer
.index_file(&req.path, &req.content)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"path": req.path,
"indexed": true,
})))
}
async fn remove_file_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<RemoveFileRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
let removed = indexer
.remove_file(&req.path)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"path": req.path,
"removed_chunks": removed,
})))
}
#[derive(Deserialize)]
pub struct ChunksParams {
#[serde(default)]
pub offset: usize,
#[serde(default = "default_chunks_limit")]
pub limit: usize,
}
fn default_chunks_limit() -> usize {
100
}
const MAX_CHUNKS_LIMIT: usize = 1_000;
async fn get_index_chunks_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Query(params): Query<ChunksParams>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let limit = params.limit.min(MAX_CHUNKS_LIMIT);
let indexer = handle.indexer.read().await;
let (total, chunks) = indexer.enumerate_chunks(params.offset, limit).await;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"total": total,
"offset": params.offset,
"limit": limit,
"chunks": chunks,
})))
}
#[derive(Deserialize, Default)]
pub struct ReindexRequest {
#[serde(default)]
pub root_path: Option<std::path::PathBuf>,
#[serde(default)]
pub force: Option<bool>,
}
async fn reindex_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
body: Option<Json<ReindexRequest>>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let index_id = IndexId::new(id.clone());
let mut handle = state.registry.get(&index_id).ok_or((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("unknown index: {}", index_id.0),
})),
))?;
if let Some(aborted_at) = state.last_reindex_aborted_at.get(&index_id) {
let elapsed = aborted_at.elapsed();
let cooldown = std::time::Duration::from_secs(
std::env::var("TRUSTY_REINDEX_COOLDOWN_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(300),
);
if elapsed < cooldown {
let remaining_secs = (cooldown - elapsed).as_secs();
tracing::warn!(
"reindex_handler: refusing reindex for index {} — last run \
aborted at memory limit {}s ago, cooldown {}s remaining",
index_id.0,
elapsed.as_secs(),
remaining_secs,
);
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "reindex cooldown active after memory-limit abort",
"index_id": index_id.0,
"retry_after_secs": remaining_secs,
"cooldown_secs": cooldown.as_secs(),
"hint": "lower TRUSTY_MAX_BATCH_SIZE or raise TRUSTY_MEMORY_LIMIT_MB before retrying",
})),
));
}
drop(aborted_at);
state.last_reindex_aborted_at.remove(&index_id);
}
let mut force = false;
if let Some(Json(req)) = body {
force = req.force.unwrap_or(false);
if let Some(new_root) = req.root_path {
if handle.root_path.as_os_str().is_empty() || handle.root_path != new_root {
let indexer = Arc::clone(&handle.indexer);
let new_handle = IndexHandle {
id: index_id.clone(),
indexer,
root_path: new_root,
include_paths: handle.include_paths.clone(),
exclude_globs: handle.exclude_globs.clone(),
extensions: handle.extensions.clone(),
domain_terms: handle.domain_terms.clone(),
path_filter: handle.path_filter.clone(),
context_embedding: Arc::clone(&handle.context_embedding),
context_summary: Arc::clone(&handle.context_summary),
};
handle = state.registry.register(new_handle);
}
}
}
let progress = Arc::new(ReindexProgress::new());
state
.reindex_progress
.insert(index_id.clone(), Arc::clone(&progress));
spawn_reindex_with_cleanup(
handle,
progress,
force,
Some(Arc::clone(&state.reindex_progress)),
Some(Arc::clone(&state.last_reindex_aborted_at)),
);
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"queued": true,
"stream_url": format!("/indexes/{}/reindex/stream", index_id.0),
})))
}
async fn reindex_stream_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Response, StatusCode> {
let index_id = IndexId::new(id);
let progress = state
.reindex_progress
.get(&index_id)
.map(|r| Arc::clone(r.value()))
.ok_or(StatusCode::NOT_FOUND)?;
let replay = progress.events.lock().await.clone();
let initial_status = progress.status.load();
let rx = progress.sender.subscribe();
fn frame(line: String) -> Result<axum::body::Bytes, std::io::Error> {
Ok(axum::body::Bytes::from(format!("data: {line}\n\n")))
}
let replay_stream = stream::iter(replay).map(frame);
let body = if initial_status != ReindexStatus::Running {
Body::from_stream(replay_stream)
} else {
let live = BroadcastStream::new(rx).map(|res| match res {
Ok(line) => frame(line),
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => Ok(
axum::body::Bytes::from(format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")),
),
});
Body::from_stream(replay_stream.chain(live))
};
Ok(Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(body)
.expect("valid SSE response"))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn health_handler_reports_indexes_and_uptime() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("health-test");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(CodeIndexer::new(
"health-test",
"/tmp/health-test",
))),
"/tmp/health-test".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let Json(resp) = health_handler(State(state)).await;
assert_eq!(resp.status, "ok");
assert_eq!(resp.version, env!("CARGO_PKG_VERSION"));
assert_eq!(resp.indexes, 1);
let _ = resp.uptime_secs;
assert_eq!(resp.embedder, "initializing");
}
#[tokio::test]
async fn graph_handler_exports_nodes_and_edges() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("graph-test");
let indexer = CodeIndexer::new("graph-test", "/tmp/graph-test");
indexer
.index_file(
"graph-test/lib.rs",
"fn callee() {}\nfn caller() { callee(); }\n",
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
"/tmp/graph-test".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let response = graph_handler(
State(state),
Path("graph-test".to_string()),
Query(GraphQueryParams::default()),
)
.await
.expect("handler ok");
assert_eq!(
response
.headers()
.get(axum::http::header::CACHE_CONTROL)
.and_then(|v| v.to_str().ok()),
Some("max-age=3600"),
);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("body bytes");
let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json body");
let nodes = value["nodes"].as_array().expect("nodes array");
assert_eq!(nodes.len(), 2, "two function symbols expected");
for node in nodes {
assert_eq!(node["type"].as_str(), Some("Symbol"));
assert!(node["id"].is_string());
assert!(node["label"].is_string());
assert!(node["metadata"]["file"].is_string());
}
let edges = value["edges"].as_array().expect("edges array");
assert_eq!(edges.len(), 1, "one CallsFunction edge expected");
assert_eq!(edges[0]["source"].as_str(), Some("caller"));
assert_eq!(edges[0]["target"].as_str(), Some("callee"));
assert_eq!(edges[0]["type"].as_str(), Some("CallsFunction"));
assert!(edges[0]["weight"].as_f64().is_some());
assert_eq!(value["stats"]["node_count"].as_u64(), Some(2));
assert_eq!(value["stats"]["edge_count"].as_u64(), Some(1));
assert!(value["generated_at"].is_string());
}
#[tokio::test]
async fn graph_handler_unknown_index_returns_404() {
use crate::core::registry::IndexRegistry;
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let err = graph_handler(
State(state),
Path("does-not-exist".to_string()),
Query(GraphQueryParams::default()),
)
.await
.expect_err("missing index must 404");
assert_eq!(err, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn graph_handler_filters_by_edge_type() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("graph-filter");
let indexer = CodeIndexer::new("graph-filter", "/tmp/graph-filter");
indexer
.index_file(
"graph-filter/lib.rs",
"fn callee() {}\nfn caller() { callee(); }\n",
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
"/tmp/graph-filter".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let response = graph_handler(
State(state),
Path("graph-filter".to_string()),
Query(GraphQueryParams {
types: None,
edge_types: Some("Implements".to_string()),
min_weight: None,
}),
)
.await
.expect("handler ok");
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("body bytes");
let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json body");
assert!(
value["edges"].as_array().expect("edges").is_empty(),
"CallsFunction edge must be filtered out",
);
assert_eq!(value["nodes"].as_array().expect("nodes").len(), 2);
}
#[tokio::test]
async fn global_search_fans_out_and_merges() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
for name in ["proj-a", "proj-b"] {
let id = IndexId::new(name);
let indexer = CodeIndexer::new(name, format!("/tmp/{name}"));
indexer
.index_file(
&format!("{name}/lib.rs"),
&format!("fn alpha_{name}() {{ println!(\"alpha hit\"); }}"),
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
format!("/tmp/{name}").into(),
));
}
let state = Arc::new(SearchAppState::new(registry));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "alpha".into(),
top_k: 10,
full_content: false,
indexes: None,
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
let total = value["total_indexes"].as_u64().expect("total_indexes");
assert_eq!(total, 2, "both indexes counted");
let searched: Vec<String> = value["indexes_searched"]
.as_array()
.expect("indexes_searched array")
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect();
assert_eq!(searched.len(), 2);
assert!(searched.contains(&"proj-a".to_string()));
assert!(searched.contains(&"proj-b".to_string()));
let results = value["results"].as_array().expect("results array");
assert!(!results.is_empty(), "expected at least one hit");
let mut from_a = false;
let mut from_b = false;
for r in results {
let idx = r["index_id"]
.as_str()
.expect("each result must be tagged with index_id");
assert!(
idx == "proj-a" || idx == "proj-b",
"unexpected index_id: {idx}"
);
from_a |= idx == "proj-a";
from_b |= idx == "proj-b";
}
assert!(from_a, "expected a result tagged with proj-a");
assert!(from_b, "expected a result tagged with proj-b");
}
#[tokio::test]
async fn global_search_empty_registry_returns_empty_results() {
use crate::core::registry::IndexRegistry;
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "anything".into(),
top_k: 5,
full_content: false,
indexes: None,
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
assert_eq!(value["total_indexes"].as_u64(), Some(0));
assert!(value["results"].as_array().unwrap().is_empty());
assert!(value["indexes_searched"].as_array().unwrap().is_empty());
}
#[tokio::test]
async fn global_search_restricts_to_named_indexes() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
for name in ["proj-a", "proj-b", "proj-c"] {
let id = IndexId::new(name);
let indexer = CodeIndexer::new(name, format!("/tmp/{name}"));
indexer
.index_file(
&format!("{name}/lib.rs"),
&format!("fn alpha_{name}() {{ println!(\"alpha hit\"); }}"),
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
format!("/tmp/{name}").into(),
));
}
let state = Arc::new(SearchAppState::new(registry));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "alpha".into(),
top_k: 10,
full_content: false,
indexes: Some(vec!["proj-a".into(), "proj-c".into()]),
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
assert_eq!(value["total_indexes"].as_u64(), Some(2));
let searched: std::collections::HashSet<String> = value["indexes_searched"]
.as_array()
.expect("array")
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect();
assert!(searched.contains("proj-a"));
assert!(searched.contains("proj-c"));
assert!(!searched.contains("proj-b"), "proj-b must be excluded");
for r in value["results"].as_array().unwrap() {
let idx = r["index_id"].as_str().unwrap();
assert_ne!(idx, "proj-b", "no result may come from excluded index");
}
}
#[test]
fn routing_mode_all_preserves_every_index_with_weights() {
let ids = vec![IndexId::new("a"), IndexId::new("b"), IndexId::new("c")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("a"), 0.9_f32),
(IndexId::new("b"), 0.2),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::All.apply(&ids, &weights);
assert_eq!(active.len(), 3, "all routing keeps every index");
assert!((map.get(&IndexId::new("a")).copied().unwrap() - 0.9).abs() < 1e-6);
assert!((map.get(&IndexId::new("b")).copied().unwrap() - 0.2).abs() < 1e-6);
assert!((map.get(&IndexId::new("c")).copied().unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn routing_mode_top_n_keeps_only_highest_similarity() {
let ids = vec![IndexId::new("low"), IndexId::new("hi"), IndexId::new("mid")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("low"), 0.1_f32),
(IndexId::new("hi"), 0.95),
(IndexId::new("mid"), 0.5),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::TopN(2).apply(&ids, &weights);
assert_eq!(active.len(), 2);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(active_set.contains("hi"));
assert!(active_set.contains("mid"));
assert!(!active_set.contains("low"));
assert!((map.get(&IndexId::new("hi")).copied().unwrap() - 1.0).abs() < 1e-6);
assert!((map.get(&IndexId::new("mid")).copied().unwrap() - 1.0).abs() < 1e-6);
assert!(!map.contains_key(&IndexId::new("low")));
}
#[test]
fn routing_mode_threshold_drops_below_cutoff() {
let ids = vec![IndexId::new("a"), IndexId::new("b"), IndexId::new("c")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("a"), 0.1_f32),
(IndexId::new("b"), 0.5),
(IndexId::new("c"), 0.8),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::Threshold(0.4).apply(&ids, &weights);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(!active_set.contains("a"), "0.1 < 0.4 must drop");
assert!(active_set.contains("b"), "0.5 >= 0.4 must keep");
assert!(active_set.contains("c"));
assert!(!map.contains_key(&IndexId::new("a")));
}
#[test]
fn routing_threshold_keeps_neutral_indexes() {
let ids = vec![IndexId::new("known"), IndexId::new("missing")];
let weights: std::collections::HashMap<IndexId, f32> =
[(IndexId::new("known"), 0.05_f32)].into_iter().collect();
let (active, _map) = RoutingMode::Threshold(0.5).apply(&ids, &weights);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(!active_set.contains("known"), "0.05 < 0.5 dropped");
assert!(
active_set.contains("missing"),
"indexes without a context embedding must use neutral 1.0 weight"
);
}
#[test]
fn routing_mode_from_request_resolves_strategy() {
let base =
|routing: Option<&str>, n: Option<usize>, t: Option<f32>| -> GlobalSearchRequest {
GlobalSearchRequest {
query: "x".into(),
top_k: 1,
full_content: false,
indexes: None,
routing: routing.map(|s| s.to_string()),
routing_n: n,
routing_threshold: t,
}
};
assert!(matches!(
RoutingMode::from_request(&base(None, None, None)),
RoutingMode::All
));
assert!(matches!(
RoutingMode::from_request(&base(Some("garbage"), None, None)),
RoutingMode::All
));
match RoutingMode::from_request(&base(Some("top_n"), Some(5), None)) {
RoutingMode::TopN(n) => assert_eq!(n, 5),
_ => panic!("expected TopN"),
}
match RoutingMode::from_request(&base(Some("top_n"), None, None)) {
RoutingMode::TopN(n) => assert_eq!(n, RoutingMode::DEFAULT_TOP_N),
_ => panic!("expected TopN default"),
}
match RoutingMode::from_request(&base(Some("threshold"), None, Some(0.7))) {
RoutingMode::Threshold(t) => assert!((t - 0.7).abs() < 1e-6),
_ => panic!("expected Threshold"),
}
}
#[tokio::test]
async fn install_embedder_error_surfaces_in_health() {
use crate::core::registry::IndexRegistry;
let state = SearchAppState::new(IndexRegistry::new());
state
.install_embedder_error("init timed out after 60s")
.await;
let state_arc = Arc::new(state);
let Json(resp) = health_handler(State(state_arc)).await;
assert_eq!(resp.embedder, "error");
assert_eq!(
resp.embedder_error.as_deref(),
Some("init timed out after 60s"),
);
}
#[tokio::test]
async fn create_index_returns_503_with_error_when_embedder_failed() {
use crate::core::registry::IndexRegistry;
use axum::body::to_bytes;
let state = SearchAppState::new(IndexRegistry::new());
state
.install_embedder_error("init timed out after 60s")
.await;
let state_arc = Arc::new(state);
let resp = create_index_handler(
State(state_arc),
Json(CreateIndexRequest {
id: "demo".to_string(),
root_path: std::path::PathBuf::from("/tmp/demo"),
include_paths: None,
exclude_globs: None,
extensions: None,
domain_terms: None,
path_filter: None,
}),
)
.await;
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let body_bytes = to_bytes(resp.into_body(), 64 * 1024)
.await
.expect("read body");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).expect("valid json");
let err_str = body
.get("error")
.and_then(|v| v.as_str())
.unwrap_or_default();
assert!(
err_str.contains("embedder init failed"),
"expected error message to mention init failure, got: {err_str}",
);
assert!(
err_str.contains("init timed out after 60s"),
"expected recorded timeout message to be surfaced, got: {err_str}",
);
}
#[tokio::test]
async fn install_embedder_clears_previous_error() {
use crate::core::embed::MockEmbedder;
use crate::core::registry::IndexRegistry;
let state = SearchAppState::new(IndexRegistry::new());
state.install_embedder_error("transient hang").await;
assert!(state.current_embedder_error().is_some());
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(8));
state.install_embedder(embedder).await;
assert!(state.current_embedder_error().is_none());
assert!(state.is_embedder_ready());
let state_arc = Arc::new(state);
let Json(resp) = health_handler(State(state_arc)).await;
assert_eq!(resp.embedder, "ready");
assert!(resp.embedder_error.is_none());
}
#[tokio::test]
async fn reindex_handler_rejects_within_cooldown() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("cooldown-test");
let tmp = tempfile::tempdir().expect("tempdir");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(CodeIndexer::new("cooldown-test", tmp.path()))),
tmp.path().to_path_buf(),
));
let state = Arc::new(SearchAppState::new(registry));
state
.last_reindex_aborted_at
.insert(id.clone(), std::time::Instant::now());
let result = reindex_handler(
State(Arc::clone(&state)),
axum::extract::Path("cooldown-test".to_string()),
None,
)
.await;
let err = result.expect_err("expected 429 inside cooldown window");
assert_eq!(err.0, StatusCode::TOO_MANY_REQUESTS);
let body = err.1 .0;
assert!(body.get("retry_after_secs").is_some());
assert!(body.get("hint").is_some());
assert_eq!(body["index_id"], "cooldown-test");
state.last_reindex_aborted_at.remove(&id);
let ok = reindex_handler(
State(Arc::clone(&state)),
axum::extract::Path("cooldown-test".to_string()),
None,
)
.await
.expect("queued");
assert_eq!(ok.0["queued"], serde_json::Value::Bool(true));
}
#[tokio::test]
async fn reindex_status_aborted_memory_serializes_lowercase() {
let status = crate::service::reindex::ReindexStatus::AbortedMemory;
let json = serde_json::to_string(&status).expect("serialize");
assert_eq!(json, "\"abortedmemory\"");
}
}