rust-ai-core 0.2.7

Shared core utilities for the rust-ai ecosystem: device selection, errors, traits, and CubeCL interop
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
# rust-ai-core Architecture

This document describes the design decisions, architectural patterns, and extension points in rust-ai-core, the foundation layer for the rust-ai ecosystem.

## Table of Contents

- [Overview]#overview
- [Design Principles]#design-principles
- [Module Architecture]#module-architecture
- [CUDA-First Device Selection]#cuda-first-device-selection
- [Error Handling Philosophy]#error-handling-philosophy
- [Trait Design]#trait-design
- [CubeCL Interop Layer]#cubecl-interop-layer
- [Extension Points]#extension-points
- [Performance Considerations]#performance-considerations
- [Future Framework Integration]#future-framework-integration

## Overview

rust-ai-core provides the shared foundation for the entire rust-ai ecosystem. It establishes:

- **Unified device selection** with CUDA-first philosophy
- **Common error types** that all crates extend
- **Trait interfaces** for configuration, quantization, and GPU dispatch
- **CubeCL interop** for Candle ↔ CubeCL tensor conversion

All rust-ai crates depend on this foundation, ensuring consistent behavior across:
- peft-rs (PEFT adapters)
- qlora-rs (4-bit quantization)
- unsloth-rs (GPU kernels)
- axolotl-rs (fine-tuning orchestration)
- trit-vsa, bitnet-quantize, vsa-optim-rs, tritter-accel (specialized components)

## Design Principles

### 1. CUDA-First Philosophy

**Rationale**: Production AI workloads require GPU acceleration. CPU execution is a compatibility fallback, not a first-class citizen.

**Implementation**:
- Default behavior: attempt CUDA, fall back to CPU with warning
- Explicit opt-in for CPU via environment variable or config
- Warning emitted exactly once per process to avoid log spam
- Clear messaging to users about performance implications

**Why this matters**: Silent CPU fallbacks hide performance issues. Explicit warnings ensure users know when they're not getting optimal performance.

### 2. Unified Error Handling

**Rationale**: Consistent error types across all crates improve ergonomics and reduce boilerplate.

**Implementation**:
- `CoreError` provides common variants (shape mismatch, device errors, etc.)
- Crates extend via domain-specific error types that wrap `CoreError`
- Helper constructors (`invalid_config()`, `shape_mismatch()`, etc.) for ergonomics
- All errors implement `thiserror::Error` for good error messages

**Why this matters**: Users work across multiple crates. Consistent error handling reduces cognitive load.

### 3. Trait-Based Extensibility

**Rationale**: Common interfaces enable interoperability and composition.

**Implementation**:
- `ValidatableConfig` for configuration validation
- `Quantize` / `Dequantize` for quantization schemes
- `GpuDispatchable` for operations with GPU/CPU implementations
- All traits are `Send + Sync` for thread safety

**Why this matters**: Crates can be composed together seamlessly. New quantization schemes integrate easily.

### 4. Zero-Cost Abstractions

**Rationale**: Performance is critical for AI/ML workloads.

**Implementation**:
- Trait methods inline where possible
- No runtime overhead for device selection
- Direct pointer access for tensor conversions
- Feature gates (`cuda`) compile out unused code

**Why this matters**: Framework overhead must be negligible compared to computation costs.

## Module Architecture

```
rust-ai-core/
├── src/
│   ├── lib.rs          # Public API surface, re-exports
│   ├── device.rs       # CUDA-first device selection
│   ├── error.rs        # Unified error types
│   ├── traits.rs       # Common trait interfaces
│   └── cubecl/
│       ├── mod.rs      # CubeCL module exports
│       └── interop.rs  # Candle ↔ CubeCL conversion
└── tests/              # Integration tests
```

### Module Responsibilities

