Skip to main content

lattice_tune/
lib.rs

1#![allow(unused_imports)]
2#![allow(clippy::needless_borrows_for_generic_args)]
3#![allow(clippy::field_reassign_with_default)]
4
5//! lattice-tune - Training infrastructure for Lattice neural models
6//!
7//! Provides a complete pipeline for training neural networks through knowledge distillation:
8//!
9//! - **Data**: Training examples, datasets, and batching
10//! - **Distill**: Knowledge distillation from teacher models (Claude, GPT, Gemini)
11//! - **Train**: Training loop, optimization, and checkpointing
12//! - **Registry**: Model versioning, storage, and deployment tracking
13//!
14//! # Architecture
15//!
16//! ```text
17//! Raw Data → Teacher (LLM) → Soft Labels → Dataset → Training → Model → Registry
18//!                                                        ↓
19//!                                                   Deployment
20//! ```
21//!
22//! # Quick Start
23//!
24//! ```rust
25//! use lattice_tune::data::{TrainingExample, IntentLabels, Dataset, DatasetConfig};
26//!
27//! // Create training examples
28//! let examples = vec![
29//!     TrainingExample::new(
30//!         vec![vec![0.1, 0.2, 0.3]],  // context embeddings
31//!         vec![0.4, 0.5, 0.6],        // message embedding
32//!         IntentLabels::continuation(0.8),
33//!     ),
34//! ];
35//!
36//! // Create a dataset
37//! let dataset = Dataset::from_examples(examples);
38//! let stats = dataset.stats();
39//! println!("Dataset has {} examples", stats.num_examples);
40//! ```
41//!
42//! # Distillation Example
43//!
44//! ```ignore
45//! use lattice_tune::distill::{TeacherConfig, DistillationPipeline, RawExample};
46//!
47//! // Configure teacher model
48//! let teacher = TeacherConfig::claude_sonnet();
49//!
50//! // Create distillation pipeline
51//! let mut pipeline = DistillationPipeline::with_teacher(teacher)?;
52//!
53//! // Create raw examples (text, not embeddings)
54//! let raw = RawExample::new(
55//!     vec!["Hello".to_string(), "How are you?".to_string()],
56//!     "What's the weather like?",
57//! );
58//!
59//! // Label with teacher
60//! let result = pipeline.label_single(&raw)?;
61//! println!("Labeled with confidence: {}", result.confidence);
62//! ```
63//!
64//! # Training Example
65//!
66//! ```ignore
67//! use lattice_tune::train::{TrainingConfig, TrainingLoop};
68//! use lattice_tune::data::Dataset;
69//!
70//! // Configure training
71//! let config = TrainingConfig::default()
72//!     .epochs(100)
73//!     .batch_size(32)
74//!     .learning_rate(0.001);
75//!
76//! // Train
77//! let mut trainer = TrainingLoop::new(config)?;
78//! let metrics = trainer.train(&mut dataset)?;
79//!
80//! println!("Final loss: {:.4}", metrics.final_train_loss);
81//! ```
82//!
83//! # Registry Example
84//!
85//! ```rust
86//! use lattice_tune::registry::{ModelRegistry, RegisteredModel, ModelMetadata};
87//!
88//! // Create a registry
89//! let registry = ModelRegistry::in_memory();
90//!
91//! // Register a model
92//! let metadata = ModelMetadata::classifier(768, 6, 10000);
93//! let model = RegisteredModel::new("intent_classifier", "1.0.0")
94//!     .with_metadata(metadata)
95//!     .with_description("Intent classification model");
96//!
97//! let weights = vec![0u8; 1000]; // Model weights
98//! let id = registry.register(model, &weights).unwrap();
99//!
100//! // Retrieve the model
101//! let loaded = registry.get("intent_classifier", "1.0.0").unwrap();
102//! println!("Loaded: {}", loaded.full_name());
103//! ```
104//!
105//! # Design Principles
106//!
107//! 1. **Data-first**: Well-defined training example format with full traceability
108//! 2. **Modular**: Distillation, training, and registry are separate concerns
109//! 3. **Extensible**: Support different teacher models (Claude, GPT, Gemini)
110//! 4. **Traceable**: All models have version, training config, and metrics
111//!
112//! # Feature Flags
113//!
114//! - `std` (default): Standard library support
115//! - `serde`: Serialization support for all types
116
117#![warn(missing_docs)]
118
119pub mod data;
120pub mod distill;
121pub mod error;
122pub mod lora;
123pub mod registry;
124pub mod train;
125
126// Re-exports for convenience
127pub use error::{Result, TuneError};
128
129// Data re-exports
130pub use data::{
131    Batch, Dataset, DatasetConfig, DatasetStats, ExampleMetadata, IntentLabels, TrainingExample,
132};
133
134// Distill re-exports
135pub use distill::{
136    DistillationConfig, DistillationPipeline, DistillationStats, EndpointSecurity, LabelingResult,
137    TeacherConfig, TeacherConfigBuilder, TeacherProvider,
138};
139
140// Train re-exports
141pub use train::{
142    Checkpoint, EarlyStopping, EpochMetrics, JitAdapter, JitConfig, JitResult, JitStrategy,
143    LRSchedule, LoggingCallback, NoOpCallback, Optimizer, OptimizerConfig, RegularizationConfig,
144    TrainingCallback, TrainingConfig, TrainingLoop, TrainingMetrics, TrainingState, freeze,
145};
146
147// GPU training re-exports (when feature enabled)
148#[cfg(feature = "gpu")]
149pub use train::{GpuTrainer, GpuTrainerBuilder};
150
151// LoRA re-exports
152pub use lora::{LoraAdapter, LoraConfig, LoraLayer};
153
154// Registry re-exports
155pub use registry::{
156    LiveModel, ModelMetadata, ModelQuery, ModelRegistry, ModelStatus, RegisteredModel,
157    RollbackController, RollbackRecord, ShadowComparison, ShadowConfig, ShadowSession, ShadowState,
158    StorageBackend,
159};
160
161/// Prelude module for common imports
162pub mod prelude {
163    pub use crate::data::{
164        Batch, Dataset, DatasetConfig, DatasetStats, ExampleMetadata, IntentLabels, TrainingExample,
165    };
166    pub use crate::distill::{
167        DistillationConfig, DistillationPipeline, DistillationStats, EndpointSecurity,
168        LabelingResult, TeacherConfig, TeacherProvider,
169    };
170    pub use crate::error::{Result, TuneError};
171    pub use crate::lora::{LoraAdapter, LoraConfig, LoraLayer};
172    pub use crate::registry::{
173        LiveModel, ModelMetadata, ModelRegistry, ModelStatus, RegisteredModel, RollbackController,
174        RollbackRecord, ShadowComparison, ShadowConfig, ShadowSession, ShadowState, StorageBackend,
175    };
176    pub use crate::train::{
177        Checkpoint, EarlyStopping, EpochMetrics, JitAdapter, JitConfig, JitResult, JitStrategy,
178        LRSchedule, Optimizer, OptimizerConfig, RegularizationConfig, TrainingCallback,
179        TrainingConfig, TrainingLoop, TrainingMetrics, TrainingState, freeze,
180    };
181
182    #[cfg(feature = "gpu")]
183    pub use crate::train::{GpuTrainer, GpuTrainerBuilder};
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_end_to_end_workflow() {
192        // 1. Create training examples
193        let examples: Vec<TrainingExample> = (0..100)
194            .map(|i| {
195                let label = match i % 6 {
196                    0 => IntentLabels::continuation(0.8),
197                    1 => IntentLabels::topic_shift(0.7),
198                    2 => IntentLabels::explicit_query(0.9),
199                    3 => IntentLabels::person_lookup(0.85),
200                    4 => IntentLabels::health_check(0.75),
201                    _ => IntentLabels::task_status(0.8),
202                };
203                TrainingExample::new(
204                    vec![vec![0.1, 0.2, 0.3]; 3], // 3 context messages
205                    vec![0.4, 0.5, 0.6],          // current message
206                    label,
207                )
208            })
209            .collect();
210
211        // 2. Create dataset
212        let mut dataset = Dataset::from_examples(examples);
213        let config = DatasetConfig::with_batch_size(16).shuffle(true).seed(42);
214        dataset.set_config(config).unwrap();
215
216        let stats = dataset.stats();
217        assert_eq!(stats.num_examples, 100);
218        assert_eq!(stats.embedding_dim, 3);
219
220        // 3. Configure training
221        let train_config = TrainingConfig::quick();
222        assert!(train_config.validate().is_ok());
223
224        // 4. Create training loop
225        let mut trainer = TrainingLoop::new(train_config).unwrap();
226
227        // 5. Train (using placeholder implementation)
228        let metrics = trainer.train(&mut dataset).unwrap();
229        assert!(metrics.epochs_completed > 0);
230
231        // 6. Create model for registry
232        let metadata = ModelMetadata::classifier(3, 6, 1000)
233            .dataset("test_dataset", 100)
234            .training_metrics(metrics);
235
236        let model = RegisteredModel::new("intent_classifier", "0.1.0")
237            .with_metadata(metadata)
238            .with_description("Test model from end-to-end workflow");
239
240        // 7. Register in registry
241        let registry = ModelRegistry::in_memory();
242        let weights = vec![0u8; 1000];
243        let id = registry.register(model, &weights).unwrap();
244
245        // 8. Verify registration
246        let loaded = registry.get("intent_classifier", "0.1.0").unwrap();
247        assert_eq!(loaded.id, id);
248        assert_eq!(loaded.metadata.num_training_examples, 100);
249    }
250
251    #[test]
252    fn test_distillation_workflow() {
253        // 1. Create raw examples
254        let raw = distill::RawExample::new(
255            vec!["Hello".to_string(), "How are you?".to_string()],
256            "What's the weather like?",
257        );
258
259        // 2. Verify prompt generation
260        let prompt = raw.to_prompt();
261        assert!(prompt.contains("Context"));
262        assert!(prompt.contains("weather"));
263
264        // 3. Create teacher config
265        let teacher = TeacherConfig::claude_sonnet();
266        assert!(teacher.validate().is_ok());
267
268        // 4. Create pipeline
269        let mut pipeline = DistillationPipeline::with_teacher(teacher).unwrap();
270
271        // 5. Label (placeholder)
272        let result = pipeline.label_single(&raw).unwrap();
273        assert!(result.is_success());
274        assert!(result.confidence > 0.0);
275
276        // 6. Check stats
277        let stats = pipeline.stats();
278        assert_eq!(stats.successful, 1);
279    }
280}