metal_candle/lib.rs
1//! metal-candle: Production-quality Rust ML for Apple Silicon
2//!
3//! `metal-candle` is a machine learning library built on [Candle](https://github.com/huggingface/candle)
4//! with Metal backend, providing `LoRA` training, model loading, and text generation
5//! for transformer models on Apple Silicon.
6//!
7//! # Features
8//!
9//! - **`LoRA` Training**: Fine-tune transformer models efficiently using Low-Rank Adaptation
10//! - **Model Loading**: Support for safetensors format with extensibility for others
11//! - **Text Generation**: High-level [`Generator`] API with streaming, repetition penalty, and stop conditions
12//! - **Sampling Strategies**: Greedy, Top-k, Top-p (nucleus), and Temperature sampling
13//! - **Metal Acceleration**: Native Metal backend for optimal Apple Silicon performance
14//!
15//! # Examples
16//!
17//! ## Text Generation
18//!
19//! ```
20//! use metal_candle::inference::{GeneratorConfig, SamplingStrategy};
21//!
22//! // Configure generation
23//! let gen_config = GeneratorConfig {
24//! max_tokens: 128,
25//! sampling: SamplingStrategy::TopP { p: 0.95 },
26//! repetition_penalty: 1.1,
27//! ..Default::default()
28//! };
29//!
30//! // With a loaded model, you would use:
31//! // let mut generator = Generator::new(Box::new(model), gen_config)?;
32//! // let output = generator.generate(&input_ids)?;
33//! ```
34//!
35//! # Project Status
36//!
37//! v1.1.0: Production-ready text generation API with comprehensive testing and documentation.
38
39// Deny unsafe code by default, but allow it where explicitly justified
40#![deny(unsafe_code)]
41#![warn(missing_docs)]
42#![warn(clippy::pedantic)]
43
44pub mod backend;
45pub mod error;
46#[cfg(feature = "graph")]
47pub mod graph;
48pub mod inference;
49pub mod models;
50pub mod training;
51
52#[cfg(feature = "embeddings")]
53pub mod embeddings;
54
55// Re-export key types for convenience
56pub use backend::{Device, DeviceInfo, DeviceType, TensorExt};
57pub use error::{Error, Result};
58pub use inference::{
59 apply_repetition_penalty, sample_token, Generator, GeneratorConfig, KVCache, KVCacheConfig,
60 SamplingStrategy,
61};
62pub use training::{
63 cross_entropy_loss, cross_entropy_loss_with_smoothing, AdamW, AdamWConfig, LRScheduler,
64 LoRAAdapter, LoRAAdapterConfig, LoRAConfig, LoRALayer, StepMetrics, TargetModule, Trainer,
65 TrainingConfig, TrainingStep,
66};
67
68/// Current version of the crate
69pub const VERSION: &str = env!("CARGO_PKG_VERSION");
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn test_version_exists() {
77 // VERSION is a compile-time constant from CARGO_PKG_VERSION
78 assert_eq!(VERSION, env!("CARGO_PKG_VERSION"));
79 }
80}