#### `device.rs`
- **Purpose**: Centralized device selection logic
- **Exports**: `DeviceConfig`, `get_device()`, `warn_if_cpu()`
- **Key features**:
  - Builder pattern for configuration
  - Environment variable support (with legacy compatibility)
  - One-time warning mechanism
  - Tracing integration for observability

#### `error.rs`
- **Purpose**: Common error types shared across all crates
- **Exports**: `CoreError`, `Result<T>`
- **Key features**:
  - Structured variants for common failures
  - Helper constructors for ergonomics
  - Conversion from `std::io::Error` and `candle_core::Error`
  - Non-exhaustive enum for forward compatibility

#### `traits.rs`
- **Purpose**: Common trait interfaces for interoperability
- **Exports**: `ValidatableConfig`, `Quantize`, `Dequantize`, `GpuDispatchable`
- **Key features**:
  - Generic over quantized types (`Quantize<Q>`)
  - Default implementation for `GpuDispatchable::dispatch()`
  - Runtime GPU availability checking

#### `cubecl/interop.rs`
- **Purpose**: Candle ↔ CubeCL tensor conversion
- **Exports**: `TensorBuffer`, `candle_to_cubecl_handle()`, `cubecl_to_candle_tensor()`
- **Key features**:
  - Zero-copy where possible (contiguous tensors)
  - Support for f32, f16, bf16 dtypes
  - Memory layout validation
  - Pre-allocation helpers

## CUDA-First Device Selection

### Strategy

```rust
pub fn get_device(config: &DeviceConfig) -> Result<Device> {
    if config.force_cpu {
        warn_and_return_cpu();
    }

    match Device::cuda_if_available(config.cuda_device) {
        Ok(Device::Cuda(cuda)) => Ok(Device::Cuda(cuda)),
        Ok(Device::Cpu) | Err(_) => {
            warn_and_return_cpu();
        }
        Ok(device) => Ok(device), // Metal, etc.
    }
}
```

### Decision Tree

```
Start
  │
  ├─ force_cpu set? ──Yes──> CPU + warning
  │         │
  │         No
  │         │
  ├─ CUDA available? ──Yes──> CUDA device
  │         │
  │         No
  │         │
  └────> CPU + warning
```

### Warning Strategy

Warnings are emitted **exactly once per process** using `std::sync::Once`:

```rust
static WARN_ONCE: Once = Once::new();

WARN_ONCE.call_once(|| {
    tracing::warn!("CPU device in use...");
    eprintln!("WARNING: CPU device in use...");
});
```

**Rationale**:
- Avoids log spam in hot loops
- Dual output (tracing + stderr) ensures visibility
- Clear actionable message with silencing option

### Environment Variables

Priority order (first match wins):

1. `RUST_AI_FORCE_CPU` / `RUST_AI_CUDA_DEVICE` (current)
2. `AXOLOTL_FORCE_CPU` / `AXOLOTL_CUDA_DEVICE` (legacy)
3. `VSA_OPTIM_FORCE_CPU` / `VSA_OPTIM_CUDA_DEVICE` (legacy)

**Rationale**: Backwards compatibility while migrating to unified variables.

## Error Handling Philosophy

### Error Hierarchy

```
CoreError (shared foundation)
    ├── InvalidConfig       - Validation failures
    ├── ShapeMismatch       - Tensor shape incompatibilities
    ├── DimensionMismatch   - Dimension count mismatches
    ├── DeviceNotAvailable  - Requested device unavailable
    ├── DeviceMismatch      - Tensors on different devices
    ├── OutOfMemory         - Memory allocation failures
    ├── KernelError         - GPU kernel failures
    ├── NotImplemented      - Unimplemented features
    ├── Io                  - I/O errors
    └── Candle              - Wrapped Candle errors

Domain-Specific Errors (per crate)
    └── wraps CoreError + adds crate-specific variants
```

### Design Rationale

**Non-exhaustive enum**: Future-proofs the API. New error variants can be added without breaking existing code.

