candle_coreml/
utils.rs

1//! Shared utilities for transformer models and multi-component architectures
2
3use candle_core::{Device, Error as CandleError, Tensor};
4use std::collections::HashMap;
5
6/// Utilities for creating attention masks used in transformer models
7pub mod mask {
8    use super::*;
9
10    /// Create a causal attention mask for autoregressive generation
11    ///
12    /// This mask prevents tokens from attending to future positions in the sequence.
13    ///
14    /// # Arguments
15    /// * `seq_len` - Sequence length for the mask
16    /// * `device` - Device to create the tensor on
17    ///
18    /// # Returns
19    /// Causal mask tensor with shape `(seq_len, seq_len)` where upper triangle is -inf
20    pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor, CandleError> {
21        let mut mask_data = vec![0.0f32; seq_len * seq_len];
22
23        // Fill upper triangle with -inf for causal masking
24        for i in 0..seq_len {
25            for j in (i + 1)..seq_len {
26                mask_data[i * seq_len + j] = f32::NEG_INFINITY;
27            }
28        }
29
30        Tensor::from_vec(mask_data, (seq_len, seq_len), device)
31    }
32
33    /// Create a causal mask for a specific position in the sequence
34    ///
35    /// This creates a mask row that allows attention to all previous positions
36    /// up to and including the current position.
37    ///
38    /// # Arguments
39    /// * `pos` - Current position in the sequence
40    /// * `context_len` - Total context length
41    /// * `device` - Device to create the tensor on
42    ///
43    /// # Returns  
44    /// Position mask tensor with shape `(1, context_len)`
45    pub fn create_position_mask(
46        pos: usize,
47        context_len: usize,
48        device: &Device,
49    ) -> Result<Tensor, CandleError> {
50        let mut mask_data = vec![f32::NEG_INFINITY; context_len];
51
52        // Allow attention to all positions up to and including current position
53        for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
54            *item = 0.0;
55        }
56
57        Tensor::from_vec(mask_data, (1, context_len), device)
58    }
59
60    /// Create a rank-4 causal mask for CoreML models that expect specific shapes
61    ///
62    /// Some CoreML models require masks with rank-4 shapes like `(1, 1, 1, seq_len)`
63    ///
64    /// # Arguments
65    /// * `pos` - Current position in the sequence  
66    /// * `context_len` - Total context length
67    /// * `device` - Device to create the tensor on
68    ///
69    /// # Returns
70    /// Rank-4 position mask tensor with shape `(1, 1, 1, context_len)`
71    pub fn create_rank4_position_mask(
72        pos: usize,
73        context_len: usize,
74        device: &Device,
75    ) -> Result<Tensor, CandleError> {
76        let mut mask_data = vec![f32::NEG_INFINITY; context_len];
77
78        // Allow attention to all positions up to and including current position
79        for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
80            *item = 0.0;
81        }
82
83        Tensor::from_vec(mask_data, (1, 1, 1, context_len), device)
84    }
85
86    /// Create an update mask for stateful models indicating which position to update
87    ///
88    /// # Arguments
89    /// * `pos` - Position to update
90    /// * `context_len` - Total context length  
91    /// * `device` - Device to create the tensor on
92    ///
93    /// # Returns
94    /// Update mask with 1.0 at the target position, 0.0 elsewhere
95    pub fn create_update_mask(
96        pos: usize,
97        context_len: usize,
98        device: &Device,
99    ) -> Result<Tensor, CandleError> {
100        let mut mask_data = vec![0.0f32; context_len];
101        if pos < context_len {
102            mask_data[pos] = 1.0;
103        }
104
105        Tensor::from_vec(mask_data, (1, 1, context_len, 1), device)
106    }
107}
108
109/// Utilities for sampling from model outputs
110pub mod sampling {
111    use super::*;
112    use rand::Rng;
113
114    /// Sample a token using temperature scaling
115    ///
116    /// Temperature controls randomness:
117    /// - temperature = 0.0: Greedy sampling (most likely token)
118    /// - temperature = 1.0: Standard sampling  
119    /// - temperature > 1.0: More random
120    /// - temperature < 1.0: More deterministic
121    ///
122    /// # Arguments
123    /// * `logits` - Model output logits tensor
124    /// * `temperature` - Temperature for scaling
125    ///
126    /// # Returns
127    /// Sampled token ID
128    pub fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<i64, CandleError> {
129        if temperature <= 0.0 {
130            // Greedy sampling - return most likely token
131            return greedy_sample(logits);
132        }
133
134        // Apply temperature scaling
135        let temp_tensor = Tensor::new(&[temperature], logits.device())?;
136        let scaled_logits = logits.broadcast_div(&temp_tensor)?;
137
138        // Convert to probabilities via softmax
139        let probs = candle_nn::ops::softmax_last_dim(&scaled_logits)?;
140        let probs_vec = probs.to_vec1::<f32>()?;
141
142        // Sample from the distribution
143        let mut rng = rand::thread_rng();
144        let random_val: f32 = rng.gen();
145
146        let mut cumulative = 0.0;
147        for (i, &prob) in probs_vec.iter().enumerate() {
148            cumulative += prob;
149            if random_val <= cumulative {
150                return Ok(i as i64);
151            }
152        }
153
154        // Fallback to last token if numerical issues
155        Ok((probs_vec.len() - 1) as i64)
156    }
157
158    /// Greedy sampling - always return the most likely token
159    pub fn greedy_sample(logits: &Tensor) -> Result<i64, CandleError> {
160        let logits_vec = logits.to_vec1::<f32>()?;
161        let max_idx = logits_vec
162            .iter()
163            .enumerate()
164            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
165            .map(|(idx, _)| idx)
166            .unwrap_or(0);
167        Ok(max_idx as i64)
168    }
169
170    /// Top-k sampling - sample from the k most likely tokens
171    pub fn sample_top_k(logits: &Tensor, k: usize, temperature: f32) -> Result<i64, CandleError> {
172        let logits_vec = logits.to_vec1::<f32>()?;
173
174        // Get indices sorted by logit value (descending)
175        let mut indexed_logits: Vec<(usize, f32)> = logits_vec
176            .iter()
177            .enumerate()
178            .map(|(i, &logit)| (i, logit))
179            .collect();
180        indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182        // Take top k
183        let top_k = indexed_logits.into_iter().take(k).collect::<Vec<_>>();
184
185        if top_k.is_empty() {
186            return Ok(0);
187        }
188
189        if temperature <= 0.0 {
190            // Return most likely from top-k
191            return Ok(top_k[0].0 as i64);
192        }
193
194        // Create tensor with only top-k logits
195        let mut filtered_logits = vec![f32::NEG_INFINITY; logits_vec.len()];
196        for (idx, logit) in top_k {
197            filtered_logits[idx] = logit;
198        }
199
200        let filtered_tensor = Tensor::from_vec(filtered_logits, logits.shape(), logits.device())?;
201        sample_with_temperature(&filtered_tensor, temperature)
202    }
203}
204
205/// Utilities for multi-component model orchestration
206pub mod multi_component {
207    use super::*;
208    use crate::Config as CoreMLConfig;
209    use std::path::Path;
210
211    /// Trait for models that consist of multiple CoreML components
212    pub trait MultiComponentModel {
213        /// Load all model components from a directory
214        fn load_components<P: AsRef<Path>>(path: P) -> Result<Self, CandleError>
215        where
216            Self: Sized;
217
218        /// Run inference through the complete pipeline
219        fn forward_pipeline(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError>;
220
221        /// Get information about the loaded components
222        fn component_info(&self) -> Vec<String>;
223    }
224
225    /// Builder for creating CoreML configurations for common component types
226    pub struct ComponentConfigBuilder {
227        base_config: CoreMLConfig,
228    }
229
230    impl ComponentConfigBuilder {
231        pub fn new(vocab_size: usize, max_seq_len: usize) -> Self {
232            Self {
233                base_config: CoreMLConfig {
234                    input_names: vec![],
235                    output_name: String::new(),
236                    max_sequence_length: max_seq_len,
237                    vocab_size,
238                    model_type: String::new(),
239                },
240            }
241        }
242
243        /// Create config for an embeddings component
244        pub fn embeddings_config(mut self, model_type: &str) -> CoreMLConfig {
245            self.base_config.input_names = vec!["input_ids".to_string()];
246            self.base_config.output_name = "hidden_states".to_string();
247            self.base_config.model_type = format!("{model_type}-embeddings");
248            self.base_config
249        }
250
251        /// Create config for an FFN/transformer component  
252        pub fn ffn_config(mut self, model_type: &str, include_mask: bool) -> CoreMLConfig {
253            self.base_config.input_names = vec!["hidden_states".to_string()];
254            if include_mask {
255                self.base_config.input_names.push("causal_mask".to_string());
256            }
257            self.base_config.output_name = "output_hidden_states".to_string();
258            self.base_config.model_type = format!("{model_type}-ffn");
259            self.base_config
260        }
261
262        /// Create config for an LM head component
263        pub fn lm_head_config(mut self, model_type: &str) -> CoreMLConfig {
264            self.base_config.input_names = vec!["hidden_states".to_string()];
265            self.base_config.output_name = "logits".to_string();
266            self.base_config.model_type = format!("{model_type}-lm-head");
267            self.base_config
268        }
269    }
270
271    /// Utility for combining chunked LM head outputs (e.g., from ANEMLL models)
272    pub fn combine_chunked_logits(
273        outputs: HashMap<String, Tensor>,
274        num_chunks: usize,
275    ) -> Result<Tensor, CandleError> {
276        let mut chunks = Vec::new();
277
278        for i in 1..=num_chunks {
279            let key = format!("logits{i}");
280            if let Some(chunk) = outputs.get(&key) {
281                chunks.push(chunk.clone());
282            } else {
283                return Err(CandleError::Msg(format!("Missing logits chunk: {key}")));
284            }
285        }
286
287        // Concatenate along vocabulary dimension (assumed to be last dimension)
288        let chunk_refs: Vec<&Tensor> = chunks.iter().collect();
289        Tensor::cat(&chunk_refs, chunks[0].dims().len() - 1)
290    }
291}