1use candle_core::{Device, Error as CandleError, Tensor};
4use std::collections::HashMap;
5
6pub mod mask {
8 use super::*;
9
10 pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor, CandleError> {
21 let mut mask_data = vec![0.0f32; seq_len * seq_len];
22
23 for i in 0..seq_len {
25 for j in (i + 1)..seq_len {
26 mask_data[i * seq_len + j] = f32::NEG_INFINITY;
27 }
28 }
29
30 Tensor::from_vec(mask_data, (seq_len, seq_len), device)
31 }
32
33 pub fn create_position_mask(
46 pos: usize,
47 context_len: usize,
48 device: &Device,
49 ) -> Result<Tensor, CandleError> {
50 let mut mask_data = vec![f32::NEG_INFINITY; context_len];
51
52 for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
54 *item = 0.0;
55 }
56
57 Tensor::from_vec(mask_data, (1, context_len), device)
58 }
59
60 pub fn create_rank4_position_mask(
72 pos: usize,
73 context_len: usize,
74 device: &Device,
75 ) -> Result<Tensor, CandleError> {
76 let mut mask_data = vec![f32::NEG_INFINITY; context_len];
77
78 for item in mask_data.iter_mut().take(pos.min(context_len - 1) + 1) {
80 *item = 0.0;
81 }
82
83 Tensor::from_vec(mask_data, (1, 1, 1, context_len), device)
84 }
85
86 pub fn create_update_mask(
96 pos: usize,
97 context_len: usize,
98 device: &Device,
99 ) -> Result<Tensor, CandleError> {
100 let mut mask_data = vec![0.0f32; context_len];
101 if pos < context_len {
102 mask_data[pos] = 1.0;
103 }
104
105 Tensor::from_vec(mask_data, (1, 1, context_len, 1), device)
106 }
107}
108
109pub mod sampling {
111 use super::*;
112
113 pub fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<i64, CandleError> {
128 if temperature <= 0.0 {
129 return greedy_sample(logits);
131 }
132
133 let temp_tensor = Tensor::new(&[temperature], logits.device())?;
135 let scaled_logits = logits.broadcast_div(&temp_tensor)?;
136
137 let probs = candle_nn::ops::softmax_last_dim(&scaled_logits)?;
139 let probs_vec = probs.to_vec1::<f32>()?;
140
141 let random_val: f32 = rand::random();
143
144 let mut cumulative = 0.0;
145 for (i, &prob) in probs_vec.iter().enumerate() {
146 cumulative += prob;
147 if random_val <= cumulative {
148 return Ok(i as i64);
149 }
150 }
151
152 Ok((probs_vec.len() - 1) as i64)
154 }
155
156 pub fn greedy_sample(logits: &Tensor) -> Result<i64, CandleError> {
158 let logits_vec = logits.to_vec1::<f32>()?;
159 let max_idx = logits_vec
160 .iter()
161 .enumerate()
162 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
163 .map(|(idx, _)| idx)
164 .unwrap_or(0);
165 Ok(max_idx as i64)
166 }
167
168 pub fn sample_top_k(logits: &Tensor, k: usize, temperature: f32) -> Result<i64, CandleError> {
170 let logits_vec = logits.to_vec1::<f32>()?;
171
172 let mut indexed_logits: Vec<(usize, f32)> = logits_vec
174 .iter()
175 .enumerate()
176 .map(|(i, &logit)| (i, logit))
177 .collect();
178 indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
179
180 let top_k = indexed_logits.into_iter().take(k).collect::<Vec<_>>();
182
183 if top_k.is_empty() {
184 return Ok(0);
185 }
186
187 if temperature <= 0.0 {
188 return Ok(top_k[0].0 as i64);
190 }
191
192 let mut filtered_logits = vec![f32::NEG_INFINITY; logits_vec.len()];
194 for (idx, logit) in top_k {
195 filtered_logits[idx] = logit;
196 }
197
198 let filtered_tensor = Tensor::from_vec(filtered_logits, logits.shape(), logits.device())?;
199 sample_with_temperature(&filtered_tensor, temperature)
200 }
201}
202
203pub mod multi_component {
205 use super::*;
206 use crate::Config as CoreMLConfig;
207 use std::path::Path;
208
209 pub trait MultiComponentModel {
211 fn load_components<P: AsRef<Path>>(path: P) -> Result<Self, CandleError>
213 where
214 Self: Sized;
215
216 fn forward_pipeline(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError>;
218
219 fn component_info(&self) -> Vec<String>;
221 }
222
223 pub struct ComponentConfigBuilder {
225 base_config: CoreMLConfig,
226 }
227
228 impl ComponentConfigBuilder {
229 pub fn new(vocab_size: usize, max_seq_len: usize) -> Self {
230 Self {
231 base_config: CoreMLConfig {
232 input_names: vec![],
233 output_name: String::new(),
234 max_sequence_length: max_seq_len,
235 vocab_size,
236 model_type: String::new(),
237 },
238 }
239 }
240
241 pub fn embeddings_config(mut self, model_type: &str) -> CoreMLConfig {
243 self.base_config.input_names = vec!["input_ids".to_string()];
244 self.base_config.output_name = "hidden_states".to_string();
245 self.base_config.model_type = format!("{model_type}-embeddings");
246 self.base_config
247 }
248
249 pub fn ffn_config(mut self, model_type: &str, include_mask: bool) -> CoreMLConfig {
251 self.base_config.input_names = vec!["hidden_states".to_string()];
252 if include_mask {
253 self.base_config.input_names.push("causal_mask".to_string());
254 }
255 self.base_config.output_name = "output_hidden_states".to_string();
256 self.base_config.model_type = format!("{model_type}-ffn");
257 self.base_config
258 }
259
260 pub fn lm_head_config(mut self, model_type: &str) -> CoreMLConfig {
262 self.base_config.input_names = vec!["hidden_states".to_string()];
263 self.base_config.output_name = "logits".to_string();
264 self.base_config.model_type = format!("{model_type}-lm-head");
265 self.base_config
266 }
267 }
268
269 pub fn combine_chunked_logits(
271 outputs: HashMap<String, Tensor>,
272 num_chunks: usize,
273 ) -> Result<Tensor, CandleError> {
274 let mut chunks = Vec::new();
275
276 for i in 1..=num_chunks {
277 let key = format!("logits{i}");
278 if let Some(chunk) = outputs.get(&key) {
279 chunks.push(chunk.clone());
280 } else {
281 return Err(CandleError::Msg(format!("Missing logits chunk: {key}")));
282 }
283 }
284
285 let chunk_refs: Vec<&Tensor> = chunks.iter().collect();
287 Tensor::cat(&chunk_refs, chunks[0].dims().len() - 1)
288 }
289}