**Structured variants with fields**: Programmatic error inspection. Callers can extract shape values, device names, etc.

**Helper constructors**: Ergonomics. Compare:

```rust
// Without helpers (verbose)
Err(CoreError::InvalidConfig("rank must be positive".to_string()))

// With helpers (concise)
Err(CoreError::invalid_config("rank must be positive"))
```

**Transparent Candle errors**: Preserves full error context from underlying tensor library.

### Crate Extension Pattern

Each crate defines its own error type that wraps `CoreError`:

```rust
// In peft-rs
#[derive(Error, Debug)]
pub enum PeftError {
    #[error("adapter '{0}' not found")]
    AdapterNotFound(String),

    #[error("rank {rank} exceeds max {max}")]
    RankTooLarge { rank: usize, max: usize },

    #[error(transparent)]
    Core(#[from] CoreError),
}
```

**Benefits**:
- Domain-specific variants for clarity
- Automatic conversion from `CoreError` via `#[from]`
- Consistent error handling across crates
- Type safety at crate boundaries

## Trait Design

### `ValidatableConfig`

**Purpose**: Standardized configuration validation.

**Design**:
```rust
pub trait ValidatableConfig: Clone + Send + Sync {
    fn validate(&self) -> Result<()>;
}
```

**Bounds**:
- `Clone`: Configs are typically small, copy-on-write friendly
- `Send + Sync`: Thread-safe for concurrent training

**Usage pattern**:
```rust
impl MyStruct {
    pub fn new(config: MyConfig) -> Result<Self> {
        config.validate()?;  // Validate in constructor
        Ok(Self { config })
    }
}
```

### `Quantize<Q>` / `Dequantize<Q>`

**Purpose**: Unified interface for quantization schemes.

**Design**:
```rust
pub trait Quantize<Q>: Send + Sync {
    fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Q>;
}

pub trait Dequantize<Q>: Send + Sync {
    fn dequantize(&self, quantized: &Q, device: &Device) -> Result<Tensor>;
}
```

**Generic parameter `Q`**: Quantized type (e.g., `Nf4Tensor`, `TernaryVector`, `Int8Tensor`).

**Separate traits**: Not all quantizers support dequantization (e.g., one-way compression).

**Device parameter**: Allows in-place quantization or device transfer.

**Examples**:
- **NF4 quantization** (qlora-rs): `Q = Nf4Tensor`
- **Ternary quantization** (trit-vsa): `Q = TernaryVector`
- **Int8 quantization** (bitnet-quantize): `Q = Int8Tensor`

### `GpuDispatchable`

**Purpose**: Automatic GPU/CPU dispatch for operations with both implementations.

**Design**:
```rust
pub trait GpuDispatchable: Send + Sync {
    type Input;
    type Output;

    fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;
    fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;

    fn dispatch(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
        match device {
            Device::Cuda(_) => self.dispatch_gpu(input, device),
            Device::Cpu => self.dispatch_cpu(input, device),
            _ => Err(CoreError::device_not_available(format!("{device:?}"))),
        }
    }

    fn gpu_available(&self) -> bool { /* ... */ }
}
```

**Associated types**: Flexible input/output types per operation.

**Default `dispatch()` implementation**: Implementors only write `dispatch_gpu` and `dispatch_cpu`.

**CUDA-first pattern**: GPU path first, CPU as fallback.

**Usage pattern**:
```rust
// In unsloth-rs
impl GpuDispatchable for FlashAttention {
    type Input = (Tensor, Tensor, Tensor);  // Q, K, V
    type Output = Tensor;

    fn dispatch_gpu(&self, (q, k, v): &Self::Input, device: &Device) -> Result<Tensor> {
        // Launch CubeCL flash attention kernel
    }

    fn dispatch_cpu(&self, (q, k, v): &Self::Input, device: &Device) -> Result<Tensor> {
        warn_if_cpu(device, "unsloth-rs");
        // Candle fallback implementation
    }
}

// Usage
let attn = FlashAttention::new();
let output = attn.dispatch(&(q, k, v), &device)?;  // Auto-routes
```

