1use uuid::Uuid;
4
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct TrainingExample {
15 pub id: Uuid,
17
18 pub context_embeddings: Vec<Vec<f32>>,
23
24 pub message_embedding: Vec<f32>,
26
27 pub labels: IntentLabels,
29
30 pub metadata: ExampleMetadata,
32}
33
34impl TrainingExample {
35 pub fn new(
37 context_embeddings: Vec<Vec<f32>>,
38 message_embedding: Vec<f32>,
39 labels: IntentLabels,
40 ) -> Self {
41 Self {
42 id: Uuid::new_v4(),
43 context_embeddings,
44 message_embedding,
45 labels,
46 metadata: ExampleMetadata::default(),
47 }
48 }
49
50 pub fn with_id(
52 id: Uuid,
53 context_embeddings: Vec<Vec<f32>>,
54 message_embedding: Vec<f32>,
55 labels: IntentLabels,
56 ) -> Self {
57 Self {
58 id,
59 context_embeddings,
60 message_embedding,
61 labels,
62 metadata: ExampleMetadata::default(),
63 }
64 }
65
66 pub fn with_metadata(mut self, metadata: ExampleMetadata) -> Self {
68 self.metadata = metadata;
69 self
70 }
71
72 pub fn embedding_dim(&self) -> usize {
74 self.message_embedding.len()
75 }
76
77 pub fn context_size(&self) -> usize {
79 self.context_embeddings.len()
80 }
81
82 pub fn validate(&self) -> Result<(), String> {
84 if self.message_embedding.is_empty() {
85 return Err("Message embedding cannot be empty".to_string());
86 }
87
88 let dim = self.embedding_dim();
89 for (i, ctx_emb) in self.context_embeddings.iter().enumerate() {
90 if ctx_emb.len() != dim {
91 return Err(format!(
92 "Context embedding {} has dimension {} but expected {}",
93 i,
94 ctx_emb.len(),
95 dim
96 ));
97 }
98 }
99
100 self.labels.validate()?;
101 Ok(())
102 }
103
104 pub fn dominant_intent(&self) -> (&'static str, f32) {
106 self.labels.dominant()
107 }
108}
109
110#[derive(Debug, Clone, Default)]
115#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
116pub struct IntentLabels {
117 pub continuation: f32,
119
120 pub topic_shift: f32,
122
123 pub explicit_query: f32,
125
126 pub person_lookup: f32,
128
129 pub health_check: f32,
131
132 pub task_status: f32,
134}
135
136impl IntentLabels {
137 pub fn continuation(prob: f32) -> Self {
139 Self {
140 continuation: prob,
141 ..Default::default()
142 }
143 }
144
145 pub fn topic_shift(prob: f32) -> Self {
147 Self {
148 topic_shift: prob,
149 ..Default::default()
150 }
151 }
152
153 pub fn explicit_query(prob: f32) -> Self {
155 Self {
156 explicit_query: prob,
157 ..Default::default()
158 }
159 }
160
161 pub fn person_lookup(prob: f32) -> Self {
163 Self {
164 person_lookup: prob,
165 ..Default::default()
166 }
167 }
168
169 pub fn health_check(prob: f32) -> Self {
171 Self {
172 health_check: prob,
173 ..Default::default()
174 }
175 }
176
177 pub fn task_status(prob: f32) -> Self {
179 Self {
180 task_status: prob,
181 ..Default::default()
182 }
183 }
184
185 pub fn from_vec(probs: &[f32]) -> Self {
189 Self {
190 continuation: probs.first().copied().unwrap_or(0.0),
191 topic_shift: probs.get(1).copied().unwrap_or(0.0),
192 explicit_query: probs.get(2).copied().unwrap_or(0.0),
193 person_lookup: probs.get(3).copied().unwrap_or(0.0),
194 health_check: probs.get(4).copied().unwrap_or(0.0),
195 task_status: probs.get(5).copied().unwrap_or(0.0),
196 }
197 }
198
199 pub fn to_vec(&self) -> Vec<f32> {
203 vec![
204 self.continuation,
205 self.topic_shift,
206 self.explicit_query,
207 self.person_lookup,
208 self.health_check,
209 self.task_status,
210 ]
211 }
212
213 pub const NUM_CLASSES: usize = 6;
215
216 pub fn class_names() -> &'static [&'static str] {
218 &[
219 "continuation",
220 "topic_shift",
221 "explicit_query",
222 "person_lookup",
223 "health_check",
224 "task_status",
225 ]
226 }
227
228 pub fn dominant(&self) -> (&'static str, f32) {
230 let probs = self.to_vec();
231 let names = Self::class_names();
232 let (idx, &prob) = probs
233 .iter()
234 .enumerate()
235 .max_by(|(_, a), (_, b)| a.total_cmp(b))
236 .unwrap_or((0, &0.0));
237 (names[idx], prob)
238 }
239
240 pub fn validate(&self) -> Result<(), String> {
242 let probs = self.to_vec();
243 for (i, &p) in probs.iter().enumerate() {
244 if !(0.0..=1.0).contains(&p) {
245 return Err(format!(
246 "Invalid probability for {}: {} (must be in [0, 1])",
247 Self::class_names()[i],
248 p
249 ));
250 }
251 }
252 Ok(())
253 }
254
255 pub fn softmax_normalize(&mut self) -> Result<(), String> {
258 let probs = self.to_vec();
259 if let Some(pos) = probs.iter().position(|v| !v.is_finite()) {
260 return Err(format!(
261 "non-finite value {} at index {} in softmax input",
262 probs[pos], pos
263 ));
264 }
265 let max_val = probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
266 let exp_sum: f32 = probs.iter().map(|&p| (p - max_val).exp()).sum();
267
268 self.continuation = ((self.continuation - max_val).exp()) / exp_sum;
269 self.topic_shift = ((self.topic_shift - max_val).exp()) / exp_sum;
270 self.explicit_query = ((self.explicit_query - max_val).exp()) / exp_sum;
271 self.person_lookup = ((self.person_lookup - max_val).exp()) / exp_sum;
272 self.health_check = ((self.health_check - max_val).exp()) / exp_sum;
273 self.task_status = ((self.task_status - max_val).exp()) / exp_sum;
274 Ok(())
275 }
276}
277
278#[derive(Debug, Clone, Default)]
282#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
283pub struct ExampleMetadata {
284 pub source_id: Option<String>,
286
287 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
289
290 pub teacher_model: Option<String>,
292
293 pub labeled_at: Option<chrono::DateTime<chrono::Utc>>,
295
296 pub teacher_confidence: Option<f32>,
298
299 #[cfg(feature = "serde")]
301 pub extra: Option<serde_json::Value>,
302
303 #[cfg(not(feature = "serde"))]
305 pub extra: Option<String>,
306}
307
308impl ExampleMetadata {
309 pub fn with_source(source_id: impl Into<String>) -> Self {
311 Self {
312 source_id: Some(source_id.into()),
313 ..Default::default()
314 }
315 }
316
317 pub fn teacher(mut self, model: impl Into<String>) -> Self {
319 self.teacher_model = Some(model.into());
320 self
321 }
322
323 pub fn timestamp(mut self, ts: chrono::DateTime<chrono::Utc>) -> Self {
325 self.timestamp = Some(ts);
326 self
327 }
328
329 pub fn labeled_at(mut self, ts: chrono::DateTime<chrono::Utc>) -> Self {
331 self.labeled_at = Some(ts);
332 self
333 }
334
335 pub fn confidence(mut self, conf: f32) -> Self {
337 self.teacher_confidence = Some(conf);
338 self
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_intent_labels_creation() {
348 let labels = IntentLabels::continuation(0.8);
349 assert_eq!(labels.continuation, 0.8);
350 assert_eq!(labels.topic_shift, 0.0);
351 }
352
353 #[test]
354 fn test_intent_labels_dominant() {
355 let labels = IntentLabels {
356 continuation: 0.1,
357 topic_shift: 0.2,
358 explicit_query: 0.5,
359 person_lookup: 0.1,
360 health_check: 0.05,
361 task_status: 0.05,
362 };
363 let (name, prob) = labels.dominant();
364 assert_eq!(name, "explicit_query");
365 assert_eq!(prob, 0.5);
366 }
367
368 #[test]
369 fn test_intent_labels_validation() {
370 let valid = IntentLabels::continuation(0.8);
371 assert!(valid.validate().is_ok());
372
373 let invalid = IntentLabels {
374 continuation: 1.5,
375 ..Default::default()
376 };
377 assert!(invalid.validate().is_err());
378 }
379
380 #[test]
381 fn test_training_example_creation() {
382 let example = TrainingExample::new(
383 vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
384 vec![0.7, 0.8, 0.9],
385 IntentLabels::explicit_query(0.9),
386 );
387
388 assert_eq!(example.embedding_dim(), 3);
389 assert_eq!(example.context_size(), 2);
390 assert!(example.validate().is_ok());
391 }
392
393 #[test]
394 fn test_training_example_validation() {
395 let example = TrainingExample::new(
397 vec![vec![0.1, 0.2]], vec![0.7, 0.8, 0.9], IntentLabels::default(),
400 );
401
402 assert!(example.validate().is_err());
403 }
404
405 #[test]
406 fn test_softmax_normalize() {
407 let mut labels = IntentLabels {
408 continuation: 2.0,
409 topic_shift: 1.0,
410 explicit_query: 0.5,
411 person_lookup: 0.0,
412 health_check: 0.0,
413 task_status: 0.0,
414 };
415 labels.softmax_normalize().expect("test inputs are finite");
416
417 let sum: f32 = labels.to_vec().iter().sum();
419 assert!((sum - 1.0).abs() < 0.001);
420
421 let (name, _) = labels.dominant();
423 assert_eq!(name, "continuation");
424 }
425}