use crate::core::{
classifier::QueryClassifier,
embed::Embedder,
indexer::SearchQuery,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use axum::{
body::Body,
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Json, Redirect, Response},
routing::{delete, get, post},
Router,
};
use dashmap::DashMap;
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{broadcast, watch, OnceCell, RwLock};
use tokio_stream::wrappers::BroadcastStream;
use trusty_common::{ChatProvider, LocalModelConfig};
use crate::service::reindex::{spawn_reindex_with_cleanup, ReindexProgress, ReindexStatus};
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DaemonEvent {
StatusChanged {
indexes: u64,
total_chunks: u64,
uptime_secs: u64,
version: String,
},
IndexRegistered { id: String },
IndexRemoved { id: String },
}
#[derive(Clone)]
pub struct SearchAppState {
pub registry: IndexRegistry,
pub reindex_progress: Arc<DashMap<IndexId, Arc<ReindexProgress>>>,
pub last_reindex_aborted_at: Arc<DashMap<IndexId, std::time::Instant>>,
pub embedder: Option<Arc<dyn Embedder>>,
pub embedder_slot: Arc<RwLock<Option<Arc<dyn Embedder>>>>,
pub embedder_ready: watch::Receiver<bool>,
pub embedder_ready_tx: Arc<watch::Sender<bool>>,
pub embedder_error: Arc<RwLock<Option<String>>>,
pub daemon_port: Option<u16>,
pub openrouter_enabled: bool,
pub started_at: Instant,
pub local_model: LocalModelConfig,
pub openrouter_model: String,
pub openrouter_api_key: String,
pub chat_provider: Arc<OnceCell<Option<Arc<dyn ChatProvider>>>>,
pub events: Arc<broadcast::Sender<DaemonEvent>>,
pub log_buffer: trusty_common::log_buffer::LogBuffer,
pub disk_bytes: Arc<std::sync::atomic::AtomicU64>,
pub sys_metrics: Arc<tokio::sync::Mutex<trusty_common::sys_metrics::SysMetrics>>,
pub embed_pool: Arc<RwLock<Option<Arc<crate::service::embed_pool::EmbedPool>>>>,
pub metrics: Option<crate::service::metrics::MetricsState>,
pub graph_scorers: Arc<DashMap<IndexId, Arc<crate::core::indexer::graph_score::GraphScorer>>>,
}
impl SearchAppState {
pub fn new(registry: IndexRegistry) -> Self {
let openrouter_api_key = std::env::var("OPENROUTER_API_KEY").unwrap_or_default();
let (events_tx, _) = broadcast::channel::<DaemonEvent>(128);
let (ready_tx, ready_rx) = watch::channel(false);
Self {
registry,
reindex_progress: Arc::new(DashMap::new()),
last_reindex_aborted_at: Arc::new(DashMap::new()),
embedder: None,
embedder_slot: Arc::new(RwLock::new(None)),
embedder_ready: ready_rx,
embedder_ready_tx: Arc::new(ready_tx),
embedder_error: Arc::new(RwLock::new(None)),
daemon_port: None,
openrouter_enabled: !openrouter_api_key.is_empty(),
started_at: Instant::now(),
local_model: LocalModelConfig::default(),
openrouter_model: "anthropic/claude-haiku-4.5".to_string(),
openrouter_api_key,
chat_provider: Arc::new(OnceCell::new()),
events: Arc::new(events_tx),
log_buffer: trusty_common::log_buffer::LogBuffer::new(
trusty_common::log_buffer::DEFAULT_LOG_CAPACITY,
),
disk_bytes: Arc::new(std::sync::atomic::AtomicU64::new(0)),
sys_metrics: Arc::new(tokio::sync::Mutex::new(
trusty_common::sys_metrics::SysMetrics::new(),
)),
embed_pool: Arc::new(RwLock::new(None)),
metrics: None,
graph_scorers: Arc::new(DashMap::new()),
}
}
pub async fn graph_scorer(
&self,
index_id: &IndexId,
) -> Option<Arc<crate::core::indexer::graph_score::GraphScorer>> {
if let Some(scorer) = self.graph_scorers.get(index_id) {
return Some(scorer.clone());
}
let handle = self.registry.get(index_id)?;
let indexer = handle.indexer.read().await;
let graph = indexer.symbol_graph().await;
if graph.node_count() == 0 {
return None;
}
let corpus = indexer.corpus_arc()?;
drop(indexer);
let communities = match corpus.load_communities() {
Ok(rows) => rows
.into_iter()
.filter_map(|(_, bytes)| {
serde_json::from_slice::<crate::core::community::CommunityRecord>(&bytes).ok()
})
.collect::<Vec<_>>(),
Err(e) => {
tracing::debug!("graph_scorer: failed to load communities for '{index_id}': {e}");
Vec::new()
}
};
let scorer = Arc::new(crate::core::indexer::graph_score::GraphScorer::build(
&graph,
&communities,
));
self.graph_scorers
.insert(index_id.clone(), Arc::clone(&scorer));
Some(scorer)
}
pub fn invalidate_graph_scorer(&self, index_id: &IndexId) {
self.graph_scorers.remove(index_id);
}
#[must_use]
pub fn with_embed_pool(self, pool: Arc<crate::service::embed_pool::EmbedPool>) -> Self {
if let Ok(mut slot) = self.embed_pool.try_write() {
*slot = Some(pool);
}
self
}
#[must_use]
pub fn with_metrics(mut self, metrics: crate::service::metrics::MetricsState) -> Self {
self.metrics = Some(metrics);
self
}
pub async fn install_embed_pool(&self, pool: Arc<crate::service::embed_pool::EmbedPool>) {
let mut slot = self.embed_pool.write().await;
*slot = Some(pool);
}
pub async fn current_embed_pool(&self) -> Option<Arc<crate::service::embed_pool::EmbedPool>> {
self.embed_pool.read().await.clone()
}
#[must_use]
pub fn with_log_buffer(mut self, buffer: trusty_common::log_buffer::LogBuffer) -> Self {
self.log_buffer = buffer;
self
}
pub fn emit(&self, event: DaemonEvent) {
let _ = self.events.send(event);
}
pub fn with_local_model(mut self, cfg: LocalModelConfig) -> Self {
self.local_model = cfg;
self
}
pub fn with_openrouter_model(mut self, model: impl Into<String>) -> Self {
self.openrouter_model = model.into();
self
}
pub fn with_openrouter_api_key(mut self, api_key: impl Into<String>) -> Self {
let api_key_str = api_key.into();
self.openrouter_enabled = !api_key_str.is_empty();
self.openrouter_api_key = api_key_str;
self
}
pub async fn chat_provider(&self) -> Option<Arc<dyn ChatProvider>> {
self.chat_provider
.get_or_init(|| async {
if self.local_model.enabled {
if let Some(mut p) =
trusty_common::auto_detect_local_provider(&self.local_model.base_url).await
{
p.model = self.local_model.model.clone();
return Some(Arc::new(p) as Arc<dyn ChatProvider>);
}
}
if !self.openrouter_api_key.is_empty() {
return Some(Arc::new(trusty_common::OpenRouterProvider::new(
self.openrouter_api_key.clone(),
self.openrouter_model.clone(),
)) as Arc<dyn ChatProvider>);
}
None
})
.await
.clone()
}
pub fn with_daemon_port(mut self, port: u16) -> Self {
self.daemon_port = Some(port);
self
}
pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
self.embedder = Some(Arc::clone(&embedder));
if let Ok(mut slot) = self.embedder_slot.try_write() {
*slot = Some(embedder);
}
let _ = self.embedder_ready_tx.send(true);
self
}
pub async fn install_embedder(&self, embedder: Arc<dyn Embedder>) {
let mut slot = self.embedder_slot.write().await;
*slot = Some(embedder);
drop(slot);
{
let mut err = self.embedder_error.write().await;
*err = None;
}
let _ = self.embedder_ready_tx.send(true);
}
pub async fn install_embedder_error(&self, message: impl Into<String>) {
let msg = message.into();
tracing::error!("embedder init failed: {msg}");
let mut err = self.embedder_error.write().await;
*err = Some(msg);
}
pub fn current_embedder_error(&self) -> Option<String> {
self.embedder_error.try_read().ok().and_then(|g| g.clone())
}
pub async fn current_embedder(&self) -> Option<Arc<dyn Embedder>> {
let slot = self.embedder_slot.read().await;
slot.clone()
}
pub fn is_embedder_ready(&self) -> bool {
*self.embedder_ready.borrow()
}
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
version: &'static str,
indexes: usize,
uptime_secs: u64,
embedder: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
embedder_error: Option<String>,
rss_mb: u64,
rss_limit_mb: u64,
disk_bytes: u64,
cpu_pct: f32,
#[serde(skip_serializing_if = "Option::is_none")]
embedder_info: Option<EmbedderInfo>,
}
#[derive(Serialize)]
struct EmbedderInfo {
dimension: usize,
provider: String,
quantized: bool,
}
#[derive(Serialize)]
struct IndexListResponse {
indexes: Vec<String>,
}
#[derive(Deserialize)]
pub struct CreateIndexRequest {
pub id: String,
pub root_path: std::path::PathBuf,
#[serde(default)]
pub include_paths: Option<Vec<String>>,
#[serde(default)]
pub exclude_globs: Option<Vec<String>>,
#[serde(default)]
pub extensions: Option<Vec<String>>,
#[serde(default)]
pub domain_terms: Option<Vec<String>>,
#[serde(default)]
pub path_filter: Option<Vec<String>>,
}
#[derive(Deserialize)]
pub struct IndexFileRequest {
pub path: String,
pub content: String,
}
#[derive(Deserialize)]
pub struct RemoveFileRequest {
pub path: String,
}
pub fn build_router(state: SearchAppState) -> Router {
use crate::service::ui::{
chat_handler, list_chat_providers, ui_asset_handler, ui_index_handler,
};
let state_arc = Arc::new(state);
spawn_status_ticker(Arc::clone(&state_arc));
spawn_disk_size_ticker(Arc::clone(&state_arc));
let limiter = crate::service::concurrency::ConcurrencyLimiter::from_env();
let limited = Router::new()
.route("/search", post(global_search_handler))
.route("/indexes/{id}/search", post(search_handler))
.route("/indexes/{id}/search_similar", post(search_similar_handler))
.route("/indexes/{id}/index-file", post(index_file_handler))
.route("/indexes/{id}/remove-file", post(remove_file_handler))
.route("/indexes/{id}/reindex", post(reindex_handler))
.route_layer(axum::middleware::from_fn(
crate::service::concurrency::apply_limiter,
))
.layer(axum::Extension(Arc::clone(&limiter)))
.with_state(Arc::clone(&state_arc));
let free = Router::new()
.route("/", get(|| async { Redirect::permanent("/ui/") }))
.route("/health", get(health_handler))
.route("/logs/tail", get(logs_tail_handler))
.route("/admin/stop", post(admin_stop_handler))
.route("/status/stream", get(status_stream_handler))
.route(
"/indexes",
get(list_indexes_handler).post(create_index_handler),
)
.route("/indexes/{id}", delete(delete_index_handler))
.route("/ui", get(|| async { Redirect::permanent("/ui/") }))
.route("/ui/", get(ui_index_handler))
.route("/ui/{*path}", get(ui_asset_handler))
.route("/chat", post(chat_handler))
.route("/api/chat/providers", get(list_chat_providers))
.route("/indexes/{id}/status", get(index_status_handler))
.route("/indexes/{id}/graph", get(graph_handler))
.route("/indexes/{id}/graph/stats", get(graph_stats_handler))
.route("/indexes/{id}/communities", get(communities_handler))
.route(
"/indexes/{id}/communities/{symbol}",
get(community_for_symbol_handler),
)
.route("/indexes/{id}/reindex/stream", get(reindex_stream_handler))
.route("/indexes/{id}/chunks", get(get_index_chunks_handler))
.route(
"/config",
get(get_config_handler).patch(patch_config_handler),
)
.with_state(Arc::clone(&state_arc));
let mut router = free.merge(limited);
if let Some(metrics_state) = state_arc.metrics.clone() {
router = router
.route("/metrics", get(crate::service::metrics::metrics_handler))
.layer(axum::Extension(metrics_state));
}
router = router.layer(axum::middleware::from_fn(
crate::service::metrics::request_metrics_middleware,
));
trusty_common::server::with_standard_middleware(router)
}
fn spawn_status_ticker(state: Arc<SearchAppState>) {
let weak = Arc::downgrade(&state);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(2));
interval.tick().await;
loop {
interval.tick().await;
let Some(state) = weak.upgrade() else {
break;
};
let (indexes, total_chunks) = collect_status_counts(&state).await;
state.emit(DaemonEvent::StatusChanged {
indexes: indexes as u64,
total_chunks: total_chunks as u64,
uptime_secs: state.started_at.elapsed().as_secs(),
version: env!("CARGO_PKG_VERSION").to_string(),
});
}
});
}
fn spawn_disk_size_ticker(state: Arc<SearchAppState>) {
let weak = Arc::downgrade(&state);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let Some(state) = weak.upgrade() else {
break;
};
let bytes =
tokio::task::spawn_blocking(|| match crate::service::persistence::data_dir() {
Ok(dir) => trusty_common::sys_metrics::dir_size_bytes(&dir),
Err(e) => {
tracing::debug!("disk_size_ticker: could not resolve data dir: {e}");
0
}
})
.await
.unwrap_or(0);
state
.disk_bytes
.store(bytes, std::sync::atomic::Ordering::Relaxed);
}
});
}
async fn health_handler(State(state): State<Arc<SearchAppState>>) -> Json<HealthResponse> {
let embedder_error = state.current_embedder_error();
let embedder_status = if state.is_embedder_ready() {
"ready"
} else if state.embedder.is_some()
|| state
.embedder_slot
.try_read()
.map(|g| g.is_some())
.unwrap_or(false)
{
"ready"
} else if embedder_error.is_some() {
"error"
} else {
"initializing"
};
let (rss_mb, cpu_pct) = {
let mut metrics = state.sys_metrics.lock().await;
metrics.sample()
};
let rss_limit_mb = crate::core::memguard::memory_limit_mb().unwrap_or(0);
let disk_bytes = state.disk_bytes.load(std::sync::atomic::Ordering::Relaxed);
let embedder_info = state.current_embedder().await.map(|e| {
let dimension = e.dimension();
EmbedderInfo {
dimension,
provider: e.provider().as_str().to_string(),
quantized: dimension == trusty_common::embedder::EMBED_DIM,
}
});
Json(HealthResponse {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
indexes: state.registry.list().len(),
uptime_secs: state.started_at.elapsed().as_secs(),
embedder: embedder_status,
embedder_error,
rss_mb,
rss_limit_mb,
disk_bytes,
cpu_pct,
embedder_info,
})
}
#[derive(Deserialize)]
pub struct LogsTailParams {
#[serde(default = "default_logs_tail_n")]
pub n: usize,
}
const DEFAULT_LOGS_TAIL_N: usize = 100;
const MAX_LOGS_TAIL_N: usize = trusty_common::log_buffer::DEFAULT_LOG_CAPACITY;
fn default_logs_tail_n() -> usize {
DEFAULT_LOGS_TAIL_N
}
async fn logs_tail_handler(
State(state): State<Arc<SearchAppState>>,
Query(params): Query<LogsTailParams>,
) -> Json<serde_json::Value> {
let n = params.n.clamp(1, MAX_LOGS_TAIL_N);
let lines = state.log_buffer.tail(n);
Json(serde_json::json!({
"lines": lines,
"total": state.log_buffer.len(),
}))
}
async fn admin_stop_handler(State(_state): State<Arc<SearchAppState>>) -> Json<serde_json::Value> {
tracing::warn!("admin_stop: shutdown requested via POST /admin/stop");
tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(200)).await;
std::process::exit(0);
});
Json(serde_json::json!({ "ok": true, "message": "shutting down" }))
}
#[derive(Debug, Deserialize, Default)]
struct PatchConfigRequest {
#[serde(default, deserialize_with = "deserialize_optional_option_u64")]
memory_limit_mb: Option<Option<u64>>,
#[serde(default, deserialize_with = "deserialize_optional_option_u64")]
index_memory_limit_mb: Option<Option<u64>>,
}
#[derive(Debug, Serialize)]
struct ConfigResponse {
memory_limit_mb: Option<u64>,
index_memory_limit_mb: Option<u64>,
}
fn deserialize_optional_option_u64<'de, D>(deserializer: D) -> Result<Option<Option<u64>>, D::Error>
where
D: serde::Deserializer<'de>,
{
let v = Option::<u64>::deserialize(deserializer)?;
Ok(Some(v))
}
async fn get_config_handler(State(_state): State<Arc<SearchAppState>>) -> Json<ConfigResponse> {
use crate::core::memguard::{index_memory_limit_mb, memory_limit_mb};
Json(ConfigResponse {
memory_limit_mb: memory_limit_mb(),
index_memory_limit_mb: index_memory_limit_mb(),
})
}
async fn patch_config_handler(
State(_state): State<Arc<SearchAppState>>,
Json(req): Json<PatchConfigRequest>,
) -> Json<ConfigResponse> {
use crate::core::memguard::{
index_memory_limit_mb, memory_limit_mb, set_index_memory_limit_mb, set_memory_limit_mb,
};
let fmt = |v: Option<u64>| match v {
Some(mb) => mb.to_string(),
None => "unlimited".to_string(),
};
if let Some(new) = req.memory_limit_mb {
let before = memory_limit_mb();
set_memory_limit_mb(new);
let after = memory_limit_mb();
tracing::info!(
"config updated: memory_limit_mb {} → {}",
fmt(before),
fmt(after)
);
}
if let Some(new) = req.index_memory_limit_mb {
let before = index_memory_limit_mb();
set_index_memory_limit_mb(new);
let after = index_memory_limit_mb();
tracing::info!(
"config updated: index_memory_limit_mb {} → {}",
fmt(before),
fmt(after)
);
}
Json(ConfigResponse {
memory_limit_mb: memory_limit_mb(),
index_memory_limit_mb: index_memory_limit_mb(),
})
}
async fn collect_status_counts(state: &SearchAppState) -> (usize, usize) {
let ids = state.registry.list();
let indexes_count = ids.len();
let mut total_chunks: usize = 0;
for id in ids {
if let Some(handle) = state.registry.get(&id) {
let indexer = handle.indexer.read().await;
total_chunks = total_chunks.saturating_add(indexer.chunk_count());
}
}
(indexes_count, total_chunks)
}
async fn status_stream_handler(State(state): State<Arc<SearchAppState>>) -> impl IntoResponse {
let rx = state.events.subscribe();
let initial = stream::once(async {
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(
"data: {\"type\":\"connected\"}\n\n",
))
});
let events = BroadcastStream::new(rx).map(|res| {
let frame = match res {
Ok(event) => match serde_json::to_string(&event) {
Ok(json) => format!("data: {json}\n\n"),
Err(e) => format!("data: {{\"type\":\"error\",\"message\":\"{e}\"}}\n\n"),
},
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")
}
};
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(frame))
});
let stream = initial.chain(events);
Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(Body::from_stream(stream))
.expect("valid SSE response")
}
async fn list_indexes_handler(State(state): State<Arc<SearchAppState>>) -> Json<IndexListResponse> {
Json(IndexListResponse {
indexes: state.registry.list().into_iter().map(|id| id.0).collect(),
})
}
async fn create_index_handler(
State(state): State<Arc<SearchAppState>>,
Json(req): Json<CreateIndexRequest>,
) -> Response {
let id = IndexId::new(req.id.clone());
if let Err(resp) = validate_root_path(&req.root_path) {
return resp;
}
if state.registry.get(&id).is_some() {
return Json(serde_json::json!({
"id": req.id,
"created": false,
"reason": "already exists",
}))
.into_response();
}
let Some(embedder) = state.current_embedder().await else {
if let Some(err) = state.current_embedder_error() {
return embedder_error_response(&err);
}
return embedder_initializing_response();
};
let mut indexer = crate::service::persistence_loader::build_indexer_with_persisted_state(
&req.id,
req.root_path.clone(),
&embedder,
)
.await;
let include_paths: Vec<std::path::PathBuf> = req
.include_paths
.clone()
.unwrap_or_default()
.into_iter()
.filter(|p| !p.trim().is_empty() && p.trim() != ".")
.map(|p| req.root_path.join(p.trim()))
.collect();
let exclude_globs: Vec<String> = req.exclude_globs.clone().unwrap_or_default();
let extensions: Vec<String> = req
.extensions
.clone()
.unwrap_or_default()
.into_iter()
.map(|e| e.trim_start_matches('.').to_string())
.filter(|e| !e.is_empty())
.collect();
let domain_terms: Vec<String> = req.domain_terms.clone().unwrap_or_default();
let path_filter: Vec<String> = req
.path_filter
.clone()
.unwrap_or_default()
.into_iter()
.filter(|p| !p.trim().is_empty())
.collect();
indexer.set_domain_terms(domain_terms.clone());
if let Err(e) = crate::service::persistence::upsert_index_registry_entry(
crate::service::persistence::PersistedIndex {
id: req.id.clone(),
root_path: req.root_path.clone(),
include_paths: req.include_paths.clone().unwrap_or_default(),
exclude_globs: exclude_globs.clone(),
extensions: extensions.clone(),
domain_terms: domain_terms.clone(),
path_filter: path_filter.clone(),
},
) {
tracing::warn!("could not persist index registry for {}: {e}", req.id);
}
let handle = IndexHandle {
id: id.clone(),
indexer: Arc::new(tokio::sync::RwLock::new(indexer)),
root_path: req.root_path,
include_paths,
exclude_globs,
extensions,
domain_terms,
path_filter,
context_embedding: Arc::new(tokio::sync::RwLock::new(None)),
context_summary: Arc::new(tokio::sync::RwLock::new(None)),
};
state.registry.register(handle);
crate::service::metrics::set_index_count(state.registry.list().len());
state.emit(DaemonEvent::IndexRegistered { id: req.id.clone() });
Json(serde_json::json!({ "id": req.id, "created": true })).into_response()
}
fn validate_root_path(path: &std::path::Path) -> Result<(), Response> {
if path.as_os_str().is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "root_path is required and must not be empty"
})),
)
.into_response());
}
if !path.is_absolute() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!(
"root_path must be absolute (got {:?}); relative paths \
would be resolved against the daemon's CWD which is \
not the caller's CWD",
path.display().to_string()
),
})),
)
.into_response());
}
if !path.is_dir() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!(
"root_path {:?} does not exist or is not a directory",
path.display().to_string()
),
})),
)
.into_response());
}
Ok(())
}
fn file_is_within_root(file: &str, root: &std::path::Path) -> bool {
let p = std::path::Path::new(file);
if p.is_absolute() {
return p.starts_with(root);
}
if file.is_empty() {
return false;
}
!p.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
}
fn embedder_initializing_response() -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "embedder initializing, retry in a few seconds"
})),
)
.into_response()
}
fn embedder_error_response(message: &str) -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": format!("embedder init failed: {message}"),
})),
)
.into_response()
}
async fn delete_index_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Json<serde_json::Value> {
let index_id = IndexId::new(id.clone());
let removed = state.registry.unregister(&index_id);
state.reindex_progress.remove(&index_id);
state.invalidate_graph_scorer(&index_id);
if removed {
if let Err(e) = crate::service::persistence::remove_index_registry_entry(&id) {
tracing::warn!("could not remove '{id}' from indexes.toml: {e}");
}
if let Err(e) = crate::service::persistence::remove_index_data_dir(&id) {
tracing::warn!("could not remove on-disk data for '{id}': {e}");
}
state.emit(DaemonEvent::IndexRemoved { id: id.clone() });
crate::service::metrics::set_index_count(state.registry.list().len());
}
Json(serde_json::json!({ "id": id, "removed": removed }))
}
async fn search_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(query): Json<SearchQuery>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let intent = QueryClassifier::classify_with_domain(&query.text, &handle.domain_terms);
let started = std::time::Instant::now();
let indexer = handle.indexer.read().await;
let mut results = indexer
.search(&query)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let root = handle.root_path.clone();
let before = results.len();
results.retain(|r| file_is_within_root(&r.file, &root));
let filtered_out = before.saturating_sub(results.len());
if filtered_out > 0 {
tracing::warn!(
index_id = %index_id,
root = %root.display(),
dropped = filtered_out,
"search_handler: dropped {} result(s) whose file path falls outside the \
index root (likely stale data from a misregistered index — see #63/#64)",
filtered_out,
);
}
let graph_snapshot = indexer.symbol_graph().await;
drop(indexer);
let (graph_scoring, community_cohesion) = match state.graph_scorer(&index_id).await {
Some(scorer) => {
for result in results.iter_mut() {
if let Some(sym) = graph_snapshot.symbol_for_chunk(&result.id) {
result.score += scorer.bonus(sym);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_n = results.iter().take(10).collect::<Vec<_>>();
let cohesion = if let Some(head) = top_n.first() {
if let Some(head_sym) = graph_snapshot.symbol_for_chunk(&head.id) {
let total = top_n.len() as f32;
let matches = top_n
.iter()
.filter(|r| {
graph_snapshot
.symbol_for_chunk(&r.id)
.map(|s| scorer.same_community(head_sym, s))
.unwrap_or(false)
})
.count() as f32;
if total > 0.0 {
matches / total
} else {
0.0
}
} else {
0.0
}
} else {
0.0
};
(true, cohesion)
}
None => (false, 0.0_f32),
};
let latency_ms = started.elapsed().as_millis() as u64;
tracing::info!(
index_id = %index_id,
intent = %format!("{intent:?}"),
latency_ms = latency_ms,
results = results.len(),
query = %&query.text[..query.text.len().min(80)],
"search"
);
Ok(Json(serde_json::json!({
"results": results,
"intent": format!("{:?}", intent),
"latency_ms": latency_ms,
"meta": {
"graph_scoring": graph_scoring,
"community_cohesion": community_cohesion,
},
})))
}
#[derive(Deserialize)]
pub struct GlobalSearchRequest {
pub query: String,
#[serde(default = "default_global_top_k")]
pub top_k: usize,
#[serde(default)]
pub full_content: bool,
#[serde(default)]
pub indexes: Option<Vec<String>>,
#[serde(default)]
pub routing: Option<String>,
#[serde(default)]
pub routing_n: Option<usize>,
#[serde(default)]
pub routing_threshold: Option<f32>,
}
fn default_global_top_k() -> usize {
10
}
async fn global_search_handler(
State(state): State<Arc<SearchAppState>>,
Json(req): Json<GlobalSearchRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
use crate::core::search::rrf::{rrf_fuse, RRF_K};
let all_ids = state.registry.list();
let index_ids: Vec<IndexId> = if let Some(requested) = req.indexes.as_ref() {
let allow: std::collections::HashSet<&str> = requested.iter().map(|s| s.as_str()).collect();
all_ids
.into_iter()
.filter(|id| allow.contains(id.0.as_str()))
.collect()
} else {
all_ids
};
let total_indexes = index_ids.len();
if index_ids.is_empty() {
return Ok(Json(serde_json::json!({
"results": Vec::<crate::core::indexer::CodeChunk>::new(),
"indexes_searched": Vec::<String>::new(),
"total_indexes": 0_usize,
"latency_ms": 0_u64,
"intent": format!("{:?}", QueryClassifier::classify(&req.query)),
})));
}
let started = std::time::Instant::now();
let intent = QueryClassifier::classify(&req.query);
let routing_mode = RoutingMode::from_request(&req);
let weights = compute_context_weights(&state.registry, &index_ids, &req.query).await;
let (active_ids, weight_map) = routing_mode.apply(&index_ids, &weights);
let routing_label = routing_mode.label().to_string();
let routing_decisions: Vec<serde_json::Value> = index_ids
.iter()
.map(|id| {
let w = weights.get(id).copied().unwrap_or(1.0);
let included = weight_map.contains_key(id);
serde_json::json!({
"index_id": id.0,
"cosine_similarity": w,
"included": included,
})
})
.collect();
let per_index_query = SearchQuery {
text: req.query.clone(),
top_k: req.top_k,
expand_graph: true,
compact: !req.full_content,
branch_files: None,
branch_boost: SearchQuery::default_branch_boost(),
branch: None,
};
let registry = state.registry.clone();
let futures = active_ids.into_iter().map(|id| {
let registry = registry.clone();
let query = per_index_query.clone();
async move {
let handle = registry.get(&id)?;
let indexer = handle.indexer.read().await;
match indexer.search(&query).await {
Ok(results) => Some((id, results)),
Err(e) => {
tracing::warn!("global search: index {} errored: {e}", id);
None
}
}
}
});
let per_index_results: Vec<(IndexId, Vec<crate::core::indexer::CodeChunk>)> =
futures::future::join_all(futures)
.await
.into_iter()
.flatten()
.collect();
let mut chunk_lookup: std::collections::HashMap<String, crate::core::indexer::CodeChunk> =
std::collections::HashMap::new();
let mut lanes: Vec<Vec<(String, f32)>> = Vec::with_capacity(per_index_results.len());
let mut indexes_searched: Vec<String> = Vec::with_capacity(per_index_results.len());
for (id, results) in per_index_results {
indexes_searched.push(id.0.clone());
let weight = weight_map.get(&id).copied().unwrap_or(1.0);
let mut lane: Vec<(String, f32)> = Vec::with_capacity(results.len());
for mut chunk in results {
let namespaced = format!("{}::{}", id.0, chunk.id);
chunk.index_id = Some(id.0.clone());
let weighted_score = chunk.score * weight;
lane.push((namespaced.clone(), weighted_score));
chunk_lookup.insert(namespaced, chunk);
}
lanes.push(lane);
}
let mut fused: Vec<(String, f32)> = Vec::new();
let oversample = req.top_k.saturating_mul(4).max(req.top_k).max(10);
for lane in lanes {
fused = rrf_fuse(&fused, &lane, 1.0, 1.0, RRF_K, oversample);
}
fused.truncate(req.top_k);
let results: Vec<crate::core::indexer::CodeChunk> = fused
.into_iter()
.filter_map(|(id, fused_score)| {
let mut chunk = chunk_lookup.remove(&id)?;
chunk.score = fused_score;
Some(chunk)
})
.collect();
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"indexes_searched": indexes_searched,
"total_indexes": total_indexes,
"latency_ms": latency_ms,
"intent": format!("{:?}", intent),
"routing": routing_label,
"routing_decisions": routing_decisions,
})))
}
#[derive(Debug, Clone, Copy)]
enum RoutingMode {
All,
TopN(usize),
Threshold(f32),
}
impl RoutingMode {
const DEFAULT_TOP_N: usize = 3;
const DEFAULT_THRESHOLD: f32 = 0.3;
fn from_request(req: &GlobalSearchRequest) -> Self {
match req.routing.as_deref() {
Some("top_n") => Self::TopN(req.routing_n.unwrap_or(Self::DEFAULT_TOP_N).max(1)),
Some("threshold") => {
Self::Threshold(req.routing_threshold.unwrap_or(Self::DEFAULT_THRESHOLD))
}
_ => Self::All,
}
}
fn label(self) -> &'static str {
match self {
Self::All => "all",
Self::TopN(_) => "top_n",
Self::Threshold(_) => "threshold",
}
}
fn apply(
self,
index_ids: &[IndexId],
weights: &std::collections::HashMap<IndexId, f32>,
) -> (Vec<IndexId>, std::collections::HashMap<IndexId, f32>) {
match self {
Self::All => {
let active: Vec<IndexId> = index_ids.to_vec();
let map: std::collections::HashMap<IndexId, f32> = index_ids
.iter()
.map(|id| (id.clone(), weights.get(id).copied().unwrap_or(1.0)))
.collect();
(active, map)
}
Self::TopN(n) => {
let mut ranked: Vec<(&IndexId, f32)> = index_ids
.iter()
.map(|id| (id, weights.get(id).copied().unwrap_or(1.0)))
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let active: Vec<IndexId> =
ranked.iter().take(n).map(|(id, _)| (*id).clone()).collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
Self::Threshold(t) => {
let active: Vec<IndexId> = index_ids
.iter()
.filter(|id| weights.get(id).copied().unwrap_or(1.0) >= t)
.cloned()
.collect();
let map: std::collections::HashMap<IndexId, f32> =
active.iter().map(|id| (id.clone(), 1.0)).collect();
(active, map)
}
}
}
}
async fn compute_context_weights(
registry: &crate::core::registry::IndexRegistry,
index_ids: &[IndexId],
query: &str,
) -> std::collections::HashMap<IndexId, f32> {
use crate::core::mmr::cosine_similarity;
let mut query_embedding: Option<Vec<f32>> = None;
for id in index_ids {
let Some(handle) = registry.get(id) else {
continue;
};
let indexer = handle.indexer.read().await;
match indexer.embed_text(query).await {
Ok(Some(vec)) => {
query_embedding = Some(vec);
break;
}
Ok(None) => continue,
Err(e) => {
tracing::debug!("context_routing: embed_text failed on {}: {e}", id.0);
continue;
}
}
}
let mut out = std::collections::HashMap::with_capacity(index_ids.len());
let Some(q) = query_embedding else {
for id in index_ids {
out.insert(id.clone(), 1.0);
}
return out;
};
for id in index_ids {
let Some(handle) = registry.get(id) else {
out.insert(id.clone(), 1.0);
continue;
};
let ctx_guard = handle.context_embedding.read().await;
let weight = match ctx_guard.as_ref() {
Some(ctx) if ctx.len() == q.len() => cosine_similarity(&q, ctx).max(0.0),
_ => 1.0,
};
out.insert(id.clone(), weight);
}
out
}
#[derive(Deserialize)]
pub struct SearchSimilarRequest {
pub file: String,
#[serde(default)]
pub function: Option<String>,
#[serde(default = "default_similar_top_k")]
pub top_k: usize,
}
fn default_similar_top_k() -> usize {
10
}
async fn search_similar_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<SearchSimilarRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let started = std::time::Instant::now();
let indexer = handle.indexer.read().await;
let chunk_id = indexer
.find_chunk_id(&req.file, req.function.as_deref())
.await
.ok_or(StatusCode::NOT_FOUND)?;
let embedding = indexer
.get_embedding(&chunk_id)
.ok_or(StatusCode::NOT_FOUND)?;
let results = indexer
.similar_by_embedding(&embedding, req.top_k, Some(&chunk_id))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let latency_ms = started.elapsed().as_millis() as u64;
Ok(Json(serde_json::json!({
"results": results,
"seed_chunk_id": chunk_id,
"latency_ms": latency_ms,
})))
}
fn index_disk_and_mtime(index_id: &str) -> (Option<u64>, Option<String>) {
let Ok(dir) = crate::service::persistence::index_data_dir(index_id) else {
return (None, None);
};
if !dir.exists() {
return (None, None);
}
let disk_bytes = Some(trusty_common::sys_metrics::dir_size_bytes(&dir));
let last_indexed = crate::service::persistence::chunks_path(index_id)
.ok()
.and_then(|p| std::fs::metadata(&p).ok())
.and_then(|m| m.modified().ok())
.map(|t| {
let dt: chrono::DateTime<chrono::Utc> = t.into();
dt.to_rfc3339()
});
(disk_bytes, last_indexed)
}
async fn index_status_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
let path_filter = if handle.path_filter.is_empty() {
serde_json::Value::Null
} else {
serde_json::Value::Array(
handle
.path_filter
.iter()
.map(|s| serde_json::Value::String(s.clone()))
.collect(),
)
};
let has_context_embedding = handle.context_embedding.read().await.is_some();
let context_summary = handle
.context_summary
.read()
.await
.clone()
.map(serde_json::Value::String)
.unwrap_or(serde_json::Value::Null);
let (disk_bytes, last_indexed) = index_disk_and_mtime(&index_id.0);
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"root_path": handle.root_path,
"chunk_count": indexer.chunk_count(),
"path_filter": path_filter,
"has_context_embedding": has_context_embedding,
"context_summary": context_summary,
"disk_bytes": disk_bytes,
"last_indexed": last_indexed,
})))
}
#[derive(Debug, Default, serde::Deserialize)]
struct GraphQueryParams {
types: Option<String>,
edge_types: Option<String>,
min_weight: Option<f32>,
}
fn parse_filter_set(raw: Option<&str>) -> Option<std::collections::HashSet<String>> {
let raw = raw?;
let set: std::collections::HashSet<String> = raw
.split(',')
.map(|s| s.trim().to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect();
if set.is_empty() {
None
} else {
Some(set)
}
}
fn node_type_for_symbol(symbol: &str) -> &'static str {
let looks_like_path = symbol.contains('/')
&& std::path::Path::new(symbol)
.extension()
.is_some_and(|e| !e.is_empty());
if looks_like_path {
"File"
} else {
"Symbol"
}
}
async fn graph_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Query(params): Query<GraphQueryParams>,
) -> Result<Response, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let graph = {
let indexer = handle.indexer.read().await;
indexer.snapshot_symbol_graph().await
};
let type_filter = parse_filter_set(params.types.as_deref());
let edge_filter = parse_filter_set(params.edge_types.as_deref());
let min_weight = params.min_weight.unwrap_or(f32::MIN);
let mut kept_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut nodes: Vec<serde_json::Value> = Vec::new();
for (symbol, chunk_id, file) in graph.all_nodes() {
let node_type = node_type_for_symbol(&symbol);
if let Some(ref filter) = type_filter {
if !filter.contains(&node_type.to_ascii_lowercase()) {
continue;
}
}
kept_symbols.insert(symbol.clone());
nodes.push(serde_json::json!({
"id": chunk_id,
"type": node_type,
"label": symbol,
"metadata": { "file": file, "symbol": symbol },
}));
}
let mut edges: Vec<serde_json::Value> = Vec::new();
for (source, target, kind) in graph.all_edges() {
if type_filter.is_some()
&& (!kept_symbols.contains(&source) || !kept_symbols.contains(&target))
{
continue;
}
let kind_name = format!("{kind:?}");
if let Some(ref filter) = edge_filter {
if !filter.contains(&kind_name.to_ascii_lowercase()) {
continue;
}
}
let weight = kind.score_multiplier();
if weight < min_weight {
continue;
}
edges.push(serde_json::json!({
"source": source,
"target": target,
"type": kind_name,
"weight": weight,
}));
}
let body = serde_json::json!({
"nodes": nodes,
"edges": edges,
"stats": {
"node_count": graph.node_count(),
"edge_count": graph.edge_count(),
},
"generated_at": chrono::Utc::now().to_rfc3339(),
});
let mut response = Json(body).into_response();
response.headers_mut().insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("max-age=3600"),
);
Ok(response)
}
async fn graph_stats_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let (graph, corpus) = {
let indexer = handle.indexer.read().await;
(
indexer.snapshot_symbol_graph().await,
indexer.corpus_store(),
)
};
let breakdown = graph.edge_kind_breakdown();
let mut edge_kinds = serde_json::Map::with_capacity(breakdown.len());
for (tag, count) in breakdown {
edge_kinds.insert(tag, serde_json::Value::from(count));
}
let (community_count, modularity) = match corpus {
Some(corpus) => {
tokio::task::spawn_blocking(move || crate::core::SymbolGraph::load_communities(&corpus))
.await
.ok()
.and_then(|r| r.ok())
.map(|records| {
let count = records.len() as u64;
let m: f64 = records.iter().map(|r| r.modularity_contribution).sum();
(count, m)
})
.unwrap_or((0, 0.0))
}
None => (0, 0.0),
};
Ok(Json(serde_json::json!({
"node_count": graph.node_count(),
"edge_count": graph.edge_count(),
"edge_kinds": serde_json::Value::Object(edge_kinds),
"community_count": community_count,
"modularity": modularity,
})))
}
const COMMUNITIES_HTTP_MEMBER_CAP: usize = 50;
async fn communities_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let corpus = {
let indexer = handle.indexer.read().await;
indexer.corpus_store()
};
let records = match corpus {
Some(corpus) => {
tokio::task::spawn_blocking(move || crate::core::SymbolGraph::load_communities(&corpus))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
}
None => Vec::new(),
};
let modularity: f64 = records.iter().map(|r| r.modularity_contribution).sum();
let communities: Vec<serde_json::Value> = records
.iter()
.map(|r| {
let members_truncated: Vec<&String> =
r.members.iter().take(COMMUNITIES_HTTP_MEMBER_CAP).collect();
serde_json::json!({
"id": r.id,
"member_count": r.member_count,
"centroid_symbol": r.centroid_symbol,
"dominant_files": r.dominant_files,
"members": members_truncated,
"modularity_contribution": r.modularity_contribution,
})
})
.collect();
Ok(Json(serde_json::json!({
"community_count": records.len(),
"modularity": modularity,
"communities": communities,
})))
}
async fn community_for_symbol_handler(
State(state): State<Arc<SearchAppState>>,
Path((id, symbol_encoded)): Path<(String, String)>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let symbol = symbol_encoded.clone();
let corpus = {
let indexer = handle.indexer.read().await;
indexer.corpus_store().ok_or(StatusCode::NOT_FOUND)?
};
let symbol_clone = symbol.clone();
let cid_opt = tokio::task::spawn_blocking(move || corpus.symbol_community(&symbol_clone))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let Some(cid) = cid_opt else {
return Err(StatusCode::NOT_FOUND);
};
let corpus2 = {
let indexer = handle.indexer.read().await;
indexer.corpus_store().ok_or(StatusCode::NOT_FOUND)?
};
let records =
tokio::task::spawn_blocking(move || crate::core::SymbolGraph::load_communities(&corpus2))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let Some(record) = records.iter().find(|r| r.id as u64 == cid) else {
return Err(StatusCode::NOT_FOUND);
};
let siblings: Vec<&String> = record.members.iter().filter(|m| **m != symbol).collect();
Ok(Json(serde_json::json!({
"community_id": record.id,
"community_size": record.member_count,
"symbol": symbol,
"siblings": siblings,
"centroid_symbol": record.centroid_symbol,
})))
}
async fn index_file_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<IndexFileRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
indexer
.index_file(&req.path, &req.content)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"path": req.path,
"indexed": true,
})))
}
async fn remove_file_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Json(req): Json<RemoveFileRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let indexer = handle.indexer.read().await;
let removed = indexer
.remove_file(&req.path)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"path": req.path,
"removed_chunks": removed,
})))
}
#[derive(Deserialize)]
pub struct ChunksParams {
#[serde(default)]
pub offset: usize,
#[serde(default = "default_chunks_limit")]
pub limit: usize,
}
fn default_chunks_limit() -> usize {
100
}
const MAX_CHUNKS_LIMIT: usize = 1_000;
async fn get_index_chunks_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Query(params): Query<ChunksParams>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let index_id = IndexId::new(id);
let handle = state.registry.get(&index_id).ok_or(StatusCode::NOT_FOUND)?;
let limit = params.limit.min(MAX_CHUNKS_LIMIT);
let indexer = handle.indexer.read().await;
let (total, chunks) = indexer.enumerate_chunks(params.offset, limit).await;
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"total": total,
"offset": params.offset,
"limit": limit,
"chunks": chunks,
})))
}
#[derive(Deserialize, Default)]
pub struct ReindexRequest {
#[serde(default)]
pub root_path: Option<std::path::PathBuf>,
#[serde(default)]
pub force: Option<bool>,
}
async fn reindex_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
body: Option<Json<ReindexRequest>>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let index_id = IndexId::new(id.clone());
let mut handle = state.registry.get(&index_id).ok_or((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("unknown index: {}", index_id.0),
})),
))?;
if let Some(aborted_at) = state.last_reindex_aborted_at.get(&index_id) {
let elapsed = aborted_at.elapsed();
let cooldown = std::time::Duration::from_secs(
std::env::var("TRUSTY_REINDEX_COOLDOWN_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(300),
);
if elapsed < cooldown {
let remaining_secs = (cooldown - elapsed).as_secs();
tracing::warn!(
"reindex_handler: refusing reindex for index {} — last run \
aborted at memory limit {}s ago, cooldown {}s remaining",
index_id.0,
elapsed.as_secs(),
remaining_secs,
);
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "reindex cooldown active after memory-limit abort",
"index_id": index_id.0,
"retry_after_secs": remaining_secs,
"cooldown_secs": cooldown.as_secs(),
"hint": "lower TRUSTY_MAX_BATCH_SIZE or raise TRUSTY_MEMORY_LIMIT_MB before retrying",
})),
));
}
drop(aborted_at);
state.last_reindex_aborted_at.remove(&index_id);
}
let mut force = false;
if let Some(Json(req)) = body {
force = req.force.unwrap_or(false);
if let Some(new_root) = req.root_path {
if let Err(resp) = validate_root_path(&new_root) {
let (parts, body) = resp.into_parts();
let status = parts.status;
let body_bytes = axum::body::to_bytes(body, 4096).await.unwrap_or_default();
let json: serde_json::Value =
serde_json::from_slice(&body_bytes).unwrap_or_else(|_| serde_json::json!({}));
return Err((status, Json(json)));
}
if handle.root_path.as_os_str().is_empty() || handle.root_path != new_root {
let indexer = Arc::clone(&handle.indexer);
let new_handle = IndexHandle {
id: index_id.clone(),
indexer,
root_path: new_root,
include_paths: handle.include_paths.clone(),
exclude_globs: handle.exclude_globs.clone(),
extensions: handle.extensions.clone(),
domain_terms: handle.domain_terms.clone(),
path_filter: handle.path_filter.clone(),
context_embedding: Arc::clone(&handle.context_embedding),
context_summary: Arc::clone(&handle.context_summary),
};
handle = state.registry.register(new_handle);
}
}
}
let progress = Arc::new(ReindexProgress::new());
state
.reindex_progress
.insert(index_id.clone(), Arc::clone(&progress));
state.invalidate_graph_scorer(&index_id);
spawn_reindex_with_cleanup(
handle,
progress,
force,
Some(Arc::clone(&state.reindex_progress)),
Some(Arc::clone(&state.last_reindex_aborted_at)),
);
Ok(Json(serde_json::json!({
"index_id": index_id.0,
"queued": true,
"stream_url": format!("/indexes/{}/reindex/stream", index_id.0),
})))
}
async fn reindex_stream_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
) -> Result<Response, StatusCode> {
let index_id = IndexId::new(id);
let progress = state
.reindex_progress
.get(&index_id)
.map(|r| Arc::clone(r.value()))
.ok_or(StatusCode::NOT_FOUND)?;
let replay = progress.events.lock().await.clone();
let initial_status = progress.status.load();
let rx = progress.sender.subscribe();
fn frame(line: String) -> Result<axum::body::Bytes, std::io::Error> {
Ok(axum::body::Bytes::from(format!("data: {line}\n\n")))
}
let replay_stream = stream::iter(replay).map(frame);
let body = if initial_status != ReindexStatus::Running {
Body::from_stream(replay_stream)
} else {
let live = BroadcastStream::new(rx).map(|res| match res {
Ok(line) => frame(line),
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => Ok(
axum::body::Bytes::from(format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")),
),
});
Body::from_stream(replay_stream.chain(live))
};
Ok(Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(body)
.expect("valid SSE response"))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn health_handler_reports_indexes_and_uptime() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("health-test");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(CodeIndexer::new(
"health-test",
"/tmp/health-test",
))),
"/tmp/health-test".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let Json(resp) = health_handler(State(state)).await;
assert_eq!(resp.status, "ok");
assert_eq!(resp.version, env!("CARGO_PKG_VERSION"));
assert_eq!(resp.indexes, 1);
let _ = resp.uptime_secs;
assert_eq!(resp.embedder, "initializing");
}
#[tokio::test]
async fn graph_handler_exports_nodes_and_edges() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("graph-test");
let indexer = CodeIndexer::new("graph-test", "/tmp/graph-test");
indexer
.index_file(
"graph-test/lib.rs",
"fn callee() {}\nfn caller() { callee(); }\n",
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
"/tmp/graph-test".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let response = graph_handler(
State(state),
Path("graph-test".to_string()),
Query(GraphQueryParams::default()),
)
.await
.expect("handler ok");
assert_eq!(
response
.headers()
.get(axum::http::header::CACHE_CONTROL)
.and_then(|v| v.to_str().ok()),
Some("max-age=3600"),
);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("body bytes");
let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json body");
let nodes = value["nodes"].as_array().expect("nodes array");
assert_eq!(nodes.len(), 2, "two function symbols expected");
for node in nodes {
assert_eq!(node["type"].as_str(), Some("Symbol"));
assert!(node["id"].is_string());
assert!(node["label"].is_string());
assert!(node["metadata"]["file"].is_string());
}
let edges = value["edges"].as_array().expect("edges array");
assert_eq!(edges.len(), 1, "one CallsFunction edge expected");
assert_eq!(edges[0]["source"].as_str(), Some("caller"));
assert_eq!(edges[0]["target"].as_str(), Some("callee"));
assert_eq!(edges[0]["type"].as_str(), Some("CallsFunction"));
assert!(edges[0]["weight"].as_f64().is_some());
assert_eq!(value["stats"]["node_count"].as_u64(), Some(2));
assert_eq!(value["stats"]["edge_count"].as_u64(), Some(1));
assert!(value["generated_at"].is_string());
}
#[tokio::test]
async fn graph_handler_unknown_index_returns_404() {
use crate::core::registry::IndexRegistry;
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let err = graph_handler(
State(state),
Path("does-not-exist".to_string()),
Query(GraphQueryParams::default()),
)
.await
.expect_err("missing index must 404");
assert_eq!(err, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn graph_handler_filters_by_edge_type() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("graph-filter");
let indexer = CodeIndexer::new("graph-filter", "/tmp/graph-filter");
indexer
.index_file(
"graph-filter/lib.rs",
"fn callee() {}\nfn caller() { callee(); }\n",
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
"/tmp/graph-filter".into(),
));
let state = Arc::new(SearchAppState::new(registry));
let response = graph_handler(
State(state),
Path("graph-filter".to_string()),
Query(GraphQueryParams {
types: None,
edge_types: Some("Implements".to_string()),
min_weight: None,
}),
)
.await
.expect("handler ok");
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("body bytes");
let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json body");
assert!(
value["edges"].as_array().expect("edges").is_empty(),
"CallsFunction edge must be filtered out",
);
assert_eq!(value["nodes"].as_array().expect("nodes").len(), 2);
}
#[tokio::test]
async fn global_search_fans_out_and_merges() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
for name in ["proj-a", "proj-b"] {
let id = IndexId::new(name);
let indexer = CodeIndexer::new(name, format!("/tmp/{name}"));
indexer
.index_file(
&format!("{name}/lib.rs"),
&format!("fn alpha_{name}() {{ println!(\"alpha hit\"); }}"),
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
format!("/tmp/{name}").into(),
));
}
let state = Arc::new(SearchAppState::new(registry));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "alpha".into(),
top_k: 10,
full_content: false,
indexes: None,
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
let total = value["total_indexes"].as_u64().expect("total_indexes");
assert_eq!(total, 2, "both indexes counted");
let searched: Vec<String> = value["indexes_searched"]
.as_array()
.expect("indexes_searched array")
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect();
assert_eq!(searched.len(), 2);
assert!(searched.contains(&"proj-a".to_string()));
assert!(searched.contains(&"proj-b".to_string()));
let results = value["results"].as_array().expect("results array");
assert!(!results.is_empty(), "expected at least one hit");
let mut from_a = false;
let mut from_b = false;
for r in results {
let idx = r["index_id"]
.as_str()
.expect("each result must be tagged with index_id");
assert!(
idx == "proj-a" || idx == "proj-b",
"unexpected index_id: {idx}"
);
from_a |= idx == "proj-a";
from_b |= idx == "proj-b";
}
assert!(from_a, "expected a result tagged with proj-a");
assert!(from_b, "expected a result tagged with proj-b");
}
#[tokio::test]
async fn global_search_empty_registry_returns_empty_results() {
use crate::core::registry::IndexRegistry;
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "anything".into(),
top_k: 5,
full_content: false,
indexes: None,
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
assert_eq!(value["total_indexes"].as_u64(), Some(0));
assert!(value["results"].as_array().unwrap().is_empty());
assert!(value["indexes_searched"].as_array().unwrap().is_empty());
}
#[tokio::test]
async fn global_search_restricts_to_named_indexes() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
for name in ["proj-a", "proj-b", "proj-c"] {
let id = IndexId::new(name);
let indexer = CodeIndexer::new(name, format!("/tmp/{name}"));
indexer
.index_file(
&format!("{name}/lib.rs"),
&format!("fn alpha_{name}() {{ println!(\"alpha hit\"); }}"),
)
.await
.expect("index_file ok");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(indexer)),
format!("/tmp/{name}").into(),
));
}
let state = Arc::new(SearchAppState::new(registry));
let Json(value) = global_search_handler(
State(state),
Json(GlobalSearchRequest {
query: "alpha".into(),
top_k: 10,
full_content: false,
indexes: Some(vec!["proj-a".into(), "proj-c".into()]),
routing: None,
routing_n: None,
routing_threshold: None,
}),
)
.await
.expect("handler ok");
assert_eq!(value["total_indexes"].as_u64(), Some(2));
let searched: std::collections::HashSet<String> = value["indexes_searched"]
.as_array()
.expect("array")
.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect();
assert!(searched.contains("proj-a"));
assert!(searched.contains("proj-c"));
assert!(!searched.contains("proj-b"), "proj-b must be excluded");
for r in value["results"].as_array().unwrap() {
let idx = r["index_id"].as_str().unwrap();
assert_ne!(idx, "proj-b", "no result may come from excluded index");
}
}
#[test]
fn routing_mode_all_preserves_every_index_with_weights() {
let ids = vec![IndexId::new("a"), IndexId::new("b"), IndexId::new("c")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("a"), 0.9_f32),
(IndexId::new("b"), 0.2),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::All.apply(&ids, &weights);
assert_eq!(active.len(), 3, "all routing keeps every index");
assert!((map.get(&IndexId::new("a")).copied().unwrap() - 0.9).abs() < 1e-6);
assert!((map.get(&IndexId::new("b")).copied().unwrap() - 0.2).abs() < 1e-6);
assert!((map.get(&IndexId::new("c")).copied().unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn routing_mode_top_n_keeps_only_highest_similarity() {
let ids = vec![IndexId::new("low"), IndexId::new("hi"), IndexId::new("mid")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("low"), 0.1_f32),
(IndexId::new("hi"), 0.95),
(IndexId::new("mid"), 0.5),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::TopN(2).apply(&ids, &weights);
assert_eq!(active.len(), 2);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(active_set.contains("hi"));
assert!(active_set.contains("mid"));
assert!(!active_set.contains("low"));
assert!((map.get(&IndexId::new("hi")).copied().unwrap() - 1.0).abs() < 1e-6);
assert!((map.get(&IndexId::new("mid")).copied().unwrap() - 1.0).abs() < 1e-6);
assert!(!map.contains_key(&IndexId::new("low")));
}
#[test]
fn routing_mode_threshold_drops_below_cutoff() {
let ids = vec![IndexId::new("a"), IndexId::new("b"), IndexId::new("c")];
let weights: std::collections::HashMap<IndexId, f32> = [
(IndexId::new("a"), 0.1_f32),
(IndexId::new("b"), 0.5),
(IndexId::new("c"), 0.8),
]
.into_iter()
.collect();
let (active, map) = RoutingMode::Threshold(0.4).apply(&ids, &weights);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(!active_set.contains("a"), "0.1 < 0.4 must drop");
assert!(active_set.contains("b"), "0.5 >= 0.4 must keep");
assert!(active_set.contains("c"));
assert!(!map.contains_key(&IndexId::new("a")));
}
#[test]
fn routing_threshold_keeps_neutral_indexes() {
let ids = vec![IndexId::new("known"), IndexId::new("missing")];
let weights: std::collections::HashMap<IndexId, f32> =
[(IndexId::new("known"), 0.05_f32)].into_iter().collect();
let (active, _map) = RoutingMode::Threshold(0.5).apply(&ids, &weights);
let active_set: std::collections::HashSet<&str> =
active.iter().map(|id| id.0.as_str()).collect();
assert!(!active_set.contains("known"), "0.05 < 0.5 dropped");
assert!(
active_set.contains("missing"),
"indexes without a context embedding must use neutral 1.0 weight"
);
}
#[test]
fn routing_mode_from_request_resolves_strategy() {
let base =
|routing: Option<&str>, n: Option<usize>, t: Option<f32>| -> GlobalSearchRequest {
GlobalSearchRequest {
query: "x".into(),
top_k: 1,
full_content: false,
indexes: None,
routing: routing.map(|s| s.to_string()),
routing_n: n,
routing_threshold: t,
}
};
assert!(matches!(
RoutingMode::from_request(&base(None, None, None)),
RoutingMode::All
));
assert!(matches!(
RoutingMode::from_request(&base(Some("garbage"), None, None)),
RoutingMode::All
));
match RoutingMode::from_request(&base(Some("top_n"), Some(5), None)) {
RoutingMode::TopN(n) => assert_eq!(n, 5),
_ => panic!("expected TopN"),
}
match RoutingMode::from_request(&base(Some("top_n"), None, None)) {
RoutingMode::TopN(n) => assert_eq!(n, RoutingMode::DEFAULT_TOP_N),
_ => panic!("expected TopN default"),
}
match RoutingMode::from_request(&base(Some("threshold"), None, Some(0.7))) {
RoutingMode::Threshold(t) => assert!((t - 0.7).abs() < 1e-6),
_ => panic!("expected Threshold"),
}
}
#[tokio::test]
async fn install_embedder_error_surfaces_in_health() {
use crate::core::registry::IndexRegistry;
let state = SearchAppState::new(IndexRegistry::new());
state
.install_embedder_error("init timed out after 60s")
.await;
let state_arc = Arc::new(state);
let Json(resp) = health_handler(State(state_arc)).await;
assert_eq!(resp.embedder, "error");
assert_eq!(
resp.embedder_error.as_deref(),
Some("init timed out after 60s"),
);
}
#[tokio::test]
async fn create_index_returns_503_with_error_when_embedder_failed() {
use crate::core::registry::IndexRegistry;
use axum::body::to_bytes;
let state = SearchAppState::new(IndexRegistry::new());
state
.install_embedder_error("init timed out after 60s")
.await;
let state_arc = Arc::new(state);
let tmp = tempfile::tempdir().expect("tempdir");
let resp = create_index_handler(
State(state_arc),
Json(CreateIndexRequest {
id: "demo".to_string(),
root_path: tmp.path().to_path_buf(),
include_paths: None,
exclude_globs: None,
extensions: None,
domain_terms: None,
path_filter: None,
}),
)
.await;
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let body_bytes = to_bytes(resp.into_body(), 64 * 1024)
.await
.expect("read body");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).expect("valid json");
let err_str = body
.get("error")
.and_then(|v| v.as_str())
.unwrap_or_default();
assert!(
err_str.contains("embedder init failed"),
"expected error message to mention init failure, got: {err_str}",
);
assert!(
err_str.contains("init timed out after 60s"),
"expected recorded timeout message to be surfaced, got: {err_str}",
);
}
#[tokio::test]
async fn install_embedder_clears_previous_error() {
use crate::core::embed::MockEmbedder;
use crate::core::registry::IndexRegistry;
let state = SearchAppState::new(IndexRegistry::new());
state.install_embedder_error("transient hang").await;
assert!(state.current_embedder_error().is_some());
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(8));
state.install_embedder(embedder).await;
assert!(state.current_embedder_error().is_none());
assert!(state.is_embedder_ready());
let state_arc = Arc::new(state);
let Json(resp) = health_handler(State(state_arc)).await;
assert_eq!(resp.embedder, "ready");
assert!(resp.embedder_error.is_none());
}
#[tokio::test]
async fn reindex_handler_rejects_within_cooldown() {
use crate::core::{
indexer::CodeIndexer,
registry::{IndexHandle, IndexId, IndexRegistry},
};
use std::sync::Arc;
use tokio::sync::RwLock;
let registry = IndexRegistry::new();
let id = IndexId::new("cooldown-test");
let tmp = tempfile::tempdir().expect("tempdir");
registry.register(IndexHandle::bare(
id.clone(),
Arc::new(RwLock::new(CodeIndexer::new("cooldown-test", tmp.path()))),
tmp.path().to_path_buf(),
));
let state = Arc::new(SearchAppState::new(registry));
state
.last_reindex_aborted_at
.insert(id.clone(), std::time::Instant::now());
let result = reindex_handler(
State(Arc::clone(&state)),
axum::extract::Path("cooldown-test".to_string()),
None,
)
.await;
let err = result.expect_err("expected 429 inside cooldown window");
assert_eq!(err.0, StatusCode::TOO_MANY_REQUESTS);
let body = err.1 .0;
assert!(body.get("retry_after_secs").is_some());
assert!(body.get("hint").is_some());
assert_eq!(body["index_id"], "cooldown-test");
state.last_reindex_aborted_at.remove(&id);
let ok = reindex_handler(
State(Arc::clone(&state)),
axum::extract::Path("cooldown-test".to_string()),
None,
)
.await
.expect("queued");
assert_eq!(ok.0["queued"], serde_json::Value::Bool(true));
}
#[tokio::test]
async fn reindex_status_aborted_memory_serializes_lowercase() {
let status = crate::service::reindex::ReindexStatus::AbortedMemory;
let json = serde_json::to_string(&status).expect("serialize");
assert_eq!(json, "\"abortedmemory\"");
}
#[tokio::test]
async fn health_includes_resource_fields() {
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let Json(resp) = health_handler(State(state)).await;
assert!(resp.rss_mb < 1024 * 1024, "rss_mb unit must be MB");
assert!(resp.cpu_pct >= 0.0, "cpu_pct must be non-negative");
assert_eq!(resp.disk_bytes, 0, "disk ticker has not ticked yet");
let _ = resp.rss_limit_mb;
}
#[test]
fn index_disk_and_mtime_handles_missing_dir() {
let id = format!("nonexistent-index-{}", std::process::id());
let (disk, mtime) = index_disk_and_mtime(&id);
assert!(disk.is_none(), "missing dir yields no disk_bytes");
assert!(mtime.is_none(), "missing dir yields no last_indexed");
}
#[tokio::test]
async fn health_omits_embedder_info_when_bm25_only() {
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let Json(resp) = health_handler(State(state)).await;
assert!(
resp.embedder_info.is_none(),
"BM25-only daemon must omit embedder_info"
);
}
#[tokio::test]
async fn logs_tail_returns_recent_lines() {
let buffer = trusty_common::log_buffer::LogBuffer::new(100);
buffer.push("line one".to_string());
buffer.push("line two".to_string());
buffer.push("line three".to_string());
let state = Arc::new(SearchAppState::new(IndexRegistry::new()).with_log_buffer(buffer));
let Json(body) = logs_tail_handler(State(state), Query(LogsTailParams { n: 2 })).await;
let lines = body["lines"].as_array().expect("lines array");
assert_eq!(lines.len(), 2, "n=2 must return two lines");
assert_eq!(lines[0].as_str(), Some("line two"));
assert_eq!(lines[1].as_str(), Some("line three"));
assert_eq!(body["total"].as_u64(), Some(3), "total counts all buffered");
}
#[tokio::test]
async fn logs_tail_clamps_n() {
let buffer = trusty_common::log_buffer::LogBuffer::new(100);
for i in 0..5 {
buffer.push(format!("l{i}"));
}
let state = Arc::new(SearchAppState::new(IndexRegistry::new()).with_log_buffer(buffer));
let Json(zero) =
logs_tail_handler(State(Arc::clone(&state)), Query(LogsTailParams { n: 0 })).await;
assert_eq!(zero["lines"].as_array().expect("lines").len(), 1);
let Json(big) = logs_tail_handler(
State(state),
Query(LogsTailParams {
n: MAX_LOGS_TAIL_N * 10,
}),
)
.await;
assert_eq!(big["lines"].as_array().expect("lines").len(), 5);
}
#[tokio::test]
async fn admin_stop_returns_ok() {
let state = Arc::new(SearchAppState::new(IndexRegistry::new()));
let Json(body) = admin_stop_handler(State(state)).await;
assert_eq!(body["ok"], serde_json::Value::Bool(true));
assert_eq!(body["message"].as_str(), Some("shutting down"));
}
#[tokio::test]
async fn create_index_rejects_relative_root_path() {
use crate::core::registry::IndexRegistry;
use axum::body::to_bytes;
let state = SearchAppState::new(IndexRegistry::new());
let embedder: Arc<dyn Embedder> = Arc::new(crate::core::embed::MockEmbedder::new(8));
state.install_embedder(embedder).await;
let state_arc = Arc::new(state);
let resp = create_index_handler(
State(state_arc),
Json(CreateIndexRequest {
id: "rel-bad".into(),
root_path: std::path::PathBuf::from("claude-mpm"),
include_paths: None,
exclude_globs: None,
extensions: None,
domain_terms: None,
path_filter: None,
}),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), 4096).await.expect("body");
let v: serde_json::Value = serde_json::from_slice(&body).expect("json");
let err = v.get("error").and_then(|x| x.as_str()).unwrap_or("");
assert!(err.contains("absolute"), "got: {err}");
}
#[tokio::test]
async fn create_index_rejects_nonexistent_root_path() {
use crate::core::registry::IndexRegistry;
use axum::body::to_bytes;
let state = SearchAppState::new(IndexRegistry::new());
let embedder: Arc<dyn Embedder> = Arc::new(crate::core::embed::MockEmbedder::new(8));
state.install_embedder(embedder).await;
let state_arc = Arc::new(state);
let resp = create_index_handler(
State(state_arc),
Json(CreateIndexRequest {
id: "ghost".into(),
root_path: std::path::PathBuf::from(
"/this/path/should/never/exist/trusty-search-test-xyz",
),
include_paths: None,
exclude_globs: None,
extensions: None,
domain_terms: None,
path_filter: None,
}),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), 4096).await.expect("body");
let v: serde_json::Value = serde_json::from_slice(&body).expect("json");
let err = v.get("error").and_then(|x| x.as_str()).unwrap_or("");
assert!(err.contains("does not exist"), "got: {err}");
}
#[tokio::test]
async fn create_index_accepts_valid_absolute_root_path() {
use crate::core::registry::IndexRegistry;
let state = SearchAppState::new(IndexRegistry::new());
let embedder: Arc<dyn Embedder> = Arc::new(crate::core::embed::MockEmbedder::new(8));
state.install_embedder(embedder).await;
let state_arc = Arc::new(state);
let tmp = tempfile::tempdir().expect("tempdir");
let resp = create_index_handler(
State(Arc::clone(&state_arc)),
Json(CreateIndexRequest {
id: "valid-abs".into(),
root_path: tmp.path().to_path_buf(),
include_paths: None,
exclude_globs: None,
extensions: None,
domain_terms: None,
path_filter: None,
}),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn file_is_within_root_relative_ok() {
let root = std::path::Path::new("/Users/me/proj");
assert!(file_is_within_root("src/auth.rs", root));
assert!(file_is_within_root("./src/auth.rs", root));
assert!(file_is_within_root("Cargo.toml", root));
}
#[test]
fn file_is_within_root_rejects_dotdot() {
let root = std::path::Path::new("/Users/me/proj");
assert!(!file_is_within_root("../other/file.rs", root));
assert!(!file_is_within_root("src/../../leak.rs", root));
}
#[test]
fn file_is_within_root_absolute_must_start_with_root() {
let root = std::path::Path::new("/Users/me/proj");
assert!(file_is_within_root("/Users/me/proj/src/auth.rs", root));
assert!(!file_is_within_root(
"/Users/me/other-proj/src/auth.rs",
root
));
assert!(!file_is_within_root("/etc/passwd", root));
}
#[test]
fn file_is_within_root_rejects_empty() {
let root = std::path::Path::new("/Users/me/proj");
assert!(!file_is_within_root("", root));
}
}