use crate::AletheiaDB;
use crate::core::error::Result;
use crate::core::id::{EdgeId, NodeId};
use crate::core::vector::ops::cosine_similarity;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PolygraphRule {
RequiresHighSimilarity(f32),
RequiresLowSimilarity(f32),
}
#[derive(Debug, Clone, PartialEq)]
pub struct Contradiction {
pub edge_id: EdgeId,
pub source: NodeId,
pub target: NodeId,
pub label: String,
pub actual_similarity: f32,
pub violated_rule: PolygraphRule,
}
pub struct Polygraph<'a> {
db: &'a AletheiaDB,
rules: HashMap<String, PolygraphRule>,
}
#[allow(clippy::collapsible_if)]
impl<'a> Polygraph<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self {
db,
rules: HashMap::new(),
}
}
pub fn add_rule(&mut self, edge_label: &str, rule: PolygraphRule) {
self.rules.insert(edge_label.to_string(), rule);
}
pub fn investigate(&self, vector_property: &str) -> Result<Vec<Contradiction>> {
let mut contradictions = Vec::new();
if self.rules.is_empty() {
return Ok(contradictions);
}
let scan_results = self.db.query().scan(None).execute(self.db)?;
for row_result in scan_results {
let row = row_result?;
if let Some(node) = row.entity.as_node() {
let source_vec_opt = node
.get_property(vector_property)
.and_then(|p| p.as_vector())
.map(|v| v.to_vec());
let source_vec = match source_vec_opt {
Some(v) => v,
None => continue,
};
let outgoing_edges = self.db.get_outgoing_edges(node.id);
for edge_id in outgoing_edges {
if let Ok(edge) = self.db.get_edge(edge_id) {
if let Some(&rule) = self.rules.get(edge.label.to_string().as_str()) {
if let Ok(target_node) = self.db.get_node(edge.target) {
if let Some(target_vec) = target_node
.properties
.get(vector_property)
.and_then(|p| p.as_vector())
{
let sim = cosine_similarity(&source_vec, target_vec)?;
let is_violation = match rule {
PolygraphRule::RequiresHighSimilarity(threshold) => {
sim < threshold
}
PolygraphRule::RequiresLowSimilarity(threshold) => {
sim > threshold
}
};
if is_violation {
contradictions.push(Contradiction {
edge_id,
source: node.id,
target: edge.target,
label: edge.label.to_string(),
actual_similarity: sim,
violated_rule: rule,
});
}
}
}
}
}
}
}
}
Ok(contradictions)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
#[test]
fn test_polygraph_detects_lies() {
let db = AletheiaDB::new().unwrap();
let a = db
.create_node(
"Person",
PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build(),
)
.unwrap();
let b = db
.create_node(
"Person",
PropertyMapBuilder::new()
.insert_vector("vec", &[0.9, 0.1])
.build(),
)
.unwrap();
let c = db
.create_node(
"Person",
PropertyMapBuilder::new()
.insert_vector("vec", &[-1.0, 0.0])
.build(),
)
.unwrap();
db.create_edge(a, b, "FRIEND", Default::default()).unwrap();
let lie_edge = db.create_edge(a, c, "FRIEND", Default::default()).unwrap();
let lie_edge_2 = db.create_edge(a, b, "ENEMY", Default::default()).unwrap();
db.create_edge(a, c, "ENEMY", Default::default()).unwrap();
let mut polygraph = Polygraph::new(&db);
polygraph.add_rule("FRIEND", PolygraphRule::RequiresHighSimilarity(0.8));
polygraph.add_rule("ENEMY", PolygraphRule::RequiresLowSimilarity(0.2));
let contradictions = polygraph.investigate("vec").unwrap();
assert_eq!(contradictions.len(), 2);
let has_friend_lie = contradictions
.iter()
.any(|c| c.edge_id == lie_edge && c.label == "FRIEND");
let has_enemy_lie = contradictions
.iter()
.any(|c| c.edge_id == lie_edge_2 && c.label == "ENEMY");
assert!(has_friend_lie);
assert!(has_enemy_lie);
}
}