use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use solverforge_core::Score;
#[derive(Clone)]
pub struct ConstraintWeightOverrides<Sc: Score> {
weights: HashMap<String, Sc>,
}
impl<Sc: Score> Debug for ConstraintWeightOverrides<Sc> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConstraintWeightOverrides")
.field("count", &self.weights.len())
.finish()
}
}
impl<Sc: Score> Default for ConstraintWeightOverrides<Sc> {
fn default() -> Self {
Self::new()
}
}
impl<Sc: Score> ConstraintWeightOverrides<Sc> {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
}
}
pub fn from_pairs<I, N>(iter: I) -> Self
where
I: IntoIterator<Item = (N, Sc)>,
N: Into<String>,
{
let weights = iter.into_iter().map(|(n, w)| (n.into(), w)).collect();
Self { weights }
}
pub fn put<N: Into<String>>(&mut self, name: N, weight: Sc) {
self.weights.insert(name.into(), weight);
}
pub fn remove(&mut self, name: &str) -> Option<Sc> {
self.weights.remove(name)
}
pub fn get_or_default(&self, name: &str, default: Sc) -> Sc {
self.weights.get(name).cloned().unwrap_or(default)
}
pub fn get(&self, name: &str) -> Option<&Sc> {
self.weights.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.weights.contains_key(name)
}
pub fn len(&self) -> usize {
self.weights.len()
}
pub fn is_empty(&self) -> bool {
self.weights.is_empty()
}
pub fn clear(&mut self) {
self.weights.clear();
}
pub fn into_arc(self) -> Arc<Self> {
Arc::new(self)
}
}
pub trait WeightProvider<Sc: Score>: Send + Sync {
fn weight(&self, name: &str) -> Option<Sc>;
fn weight_or_default(&self, name: &str, default: Sc) -> Sc {
self.weight(name).unwrap_or(default)
}
}
impl<Sc: Score> WeightProvider<Sc> for ConstraintWeightOverrides<Sc> {
fn weight(&self, name: &str) -> Option<Sc> {
self.get(name).cloned()
}
}
impl<Sc: Score> WeightProvider<Sc> for Arc<ConstraintWeightOverrides<Sc>> {
fn weight(&self, name: &str) -> Option<Sc> {
self.get(name).cloned()
}
}