1use candle_core::{Device, Error as CandleError, Tensor};
4use std::collections::HashMap;
5
6pub mod mask {
8 use super::*;
9
10 pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor, CandleError> {
21 let mut mask_data = vec![0.0f32; seq_len * seq_len];
22
23 for i in 0..seq_len {
25 for j in (i + 1)..seq_len {
26 mask_data[i * seq_len + j] = f32::NEG_INFINITY;
27 }
28 }
29
30 Tensor::from_vec(mask_data, (seq_len, seq_len), device)
31 }
32
33 pub fn create_position_mask(
46 pos: usize,
47 context_len: usize,
48 device: &Device,
49 ) -> Result<Tensor, CandleError> {
50 let mut mask_data = vec![f32::NEG_INFINITY; context_len];
51
52 for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
54 *item = 0.0;
55 }
56
57 Tensor::from_vec(mask_data, (1, context_len), device)
58 }
59
60 pub fn create_rank4_position_mask(
72 pos: usize,
73 context_len: usize,
74 device: &Device,
75 ) -> Result<Tensor, CandleError> {
76 let mut mask_data = vec![f32::NEG_INFINITY; context_len];
77
78 for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
80 *item = 0.0;
81 }
82
83 Tensor::from_vec(mask_data, (1, 1, 1, context_len), device)
84 }
85
86 pub fn create_update_mask(
96 pos: usize,
97 context_len: usize,
98 device: &Device,
99 ) -> Result<Tensor, CandleError> {
100 let mut mask_data = vec![0.0f32; context_len];
101 if pos < context_len {
102 mask_data[pos] = 1.0;
103 }
104
105 Tensor::from_vec(mask_data, (1, 1, context_len, 1), device)
106 }
107}
108
109pub mod sampling {
111 use super::*;
112 use rand::Rng;
113
114 pub fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<i64, CandleError> {
129 if temperature <= 0.0 {
130 return greedy_sample(logits);
132 }
133
134 let temp_tensor = Tensor::new(&[temperature], logits.device())?;
136 let scaled_logits = logits.broadcast_div(&temp_tensor)?;
137
138 let probs = candle_nn::ops::softmax_last_dim(&scaled_logits)?;
140 let probs_vec = probs.to_vec1::<f32>()?;
141
142 let mut rng = rand::thread_rng();
144 let random_val: f32 = rng.gen();
145
146 let mut cumulative = 0.0;
147 for (i, &prob) in probs_vec.iter().enumerate() {
148 cumulative += prob;
149 if random_val <= cumulative {
150 return Ok(i as i64);
151 }
152 }
153
154 Ok((probs_vec.len() - 1) as i64)
156 }
157
158 pub fn greedy_sample(logits: &Tensor) -> Result<i64, CandleError> {
160 let logits_vec = logits.to_vec1::<f32>()?;
161 let max_idx = logits_vec
162 .iter()
163 .enumerate()
164 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
165 .map(|(idx, _)| idx)
166 .unwrap_or(0);
167 Ok(max_idx as i64)
168 }
169
170 pub fn sample_top_k(logits: &Tensor, k: usize, temperature: f32) -> Result<i64, CandleError> {
172 let logits_vec = logits.to_vec1::<f32>()?;
173
174 let mut indexed_logits: Vec<(usize, f32)> = logits_vec
176 .iter()
177 .enumerate()
178 .map(|(i, &logit)| (i, logit))
179 .collect();
180 indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182 let top_k = indexed_logits.into_iter().take(k).collect::<Vec<_>>();
184
185 if top_k.is_empty() {
186 return Ok(0);
187 }
188
189 if temperature <= 0.0 {
190 return Ok(top_k[0].0 as i64);
192 }
193
194 let mut filtered_logits = vec![f32::NEG_INFINITY; logits_vec.len()];
196 for (idx, logit) in top_k {
197 filtered_logits[idx] = logit;
198 }
199
200 let filtered_tensor = Tensor::from_vec(filtered_logits, logits.shape(), logits.device())?;
201 sample_with_temperature(&filtered_tensor, temperature)
202 }
203}
204
205pub mod multi_component {
207 use super::*;
208 use crate::Config as CoreMLConfig;
209 use std::path::Path;
210
211 pub trait MultiComponentModel {
213 fn load_components<P: AsRef<Path>>(path: P) -> Result<Self, CandleError>
215 where
216 Self: Sized;
217
218 fn forward_pipeline(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError>;
220
221 fn component_info(&self) -> Vec<String>;
223 }
224
225 pub struct ComponentConfigBuilder {
227 base_config: CoreMLConfig,
228 }
229
230 impl ComponentConfigBuilder {
231 pub fn new(vocab_size: usize, max_seq_len: usize) -> Self {
232 Self {
233 base_config: CoreMLConfig {
234 input_names: vec![],
235 output_name: String::new(),
236 max_sequence_length: max_seq_len,
237 vocab_size,
238 model_type: String::new(),
239 },
240 }
241 }
242
243 pub fn embeddings_config(mut self, model_type: &str) -> CoreMLConfig {
245 self.base_config.input_names = vec!["input_ids".to_string()];
246 self.base_config.output_name = "hidden_states".to_string();
247 self.base_config.model_type = format!("{model_type}-embeddings");
248 self.base_config
249 }
250
251 pub fn ffn_config(mut self, model_type: &str, include_mask: bool) -> CoreMLConfig {
253 self.base_config.input_names = vec!["hidden_states".to_string()];
254 if include_mask {
255 self.base_config.input_names.push("causal_mask".to_string());
256 }
257 self.base_config.output_name = "output_hidden_states".to_string();
258 self.base_config.model_type = format!("{model_type}-ffn");
259 self.base_config
260 }
261
262 pub fn lm_head_config(mut self, model_type: &str) -> CoreMLConfig {
264 self.base_config.input_names = vec!["hidden_states".to_string()];
265 self.base_config.output_name = "logits".to_string();
266 self.base_config.model_type = format!("{model_type}-lm-head");
267 self.base_config
268 }
269 }
270
271 pub fn combine_chunked_logits(
273 outputs: HashMap<String, Tensor>,
274 num_chunks: usize,
275 ) -> Result<Tensor, CandleError> {
276 let mut chunks = Vec::new();
277
278 for i in 1..=num_chunks {
279 let key = format!("logits{i}");
280 if let Some(chunk) = outputs.get(&key) {
281 chunks.push(chunk.clone());
282 } else {
283 return Err(CandleError::Msg(format!("Missing logits chunk: {key}")));
284 }
285 }
286
287 let chunk_refs: Vec<&Tensor> = chunks.iter().collect();
289 Tensor::cat(&chunk_refs, chunks[0].dims().len() - 1)
290 }
291}