Skip to main content

trit_vsa/
dispatch.rs

1//! Smart kernel dispatch for optimal ternary operations.
2//!
3//! This module provides intelligent routing between different ternary representations
4//! and kernel implementations based on operation type, data characteristics, and hardware.
5//!
6//! # Architecture
7//!
8//! This module integrates with the modular [`kernels`](crate::kernels) backend system:
9//!
10//! ```text
11//! TritVector                   DispatchConfig
12//!     |                              |
13//!     v                              v
14//! +----------+                +--------------+
15//! | dispatch |  ----------->  | BackendConfig|
16//! +----------+                +--------------+
17//!     |                              |
18//!     v                              v
19//! +-------------------------------------------+
20//! |           TernaryBackend (trait)          |
21//! +-------------------------------------------+
22//!     |           |              |
23//!     v           v              v
24//! +------+   +--------+     +------+
25//! |  CPU |   | CubeCL |     | Burn |
26//! +------+   +--------+     +------+
27//! ```
28//!
29//! # Ternary Representations
30//!
31//! ## Tritsliced (Implied Zero with Positive/Negative Planes)
32//!
33//! Two parallel bit planes where:
34//! - `+plane[i] = 1` indicates trit `+1`
35//! - `-plane[i] = 1` indicates trit `-1`
36//! - Both `0` indicates trit `0` (implied zero)
37//!
38//! **Optimal for:**
39//! - Dot products (popcount-based)
40//! - Element-wise bind/unbind
41//! - Bundle (majority voting)
42//! - Dense vectors (< 90% zeros)
43//!
44//! ## Tritpacked (2-bit per trit)
45//!
46//! Each trit encoded as 2 bits: `00` = -1, `01` = 0, `10` = +1
47//!
48//! **Optimal for:**
49//! - Sequential access patterns
50//! - Serialization/deserialization
51//! - Mixed arithmetic operations
52//! - Memory-constrained scenarios
53//!
54//! ## Sparse (COO Format)
55//!
56//! Separate index lists for positive and negative values.
57//!
58//! **Optimal for:**
59//! - Very sparse vectors (> 90% zeros)
60//! - Similarity between sparse vectors
61//! - Memory efficiency for high-dimensional sparse data
62//!
63//! # Dispatch Strategy
64//!
65//! The dispatcher selects the optimal kernel based on:
66//! 1. **Sparsity**: Sparse format for > 90% zeros
67//! 2. **Operation type**: Popcount ops -> tritsliced, arithmetic -> tritpacked
68//! 3. **Vector size**: GPU for large (> 4096 dims), CPU for small
69//! 4. **Hardware**: SIMD availability, GPU presence
70//!
71//! # Example
72//!
73//! ```rust,ignore
74//! use trit_vsa::dispatch::{TritVector, DispatchConfig, Operation};
75//!
76//! // Automatic format selection
77//! let a = TritVector::from_packed(packed_vec);
78//! let b = TritVector::from_packed(other_vec);
79//!
80//! // Dispatcher chooses optimal kernel
81//! let similarity = a.cosine_similarity(&b, &DispatchConfig::auto());
82//!
83//! // Force specific format
84//! let config = DispatchConfig::new()
85//!     .prefer_format(Format::Tritsliced)
86//!     .gpu_threshold(8192);
87//! let result = a.bind(&b, &config);
88//! ```
89//!
90//! # Using the Modular Backend System
91//!
92//! For more control, use the [`kernels`](crate::kernels) module directly:
93//!
94//! ```rust,ignore
95//! use trit_vsa::kernels::{get_backend, BackendConfig, TernaryBackend};
96//!
97//! let config = BackendConfig::auto();
98//! let backend = get_backend(&config);
99//!
100//! let result = backend.bind(&a, &b)?;
101//! let similarity = backend.dot_similarity(&a, &b)?;
102//! ```
103
104use crate::kernels::{self, BackendConfig, BackendPreference, TernaryBackend};
105use crate::{PackedTritVec, Result, SparseVec, TernaryError, Trit};
106
107/// Preferred kernel format for operations.
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
109pub enum Format {
110    /// Tritsliced: two bit planes (optimal for popcount operations)
111    #[default]
112    Tritsliced,
113    /// Tritpacked: 2 bits per trit (optimal for sequential access)
114    Tritpacked,
115    /// Sparse: COO format (optimal for > 90% zeros)
116    Sparse,
117    /// Automatic selection based on data characteristics
118    Auto,
119}
120
121/// Device preference for computation.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
123pub enum DevicePreference {
124    /// Automatic GPU/CPU selection based on size
125    #[default]
126    Auto,
127    /// Force CPU execution
128    Cpu,
129    /// Force GPU execution (requires `cuda` feature)
130    Gpu,
131}
132
133/// Configuration for kernel dispatch.
134#[derive(Debug, Clone)]
135pub struct DispatchConfig {
136    /// Preferred format for operations
137    pub format: Format,
138    /// Device preference
139    pub device: DevicePreference,
140    /// Sparsity threshold for automatic sparse selection (default: 0.90)
141    pub sparse_threshold: f32,
142    /// Minimum dimensions for GPU dispatch (default: 4096)
143    pub gpu_threshold: usize,
144    /// Enable format caching for repeated operations
145    pub cache_conversions: bool,
146}
147
148impl Default for DispatchConfig {
149    fn default() -> Self {
150        Self::auto()
151    }
152}
153
154impl DispatchConfig {
155    /// Create a new configuration with automatic settings.
156    #[must_use]
157    pub fn auto() -> Self {
158        Self {
159            format: Format::Auto,
160            device: DevicePreference::Auto,
161            sparse_threshold: 0.90,
162            gpu_threshold: 4096,
163            cache_conversions: true,
164        }
165    }
166
167    /// Create a CPU-only configuration.
168    #[must_use]
169    pub fn cpu_only() -> Self {
170        Self {
171            device: DevicePreference::Cpu,
172            ..Self::auto()
173        }
174    }
175
176    /// Set preferred format.
177    #[must_use]
178    pub fn with_format(mut self, format: Format) -> Self {
179        self.format = format;
180        self
181    }
182
183    /// Set device preference.
184    #[must_use]
185    pub fn with_device(mut self, device: DevicePreference) -> Self {
186        self.device = device;
187        self
188    }
189
190    /// Set sparsity threshold for automatic sparse format selection.
191    #[must_use]
192    pub fn with_sparse_threshold(mut self, threshold: f32) -> Self {
193        self.sparse_threshold = threshold;
194        self
195    }
196
197    /// Set minimum dimensions for GPU dispatch.
198    #[must_use]
199    pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
200        self.gpu_threshold = threshold;
201        self
202    }
203}
204
205/// Operation types for dispatch decisions.
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum Operation {
208    /// Dot product (popcount-optimal)
209    Dot,
210    /// Cosine similarity
211    Similarity,
212    /// Bind operation (XOR-like composition)
213    Bind,
214    /// Unbind operation (inverse of bind)
215    Unbind,
216    /// Bundle (majority voting)
217    Bundle,
218    /// Element-wise negation
219    Negate,
220    /// Hamming distance
221    Hamming,
222}
223
224impl Operation {
225    /// Returns the preferred format for this operation.
226    #[must_use]
227    pub fn preferred_format(self) -> Format {
228        match self {
229            // Popcount-based operations prefer tritsliced
230            Operation::Dot | Operation::Similarity | Operation::Hamming => Format::Tritsliced,
231            // Element-wise operations work well with tritsliced
232            Operation::Bind | Operation::Unbind | Operation::Negate => Format::Tritsliced,
233            // Bundle needs to track counts, tritsliced is fine
234            Operation::Bundle => Format::Tritsliced,
235        }
236    }
237
238    /// Returns true if this operation can benefit from sparse representation.
239    #[must_use]
240    pub fn benefits_from_sparse(self) -> bool {
241        matches!(self, Operation::Dot | Operation::Similarity)
242    }
243}
244
245/// Unified ternary vector type with smart dispatch.
246#[derive(Debug, Clone)]
247pub enum TritVector {
248    /// Tritsliced format ([`PackedTritVec`])
249    Sliced(PackedTritVec),
250    /// Sparse format ([`SparseVec`])
251    Sparse(SparseVec),
252}
253
254impl TritVector {
255    /// Create a new zero vector in tritsliced format.
256    #[must_use]
257    pub fn new(dims: usize) -> Self {
258        Self::Sliced(PackedTritVec::new(dims))
259    }
260
261    /// Create from a `PackedTritVec`.
262    #[must_use]
263    pub fn from_packed(packed: PackedTritVec) -> Self {
264        Self::Sliced(packed)
265    }
266
267    /// Create from a `SparseVec`.
268    #[must_use]
269    pub fn from_sparse(sparse: SparseVec) -> Self {
270        Self::Sparse(sparse)
271    }
272
273    /// Get the number of dimensions.
274    #[must_use]
275    pub fn dims(&self) -> usize {
276        match self {
277            Self::Sliced(p) => p.len(),
278            Self::Sparse(s) => s.num_dims(),
279        }
280    }
281
282    /// Compute sparsity (fraction of zeros).
283    #[must_use]
284    pub fn sparsity(&self) -> f32 {
285        match self {
286            Self::Sliced(p) => p.sparsity(),
287            Self::Sparse(s) => s.sparsity(),
288        }
289    }
290
291    /// Get value at index.
292    #[must_use]
293    pub fn get(&self, idx: usize) -> Trit {
294        match self {
295            Self::Sliced(p) => p.get(idx),
296            Self::Sparse(s) => s.get(idx),
297        }
298    }
299
300    /// Set value at index (may require format conversion).
301    pub fn set(&mut self, idx: usize, value: Trit) {
302        match self {
303            Self::Sliced(p) => p.set(idx, value),
304            Self::Sparse(s) => s.set(idx, value),
305        }
306    }
307
308    /// Convert to `PackedTritVec`.
309    #[must_use]
310    pub fn to_packed(&self) -> PackedTritVec {
311        match self {
312            Self::Sliced(p) => p.clone(),
313            Self::Sparse(s) => s.to_packed(),
314        }
315    }
316
317    /// Convert to `SparseVec`.
318    #[must_use]
319    pub fn to_sparse(&self) -> SparseVec {
320        match self {
321            Self::Sliced(p) => SparseVec::from_packed(p),
322            Self::Sparse(s) => s.clone(),
323        }
324    }
325
326    /// Select optimal format based on operation and data characteristics.
327    fn select_format(
328        &self,
329        other: Option<&Self>,
330        op: Operation,
331        config: &DispatchConfig,
332    ) -> Format {
333        // Explicit format preference overrides auto-selection
334        if config.format != Format::Auto {
335            return config.format;
336        }
337
338        // Check if sparse format would be beneficial
339        let self_sparse = self.sparsity() > config.sparse_threshold;
340        let other_sparse = other.is_some_and(|o| o.sparsity() > config.sparse_threshold);
341
342        if op.benefits_from_sparse() && self_sparse && other_sparse {
343            return Format::Sparse;
344        }
345
346        // Default to operation's preferred format
347        op.preferred_format()
348    }
349
350    /// Convert [`DispatchConfig`] to [`BackendConfig`] for the kernels module.
351    fn to_backend_config(config: &DispatchConfig) -> BackendConfig {
352        let preferred = match config.device {
353            DevicePreference::Cpu => BackendPreference::Cpu,
354            DevicePreference::Gpu => BackendPreference::Gpu,
355            DevicePreference::Auto => BackendPreference::Auto,
356        };
357
358        BackendConfig {
359            preferred,
360            gpu_threshold: config.gpu_threshold,
361            use_simd: true,
362        }
363    }
364
365    /// Get the appropriate backend based on config and problem size.
366    fn get_backend_for_config(&self, config: &DispatchConfig) -> kernels::DynamicBackend {
367        let backend_config = Self::to_backend_config(config);
368        kernels::get_backend_for_size(&backend_config, self.dims())
369    }
370
371    /// Compute dot product with smart dispatch.
372    ///
373    /// Uses the modular backend system for automatic CPU/GPU selection.
374    ///
375    /// # Errors
376    ///
377    /// Returns error if vectors have mismatched dimensions.
378    pub fn dot(&self, other: &Self, config: &DispatchConfig) -> Result<i32> {
379        if self.dims() != other.dims() {
380            return Err(TernaryError::DimensionMismatch {
381                expected: self.dims(),
382                actual: other.dims(),
383            });
384        }
385
386        let format = self.select_format(Some(other), Operation::Dot, config);
387
388        match format {
389            Format::Sparse => {
390                // Sparse path: use direct sparse dot product
391                let a = self.to_sparse();
392                let b = other.to_sparse();
393                Ok(a.dot(&b))
394            }
395            Format::Tritsliced | Format::Tritpacked | Format::Auto => {
396                // Use modular backend system
397                let a = self.to_packed();
398                let b = other.to_packed();
399                let backend = self.get_backend_for_config(config);
400                backend.dot_similarity(&a, &b)
401            }
402        }
403    }
404
405    /// Compute cosine similarity with smart dispatch.
406    ///
407    /// Uses the modular backend system for automatic CPU/GPU selection.
408    ///
409    /// # Errors
410    ///
411    /// Returns error if vectors have mismatched dimensions.
412    pub fn cosine_similarity(&self, other: &Self, config: &DispatchConfig) -> Result<f32> {
413        if self.dims() != other.dims() {
414            return Err(TernaryError::DimensionMismatch {
415                expected: self.dims(),
416                actual: other.dims(),
417            });
418        }
419
420        let format = self.select_format(Some(other), Operation::Similarity, config);
421
422        match format {
423            Format::Sparse => {
424                // Sparse path: use direct sparse cosine similarity
425                let a = self.to_sparse();
426                let b = other.to_sparse();
427                Ok(crate::vsa::cosine_similarity_sparse(&a, &b))
428            }
429            Format::Tritsliced | Format::Tritpacked | Format::Auto => {
430                // Use modular backend system
431                let a = self.to_packed();
432                let b = other.to_packed();
433                let backend = self.get_backend_for_config(config);
434                backend.cosine_similarity(&a, &b)
435            }
436        }
437    }
438
439    /// Bind operation with smart dispatch.
440    ///
441    /// Uses the modular backend system for automatic CPU/GPU selection.
442    ///
443    /// # Errors
444    ///
445    /// Returns error if vectors have mismatched dimensions.
446    pub fn bind(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
447        if self.dims() != other.dims() {
448            return Err(TernaryError::DimensionMismatch {
449                expected: self.dims(),
450                actual: other.dims(),
451            });
452        }
453
454        let a = self.to_packed();
455        let b = other.to_packed();
456
457        // Use modular backend system
458        let backend = self.get_backend_for_config(config);
459        let result = backend.bind(&a, &b)?;
460        Ok(Self::Sliced(result))
461    }
462
463    /// Unbind operation with smart dispatch.
464    ///
465    /// Uses the modular backend system for automatic CPU/GPU selection.
466    ///
467    /// # Errors
468    ///
469    /// Returns error if vectors have mismatched dimensions.
470    pub fn unbind(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
471        if self.dims() != other.dims() {
472            return Err(TernaryError::DimensionMismatch {
473                expected: self.dims(),
474                actual: other.dims(),
475            });
476        }
477
478        let a = self.to_packed();
479        let b = other.to_packed();
480
481        // Use modular backend system
482        let backend = self.get_backend_for_config(config);
483        let result = backend.unbind(&a, &b)?;
484        Ok(Self::Sliced(result))
485    }
486
487    /// Bundle (majority voting) with smart dispatch.
488    ///
489    /// Uses the modular backend system for automatic CPU/GPU selection.
490    ///
491    /// # Errors
492    ///
493    /// Returns error if vectors have mismatched dimensions.
494    pub fn bundle(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
495        if self.dims() != other.dims() {
496            return Err(TernaryError::DimensionMismatch {
497                expected: self.dims(),
498                actual: other.dims(),
499            });
500        }
501
502        let a = self.to_packed();
503        let b = other.to_packed();
504
505        // Use modular backend system
506        let backend = self.get_backend_for_config(config);
507        let result = backend.bundle(&[&a, &b])?;
508        Ok(Self::Sliced(result))
509    }
510
511    /// Compute Hamming distance with smart dispatch.
512    ///
513    /// Uses the modular backend system for automatic CPU/GPU selection.
514    ///
515    /// # Errors
516    ///
517    /// Returns error if vectors have mismatched dimensions.
518    pub fn hamming_distance(&self, other: &Self, config: &DispatchConfig) -> Result<usize> {
519        if self.dims() != other.dims() {
520            return Err(TernaryError::DimensionMismatch {
521                expected: self.dims(),
522                actual: other.dims(),
523            });
524        }
525
526        let a = self.to_packed();
527        let b = other.to_packed();
528
529        // Use modular backend system
530        let backend = self.get_backend_for_config(config);
531        backend.hamming_distance(&a, &b)
532    }
533
534    /// Negate all elements.
535    #[must_use]
536    pub fn negate(&self) -> Self {
537        match self {
538            Self::Sliced(p) => Self::Sliced(p.negated()),
539            Self::Sparse(s) => Self::Sparse(s.negated()),
540        }
541    }
542}
543
544impl From<PackedTritVec> for TritVector {
545    fn from(packed: PackedTritVec) -> Self {
546        Self::Sliced(packed)
547    }
548}
549
550impl From<SparseVec> for TritVector {
551    fn from(sparse: SparseVec) -> Self {
552        Self::Sparse(sparse)
553    }
554}
555
556/// Statistics about dispatch decisions for profiling.
557#[derive(Debug, Default, Clone)]
558pub struct DispatchStats {
559    /// Number of times tritsliced format was used
560    pub tritsliced_count: usize,
561    /// Number of times sparse format was used
562    pub sparse_count: usize,
563    /// Number of GPU dispatches
564    pub gpu_count: usize,
565    /// Number of CPU dispatches
566    pub cpu_count: usize,
567    /// Number of format conversions
568    pub conversion_count: usize,
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    fn make_test_vector(values: &[i8]) -> TritVector {
576        let mut packed = PackedTritVec::new(values.len());
577        for (i, &v) in values.iter().enumerate() {
578            let trit = match v {
579                -1 => Trit::N,
580                0 => Trit::Z,
581                1 => Trit::P,
582                _ => panic!("Invalid trit value"),
583            };
584            packed.set(i, trit);
585        }
586        TritVector::Sliced(packed)
587    }
588
589    #[test]
590    fn test_dispatch_config_default() {
591        let config = DispatchConfig::auto();
592        assert_eq!(config.format, Format::Auto);
593        assert_eq!(config.device, DevicePreference::Auto);
594        assert!((config.sparse_threshold - 0.90).abs() < f32::EPSILON);
595        assert_eq!(config.gpu_threshold, 4096);
596    }
597
598    #[test]
599    fn test_trit_vector_from_packed() {
600        let packed = PackedTritVec::new(100);
601        let tv = TritVector::from_packed(packed.clone());
602        assert_eq!(tv.dims(), 100);
603        assert!(matches!(tv, TritVector::Sliced(_)));
604    }
605
606    #[test]
607    fn test_operation_preferred_format() {
608        assert_eq!(Operation::Dot.preferred_format(), Format::Tritsliced);
609        assert_eq!(Operation::Similarity.preferred_format(), Format::Tritsliced);
610        assert_eq!(Operation::Bind.preferred_format(), Format::Tritsliced);
611    }
612
613    #[test]
614    fn test_dot_product_dispatch() {
615        let a = TritVector::new(100);
616        let b = TritVector::new(100);
617        let config = DispatchConfig::cpu_only();
618
619        let result = a.dot(&b, &config);
620        assert!(result.is_ok());
621        assert_eq!(result.unwrap(), 0);
622    }
623
624    #[test]
625    fn test_dimension_mismatch() {
626        let a = TritVector::new(100);
627        let b = TritVector::new(200);
628        let config = DispatchConfig::cpu_only();
629
630        let result = a.dot(&b, &config);
631        assert!(result.is_err());
632    }
633
634    #[test]
635    fn test_bind_unbind_with_backend() {
636        let a = make_test_vector(&[1, -1, 0, 1, -1, 0, 1]);
637        let b = make_test_vector(&[-1, 1, 0, -1, 1, 0, -1]);
638        let config = DispatchConfig::cpu_only();
639
640        let bound = a.bind(&b, &config).unwrap();
641        let recovered = bound.unbind(&b, &config).unwrap();
642
643        // Verify bind/unbind inverse property
644        for i in 0..a.dims() {
645            assert_eq!(recovered.get(i), a.get(i), "mismatch at position {i}");
646        }
647    }
648
649    #[test]
650    fn test_bundle_with_backend() {
651        let a = make_test_vector(&[1, 1, -1, 0, 0]);
652        let b = make_test_vector(&[1, -1, -1, 1, -1]);
653        let config = DispatchConfig::cpu_only();
654
655        let bundled = a.bundle(&b, &config).unwrap();
656
657        // Position 0: 1, 1 -> 1
658        assert_eq!(bundled.get(0), Trit::P);
659        // Position 1: 1, -1 -> 0 (tie)
660        assert_eq!(bundled.get(1), Trit::Z);
661        // Position 2: -1, -1 -> -1
662        assert_eq!(bundled.get(2), Trit::N);
663    }
664
665    #[test]
666    fn test_cosine_similarity_with_backend() {
667        let a = make_test_vector(&[1, 1, -1, -1]);
668        let config = DispatchConfig::cpu_only();
669
670        let sim = a.cosine_similarity(&a, &config).unwrap();
671        assert!((sim - 1.0).abs() < 0.001);
672    }
673
674    #[test]
675    fn test_hamming_distance_with_backend() {
676        let a = make_test_vector(&[1, 0, -1, 1]);
677        let b = make_test_vector(&[1, -1, -1, 0]);
678        let config = DispatchConfig::cpu_only();
679
680        let dist = a.hamming_distance(&b, &config).unwrap();
681        // Positions 1 and 3 differ
682        assert_eq!(dist, 2);
683    }
684
685    #[test]
686    fn test_backend_config_conversion() {
687        let config = DispatchConfig::cpu_only();
688        let backend_config = TritVector::to_backend_config(&config);
689        assert_eq!(backend_config.preferred, BackendPreference::Cpu);
690
691        let config = DispatchConfig::auto().with_device(DevicePreference::Gpu);
692        let backend_config = TritVector::to_backend_config(&config);
693        assert_eq!(backend_config.preferred, BackendPreference::Gpu);
694    }
695
696    #[test]
697    fn test_auto_backend_selection() {
698        let small_vec = TritVector::new(100);
699        let large_vec = TritVector::new(10000);
700
701        let config = DispatchConfig::auto().with_gpu_threshold(5000);
702
703        // Small vector should use CPU
704        let backend = small_vec.get_backend_for_config(&config);
705        assert!(backend.name().starts_with("cpu"));
706
707        // Large vector would use GPU if available, otherwise CPU
708        let backend = large_vec.get_backend_for_config(&config);
709        assert!(backend.is_available());
710    }
711}