use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use serde::Deserialize;
use std::sync::Arc;
use crate::core::indexer::{
typeahead::{TypeaheadHit, TypeaheadMode, TypeaheadResponse},
SearchQuery, SearchStage,
};
use crate::core::registry::IndexId;
use super::state::SearchAppState;
pub(super) const MAX_TYPEAHEAD_LIMIT: usize = 25;
const DEFAULT_TYPEAHEAD_LIMIT: usize = 6;
#[derive(Debug, Deserialize)]
pub struct TypeaheadParams {
#[serde(default)]
pub q: String,
pub limit: Option<usize>,
pub mode: Option<TypeaheadMode>,
}
pub async fn typeahead_handler(
State(state): State<Arc<SearchAppState>>,
Path(id): Path<String>,
Query(params): Query<TypeaheadParams>,
) -> Result<Json<TypeaheadResponse>, (StatusCode, Json<serde_json::Value>)> {
let q = params.q.trim();
if q.is_empty() {
return Ok(Json(TypeaheadResponse {
hits: vec![],
mode: "lexical".to_string(),
latency_ms: 0,
}));
}
let limit = params
.limit
.unwrap_or(DEFAULT_TYPEAHEAD_LIMIT)
.clamp(1, MAX_TYPEAHEAD_LIMIT);
let mode = params.mode.unwrap_or_default();
let index_id = IndexId::new(id.clone());
let handle = match state.registry.get(&index_id) {
Some(h) => h,
None => {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": format!("unknown index: {id}") })),
))
}
};
let (stage, expand_graph) = match mode {
TypeaheadMode::Lexical => (SearchStage::Lexical, false),
TypeaheadMode::Blended => (SearchStage::Semantic, true),
};
let query = SearchQuery {
text: q.to_owned(),
top_k: limit,
expand_graph,
compact: true,
stage: Some(stage),
..SearchQuery::default()
};
let indexer = handle.indexer.read().await;
let started = std::time::Instant::now();
let chunks = indexer.search(&query).await.map_err(|e| {
tracing::warn!(index_id = %id, err = %e, "typeahead search error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "internal search error" })),
)
})?;
drop(indexer);
let latency_ms = started.elapsed().as_millis() as u64;
let mode_str = match mode {
TypeaheadMode::Lexical => "lexical",
TypeaheadMode::Blended => "blended",
};
let hits: Vec<TypeaheadHit> = chunks
.iter()
.take(limit)
.map(|chunk| {
let source = TypeaheadHit::classify_source(&chunk.match_reason);
TypeaheadHit::from_chunk(chunk, source)
})
.collect();
tracing::debug!(
index_id = %id,
q = %q,
mode = %mode_str,
hits = hits.len(),
latency_ms = latency_ms,
"typeahead"
);
Ok(Json(TypeaheadResponse {
hits,
mode: mode_str.to_owned(),
latency_ms,
}))
}