## CubeCL Interop Layer

### Problem Statement

Candle provides high-level tensor operations, CubeCL provides low-level GPU kernels. We need seamless conversion between them.

### Solution Architecture

```
Candle Tensor (Device::Cuda)
    ├─> candle_to_cubecl_handle()
    │       │
    │       ├─ Ensure contiguity
    │       ├─ Extract raw bytes
    │       └─> TensorBuffer { bytes, shape, dtype }
    │               │
    │               └─> client.create(&bytes) → CubeCL Handle
    ├─> CubeCL Kernel Execution
    │       │
    │       └─> CubeCL Output Handle
    │               │
    │               └─> client.read(&handle) → bytes
    │                       │
    │                       └─> TensorBuffer { bytes, shape, dtype }
    └─> cubecl_to_candle_tensor()
            └─> Candle Tensor (Device::Cuda)
```

### Memory Management

**TensorBuffer design**:
```rust
pub struct TensorBuffer {
    pub bytes: Vec<u8>,      // Raw little-endian bytes
    pub shape: Vec<usize>,   // Tensor dimensions
    pub dtype: DType,        // Element type
}
```

**Contiguity requirement**: Candle tensors must be contiguous for raw pointer access. Non-contiguous tensors are made contiguous (involves copy).

**Zero-copy path**: Contiguous tensors → direct memory access → no allocation.

**Supported dtypes**: f32, f16, bf16 (common for ML). Others return error.

### Validation

```rust
pub fn cubecl_to_candle_tensor(buffer: &TensorBuffer, device: &Device) -> Result<Tensor> {
    // 1. Device validation
    if !matches!(device, Device::Cuda(_)) {
        return Err(CoreError::invalid_config("requires CUDA device"));
    }

    // 2. Size validation
    let expected = buffer.numel() * buffer.dtype.size_in_bytes();
    if buffer.bytes.len() != expected {
        return Err(CoreError::shape_mismatch(...));
    }

    // 3. Reconstruct tensor
    match buffer.dtype {
        DType::F32 => { /* ... */ }
        DType::F16 => { /* ... */ }
        DType::BF16 => { /* ... */ }
        _ => Err(CoreError::invalid_config("unsupported dtype")),
    }
}
```

### Performance Considerations

**Pre-allocation**: Use `allocate_output_buffer()` to avoid allocations in hot loops.

```rust
// Outside loop
let mut output_buffer = allocate_output_buffer(&output_shape, DType::F32)?;

// In loop (reuse buffer)
for batch in batches {
    let result = launch_kernel(&input, &mut output_buffer)?;
    // ... use result
}
```

**Contiguity check cost**: Negligible compared to kernel launch overhead.

## Extension Points

### 1. New Quantization Schemes

To add a new quantization scheme (e.g., FP8):

```rust
// 1. Define quantized type
pub struct Fp8Tensor {
    data: Vec<u8>,
    scale: f32,
    shape: Vec<usize>,
}

// 2. Implement Quantize trait
pub struct Fp8Quantizer;

impl Quantize<Fp8Tensor> for Fp8Quantizer {
    fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Fp8Tensor> {
        // Quantization logic
    }
}

// 3. Implement Dequantize trait
impl Dequantize<Fp8Tensor> for Fp8Quantizer {
    fn dequantize(&self, quantized: &Fp8Tensor, device: &Device) -> Result<Tensor> {
        // Dequantization logic
    }
}
```

### 2. New Device Types

To support new device types (e.g., Metal):

```rust
pub fn get_device(config: &DeviceConfig) -> Result<Device> {
    // ... existing logic ...

    // Add Metal support
    #[cfg(feature = "metal")]
    if config.prefer_metal {
        if let Ok(device) = Device::metal_if_available(0) {
            return Ok(device);
        }
    }

    // ... fallback logic ...
}
```

