use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct TrialId(u64);
impl TrialId {
pub const fn new(id: u64) -> Self {
Self(id)
}
pub const fn get(self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NextTrial {
pub id: TrialId,
pub params: Params,
pub next_step: Option<u64>,
}
impl NextTrial {
pub fn evaluated(&self, values: Values, current_step: u64) -> EvaluatedTrial {
EvaluatedTrial {
id: self.id,
values,
current_step,
}
}
pub fn unevaluable(&self) -> EvaluatedTrial {
self.evaluated(Values::new(Vec::new()), self.next_step.unwrap_or(0))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvaluatedTrial {
pub id: TrialId,
pub values: Values,
pub current_step: u64,
}
#[derive(Debug)]
pub struct IdGen {
next: u64,
}
impl IdGen {
pub const fn new() -> Self {
Self { next: 0 }
}
pub const fn from_next_id(next: u64) -> Self {
Self { next }
}
pub fn generate(&mut self) -> TrialId {
let id = TrialId(self.next);
self.next += 1;
id
}
pub fn peek_id(&self) -> TrialId {
TrialId(self.next)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Params(#[serde(with = "nullable_f64_vec")] Vec<f64>);
impl Params {
pub const fn new(params: Vec<f64>) -> Self {
Self(params)
}
pub fn into_vec(self) -> Vec<f64> {
self.0
}
pub fn get(&self) -> &[f64] {
&self.0
}
fn ordered_floats(&self) -> impl '_ + Iterator<Item = OrderedFloat<f64>> {
self.0.iter().copied().map(OrderedFloat)
}
}
impl PartialEq for Params {
fn eq(&self, other: &Self) -> bool {
self.ordered_floats().eq(other.ordered_floats())
}
}
impl Eq for Params {}
impl Hash for Params {
fn hash<H: Hasher>(&self, hasher: &mut H) {
for x in self.ordered_floats() {
x.hash(hasher);
}
}
}
impl Deref for Params {
type Target = [f64];
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Values(Vec<f64>);
impl Values {
pub const fn new(values: Vec<f64>) -> Self {
Self(values)
}
pub fn into_vec(self) -> Vec<f64> {
self.0
}
fn ordered_floats(&self) -> impl '_ + Iterator<Item = OrderedFloat<f64>> {
self.0.iter().copied().map(OrderedFloat)
}
}
impl PartialEq for Values {
fn eq(&self, other: &Self) -> bool {
self.ordered_floats().eq(other.ordered_floats())
}
}
impl PartialOrd for Values {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let mut ord = None;
for (a, b) in self.0.iter().zip(other.0.iter()) {
if ord == None {
ord = a.partial_cmp(b);
if ord == None {
return None;
}
} else if ord != a.partial_cmp(b) {
return None;
}
}
if ord == None {
Some(Ordering::Equal) } else {
ord
}
}
}
impl Eq for Values {}
impl Hash for Values {
fn hash<H: Hasher>(&self, hasher: &mut H) {
for x in self.ordered_floats() {
x.hash(hasher);
}
}
}
impl Deref for Values {
type Target = [f64];
fn deref(&self) -> &Self::Target {
&self.0
}
}
mod nullable_f64_vec {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::f64::NAN;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where
D: Deserializer<'de>,
{
let v: Vec<Option<f64>> = Deserialize::deserialize(deserializer)?;
Ok(v.into_iter()
.map(|v| if let Some(v) = v { v } else { NAN })
.collect())
}
pub fn serialize<S>(v: &[f64], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let v = v
.iter()
.map(|v| if v.is_finite() { Some(*v) } else { None })
.collect::<Vec<_>>();
v.serialize(serializer)
}
}