use std::collections::BTreeMap;
use std::sync::Arc;
use async_trait::async_trait;
use jiff::Timestamp;
use crate::{ParameterName, TrialId, Value};
use crate::trial::Trial;
use serde::{Deserialize, Serialize};
use crate::element::Element;
use crate::error::Result;
use crate::lifecycle::{LiveStatusSummary, StateTransition};
#[derive(Debug, Clone)]
pub struct TrialContext {
pub trial_id: TrialId,
pub trial: Arc<Trial>,
pub timestamp: Timestamp,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ResolvedConfiguration(BTreeMap<ParameterName, Value>);
impl ResolvedConfiguration {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: ParameterName, value: Value) -> Option<Value> {
self.0.insert(name, value)
}
#[must_use]
pub fn get(&self, name: &ParameterName) -> Option<&Value> {
self.0.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&ParameterName, &Value)> {
self.0.iter()
}
#[must_use]
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl FromIterator<(ParameterName, Value)> for ResolvedConfiguration {
fn from_iter<I: IntoIterator<Item = (ParameterName, Value)>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct MaterializationOutputs(BTreeMap<ParameterName, Value>);
impl MaterializationOutputs {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: ParameterName, value: Value) -> Option<Value> {
self.0.insert(name, value)
}
#[must_use]
pub fn get(&self, name: &ParameterName) -> Option<&Value> {
self.0.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&ParameterName, &Value)> {
self.0.iter()
}
#[must_use]
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl FromIterator<(ParameterName, Value)> for MaterializationOutputs {
fn from_iter<I: IntoIterator<Item = (ParameterName, Value)>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}
pub type StateTransitionListener = Box<dyn Fn(StateTransition) + Send + Sync + 'static>;
pub trait StateObservation: Send + Sync + 'static {
fn cancel(&self);
}
#[async_trait]
pub trait ElementRuntime: Send + Sync + 'static {
async fn materialize(
&self,
resolved: &ResolvedConfiguration,
) -> Result<MaterializationOutputs>;
async fn dematerialize(&self) -> Result<()>;
async fn status_check(&self) -> LiveStatusSummary;
async fn on_trial_starting(&self, _ctx: &TrialContext) -> Result<()> {
Ok(())
}
async fn on_trial_ending(&self, _ctx: &TrialContext) -> Result<()> {
Ok(())
}
fn observe_state(
&self,
listener: StateTransitionListener,
) -> Box<dyn StateObservation>;
}
pub trait ElementRuntimeRegistry: Send + Sync + std::fmt::Debug + 'static {
fn runtime_for(&self, element: &Element) -> Result<Arc<dyn ElementRuntime>>;
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use ulid::Ulid;
use super::*;
use crate::lifecycle::OperationalState;
fn tid() -> TrialId {
TrialId::from_ulid(Ulid::from_parts(1_700_000_000_000, 1))
}
fn trial() -> Trial {
Trial::builder()
.id(tid())
.assignments(crate::Assignments::empty())
.build()
}
#[derive(Debug)]
struct MockRuntime {
materialized: AtomicBool,
trial_starts: AtomicUsize,
}
#[async_trait]
impl ElementRuntime for MockRuntime {
async fn materialize(
&self,
_resolved: &ResolvedConfiguration,
) -> Result<MaterializationOutputs> {
self.materialized.store(true, Ordering::SeqCst);
Ok(MaterializationOutputs::new())
}
async fn dematerialize(&self) -> Result<()> {
self.materialized.store(false, Ordering::SeqCst);
Ok(())
}
async fn status_check(&self) -> LiveStatusSummary {
LiveStatusSummary {
state: if self.materialized.load(Ordering::SeqCst) {
OperationalState::Ready
} else {
OperationalState::Inactive
},
summary: "mock".to_owned(),
}
}
async fn on_trial_starting(&self, _ctx: &TrialContext) -> Result<()> {
self.trial_starts.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn observe_state(
&self,
_listener: StateTransitionListener,
) -> Box<dyn StateObservation> {
Box::new(NoopObservation)
}
}
#[derive(Debug)]
struct NoopObservation;
impl StateObservation for NoopObservation {
fn cancel(&self) {}
}
#[tokio::test]
async fn mock_runtime_materialize_and_status_check() {
let rt = MockRuntime {
materialized: AtomicBool::new(false),
trial_starts: AtomicUsize::new(0),
};
let r = rt.status_check().await;
assert_eq!(r.state, OperationalState::Inactive);
rt.materialize(&ResolvedConfiguration::new()).await.unwrap();
let r = rt.status_check().await;
assert_eq!(r.state, OperationalState::Ready);
rt.dematerialize().await.unwrap();
}
#[tokio::test]
async fn mock_runtime_trial_hooks_dispatch() {
let rt = MockRuntime {
materialized: AtomicBool::new(false),
trial_starts: AtomicUsize::new(0),
};
let ctx = TrialContext {
trial_id: tid(),
trial: Arc::new(trial()),
timestamp: Timestamp::from_second(0).unwrap(),
};
rt.on_trial_starting(&ctx).await.unwrap();
rt.on_trial_ending(&ctx).await.unwrap();
assert_eq!(rt.trial_starts.load(Ordering::SeqCst), 1);
}
#[test]
fn resolved_configuration_iter_is_sorted() {
let mut rc = ResolvedConfiguration::new();
rc.insert(
ParameterName::new("zebra").unwrap(),
Value::integer(ParameterName::new("zebra").unwrap(), 1, None),
);
rc.insert(
ParameterName::new("apple").unwrap(),
Value::integer(ParameterName::new("apple").unwrap(), 2, None),
);
let names: Vec<&str> = rc.iter().map(|(n, _)| n.as_str()).collect();
assert_eq!(names, vec!["apple", "zebra"]);
}
#[test]
fn materialization_outputs_serde_roundtrip() {
let mut o = MaterializationOutputs::new();
o.insert(
ParameterName::new("endpoint").unwrap(),
Value::string(
ParameterName::new("endpoint").unwrap(),
"http://example:4567",
None,
),
);
let json = serde_json::to_string(&o).unwrap();
let back: MaterializationOutputs = serde_json::from_str(&json).unwrap();
assert_eq!(o, back);
}
}