1use std::sync::Arc;
11
12use arrow::array::{Array, RecordBatch, StringArray};
13use rand::{Rng, SeedableRng};
14
15use super::Transform;
16use crate::error::{Error, Result};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum FimFormat {
21 PSM,
23 SPM,
25}
26
27#[derive(Debug, Clone)]
29pub struct FimTokens {
30 pub prefix: String,
32 pub suffix: String,
34 pub middle: String,
36}
37
38impl Default for FimTokens {
39 fn default() -> Self {
40 Self {
41 prefix: "<|fim_prefix|>".to_string(),
42 suffix: "<|fim_suffix|>".to_string(),
43 middle: "<|fim_middle|>".to_string(),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
56pub struct Fim {
57 column: String,
59 rate: f64,
61 format: FimFormat,
63 tokens: FimTokens,
65 min_chars: usize,
67 seed: u64,
69}
70
71impl Fim {
72 pub fn new(column: impl Into<String>) -> Self {
74 Self {
75 column: column.into(),
76 rate: 0.5,
77 format: FimFormat::PSM,
78 tokens: FimTokens::default(),
79 min_chars: 10,
80 seed: 42,
81 }
82 }
83
84 #[must_use]
86 pub fn with_rate(mut self, rate: f64) -> Self {
87 self.rate = rate.clamp(0.0, 1.0);
88 self
89 }
90
91 #[must_use]
93 pub fn with_format(mut self, format: FimFormat) -> Self {
94 self.format = format;
95 self
96 }
97
98 #[must_use]
100 pub fn with_tokens(mut self, tokens: FimTokens) -> Self {
101 self.tokens = tokens;
102 self
103 }
104
105 #[must_use]
107 pub fn with_min_chars(mut self, min_chars: usize) -> Self {
108 self.min_chars = min_chars;
109 self
110 }
111
112 #[must_use]
114 pub fn with_seed(mut self, seed: u64) -> Self {
115 self.seed = seed;
116 self
117 }
118}
119
120fn apply_fim_to_text(
122 text: &str,
123 format: FimFormat,
124 tokens: &FimTokens,
125 rng: &mut impl Rng,
126) -> String {
127 let len = text.len();
128 if len < 2 {
129 return text.to_string();
130 }
131
132 let mut a = rng.gen_range(0..len);
134 let mut b = rng.gen_range(0..len);
135 if a > b {
136 std::mem::swap(&mut a, &mut b);
137 }
138
139 let a = find_char_boundary(text, a);
141 let b = find_char_boundary(text, b);
142
143 let prefix = &text[..a];
144 let middle = &text[a..b];
145 let suffix = &text[b..];
146
147 match format {
148 FimFormat::PSM => {
149 format!(
150 "{}{}{}{}{}{}",
151 tokens.prefix, prefix, tokens.suffix, suffix, tokens.middle, middle
152 )
153 }
154 FimFormat::SPM => {
155 format!(
156 "{}{}{}{}{}{}",
157 tokens.suffix, suffix, tokens.prefix, prefix, tokens.middle, middle
158 )
159 }
160 }
161}
162
163fn find_char_boundary(s: &str, byte_offset: usize) -> usize {
165 let mut offset = byte_offset.min(s.len());
166 while offset < s.len() && !s.is_char_boundary(offset) {
167 offset += 1;
168 }
169 offset.min(s.len())
170}
171
172impl Transform for Fim {
173 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
174 let schema = batch.schema();
175 let col_idx = schema
176 .index_of(&self.column)
177 .map_err(|_| Error::column_not_found(&self.column))?;
178
179 let col = batch
180 .column(col_idx)
181 .as_any()
182 .downcast_ref::<StringArray>()
183 .ok_or_else(|| {
184 Error::transform(format!(
185 "Column '{}' must be Utf8 type for FIM transform",
186 self.column
187 ))
188 })?;
189
190 let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
191 let transformed: Vec<Option<String>> = (0..col.len())
192 .map(|i| {
193 if col.is_null(i) {
194 return None;
195 }
196 let text = col.value(i);
197 if text.len() < self.min_chars {
198 return Some(text.to_string());
199 }
200 let apply_fim: bool = rng.gen_bool(self.rate);
201 if apply_fim {
202 Some(apply_fim_to_text(text, self.format, &self.tokens, &mut rng))
203 } else {
204 Some(text.to_string())
205 }
206 })
207 .collect();
208
209 let new_col = StringArray::from(transformed);
210 let mut columns: Vec<Arc<dyn arrow::array::Array>> = batch.columns().to_vec();
211 columns[col_idx] = Arc::new(new_col);
212 RecordBatch::try_new(schema, columns).map_err(Error::Arrow)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use arrow::datatypes::{DataType, Field, Schema};
219
220 use super::*;
221
222 fn create_code_batch() -> RecordBatch {
223 let schema = Arc::new(Schema::new(vec![Field::new("code", DataType::Utf8, false)]));
224 let code = StringArray::from(vec![
225 "def hello():\n print('hello world')\n",
226 "class Foo:\n def bar(self):\n return 42\n",
227 "x = 1",
228 ]);
229 RecordBatch::try_new(schema, vec![Arc::new(code)]).expect("batch creation should succeed")
230 }
231
232 #[test]
233 fn test_fim_psm_format() {
234 let text = "def hello():\n print('hello')";
235 let tokens = FimTokens::default();
236 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
237 let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
238 assert!(result.contains("<|fim_prefix|>"));
239 assert!(result.contains("<|fim_suffix|>"));
240 assert!(result.contains("<|fim_middle|>"));
241 }
242
243 #[test]
244 fn test_fim_spm_format() {
245 let text = "def hello():\n print('hello')";
246 let tokens = FimTokens::default();
247 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
248 let result = apply_fim_to_text(text, FimFormat::SPM, &tokens, &mut rng);
249 assert!(result.starts_with("<|fim_suffix|>"));
251 }
252
253 #[test]
254 fn test_fim_transform_applies_to_batch() {
255 let batch = create_code_batch();
256 let fim = Fim::new("code").with_rate(1.0).with_seed(42);
257 let result = fim.apply(batch);
258 assert!(result.is_ok());
259 let result = result.expect("should succeed");
260 assert_eq!(result.num_rows(), 3);
261
262 let col = result
263 .column(0)
264 .as_any()
265 .downcast_ref::<StringArray>()
266 .expect("should be string");
267 assert!(col.value(0).contains("<|fim_prefix|>"));
269 assert!(col.value(1).contains("<|fim_prefix|>"));
270 assert_eq!(col.value(2), "x = 1");
272 }
273
274 #[test]
275 fn test_fim_rate_zero_leaves_unchanged() {
276 let batch = create_code_batch();
277 let fim = Fim::new("code").with_rate(0.0).with_seed(42);
278 let result = fim.apply(batch.clone()).expect("should succeed");
279 let original = batch
280 .column(0)
281 .as_any()
282 .downcast_ref::<StringArray>()
283 .expect("string");
284 let transformed = result
285 .column(0)
286 .as_any()
287 .downcast_ref::<StringArray>()
288 .expect("string");
289 for i in 0..original.len() {
290 assert_eq!(original.value(i), transformed.value(i));
291 }
292 }
293
294 #[test]
295 fn test_fim_deterministic_with_seed() {
296 let batch = create_code_batch();
297 let fim1 = Fim::new("code").with_rate(1.0).with_seed(123);
298 let fim2 = Fim::new("code").with_rate(1.0).with_seed(123);
299 let r1 = fim1.apply(batch.clone()).expect("should succeed");
300 let r2 = fim2.apply(batch).expect("should succeed");
301 let c1 = r1
302 .column(0)
303 .as_any()
304 .downcast_ref::<StringArray>()
305 .expect("s");
306 let c2 = r2
307 .column(0)
308 .as_any()
309 .downcast_ref::<StringArray>()
310 .expect("s");
311 for i in 0..c1.len() {
312 assert_eq!(c1.value(i), c2.value(i));
313 }
314 }
315
316 #[test]
317 fn test_fim_wrong_column_errors() {
318 let batch = create_code_batch();
319 let fim = Fim::new("nonexistent");
320 let result = fim.apply(batch);
321 assert!(result.is_err());
322 }
323
324 #[test]
325 fn test_fim_custom_tokens() {
326 let text = "def foo(): pass";
327 let tokens = FimTokens {
328 prefix: "<PRE>".to_string(),
329 suffix: "<SUF>".to_string(),
330 middle: "<MID>".to_string(),
331 };
332 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
333 let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
334 assert!(result.contains("<PRE>"));
335 assert!(result.contains("<SUF>"));
336 assert!(result.contains("<MID>"));
337 }
338
339 #[test]
340 fn test_fim_preserves_content() {
341 let text = "def hello():\n print('hello')";
342 let tokens = FimTokens::default();
343 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
344 let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
345 let stripped = result
347 .replace("<|fim_prefix|>", "")
348 .replace("<|fim_suffix|>", "")
349 .replace("<|fim_middle|>", "");
350 for ch in text.chars() {
352 assert!(stripped.contains(ch), "Missing char: {ch}");
353 }
354 }
355
356 #[test]
357 fn test_find_char_boundary() {
358 let s = "hello";
359 assert_eq!(find_char_boundary(s, 0), 0);
360 assert_eq!(find_char_boundary(s, 3), 3);
361 assert_eq!(find_char_boundary(s, 10), 5);
362 }
363
364 #[test]
365 fn test_find_char_boundary_multibyte() {
366 let s = "héllo"; let boundary = find_char_boundary(s, 2);
368 assert!(s.is_char_boundary(boundary));
369 }
370}