1use converge_pack::{ExecutionIdentity, FactPayload};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21#[serde(deny_unknown_fields)]
22pub struct ClassificationFeaturesPayload {
23 pub features: Vec<f64>,
24}
25
26impl ClassificationFeaturesPayload {
27 #[must_use]
28 pub fn new(features: Vec<f64>) -> Self {
29 Self { features }
30 }
31
32 #[must_use]
33 pub fn features(&self) -> &[f64] {
34 &self.features
35 }
36}
37
38impl FactPayload for ClassificationFeaturesPayload {
39 const FAMILY: &'static str = "crucible.classification.features";
40 const VERSION: u16 = 1;
41}
42
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57#[serde(deny_unknown_fields)]
58pub struct ClassPredictionPayload {
59 pub predicted_class: usize,
60 pub class_probabilities: Vec<f64>,
61 pub execution_identity: ExecutionIdentity,
62}
63
64impl ClassPredictionPayload {
65 #[must_use]
66 pub fn new(
67 predicted_class: usize,
68 class_probabilities: Vec<f64>,
69 execution_identity: ExecutionIdentity,
70 ) -> Self {
71 Self {
72 predicted_class,
73 class_probabilities,
74 execution_identity,
75 }
76 }
77
78 #[must_use]
79 pub fn predicted_class(&self) -> usize {
80 self.predicted_class
81 }
82
83 #[must_use]
84 pub fn class_probabilities(&self) -> &[f64] {
85 &self.class_probabilities
86 }
87
88 #[must_use]
89 pub fn execution_identity(&self) -> &ExecutionIdentity {
90 &self.execution_identity
91 }
92}
93
94impl FactPayload for ClassPredictionPayload {
95 const FAMILY: &'static str = "crucible.classification.prediction";
96 const VERSION: u16 = 2;
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn features_payload_round_trips() {
105 let p = ClassificationFeaturesPayload::new(vec![1.0, 2.0, 3.0]);
106 assert_eq!(p.features(), &[1.0, 2.0, 3.0]);
107 let json = serde_json::to_string(&p).unwrap();
108 let back: ClassificationFeaturesPayload = serde_json::from_str(&json).unwrap();
109 assert_eq!(back, p);
110 }
111
112 fn fixture_identity() -> ExecutionIdentity {
113 ExecutionIdentity::non_native(
114 "test-producer",
115 "0.0.0",
116 "fixture-backend",
117 "{\"hyperparam\":42}",
118 )
119 }
120
121 #[test]
122 fn prediction_payload_round_trips() {
123 let p = ClassPredictionPayload::new(1, vec![0.2, 0.8], fixture_identity());
124 assert_eq!(p.predicted_class(), 1);
125 assert_eq!(p.class_probabilities(), &[0.2, 0.8]);
126 assert_eq!(p.execution_identity().backend, "fixture-backend");
127 let json = serde_json::to_string(&p).unwrap();
128 let back: ClassPredictionPayload = serde_json::from_str(&json).unwrap();
129 assert_eq!(back, p);
130 }
131
132 #[test]
133 fn payload_family_strings_are_stable() {
134 assert_eq!(
135 ClassificationFeaturesPayload::FAMILY,
136 "crucible.classification.features"
137 );
138 assert_eq!(
139 ClassPredictionPayload::FAMILY,
140 "crucible.classification.prediction"
141 );
142 assert_eq!(ClassificationFeaturesPayload::VERSION, 1);
143 assert_eq!(ClassPredictionPayload::VERSION, 2);
144 }
145}