candle_coreml/
utils.rs

1//! Shared utilities for transformer models and multi-component architectures
2
3use candle_core::{Device, Error as CandleError, Tensor};
4use std::collections::HashMap;
5
6/// Utilities for creating attention masks used in transformer models
7pub mod mask {
8    use super::*;
9
10    /// Create a causal attention mask for autoregressive generation
11    ///
12    /// This mask prevents tokens from attending to future positions in the sequence.
13    ///
14    /// # Arguments
15    /// * `seq_len` - Sequence length for the mask
16    /// * `device` - Device to create the tensor on
17    ///
18    /// # Returns
19    /// Causal mask tensor with shape `(seq_len, seq_len)` where upper triangle is -inf
20    pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor, CandleError> {
21        let mut mask_data = vec![0.0f32; seq_len * seq_len];
22
23        // Fill upper triangle with -inf for causal masking
24        for i in 0..seq_len {
25            for j in (i + 1)..seq_len {
26                mask_data[i * seq_len + j] = f32::NEG_INFINITY;
27            }
28        }
29
30        Tensor::from_vec(mask_data, (seq_len, seq_len), device)
31    }
32
33    /// Create a causal mask for a specific position in the sequence
34    ///
35    /// This creates a mask row that allows attention to all previous positions
36    /// up to and including the current position.
37    ///
38    /// # Arguments
39    /// * `pos` - Current position in the sequence
40    /// * `context_len` - Total context length
41    /// * `device` - Device to create the tensor on
42    ///
43    /// # Returns  
44    /// Position mask tensor with shape `(1, context_len)`
45    pub fn create_position_mask(
46        pos: usize,
47        context_len: usize,
48        device: &Device,
49    ) -> Result<Tensor, CandleError> {
50        let mut mask_data = vec![f32::NEG_INFINITY; context_len];
51
52        // Allow attention to all positions up to and including current position
53        for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
54            *item = 0.0;
55        }
56
57        Tensor::from_vec(mask_data, (1, context_len), device)
58    }
59
60    /// Create a rank-4 causal mask for CoreML models that expect specific shapes
61    ///
62    /// Some CoreML models require masks with rank-4 shapes like `(1, 1, 1, seq_len)`
63    ///
64    /// # Arguments
65    /// * `pos` - Current position in the sequence  
66    /// * `context_len` - Total context length
67    /// * `device` - Device to create the tensor on
68    ///
69    /// # Returns
70    /// Rank-4 position mask tensor with shape `(1, 1, 1, context_len)`
71    pub fn create_rank4_position_mask(
72        pos: usize,
73        context_len: usize,
74        device: &Device,
75    ) -> Result<Tensor, CandleError> {
76        let mut mask_data = vec![f32::NEG_INFINITY; context_len];
77
78        // Allow attention to all positions up to and including current position
79        for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
80            *item = 0.0;
81        }
82
83        Tensor::from_vec(mask_data, (1, 1, 1, context_len), device)
84    }
85
86    /// Create an update mask for stateful models indicating which position to update
87    ///
88    /// # Arguments
89    /// * `pos` - Position to update
90    /// * `context_len` - Total context length  
91    /// * `device` - Device to create the tensor on
92    ///
93    /// # Returns
94    /// Update mask with 1.0 at the target position, 0.0 elsewhere
95    pub fn create_update_mask(
96        pos: usize,
97        context_len: usize,
98        device: &Device,
99    ) -> Result<Tensor, CandleError> {
100        let mut mask_data = vec![0.0f32; context_len];
101        if pos < context_len {
102            mask_data[pos] = 1.0;
103        }
104
105        Tensor::from_vec(mask_data, (1, 1, context_len, 1), device)
106    }
107}
108
109/// Utilities for sampling from model outputs
110pub mod sampling {
111    use super::*;
112
113    /// Sample a token using temperature scaling
114    ///
115    /// Temperature controls randomness:
116    /// - temperature = 0.0: Greedy sampling (most likely token)
117    /// - temperature = 1.0: Standard sampling  
118    /// - temperature > 1.0: More random
119    /// - temperature < 1.0: More deterministic
120    ///
121    /// # Arguments
122    /// * `logits` - Model output logits tensor
123    /// * `temperature` - Temperature for scaling
124    ///
125    /// # Returns
126    /// Sampled token ID
127    pub fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<i64, CandleError> {
128        if temperature <= 0.0 {
129            // Greedy sampling - return most likely token
130            return greedy_sample(logits);
131        }
132
133        // Apply temperature scaling
134        let temp_tensor = Tensor::new(&[temperature], logits.device())?;
135        let scaled_logits = logits.broadcast_div(&temp_tensor)?;
136
137        // Convert to probabilities via softmax
138        let probs = candle_nn::ops::softmax_last_dim(&scaled_logits)?;
139        let probs_vec = probs.to_vec1::<f32>()?;
140
141        // Sample from the distribution
142        let random_val: f32 = rand::random();
143
144        let mut cumulative = 0.0;
145        for (i, &prob) in probs_vec.iter().enumerate() {
146            cumulative += prob;
147            if random_val <= cumulative {
148                return Ok(i as i64);
149            }
150        }
151
152        // Fallback to last token if numerical issues
153        Ok((probs_vec.len() - 1) as i64)
154    }
155
156    /// Greedy sampling - always return the most likely token
157    pub fn greedy_sample(logits: &Tensor) -> Result<i64, CandleError> {
158        let logits_vec = logits.to_vec1::<f32>()?;
159        let max_idx = logits_vec
160            .iter()
161            .enumerate()
162            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
163            .map(|(idx, _)| idx)
164            .unwrap_or(0);
165        Ok(max_idx as i64)
166    }
167
168    /// Top-k sampling - sample from the k most likely tokens
169    pub fn sample_top_k(logits: &Tensor, k: usize, temperature: f32) -> Result<i64, CandleError> {
170        let logits_vec = logits.to_vec1::<f32>()?;
171
172        // Get indices sorted by logit value (descending)
173        let mut indexed_logits: Vec<(usize, f32)> = logits_vec
174            .iter()
175            .enumerate()
176            .map(|(i, &logit)| (i, logit))
177            .collect();
178        indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
179
180        // Take top k
181        let top_k = indexed_logits.into_iter().take(k).collect::<Vec<_>>();
182
183        if top_k.is_empty() {
184            return Ok(0);
185        }
186
187        if temperature <= 0.0 {
188            // Return most likely from top-k
189            return Ok(top_k[0].0 as i64);
190        }
191
192        // Create tensor with only top-k logits
193        let mut filtered_logits = vec![f32::NEG_INFINITY; logits_vec.len()];
194        for (idx, logit) in top_k {
195            filtered_logits[idx] = logit;
196        }
197
198        let filtered_tensor = Tensor::from_vec(filtered_logits, logits.shape(), logits.device())?;
199        sample_with_temperature(&filtered_tensor, temperature)
200    }
201}
202
203/// Utilities for multi-component model orchestration
204pub mod multi_component {
205    use super::*;
206    use crate::Config as CoreMLConfig;
207    use std::path::Path;
208
209    /// Trait for models that consist of multiple CoreML components
210    pub trait MultiComponentModel {
211        /// Load all model components from a directory
212        fn load_components<P: AsRef<Path>>(path: P) -> Result<Self, CandleError>
213        where
214            Self: Sized;
215
216        /// Run inference through the complete pipeline
217        fn forward_pipeline(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError>;
218
219        /// Get information about the loaded components
220        fn component_info(&self) -> Vec<String>;
221    }
222
223    /// Builder for creating CoreML configurations for common component types
224    pub struct ComponentConfigBuilder {
225        base_config: CoreMLConfig,
226    }
227
228    impl ComponentConfigBuilder {
229        pub fn new(vocab_size: usize, max_seq_len: usize) -> Self {
230            Self {
231                base_config: CoreMLConfig {
232                    input_names: vec![],
233                    output_name: String::new(),
234                    max_sequence_length: max_seq_len,
235                    vocab_size,
236                    model_type: String::new(),
237                },
238            }
239        }
240
241        /// Create config for an embeddings component
242        pub fn embeddings_config(mut self, model_type: &str) -> CoreMLConfig {
243            self.base_config.input_names = vec!["input_ids".to_string()];
244            self.base_config.output_name = "hidden_states".to_string();
245            self.base_config.model_type = format!("{model_type}-embeddings");
246            self.base_config
247        }
248
249        /// Create config for an FFN/transformer component  
250        pub fn ffn_config(mut self, model_type: &str, include_mask: bool) -> CoreMLConfig {
251            self.base_config.input_names = vec!["hidden_states".to_string()];
252            if include_mask {
253                self.base_config.input_names.push("causal_mask".to_string());
254            }
255            self.base_config.output_name = "output_hidden_states".to_string();
256            self.base_config.model_type = format!("{model_type}-ffn");
257            self.base_config
258        }
259
260        /// Create config for an LM head component
261        pub fn lm_head_config(mut self, model_type: &str) -> CoreMLConfig {
262            self.base_config.input_names = vec!["hidden_states".to_string()];
263            self.base_config.output_name = "logits".to_string();
264            self.base_config.model_type = format!("{model_type}-lm-head");
265            self.base_config
266        }
267    }
268
269    /// Utility for combining chunked LM head outputs (e.g., from ANEMLL models)
270    pub fn combine_chunked_logits(
271        outputs: HashMap<String, Tensor>,
272        num_chunks: usize,
273    ) -> Result<Tensor, CandleError> {
274        let mut chunks = Vec::new();
275
276        for i in 1..=num_chunks {
277            let key = format!("logits{i}");
278            if let Some(chunk) = outputs.get(&key) {
279                chunks.push(chunk.clone());
280            } else {
281                return Err(CandleError::Msg(format!("Missing logits chunk: {key}")));
282            }
283        }
284
285        // Concatenate along vocabulary dimension (assumed to be last dimension)
286        let chunk_refs: Vec<&Tensor> = chunks.iter().collect();
287        Tensor::cat(&chunk_refs, chunks[0].dims().len() - 1)
288    }
289}