Skip to main content

lattice_tune/data/
example.rs

1//! Training example types
2
3use uuid::Uuid;
4
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8/// A single training example for intent classification
9///
10/// Contains the input embeddings, soft labels from teacher, and metadata
11/// for traceability and debugging.
12#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct TrainingExample {
15    /// Unique identifier for this example
16    pub id: Uuid,
17
18    /// Context embeddings from the last N messages
19    ///
20    /// Each inner Vec represents an embedding vector for one message.
21    /// Order: oldest to newest (chronological).
22    pub context_embeddings: Vec<Vec<f32>>,
23
24    /// Embedding of the current message to classify
25    pub message_embedding: Vec<f32>,
26
27    /// Soft labels from teacher model
28    pub labels: IntentLabels,
29
30    /// Metadata about this example
31    pub metadata: ExampleMetadata,
32}
33
34impl TrainingExample {
35    /// Create a new training example with minimal metadata
36    pub fn new(
37        context_embeddings: Vec<Vec<f32>>,
38        message_embedding: Vec<f32>,
39        labels: IntentLabels,
40    ) -> Self {
41        Self {
42            id: Uuid::new_v4(),
43            context_embeddings,
44            message_embedding,
45            labels,
46            metadata: ExampleMetadata::default(),
47        }
48    }
49
50    /// Create a new training example with specific ID
51    pub fn with_id(
52        id: Uuid,
53        context_embeddings: Vec<Vec<f32>>,
54        message_embedding: Vec<f32>,
55        labels: IntentLabels,
56    ) -> Self {
57        Self {
58            id,
59            context_embeddings,
60            message_embedding,
61            labels,
62            metadata: ExampleMetadata::default(),
63        }
64    }
65
66    /// Set metadata for this example
67    pub fn with_metadata(mut self, metadata: ExampleMetadata) -> Self {
68        self.metadata = metadata;
69        self
70    }
71
72    /// Get the embedding dimension (from message embedding)
73    pub fn embedding_dim(&self) -> usize {
74        self.message_embedding.len()
75    }
76
77    /// Get the context window size (number of context messages)
78    pub fn context_size(&self) -> usize {
79        self.context_embeddings.len()
80    }
81
82    /// Validate the example structure
83    pub fn validate(&self) -> Result<(), String> {
84        if self.message_embedding.is_empty() {
85            return Err("Message embedding cannot be empty".to_string());
86        }
87
88        let dim = self.embedding_dim();
89        for (i, ctx_emb) in self.context_embeddings.iter().enumerate() {
90            if ctx_emb.len() != dim {
91                return Err(format!(
92                    "Context embedding {} has dimension {} but expected {}",
93                    i,
94                    ctx_emb.len(),
95                    dim
96                ));
97            }
98        }
99
100        self.labels.validate()?;
101        Ok(())
102    }
103
104    /// Get the dominant intent (highest probability label)
105    pub fn dominant_intent(&self) -> (&'static str, f32) {
106        self.labels.dominant()
107    }
108}
109
110/// Soft labels for intent classification
111///
112/// Each field represents the probability of that intent class.
113/// Values should be in [0, 1] and typically sum to ~1.0 (soft labels may not sum exactly).
114#[derive(Debug, Clone, Default)]
115#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
116pub struct IntentLabels {
117    /// Probability of continuation (continue conversation naturally)
118    pub continuation: f32,
119
120    /// Probability of topic shift (user changing subject)
121    pub topic_shift: f32,
122
123    /// Probability of explicit query (direct question/request)
124    pub explicit_query: f32,
125
126    /// Probability of person lookup (looking up contact/person info)
127    pub person_lookup: f32,
128
129    /// Probability of health check (health/wellness related)
130    pub health_check: f32,
131
132    /// Probability of task status (checking task/todo status)
133    pub task_status: f32,
134}
135
136impl IntentLabels {
137    /// Create labels with dominant continuation intent
138    pub fn continuation(prob: f32) -> Self {
139        Self {
140            continuation: prob,
141            ..Default::default()
142        }
143    }
144
145    /// Create labels with dominant topic_shift intent
146    pub fn topic_shift(prob: f32) -> Self {
147        Self {
148            topic_shift: prob,
149            ..Default::default()
150        }
151    }
152
153    /// Create labels with dominant explicit_query intent
154    pub fn explicit_query(prob: f32) -> Self {
155        Self {
156            explicit_query: prob,
157            ..Default::default()
158        }
159    }
160
161    /// Create labels with dominant person_lookup intent
162    pub fn person_lookup(prob: f32) -> Self {
163        Self {
164            person_lookup: prob,
165            ..Default::default()
166        }
167    }
168
169    /// Create labels with dominant health_check intent
170    pub fn health_check(prob: f32) -> Self {
171        Self {
172            health_check: prob,
173            ..Default::default()
174        }
175    }
176
177    /// Create labels with dominant task_status intent
178    pub fn task_status(prob: f32) -> Self {
179        Self {
180            task_status: prob,
181            ..Default::default()
182        }
183    }
184
185    /// Create labels from a probability vector
186    ///
187    /// Order: [continuation, topic_shift, explicit_query, person_lookup, health_check, task_status]
188    pub fn from_vec(probs: &[f32]) -> Self {
189        Self {
190            continuation: probs.first().copied().unwrap_or(0.0),
191            topic_shift: probs.get(1).copied().unwrap_or(0.0),
192            explicit_query: probs.get(2).copied().unwrap_or(0.0),
193            person_lookup: probs.get(3).copied().unwrap_or(0.0),
194            health_check: probs.get(4).copied().unwrap_or(0.0),
195            task_status: probs.get(5).copied().unwrap_or(0.0),
196        }
197    }
198
199    /// Convert to probability vector
200    ///
201    /// Order: [continuation, topic_shift, explicit_query, person_lookup, health_check, task_status]
202    pub fn to_vec(&self) -> Vec<f32> {
203        vec![
204            self.continuation,
205            self.topic_shift,
206            self.explicit_query,
207            self.person_lookup,
208            self.health_check,
209            self.task_status,
210        ]
211    }
212
213    /// Number of intent classes
214    pub const NUM_CLASSES: usize = 6;
215
216    /// Get all intent names
217    pub fn class_names() -> &'static [&'static str] {
218        &[
219            "continuation",
220            "topic_shift",
221            "explicit_query",
222            "person_lookup",
223            "health_check",
224            "task_status",
225        ]
226    }
227
228    /// Get the dominant intent (highest probability)
229    pub fn dominant(&self) -> (&'static str, f32) {
230        let probs = self.to_vec();
231        let names = Self::class_names();
232        let (idx, &prob) = probs
233            .iter()
234            .enumerate()
235            .max_by(|(_, a), (_, b)| a.total_cmp(b))
236            .unwrap_or((0, &0.0));
237        (names[idx], prob)
238    }
239
240    /// Validate that all probabilities are in [0, 1]
241    pub fn validate(&self) -> Result<(), String> {
242        let probs = self.to_vec();
243        for (i, &p) in probs.iter().enumerate() {
244            if !(0.0..=1.0).contains(&p) {
245                return Err(format!(
246                    "Invalid probability for {}: {} (must be in [0, 1])",
247                    Self::class_names()[i],
248                    p
249                ));
250            }
251        }
252        Ok(())
253    }
254
255    /// Apply softmax normalization.
256    /// Returns an error if any input is non-finite.
257    pub fn softmax_normalize(&mut self) -> Result<(), String> {
258        let probs = self.to_vec();
259        if let Some(pos) = probs.iter().position(|v| !v.is_finite()) {
260            return Err(format!(
261                "non-finite value {} at index {} in softmax input",
262                probs[pos], pos
263            ));
264        }
265        let max_val = probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
266        let exp_sum: f32 = probs.iter().map(|&p| (p - max_val).exp()).sum();
267
268        self.continuation = ((self.continuation - max_val).exp()) / exp_sum;
269        self.topic_shift = ((self.topic_shift - max_val).exp()) / exp_sum;
270        self.explicit_query = ((self.explicit_query - max_val).exp()) / exp_sum;
271        self.person_lookup = ((self.person_lookup - max_val).exp()) / exp_sum;
272        self.health_check = ((self.health_check - max_val).exp()) / exp_sum;
273        self.task_status = ((self.task_status - max_val).exp()) / exp_sum;
274        Ok(())
275    }
276}
277
278/// Metadata about a training example
279///
280/// Provides traceability back to the source data.
281#[derive(Debug, Clone, Default)]
282#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
283pub struct ExampleMetadata {
284    /// Source conversation or session ID
285    pub source_id: Option<String>,
286
287    /// Timestamp when the original message was created
288    pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
289
290    /// Teacher model that generated the labels
291    pub teacher_model: Option<String>,
292
293    /// Teacher generation timestamp
294    pub labeled_at: Option<chrono::DateTime<chrono::Utc>>,
295
296    /// Confidence of the teacher's labels (0-1)
297    pub teacher_confidence: Option<f32>,
298
299    /// Additional metadata as key-value pairs
300    #[cfg(feature = "serde")]
301    pub extra: Option<serde_json::Value>,
302
303    /// Additional metadata as key-value pairs (non-serde fallback)
304    #[cfg(not(feature = "serde"))]
305    pub extra: Option<String>,
306}
307
308impl ExampleMetadata {
309    /// Create metadata with source ID
310    pub fn with_source(source_id: impl Into<String>) -> Self {
311        Self {
312            source_id: Some(source_id.into()),
313            ..Default::default()
314        }
315    }
316
317    /// Set the teacher model
318    pub fn teacher(mut self, model: impl Into<String>) -> Self {
319        self.teacher_model = Some(model.into());
320        self
321    }
322
323    /// Set the timestamp
324    pub fn timestamp(mut self, ts: chrono::DateTime<chrono::Utc>) -> Self {
325        self.timestamp = Some(ts);
326        self
327    }
328
329    /// Set the labeled_at timestamp
330    pub fn labeled_at(mut self, ts: chrono::DateTime<chrono::Utc>) -> Self {
331        self.labeled_at = Some(ts);
332        self
333    }
334
335    /// Set teacher confidence
336    pub fn confidence(mut self, conf: f32) -> Self {
337        self.teacher_confidence = Some(conf);
338        self
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_intent_labels_creation() {
348        let labels = IntentLabels::continuation(0.8);
349        assert_eq!(labels.continuation, 0.8);
350        assert_eq!(labels.topic_shift, 0.0);
351    }
352
353    #[test]
354    fn test_intent_labels_dominant() {
355        let labels = IntentLabels {
356            continuation: 0.1,
357            topic_shift: 0.2,
358            explicit_query: 0.5,
359            person_lookup: 0.1,
360            health_check: 0.05,
361            task_status: 0.05,
362        };
363        let (name, prob) = labels.dominant();
364        assert_eq!(name, "explicit_query");
365        assert_eq!(prob, 0.5);
366    }
367
368    #[test]
369    fn test_intent_labels_validation() {
370        let valid = IntentLabels::continuation(0.8);
371        assert!(valid.validate().is_ok());
372
373        let invalid = IntentLabels {
374            continuation: 1.5,
375            ..Default::default()
376        };
377        assert!(invalid.validate().is_err());
378    }
379
380    #[test]
381    fn test_training_example_creation() {
382        let example = TrainingExample::new(
383            vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
384            vec![0.7, 0.8, 0.9],
385            IntentLabels::explicit_query(0.9),
386        );
387
388        assert_eq!(example.embedding_dim(), 3);
389        assert_eq!(example.context_size(), 2);
390        assert!(example.validate().is_ok());
391    }
392
393    #[test]
394    fn test_training_example_validation() {
395        // Dimension mismatch should fail
396        let example = TrainingExample::new(
397            vec![vec![0.1, 0.2]], // 2D
398            vec![0.7, 0.8, 0.9],  // 3D - mismatch!
399            IntentLabels::default(),
400        );
401
402        assert!(example.validate().is_err());
403    }
404
405    #[test]
406    fn test_softmax_normalize() {
407        let mut labels = IntentLabels {
408            continuation: 2.0,
409            topic_shift: 1.0,
410            explicit_query: 0.5,
411            person_lookup: 0.0,
412            health_check: 0.0,
413            task_status: 0.0,
414        };
415        labels.softmax_normalize().expect("test inputs are finite");
416
417        // Sum should be approximately 1.0
418        let sum: f32 = labels.to_vec().iter().sum();
419        assert!((sum - 1.0).abs() < 0.001);
420
421        // continuation should have highest probability
422        let (name, _) = labels.dominant();
423        assert_eq!(name, "continuation");
424    }
425}