use std::time::Duration;
use anyhow::{Result, anyhow};
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use super::{AsyncVectorIndex, SearchHit};
pub struct ElasticsearchVectorIndex {
client: reqwest::Client,
base_url: String,
index: String,
dim: usize,
}
impl ElasticsearchVectorIndex {
pub async fn new(url: &str, index: &str, dim: usize) -> Result<Self> {
validate_index_name(index)?;
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))?;
let idx = Self {
client,
base_url: url.trim_end_matches('/').to_string(),
index: index.to_string(),
dim,
};
idx.ensure_index().await?;
Ok(idx)
}
pub async fn delete_index(&self) -> Result<()> {
let resp = self.delete(&format!("/{}", &self.index)).await?;
if !resp.status().is_success() && resp.status().as_u16() != 404 {
return Err(self.status_err(resp).await);
}
Ok(())
}
async fn ensure_index(&self) -> Result<()> {
let head = self
.client
.head(format!("{}/{}", &self.base_url, &self.index))
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))?;
if head.status().as_u16() == 200 {
return Ok(());
}
if head.status().as_u16() != 404 {
return Err(self.status_err(head).await);
}
let body = serde_json::json!({
"mappings": {
"properties": {
"vector": {
"type": "dense_vector",
"dims": self.dim,
"index": true,
"similarity": "cosine"
},
"ext_id": { "type": "long" }
}
}
});
let resp = self.put(&format!("/{}", &self.index), body).await?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
Ok(())
}
fn parse_id(_id: &str) -> Option<u64> {
_id.parse::<u64>().ok()
}
async fn put(&self, path: &str, body: serde_json::Value) -> Result<reqwest::Response> {
self.client
.put(format!("{}{}", &self.base_url, path))
.json(&body)
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))
}
async fn delete(&self, path: &str) -> Result<reqwest::Response> {
self.client
.delete(format!("{}{}", &self.base_url, path))
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))
}
async fn ndjson(&self, path: &str, body: String) -> Result<reqwest::Response> {
self.client
.post(format!("{}{}", &self.base_url, path))
.header(CONTENT_TYPE, "application/x-ndjson")
.body(body)
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))
}
async fn status_err(&self, resp: reqwest::Response) -> anyhow::Error {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let body = truncate_error_body(&redact_credentials(&body));
anyhow!(
"elasticsearch returned status {status} for index `{}` [url redacted]: {}",
&self.index,
body
)
}
}
#[async_trait::async_trait]
impl AsyncVectorIndex for ElasticsearchVectorIndex {
async fn add(&self, vectors: &[Vec<f32>], ids: &[u64]) -> Result<()> {
if vectors.len() != ids.len() {
return Err(anyhow!(
"vectors.len() ({}) must equal ids.len() ({})",
vectors.len(),
ids.len()
));
}
if vectors.is_empty() {
return Ok(());
}
let mut body = String::new();
for (v, &id) in vectors.iter().zip(ids.iter()) {
body.push_str(
&serde_json::to_string(&serde_json::json!({
"index": { "_index": &self.index, "_id": id.to_string() }
}))
.map_err(|e| anyhow!("bulk encode: {e}"))?,
);
body.push('\n');
body.push_str(
&serde_json::to_string(&serde_json::json!({
"ext_id": id,
"vector": v
}))
.map_err(|e| anyhow!("bulk encode: {e}"))?,
);
body.push('\n');
}
let resp = self.ndjson("/_bulk?refresh=wait_for", body).await?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
let parsed: BulkResponse = decode(resp).await?;
if parsed.errors {
return Err(anyhow!(
"elasticsearch bulk upsert reported per-item errors [url redacted]: {}",
first_failing_bulk_item(&parsed.items)
));
}
Ok(())
}
async fn remove(&self, ids: &[u64]) -> Result<()> {
if ids.is_empty() {
return Ok(());
}
let mut body = String::new();
for &id in ids {
body.push_str(
&serde_json::to_string(&serde_json::json!({
"delete": { "_index": &self.index, "_id": id.to_string() }
}))
.map_err(|e| anyhow!("bulk encode: {e}"))?,
);
body.push('\n');
}
let resp = self.ndjson("/_bulk?refresh=wait_for", body).await?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
let parsed: BulkResponse = decode(resp).await?;
if parsed.errors {
return Err(anyhow!(
"elasticsearch bulk delete reported per-item errors [url redacted]: {}",
first_failing_bulk_item(&parsed.items)
));
}
Ok(())
}
async fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchHit>> {
let num_candidates = knn_num_candidates(k);
let body = serde_json::json!({
"knn": {
"field": "vector",
"query_vector": query,
"k": k,
"num_candidates": num_candidates
},
"_source": false,
"size": k
});
let resp = self
.client
.post(format!("{}/{}/_search", &self.base_url, &self.index))
.json(&body)
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
let parsed: SearchResponse = decode(resp).await?;
Ok(parsed
.hits
.hits
.into_iter()
.filter_map(|h| {
Self::parse_id(&h._id).map(|id| SearchHit {
id,
score: h._score,
})
})
.collect())
}
async fn search_filtered(
&self,
query: &[f32],
k: usize,
allowlist: &[u64],
) -> Result<Vec<SearchHit>> {
if allowlist.is_empty() {
return Ok(vec![]);
}
let num_candidates = knn_num_candidates(k);
let allowlist: Vec<u64> = allowlist.to_vec();
let body = serde_json::json!({
"knn": {
"field": "vector",
"query_vector": query,
"k": k,
"num_candidates": num_candidates,
"filter": [{ "terms": { "ext_id": allowlist } }]
},
"_source": false,
"size": k
});
let resp = self
.client
.post(format!("{}/{}/_search", &self.base_url, &self.index))
.json(&body)
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
let parsed: SearchResponse = decode(resp).await?;
Ok(parsed
.hits
.hits
.into_iter()
.filter_map(|h| {
Self::parse_id(&h._id).map(|id| SearchHit {
id,
score: h._score,
})
})
.collect())
}
async fn len(&self) -> Result<usize> {
let resp = self
.client
.post(format!("{}/{}/_count", &self.base_url, &self.index))
.json(&serde_json::json!({}))
.send()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))?;
if !resp.status().is_success() {
return Err(self.status_err(resp).await);
}
let parsed: CountResponse = decode(resp).await?;
Ok(parsed.count as usize)
}
fn dim(&self) -> usize {
self.dim
}
}
pub fn redact_credentials(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut rest = s;
loop {
match rest.find("://") {
None => {
out.push_str(rest);
break;
}
Some(idx) => {
out.push_str(&rest[..idx + 3]);
let after = &rest[idx + 3..];
let auth_end = after.find(['/', '?', '#']).unwrap_or(after.len());
let auth = &after[..auth_end];
if let Some(at) = auth.rfind('@') {
out.push_str("<redacted>@");
out.push_str(&auth[at + 1..]);
} else {
out.push_str(auth);
}
rest = &after[auth_end..];
}
}
}
out
}
const MAX_KNN_CANDIDATES: usize = 1_000;
fn knn_num_candidates(k: usize) -> usize {
let base = k.max(1).saturating_mul(10);
base.min(MAX_KNN_CANDIDATES).max(k)
}
const ERROR_BODY_MAX_CHARS: usize = 1024;
fn truncate_error_body(s: &str) -> String {
if s.chars().count() <= ERROR_BODY_MAX_CHARS {
return s.to_string();
}
let cut = s
.char_indices()
.nth(ERROR_BODY_MAX_CHARS)
.map(|(i, _)| i)
.unwrap_or(s.len());
format!("{}... [truncated]", &s[..cut])
}
fn validate_index_name(index: &str) -> Result<()> {
if index.is_empty() {
return Err(anyhow!("elasticsearch index name must not be empty"));
}
if index == "." || index == ".." {
return Err(anyhow!(
"elasticsearch index name must not be `.` or `..` (reserved): `{}`",
index
));
}
if index.len() > 255 {
return Err(anyhow!(
"elasticsearch index name exceeds 255 bytes ({} bytes)",
index.len()
));
}
match index.as_bytes()[0] {
b'_' | b'-' | b'+' => {
return Err(anyhow!(
"elasticsearch index name must not start with `_`, `-`, or `+`: `{}`",
index
));
}
_ => {}
}
if let Some(bad) = index.bytes().find(|&c| {
!(c.is_ascii_lowercase() || c.is_ascii_digit() || matches!(c, b'_' | b'-' | b'.'))
}) {
return Err(anyhow!(
"elasticsearch index name contains an illegal byte 0x{bad:02x} (`{}`): \
only lowercase a-z, 0-9, `_`, `-`, `.` are allowed",
index
));
}
Ok(())
}
fn first_failing_bulk_item(items: &[serde_json::Value]) -> String {
for item in items {
if let Some(detail) = item.as_object().and_then(|o| o.values().next()) {
let status = detail.get("status").and_then(|v| v.as_i64()).unwrap_or(0);
let has_error = detail.get("error").is_some();
if status >= 400 || has_error {
return redact_credentials(&item.to_string());
}
}
}
"(no failing item found)".into()
}
async fn decode<T: serde::de::DeserializeOwned>(resp: reqwest::Response) -> Result<T> {
resp.json::<T>()
.await
.map_err(|e| anyhow!(redact_credentials(&e.to_string())))
}
#[derive(Deserialize)]
struct SearchResponse {
hits: SearchHits,
}
#[derive(Deserialize)]
struct SearchHits {
hits: Vec<SearchInnerHit>,
}
#[derive(Deserialize)]
struct SearchInnerHit {
_id: String,
_score: f32,
}
#[derive(Deserialize)]
struct CountResponse {
count: u64,
}
#[derive(Deserialize)]
struct BulkResponse {
errors: bool,
#[serde(default)]
items: Vec<serde_json::Value>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::AsyncVectorIndex;
const DIM: usize = 4;
fn unique_index() -> String {
format!("llm_kernel_test_{}", std::process::id())
}
fn offline_index(base_url: &str, dim: usize) -> ElasticsearchVectorIndex {
ElasticsearchVectorIndex {
client: reqwest::Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
index: "llm_kernel_test_offline".to_string(),
dim,
}
}
#[test]
fn parse_id_accepts_numeric_and_drops_rest() {
assert_eq!(ElasticsearchVectorIndex::parse_id("42"), Some(42));
assert_eq!(ElasticsearchVectorIndex::parse_id("0"), Some(0));
assert_eq!(
ElasticsearchVectorIndex::parse_id("18446744073709551615"),
Some(u64::MAX)
);
assert_eq!(ElasticsearchVectorIndex::parse_id("abc"), None);
assert_eq!(ElasticsearchVectorIndex::parse_id(""), None);
assert_eq!(ElasticsearchVectorIndex::parse_id("1.5"), None);
}
#[test]
fn redact_credentials_strips_userinfo() {
let cases = [
("http://u:pw@host:9200", "http://<redacted>@host:9200"),
(
"https://elastic:secret@es.local/x",
"https://<redacted>@es.local/x",
),
("http://localhost:9200", "http://localhost:9200"),
("http://user@host", "http://<redacted>@host"),
("https://u:p@ss@host:9200", "https://<redacted>@host:9200"),
("no url here", "no url here"),
(
"index 中文 — http://u:pw@h:9200",
"index 中文 — http://<redacted>@h:9200",
),
];
for (input, expected) in cases {
assert_eq!(redact_credentials(input), expected, "input = {input:?}");
}
assert!(!redact_credentials("https://u:secret@host").contains("secret"));
let leaked = redact_credentials("https://u:p@ss@host:9200");
assert!(!leaked.contains("p@ss"), "password tail leaked: {leaked}");
assert!(!leaked.contains("ss@"), "password tail leaked: {leaked}");
}
#[test]
fn validate_index_name_accepts_and_rejects() {
for ok in ["docs", "docs_v2", "my-index", "idx.2026", "a", "a.b-c_d"] {
assert!(
validate_index_name(ok).is_ok(),
"{ok:?} should be a valid index name"
);
}
for name in [
"", "Docs", "with space", "comma,idx", "_underscore", "-dash", "+plus", "bad/slash", "한글", ".", "..", ] {
assert!(
validate_index_name(name).is_err(),
"{name:?} should be rejected"
);
}
assert!(validate_index_name(&"a".repeat(255)).is_ok());
assert!(validate_index_name(&"a".repeat(256)).is_err());
}
#[test]
fn knn_num_candidates_scales_caps_and_floors() {
assert_eq!(knn_num_candidates(1), 10);
assert_eq!(knn_num_candidates(5), 50);
assert_eq!(knn_num_candidates(50), 500);
assert_eq!(knn_num_candidates(100), MAX_KNN_CANDIDATES);
assert_eq!(knn_num_candidates(200), MAX_KNN_CANDIDATES);
assert!(knn_num_candidates(200) >= 200);
assert_eq!(knn_num_candidates(0), 10);
}
#[test]
fn truncate_error_body_leaves_short_body_unchanged() {
assert_eq!(truncate_error_body(""), "");
assert_eq!(truncate_error_body("short error"), "short error");
let at_cap: String = "a".repeat(ERROR_BODY_MAX_CHARS);
let out = truncate_error_body(&at_cap);
assert_eq!(out.chars().count(), ERROR_BODY_MAX_CHARS);
assert!(!out.contains("[truncated]"));
}
#[test]
fn truncate_error_body_caps_huge_body_with_marker() {
let huge: String = "a".repeat(ERROR_BODY_MAX_CHARS + 500);
let out = truncate_error_body(&huge);
assert!(out.ends_with("... [truncated]"));
let kept = out.strip_suffix("... [truncated]").unwrap();
assert_eq!(kept.chars().count(), ERROR_BODY_MAX_CHARS);
let cjk: String = "중".repeat(ERROR_BODY_MAX_CHARS + 10);
let out_cjk = truncate_error_body(&cjk);
assert!(out_cjk.contains("[truncated]"));
}
#[test]
fn truncate_error_body_keeps_credentials_redacted() {
let with_cred = "error: see https://u:super-secret@host/idx for details";
let out = truncate_error_body(&redact_credentials(with_cred));
assert!(!out.contains("super-secret"), "credential leaked: {out}");
assert!(out.contains("<redacted>"));
let padding: String = "x".repeat(ERROR_BODY_MAX_CHARS + 50);
let long_cred = format!("{padding} then https://u:p@ss@host:9200");
let redacted = redact_credentials(&long_cred);
let out2 = truncate_error_body(&redacted);
assert!(
!out2.contains("p@ss") && !out2.contains("super-secret"),
"credential tail leaked: {out2}"
);
}
#[test]
fn first_failing_bulk_item_picks_failing_and_redacts() {
let items = vec![
serde_json::json!({ "index": { "_id": "1", "status": 200 } }),
serde_json::json!({
"index": { "_id": "2", "status": 400, "error": { "type": "mapper", "reason": "bad" } }
}),
];
let s = first_failing_bulk_item(&items);
assert!(
s.contains("\"_id\":\"2\""),
"should name the failing item: {s}"
);
assert!(s.contains("400"));
let err_only = vec![serde_json::json!({
"delete": { "_id": "9", "error": { "type": "x", "reason": "y" } }
})];
assert!(first_failing_bulk_item(&err_only).contains("\"_id\":\"9\""));
let with_url = vec![serde_json::json!({
"index": { "_id": "3", "status": 500, "error": { "reason": "see https://u:secret@host" } }
})];
let leaked = first_failing_bulk_item(&with_url);
assert!(!leaked.contains("secret"), "credential leaked: {leaked}");
assert!(leaked.contains("<redacted>"));
let none = vec![serde_json::json!({ "index": { "_id": "1", "status": 200 } })];
assert_eq!(first_failing_bulk_item(&none), "(no failing item found)");
}
#[test]
fn credentialed_url_error_redacts_password() {
let credentialed = "https://elastic:super-secret-pw@es.internal:9200/idx";
let raw = format!("error sending request for url ({credentialed}): connection refused");
let redacted = redact_credentials(&raw);
assert!(
!redacted.contains("super-secret-pw"),
"password leaked in redacted error: {redacted}"
);
assert!(redacted.contains("<redacted>"));
}
#[tokio::test]
async fn empty_allowlist_returns_empty_without_network() {
let idx = offline_index("http://0.0.0.0:1", DIM);
let res = idx.search_filtered(&[1.0, 0.0, 0.0, 0.0], 5, &[]).await;
assert!(res.is_ok(), "empty allowlist must not error: {res:?}");
assert!(res.unwrap().is_empty());
}
async fn run_live_conformance(idx: &ElasticsearchVectorIndex) -> Result<()> {
if idx.dim() != DIM {
return Err(anyhow!("dim mismatch"));
}
if !idx.is_empty().await? {
return Err(anyhow!("not empty at start"));
}
idx.add(
&[vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]],
&[1, 2],
)
.await?;
if idx.len().await? != 2 {
return Err(anyhow!("len != 2 after add"));
}
let hits = idx.search(&[1.0, 0.0, 0.0, 0.0], 1).await?;
if hits.len() != 1 || hits[0].id != 1 {
return Err(anyhow!("nearest neighbor != id 1"));
}
let filtered = idx.search_filtered(&[1.0, 0.0, 0.0, 0.0], 2, &[2]).await?;
if filtered.len() != 1 || filtered[0].id != 2 {
return Err(anyhow!("filtered search != id 2"));
}
idx.add(&[vec![0.9, 0.1, 0.0, 0.0]], &[1]).await?;
if idx.len().await? != 2 {
return Err(anyhow!("len != 2 after re-add"));
}
idx.remove(&[1]).await?;
if idx.len().await? != 1 {
return Err(anyhow!("len != 1 after remove"));
}
let after = idx.search(&[1.0, 0.0, 0.0, 0.0], 5).await?;
if after.iter().any(|h| h.id == 1) {
return Err(anyhow!("id 1 still present after remove"));
}
Ok(())
}
#[tokio::test]
async fn live_elastic_conformance() {
let url = match std::env::var("LLMKERNEL_ELASTIC_URL") {
Ok(u) => u,
Err(_) => {
eprintln!("skipped: LLMKERNEL_ELASTIC_URL unset (no live Elasticsearch)");
return;
}
};
let index = unique_index();
let idx = match ElasticsearchVectorIndex::new(&url, &index, DIM).await {
Ok(i) => i,
Err(e) => panic!("connect + create index: {e:?}"),
};
let result = run_live_conformance(&idx).await;
let _ = idx.delete_index().await;
result.expect("elasticsearch conformance failed");
}
}