use crate::storage::schema::Value;
use super::super::entity::{EntityData, EntityKind, UnifiedEntity};
use super::filters::{Filter, FilterOp, FilterValue};
pub fn apply_filters(entity: &UnifiedEntity, filters: &[Filter]) -> bool {
for filter in filters {
let value = get_entity_field(entity, &filter.field);
if !match_filter(&value, &filter.op, &filter.value) {
return false;
}
}
true
}
pub fn get_entity_field(entity: &UnifiedEntity, field: &str) -> Option<Value> {
match field {
"id" => return Some(Value::Integer(entity.id.raw() as i64)),
"created_at" => return Some(Value::Integer(entity.created_at as i64)),
"updated_at" => return Some(Value::Integer(entity.updated_at as i64)),
_ => {}
}
match &entity.data {
EntityData::Node(node) => node.get(field).cloned(),
EntityData::Edge(edge) => edge.get(field).cloned(),
EntityData::Row(row) => row.get_by_name(field).cloned(),
EntityData::Vector(vec) => {
if field == "content" {
vec.content.as_ref().map(|c| Value::text(c.clone()))
} else {
None
}
}
EntityData::TimeSeries(_) => None,
EntityData::QueueMessage(_) => None,
}
}
pub fn match_filter(value: &Option<Value>, op: &FilterOp, filter_value: &FilterValue) -> bool {
match (value, op, filter_value) {
(Some(Value::Text(s)), FilterOp::Equals, FilterValue::String(fs)) => &**s == fs.as_str(),
(Some(Value::Integer(i)), FilterOp::Equals, FilterValue::Int(fi)) => *i == *fi,
(Some(Value::Float(f)), FilterOp::Equals, FilterValue::Float(ff)) => {
(*f - *ff).abs() < 0.0001
}
(Some(Value::Boolean(b)), FilterOp::Equals, FilterValue::Bool(fb)) => *b == *fb,
(Some(Value::Text(s)), FilterOp::Contains, FilterValue::String(fs)) => {
s.contains(fs.as_str())
}
(Some(Value::Text(s)), FilterOp::StartsWith, FilterValue::String(fs)) => {
s.starts_with(fs.as_str())
}
(Some(Value::Text(s)), FilterOp::EndsWith, FilterValue::String(fs)) => {
s.ends_with(fs.as_str())
}
(Some(Value::Integer(i)), FilterOp::GreaterThan, FilterValue::Int(fi)) => *i > *fi,
(Some(Value::Integer(i)), FilterOp::LessThan, FilterValue::Int(fi)) => *i < *fi,
(Some(Value::Float(f)), FilterOp::GreaterThan, FilterValue::Float(ff)) => *f > *ff,
(Some(Value::Float(f)), FilterOp::LessThan, FilterValue::Float(ff)) => *f < *ff,
(None, FilterOp::IsNull, _) => true,
(Some(_), FilterOp::IsNotNull, _) => true,
(None, FilterOp::IsNotNull, _) => false,
(Some(_), FilterOp::IsNull, _) => false,
_ => false,
}
}
pub fn calculate_entity_similarity(
entity: &UnifiedEntity,
query: &[f32],
slot: &Option<String>,
) -> f32 {
let mut best_similarity = 0.0f32;
for emb in entity.embeddings() {
if let Some(ref slot_name) = slot {
if &emb.name != slot_name {
continue;
}
}
let sim = cosine_similarity(query, &emb.vector);
best_similarity = best_similarity.max(sim);
}
if let EntityData::Vector(ref vec_data) = entity.data {
let sim = cosine_similarity(query, &vec_data.dense);
best_similarity = best_similarity.max(sim);
}
best_similarity
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = (norm_a * norm_b).sqrt();
if denom > 0.0 {
dot / denom
} else {
0.0
}
}
pub fn extract_searchable_text(entity: &UnifiedEntity) -> String {
let mut parts = Vec::new();
match &entity.kind {
EntityKind::GraphNode(ref node) => {
parts.push(node.label.clone());
parts.push(node.node_type.clone());
}
EntityKind::GraphEdge(ref edge) => {
parts.push(edge.label.clone());
}
EntityKind::TableRow { table, .. } => {
parts.push(table.to_string());
}
EntityKind::Vector { collection } => {
parts.push(collection.clone());
}
EntityKind::TimeSeriesPoint(ref ts) => {
parts.push(ts.series.clone());
parts.push(ts.metric.clone());
}
EntityKind::QueueMessage { queue, .. } => {
parts.push(queue.clone());
}
}
match &entity.data {
EntityData::Node(node) => {
for (k, v) in &node.properties {
parts.push(k.clone());
if let Value::Text(s) = v {
parts.push(s.to_string());
}
}
}
EntityData::Edge(edge) => {
for (k, v) in &edge.properties {
parts.push(k.clone());
if let Value::Text(s) = v {
parts.push(s.to_string());
}
}
}
EntityData::Row(row) => {
if !row.columns.is_empty() {
for col in &row.columns {
if let Value::Text(s) = col {
parts.push(s.to_string());
}
}
} else if let Some(named) = &row.named {
for (k, v) in named {
parts.push(k.clone());
if let Value::Text(s) = v {
parts.push(s.to_string());
}
}
}
}
EntityData::Vector(vec) => {
if let Some(ref content) = vec.content {
parts.push(content.clone());
}
}
EntityData::TimeSeries(ts) => {
parts.push(ts.metric.clone());
}
EntityData::QueueMessage(_) => {}
}
parts.join(" ")
}