use std::{
collections::{HashMap, HashSet},
fs::File,
hash::Hash,
io::{Result as IoResult, Write as IoWrite},
};
use crate::set::{Collection, Poset};
#[derive(Debug, Clone)]
pub struct LatticeNode<T> {
element: T,
successors: HashSet<T>,
predecessors: HashSet<T>,
}
#[derive(Debug, Default, Clone)]
pub struct Lattice<T> {
nodes: HashMap<T, LatticeNode<T>>,
}
impl<T: Hash + Eq + Clone> Lattice<T> {
pub fn new() -> Self { Self { nodes: HashMap::new() } }
pub fn add_element(&mut self, element: T) {
if !self.nodes.contains_key(&element) {
self.nodes.insert(element.clone(), LatticeNode {
element,
successors: HashSet::new(),
predecessors: HashSet::new(),
});
}
}
pub fn add_relation(&mut self, a: T, b: T) {
self.add_element(a.clone());
self.add_element(b.clone());
if let Some(node_a) = self.nodes.get_mut(&a) {
node_a.successors.insert(b.clone());
}
if let Some(node_b) = self.nodes.get_mut(&b) {
node_b.predecessors.insert(a);
}
self.compute_transitive_closure();
}
fn compute_transitive_closure(&mut self) {
let mut changed = true;
while changed {
changed = false;
let mut updates = Vec::new();
for node in self.nodes.values() {
for succ in &node.successors {
if let Some(succ_node) = self.nodes.get(succ) {
for succ_succ in &succ_node.successors {
updates.push((node.element.clone(), succ_succ.clone()));
}
}
}
}
for (a, b) in updates {
if let Some(node_a) = self.nodes.get_mut(&a) {
if node_a.successors.insert(b.clone()) {
changed = true;
}
}
if let Some(node_b) = self.nodes.get_mut(&b) {
if node_b.predecessors.insert(a) {
changed = true;
}
}
}
}
}
}
impl<T: Hash + Eq + Clone> Collection for Lattice<T> {
type Item = T;
fn contains(&self, point: &Self::Item) -> bool { self.nodes.contains_key(point) }
fn is_empty(&self) -> bool { self.nodes.is_empty() }
}
impl<T: Hash + Eq + Clone> Poset for Lattice<T> {
fn leq(&self, a: &T, b: &T) -> Option<bool> {
if !self.nodes.contains_key(a) || !self.nodes.contains_key(b) {
return None;
}
if a == b {
return Some(true);
}
let node_a = self.nodes.get(a).unwrap();
Some(node_a.successors.contains(b))
}
fn minimal_elements(&self) -> HashSet<T> {
self
.nodes
.iter()
.filter(|(_, node)| node.predecessors.is_empty())
.map(|(element, _)| element.clone())
.collect()
}
fn maximal_elements(&self) -> HashSet<T> {
self
.nodes
.iter()
.filter(|(_, node)| node.successors.is_empty())
.map(|(element, _)| element.clone())
.collect()
}
fn join(&self, a: T, b: T) -> Option<T> {
if !self.nodes.contains_key(&a) || !self.nodes.contains_key(&b) {
return None; }
let node_a = self.nodes.get(&a).unwrap();
let node_b = self.nodes.get(&b).unwrap();
let mut upper_bounds_a = node_a.successors.iter().cloned().collect::<HashSet<T>>();
upper_bounds_a.insert(a.clone());
let mut upper_bounds_b = node_b.successors.iter().cloned().collect::<HashSet<T>>();
upper_bounds_b.insert(b.clone());
let common_upper_bounds: HashSet<T> =
upper_bounds_a.intersection(&upper_bounds_b).cloned().collect();
if common_upper_bounds.is_empty() {
return None;
}
let minimal_common_upper_bounds: Vec<T> = common_upper_bounds
.iter()
.filter(|&x| common_upper_bounds.iter().all(|y| x == y || !self.leq(y, x).unwrap_or(false)))
.cloned()
.collect();
if minimal_common_upper_bounds.len() == 1 {
Some(minimal_common_upper_bounds[0].clone())
} else {
None
}
}
fn meet(&self, a: T, b: T) -> Option<T> {
if !self.nodes.contains_key(&a) || !self.nodes.contains_key(&b) {
return None; }
let node_a = self.nodes.get(&a).unwrap();
let node_b = self.nodes.get(&b).unwrap();
let mut lower_bounds_a = node_a.predecessors.iter().cloned().collect::<HashSet<T>>();
lower_bounds_a.insert(a.clone());
let mut lower_bounds_b = node_b.predecessors.iter().cloned().collect::<HashSet<T>>();
lower_bounds_b.insert(b.clone());
let common_lower_bounds: HashSet<T> =
lower_bounds_a.intersection(&lower_bounds_b).cloned().collect();
if common_lower_bounds.is_empty() {
return None;
}
let maximal_common_lower_bounds: Vec<T> = common_lower_bounds
.iter()
.filter(|&x| common_lower_bounds.iter().all(|y| x == y || !self.leq(x, y).unwrap_or(false)))
.cloned()
.collect();
if maximal_common_lower_bounds.len() == 1 {
Some(maximal_common_lower_bounds[0].clone())
} else {
None
}
}
fn downset(&self, a: T) -> HashSet<T> {
self
.nodes
.iter()
.filter(|(_, node)| self.leq(&node.element, &a).unwrap_or(false))
.map(|(element, _)| element.clone())
.collect()
}
fn upset(&self, a: T) -> HashSet<T> {
self
.nodes
.iter()
.filter(|(_, node)| self.leq(&a, &node.element).unwrap_or(false))
.map(|(element, _)| element.clone())
.collect()
}
fn successors(&self, a: T) -> HashSet<T> {
self.nodes.get(&a).map_or_else(HashSet::new, |node_a| {
let all_successors = &node_a.successors;
all_successors
.iter()
.filter(|&b| {
!all_successors.iter().any(|c| {
c != b && self.nodes.get(c).is_some_and(|node_c| node_c.successors.contains(b))
})
})
.cloned()
.collect()
})
}
fn predecessors(&self, a: T) -> HashSet<T> {
self.nodes.get(&a).map_or_else(HashSet::new, |node_a| {
let all_predecessors = &node_a.predecessors;
all_predecessors
.iter()
.filter(|&b| {
!all_predecessors.iter().any(|c| {
c != b && self.nodes.get(c).is_some_and(|node_c| node_c.predecessors.contains(b))
})
})
.cloned()
.collect()
})
}
}
fn escape_dot_label(label: &str) -> String { label.replace('"', "\\\"") }
impl<T: Hash + Eq + Clone + std::fmt::Display + Ord> Lattice<T> {
pub fn save_to_dot_file(&self, filename: &str) -> IoResult<()> {
let mut file = File::create(filename)?;
if self.nodes.is_empty() {
return writeln!(file, "digraph Lattice {{\n label=\"Empty Lattice\";\n}}");
}
writeln!(file, "digraph Lattice {{")?;
writeln!(file, " rankdir=\"BT\";")?;
writeln!(file, " node [shape=plaintext];")?;
let mut sorted_node_keys: Vec<&T> = self.nodes.keys().collect();
sorted_node_keys.sort();
for node_key_ptr in &sorted_node_keys {
let node_key = *node_key_ptr; writeln!(file, " \"{}\";", escape_dot_label(&node_key.to_string()))?;
}
writeln!(file)?;
for source_key_ptr in &sorted_node_keys {
let source_key = *source_key_ptr; if let Some(node) = self.nodes.get(source_key) {
let mut sorted_successors: Vec<&T> = node.successors.iter().collect();
sorted_successors.sort();
for succ_key in sorted_successors {
let mut is_immediate = true;
let mut inner_sorted_successors_for_check: Vec<&T> = node.successors.iter().collect();
inner_sorted_successors_for_check.sort();
for intermediate_key in inner_sorted_successors_for_check {
if intermediate_key == succ_key {
continue;
}
if let Some(intermediate_node_w) = self.nodes.get(intermediate_key) {
if intermediate_node_w.successors.contains(succ_key) {
is_immediate = false;
break;
}
}
}
if is_immediate {
writeln!(
file,
" \"{}\" -> \"{}\";",
escape_dot_label(&source_key.to_string()), escape_dot_label(&succ_key.to_string()) )?;
}
}
}
}
writeln!(file, "}}")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn m_lattice() -> Lattice<i32> {
let mut m_lattice: Lattice<i32> = Lattice::new();
m_lattice.add_element(1);
m_lattice.add_element(2);
m_lattice.add_element(3);
m_lattice.add_element(5);
m_lattice.add_element(6);
m_lattice.add_relation(1, 5);
m_lattice.add_relation(2, 5);
m_lattice.add_relation(2, 6);
m_lattice.add_relation(3, 6);
m_lattice
}
fn diamond_lattice() -> Lattice<i32> {
let mut diamond_lattice: Lattice<i32> = Lattice::new();
diamond_lattice.add_element(1);
diamond_lattice.add_element(2);
diamond_lattice.add_element(3);
diamond_lattice.add_element(4);
diamond_lattice.add_relation(1, 2);
diamond_lattice.add_relation(1, 3);
diamond_lattice.add_relation(2, 4);
diamond_lattice.add_relation(3, 4);
diamond_lattice
}
#[test]
fn test_basic_lattice() {
let mut lattice = Lattice::new();
lattice.add_relation(1, 2);
lattice.add_relation(2, 3);
assert!(lattice.leq(&1, &2).unwrap_or(false));
assert!(lattice.leq(&2, &3).unwrap_or(false));
assert!(lattice.leq(&1, &3).unwrap_or(false));
assert!(!lattice.leq(&2, &1).unwrap_or(false));
let minimal = lattice.minimal_elements();
assert_eq!(minimal.len(), 1);
assert!(minimal.contains(&1));
let maximal = lattice.maximal_elements();
assert_eq!(maximal.len(), 1);
assert!(maximal.contains(&3));
}
#[test]
fn test_diamond_lattice() {
let lattice = diamond_lattice();
assert!(lattice.leq(&1, &4).unwrap_or(false));
assert!(!lattice.leq(&2, &3).unwrap_or(false));
assert!(!lattice.leq(&3, &2).unwrap_or(false));
let minimal = lattice.minimal_elements();
assert_eq!(minimal.len(), 1);
assert!(minimal.contains(&1));
let maximal = lattice.maximal_elements();
assert_eq!(maximal.len(), 1);
assert!(maximal.contains(&4));
}
#[test]
fn test_lattice_operations_diamond() {
let mut lattice = Lattice::new();
lattice.add_relation(1, 2);
lattice.add_relation(1, 3);
lattice.add_relation(2, 4);
lattice.add_relation(3, 4);
println!("join(2, 3): {:?}", lattice.join(2, 3));
assert_eq!(lattice.join(2, 3), Some(4));
println!("join(1, 4): {:?}", lattice.join(1, 4));
assert_eq!(lattice.join(1, 4), Some(4));
println!("join(1, 2): {:?}", lattice.join(1, 2));
assert_eq!(lattice.join(1, 2), Some(2));
println!("join(1, 1): {:?}", lattice.join(1, 1));
assert_eq!(lattice.join(1, 1), Some(1));
println!("meet(2, 3): {:?}", lattice.meet(2, 3));
assert_eq!(lattice.meet(2, 3), Some(1));
println!("meet(1, 4): {:?}", lattice.meet(1, 4));
assert_eq!(lattice.meet(1, 4), Some(1));
println!("meet(2, 4): {:?}", lattice.meet(2, 4));
assert_eq!(lattice.meet(2, 4), Some(2));
println!("meet(4, 4): {:?}", lattice.meet(4, 4));
assert_eq!(lattice.meet(4, 4), Some(4));
}
#[test]
fn test_lattice_operations_non_lattice_examples() {
let m_lattice = m_lattice();
println!("join(1, 2) for M-shape: {:?}", m_lattice.join(1, 2));
assert_eq!(m_lattice.join(1, 2), Some(5));
println!("join(2, 3) for M-shape: {:?}", m_lattice.join(2, 3));
assert_eq!(m_lattice.join(2, 3), Some(6));
println!("join(1, 3) for M-shape: {:?}", m_lattice.join(1, 3));
assert_eq!(m_lattice.join(1, 3), None);
let mut non_join_lattice: Lattice<&str> = Lattice::new();
non_join_lattice.add_relation("a", "c");
non_join_lattice.add_relation("a", "d");
non_join_lattice.add_relation("b", "c");
non_join_lattice.add_relation("b", "d");
assert_eq!(non_join_lattice.join("a", "b"), None);
assert_eq!(non_join_lattice.meet("c", "d"), None);
let mut non_meet_lattice: Lattice<&str> = Lattice::new();
non_meet_lattice.add_relation("c", "a");
non_meet_lattice.add_relation("d", "a");
non_meet_lattice.add_relation("c", "b");
non_meet_lattice.add_relation("d", "b");
assert_eq!(non_meet_lattice.meet("a", "b"), None);
assert_eq!(non_meet_lattice.join("c", "d"), None);
}
#[test]
fn test_graphviz_output() {
let temp_dir = tempfile::tempdir().unwrap();
let temp_path = temp_dir.path().join("test_m_shape_lattice.dot");
let filename = temp_path.to_str().unwrap();
println!("--- M-shape Example - Saving to {filename} ---");
m_lattice().save_to_dot_file(filename).expect("Failed to save M-shape lattice");
assert!(temp_path.exists());
drop(temp_dir);
}
#[test]
#[ignore = "Manual test to see output of M-shape lattice"]
fn test_graphviz_output_manual() {
let filename = "test_m_shape_lattice.dot";
println!("--- M-shape Example - Saving to {filename} ---");
m_lattice().save_to_dot_file(filename).expect("Failed to save M-shape lattice");
}
#[test]
fn test_downset() {
let lattice = diamond_lattice();
let downset = lattice.downset(4);
assert_eq!(downset, HashSet::from([1, 2, 3, 4]));
let downset = lattice.downset(2);
assert_eq!(downset, HashSet::from([1, 2]));
let downset = lattice.downset(1);
assert_eq!(downset, HashSet::from([1]));
}
#[test]
fn test_upset() {
let lattice = diamond_lattice();
let upset = lattice.upset(4);
assert_eq!(upset, HashSet::from([4]));
let upset = lattice.upset(2);
assert_eq!(upset, HashSet::from([2, 4]));
let upset = lattice.upset(1);
assert_eq!(upset, HashSet::from([1, 2, 3, 4]));
}
#[test]
fn test_successors() {
let lattice = diamond_lattice();
let successors = lattice.successors(1);
assert_eq!(successors, HashSet::from([2, 3]));
}
#[test]
fn test_predecessors() {
let lattice = diamond_lattice();
let predecessors = lattice.predecessors(4);
assert_eq!(predecessors, HashSet::from([2, 3]));
}
}