### 3. New Error Variants

Add to `CoreError` (non-breaking due to `#[non_exhaustive]`):

```rust
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum CoreError {
    // ... existing variants ...

    #[error("quantization error: {0}")]
    QuantizationError(String),
}

impl CoreError {
    pub fn quantization(msg: impl Into<String>) -> Self {
        Self::QuantizationError(msg.into())
    }
}
```

### 4. Custom GPU Operations

Implement `GpuDispatchable` for new operations:

```rust
pub struct MyCustomOp {
    // ... state ...
}

impl GpuDispatchable for MyCustomOp {
    type Input = MyInput;
    type Output = MyOutput;

    fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
        // CubeCL kernel path
    }

    fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
        warn_if_cpu(device, "my-crate");
        // Candle fallback
    }
}
```

## Performance Considerations

### Device Selection Overhead

**Cost**: Negligible. Device selection happens once per operation, warning only once per process.

**Benchmark**: `get_device()` takes < 1μs on typical hardware.

### Error Handling Overhead

**Cost**: Zero in happy path (errors are exceptional).

**Design**: `Result<T>` is zero-cost when `Ok`, same size as `T` for common cases.

### Trait Dispatch Overhead

**Cost**: Zero due to monomorphization.

**Explanation**: Rust traits compile to static dispatch. No vtable lookups, no runtime overhead.

### CubeCL Conversion Overhead

**Cost**: Dominated by kernel launch latency (10-50μs).

**Optimization**: Pre-allocate buffers, ensure tensor contiguity before hot loops.

**Benchmark**: Conversion adds < 1% overhead for typical kernel operations.

## Future Framework Integration

rust-ai-core is designed to support a future AI framework with these goals:

### 1. Transparency

**Foundation support**:
- Tracing integration throughout (`tracing::info`, `tracing::warn`)
- Structured errors with detailed context
- Clear naming conventions

**Future additions**:
- Operation-level tracing hooks
- Computation graph visualization
- Step-by-step execution logging

### 2. Traceability

**Foundation support**:
- Device selection logged
- Error provenance tracked (source location, call stack)
- Configuration validation recorded

**Future additions**:
- Tensor provenance tracking (what operations created this tensor?)
- Execution timeline recording
- Reproducibility metadata (seeds, versions, configurations)

### 3. Performance

**Foundation support**:
- Zero-cost abstractions
- CUDA-first philosophy
- Efficient CubeCL interop

**Future additions**:
- Kernel fusion hints
- Memory pool management
- Multi-GPU coordination

### 4. Ease of Use

**Foundation support**:
- Builder patterns for configuration
- Helpful error messages
- Sensible defaults

**Future additions**:
- High-level API wrappers
- Auto-configuration (detect hardware, tune parameters)
- Progress bars and status updates

### 5. Repeatability

**Foundation support**:
- Deterministic device selection (given same config)
- No hidden state in core utilities
- Explicit configuration over magic

**Future additions**:
- Seed management
- Checkpoint/restore
- Experiment tracking integration

### 6. Customization Depth

**Foundation support**:
- Trait-based extension points
- Feature flags for optional components
- Public internal APIs where appropriate

**Future additions**:
- Custom kernel registration
- Operation graph manipulation
- Plugin system for new backends

## Summary

rust-ai-core provides a solid, performant, and extensible foundation for the rust-ai ecosystem. Its design enables:

- **Consistency**: Unified device selection, errors, and traits across all crates
- **Performance**: Zero-cost abstractions, GPU-first execution
- **Extensibility**: Trait-based interfaces for new quantization schemes, operations, and devices
- **Developer experience**: Clear errors, helpful warnings, comprehensive documentation
- **Future-proofing**: Extension points and non-exhaustive types for evolution

The architecture balances immediate practical needs with long-term framework goals, providing a stable foundation while enabling future innovation.