use rez_next_package::{Package, PackageRequirement};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct DependencyConflict {
pub package_name: String,
pub conflicting_requirements: Vec<String>,
pub severity_bits: u64,
pub conflict_type: ConflictType,
}
impl DependencyConflict {
pub fn severity(&self) -> f64 {
f64::from_bits(self.severity_bits)
}
pub fn new(
package_name: String,
conflicting_requirements: Vec<String>,
severity: f64,
conflict_type: ConflictType,
) -> Self {
Self {
package_name,
conflicting_requirements,
severity_bits: severity.to_bits(),
conflict_type,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ConflictType {
VersionConflict,
CircularDependency,
MissingPackage,
PlatformConflict,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchState {
pub resolved_packages: HashMap<String, Package>,
pub pending_requirements: Vec<PackageRequirement>,
pub conflicts: Vec<DependencyConflict>,
pub cost_so_far: f64,
pub estimated_total_cost: f64,
pub depth: usize,
pub parent_id: Option<u64>,
pub state_id: u64,
state_hash: u64,
}
impl SearchState {
pub fn new_initial(requirements: Vec<PackageRequirement>) -> Self {
let mut state = Self {
resolved_packages: HashMap::new(),
pending_requirements: requirements,
conflicts: Vec::new(),
cost_so_far: 0.0,
estimated_total_cost: 0.0,
depth: 0,
parent_id: None,
state_id: 0,
state_hash: 0,
};
state.update_hash();
state.state_id = state.state_hash;
state
}
pub fn new_from_parent(
parent: &SearchState,
resolved_package: Package,
new_requirements: Vec<PackageRequirement>,
additional_cost: f64,
) -> Self {
let mut resolved_packages = parent.resolved_packages.clone();
resolved_packages.insert(resolved_package.name.clone(), resolved_package);
let mut pending_requirements = parent.pending_requirements.clone();
pending_requirements.extend(new_requirements);
let mut state = Self {
resolved_packages,
pending_requirements,
conflicts: parent.conflicts.clone(),
cost_so_far: parent.cost_so_far + additional_cost,
estimated_total_cost: 0.0,
depth: parent.depth + 1,
parent_id: Some(parent.state_id),
state_id: 0,
state_hash: 0,
};
state.update_hash();
state.state_id = state.state_hash;
state
}
pub fn is_goal(&self) -> bool {
self.pending_requirements.is_empty() && self.conflicts.is_empty()
}
pub fn is_valid(&self) -> bool {
for conflict in &self.conflicts {
match conflict.conflict_type {
ConflictType::MissingPackage => return false,
ConflictType::CircularDependency => return false,
_ => {}
}
}
true
}
pub fn get_next_requirement(&self) -> Option<&PackageRequirement> {
self.pending_requirements.first()
}
pub fn add_conflict(&mut self, conflict: DependencyConflict) {
self.conflicts.push(conflict);
self.update_hash();
}
pub fn remove_requirement(&mut self, requirement: &PackageRequirement) {
self.pending_requirements
.retain(|req| req.name != requirement.name);
self.update_hash();
}
fn update_hash(&mut self) {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
let mut package_names: Vec<_> = self.resolved_packages.keys().collect();
package_names.sort();
for name in package_names {
name.hash(&mut hasher);
if let Some(pkg) = self.resolved_packages.get(name) {
if let Some(ref ver) = pkg.version {
ver.as_str().hash(&mut hasher);
}
}
}
let mut req_strings: Vec<String> = self
.pending_requirements
.iter()
.map(|req| req.to_string())
.collect();
req_strings.sort();
for req_str in &req_strings {
req_str.hash(&mut hasher);
}
for conflict in &self.conflicts {
conflict.package_name.hash(&mut hasher);
conflict.conflict_type.hash(&mut hasher);
}
self.state_hash = hasher.finish();
}
pub fn get_hash(&self) -> u64 {
self.state_hash
}
pub fn calculate_complexity(&self) -> usize {
self.resolved_packages.len() + self.pending_requirements.len() + self.conflicts.len() * 2
}
}
impl PartialEq for SearchState {
fn eq(&self, other: &Self) -> bool {
self.state_hash == other.state_hash
}
}
impl Eq for SearchState {}
impl Hash for SearchState {
fn hash<H: Hasher>(&self, state: &mut H) {
self.state_hash.hash(state);
}
}
impl PartialOrd for SearchState {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchState {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.state_hash.cmp(&other.state_hash)
}
}
#[derive(Debug)]
pub struct OrdByEstimatedCost(pub SearchState);
impl PartialEq for OrdByEstimatedCost {
fn eq(&self, other: &Self) -> bool {
self.0.estimated_total_cost.to_bits() == other.0.estimated_total_cost.to_bits()
&& self.0.state_hash == other.0.state_hash
}
}
impl Eq for OrdByEstimatedCost {}
impl PartialOrd for OrdByEstimatedCost {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrdByEstimatedCost {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let self_bits = self.0.estimated_total_cost.to_bits();
let other_bits = other.0.estimated_total_cost.to_bits();
other_bits
.cmp(&self_bits)
.then_with(|| self.0.state_hash.cmp(&other.0.state_hash))
}
}
pub struct StatePool {
pool: Vec<SearchState>,
max_size: usize,
}
impl StatePool {
pub fn new(max_size: usize) -> Self {
Self {
pool: Vec::with_capacity(max_size),
max_size,
}
}
pub fn get_state(&mut self) -> SearchState {
self.pool
.pop()
.unwrap_or_else(|| SearchState::new_initial(Vec::new()))
}
pub fn return_state(&mut self, mut state: SearchState) {
if self.pool.len() < self.max_size {
state.resolved_packages.clear();
state.pending_requirements.clear();
state.conflicts.clear();
state.cost_so_far = 0.0;
state.estimated_total_cost = 0.0;
state.depth = 0;
state.parent_id = None;
state.state_id = 0;
state.state_hash = 0;
self.pool.push(state);
}
}
pub fn size(&self) -> usize {
self.pool.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rez_next_package::PackageRequirement;
fn make_req(name: &str) -> PackageRequirement {
PackageRequirement::new(name.to_string())
}
#[test]
fn test_search_state_creation() {
let req = make_req("test_package");
let state = SearchState::new_initial(vec![req]);
assert_eq!(state.pending_requirements.len(), 1);
assert_eq!(state.resolved_packages.len(), 0);
assert_eq!(state.depth, 0);
assert!(!state.is_goal());
}
#[test]
fn test_state_hash_consistency() {
let req = make_req("test_package");
let state1 = SearchState::new_initial(vec![req.clone()]);
let state2 = SearchState::new_initial(vec![req]);
assert_eq!(state1.get_hash(), state2.get_hash());
assert_eq!(state1, state2);
}
#[test]
fn test_goal_state_empty_requirements() {
let state = SearchState::new_initial(vec![]);
assert!(
state.is_goal(),
"Empty requirements with no conflicts is a goal"
);
}
#[test]
fn test_state_validity_missing_package_conflict() {
let mut state = SearchState::new_initial(vec![]);
state.add_conflict(DependencyConflict::new(
"missing_pkg".to_string(),
vec![],
1.0,
ConflictType::MissingPackage,
));
assert!(
!state.is_valid(),
"MissingPackage conflict makes state invalid"
);
}
#[test]
fn test_state_validity_version_conflict_still_valid() {
let mut state = SearchState::new_initial(vec![]);
state.add_conflict(DependencyConflict::new(
"pkg".to_string(),
vec!["pkg-1.0".to_string(), "pkg-2.0".to_string()],
0.5,
ConflictType::VersionConflict,
));
assert!(
state.is_valid(),
"VersionConflict alone does not invalidate state"
);
}
#[test]
fn test_state_pool() {
let mut pool = StatePool::new(5);
assert_eq!(pool.size(), 0);
let state = pool.get_state();
pool.return_state(state);
assert_eq!(pool.size(), 1);
let _state = pool.get_state();
assert_eq!(pool.size(), 0);
}
#[test]
fn test_state_ordering_lower_cost_higher_priority() {
let mut s1 = SearchState::new_initial(vec![]);
s1.estimated_total_cost = 5.0;
let mut s2 = SearchState::new_initial(vec![make_req("x")]);
s2.estimated_total_cost = 10.0;
use std::collections::BinaryHeap;
let mut heap: BinaryHeap<OrdByEstimatedCost> = BinaryHeap::new();
heap.push(OrdByEstimatedCost(s1));
heap.push(OrdByEstimatedCost(s2));
let top = heap.pop().unwrap().0;
assert_eq!(
top.estimated_total_cost, 5.0,
"Lower cost should have higher priority"
);
}
#[test]
fn test_remove_requirement() {
let req1 = make_req("pkg_a");
let req2 = make_req("pkg_b");
let mut state = SearchState::new_initial(vec![req1.clone(), req2.clone()]);
assert_eq!(state.pending_requirements.len(), 2);
state.remove_requirement(&req1);
assert_eq!(state.pending_requirements.len(), 1);
assert_eq!(state.pending_requirements[0].name, "pkg_b");
}
}