use crate::collection::types::Collection;
use crate::guardrails::{GuardRails, QueryLimits};
use crate::point::Point;
use crate::test_fixtures::fixtures::setup_collection;
use std::collections::HashMap;
use std::sync::Arc;
use tempfile::TempDir;
fn make_collection() -> (TempDir, Collection) {
let (dir, col) = setup_collection(4);
let points: Vec<Point> = (0u64..10)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
Point::new(
i,
vec![i as f32 / 10.0, 0.1, 0.1, 0.1],
Some(serde_json::json!({
"idx": i,
"category": if i % 2 == 0 { "even" } else { "odd" },
"alt_vec": [i as f64 / 10.0, 0.2, 0.2, 0.2]
})),
)
})
.collect();
col.upsert(points).expect("test: upsert");
(dir, col)
}
#[test]
fn test_execute_query_str_parses_and_executes() {
let (_dir, col) = make_collection();
let params = HashMap::new();
let result = col.execute_query_str("SELECT * FROM col LIMIT 5;", ¶ms);
let results = result.expect("execute_query_str should succeed");
assert_eq!(
results.len(),
5,
"10-point collection with LIMIT 5 must return exactly 5"
);
for r in &results {
let idx = r
.point
.payload
.as_ref()
.and_then(|p| p.get("idx"))
.and_then(serde_json::Value::as_u64)
.expect("each result must have an idx payload field");
assert!(idx < 10, "idx {idx} out of fixture range 0..10");
}
}
#[test]
fn test_execute_query_str_caches_repeated_calls() {
let (_dir, col) = make_collection();
let params = HashMap::new();
let sql = "SELECT * FROM col LIMIT 3;";
let stats_before = col.query_cache.stats();
let r1 = col
.execute_query_str(sql, ¶ms)
.expect("first call failed");
let r2 = col
.execute_query_str(sql, ¶ms)
.expect("second call failed");
assert_eq!(
r1.len(),
r2.len(),
"repeated queries should return the same count"
);
assert_eq!(
col.query_cache.len(),
1,
"identical SQL must yield a single cache entry"
);
let hits = col.query_cache.stats().hits - stats_before.hits;
assert!(
hits >= 1,
"second identical call should register a cache hit, got {hits}"
);
}
#[test]
fn test_execute_query_str_rejects_invalid_sql() {
let (_dir, col) = make_collection();
let params = HashMap::new();
let result = col.execute_query_str("NOT VALID SQL !!!", ¶ms);
assert!(result.is_err(), "Invalid SQL should return an error");
}
#[test]
fn test_execute_query_str_metadata_filter() {
let (_dir, col) = make_collection();
let params = HashMap::new();
let result = col
.execute_query_str(
"SELECT * FROM col WHERE category = 'even' LIMIT 10;",
¶ms,
)
.expect("query failed");
for r in &result {
if let Some(ref payload) = r.point.payload {
assert_eq!(
payload.get("category").and_then(|v| v.as_str()),
Some("even"),
"Non-even category in result: {:?}",
payload
);
}
}
}
#[test]
fn test_e2e_guardrails_cardinality_respected() {
let (_dir, mut col) = make_collection();
let limits = QueryLimits {
max_cardinality: 3, ..QueryLimits::default()
};
col.guard_rails = Arc::new(GuardRails::with_limits(limits));
let params = HashMap::new();
let result = col.execute_query_str("SELECT * FROM col LIMIT 10;", ¶ms);
assert!(result.is_err(), "Cardinality guardrail should fire");
let err = result.unwrap_err().to_string();
assert!(
err.contains("Guard-rail") || err.contains("cardinality") || err.contains("Cardinality"),
"Unexpected error message: {err}"
);
}
#[test]
fn test_e2e_guardrails_timeout_zero_disables_check() {
let (_dir, mut col) = make_collection();
col.guard_rails = Arc::new(GuardRails::with_limits(QueryLimits {
timeout_ms: 0,
..QueryLimits::default()
}));
let params = HashMap::new();
let result = col.execute_query_str("SELECT * FROM col LIMIT 5;", ¶ms);
assert!(
result.is_ok(),
"timeout_ms=0 should disable the guard-rail, not reject the query"
);
}
#[test]
fn test_e2e_guardrails_circuit_breaker_state() {
let (_dir, mut col) = make_collection();
col.guard_rails = Arc::new(GuardRails::with_limits(QueryLimits {
max_cardinality: 1, circuit_failure_threshold: 2,
circuit_recovery_seconds: 60,
..QueryLimits::default()
}));
let params = HashMap::new();
let sql = "SELECT * FROM col LIMIT 10;";
let _ = col.execute_query_str(sql, ¶ms); let _ = col.execute_query_str(sql, ¶ms);
let state = col.guard_rails.circuit_breaker.state();
assert_eq!(
state,
crate::guardrails::CircuitState::Open,
"Circuit breaker should open after 2 failures"
);
}
#[test]
fn test_e2e_similarity_primary_vector_field() {
let (_dir, col) = make_collection();
let mut params = HashMap::new();
params.insert("v".to_string(), serde_json::json!([0.5, 0.1, 0.1, 0.1]));
let result = col
.execute_query_str(
"SELECT * FROM col WHERE similarity(vector, $v) > 0.5 LIMIT 5;",
¶ms,
)
.expect("primary vector similarity should succeed");
for r in &result {
assert!(r.score >= 0.5, "Score {} below threshold 0.5", r.score);
}
}
#[test]
fn test_e2e_similarity_named_payload_vector_field() {
let (_dir, col) = make_collection();
let mut params = HashMap::new();
params.insert("v".to_string(), serde_json::json!([0.5, 0.2, 0.2, 0.2]));
let result = col.execute_query_str(
"SELECT * FROM col WHERE similarity(alt_vec, $v) > 0.0 LIMIT 10;",
¶ms,
);
let results =
result.expect("named-field similarity should succeed (multi-vector restriction removed)");
assert_eq!(
results.len(),
10,
"all 10 points should pass the > 0.0 threshold"
);
let p5 = results
.iter()
.find(|r| r.point.id == 5)
.expect("point 5 must be in results");
assert!(
(p5.score - 1.0).abs() < 1e-4,
"point 5 (alt_vec==query) should score ~1.0, got {}",
p5.score
);
let p0 = results
.iter()
.find(|r| r.point.id == 0)
.expect("point 0 must be in results");
assert!(
p0.score < p5.score,
"point 0 must score below the exact match"
);
}
#[test]
fn test_e2e_cbo_with_vector_and_filter_no_panic() {
let (_dir, col) = make_collection();
let mut params = HashMap::new();
params.insert("v".to_string(), serde_json::json!([0.5, 0.1, 0.1, 0.1]));
let result = col.execute_query_str(
"SELECT * FROM col WHERE vector NEAR $v AND category = 'even' LIMIT 5;",
¶ms,
);
match result {
Ok(results) => {
for r in &results {
if let Some(ref payload) = r.point.payload {
assert_eq!(
payload.get("category").and_then(|v| v.as_str()),
Some("even")
);
}
}
}
Err(e) => {
let msg = e.to_string();
assert!(
msg.contains("Guard-rail") || msg.contains("Query"),
"Unexpected CBO error: {msg}"
);
}
}
}