Skip to main content

alimentar/transform/
fim.rs

1//! Fill-in-the-Middle (FIM) data transform for code model training.
2//!
3//! Implements PSM (Prefix-Suffix-Middle) and SPM (Suffix-Prefix-Middle)
4//! FIM formats from Bavarian et al. (2022):
5//! "Efficient Training of Language Models to Fill in the Middle"
6//!
7//! Given a code sequence, FIM randomly splits it into (prefix, middle, suffix)
8//! and rearranges with sentinel tokens so the model learns to infill code.
9
10use std::sync::Arc;
11
12use arrow::array::{Array, RecordBatch, StringArray};
13use rand::{Rng, SeedableRng};
14
15use super::Transform;
16use crate::error::{Error, Result};
17
18/// FIM format variant.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum FimFormat {
21    /// Prefix-Suffix-Middle: `<PRE>prefix<SUF>suffix<MID>middle`
22    PSM,
23    /// Suffix-Prefix-Middle: `<SUF>suffix<PRE>prefix<MID>middle`
24    SPM,
25}
26
27/// Configuration for FIM sentinel tokens.
28#[derive(Debug, Clone)]
29pub struct FimTokens {
30    /// Prefix sentinel token
31    pub prefix: String,
32    /// Suffix sentinel token
33    pub suffix: String,
34    /// Middle sentinel token
35    pub middle: String,
36}
37
38impl Default for FimTokens {
39    fn default() -> Self {
40        Self {
41            prefix: "<|fim_prefix|>".to_string(),
42            suffix: "<|fim_suffix|>".to_string(),
43            middle: "<|fim_middle|>".to_string(),
44        }
45    }
46}
47
48/// Fill-in-the-Middle transform for code training data.
49///
50/// Applies FIM transformation to a text column in a RecordBatch.
51/// Each row is randomly split into (prefix, middle, suffix) and
52/// reassembled in PSM or SPM format with sentinel tokens.
53///
54/// Rows shorter than `min_chars` are left unchanged.
55#[derive(Debug, Clone)]
56pub struct Fim {
57    /// Column name containing code text
58    column: String,
59    /// Probability of applying FIM to each row (0.0-1.0)
60    rate: f64,
61    /// FIM format variant
62    format: FimFormat,
63    /// Sentinel tokens
64    tokens: FimTokens,
65    /// Minimum characters for FIM to apply
66    min_chars: usize,
67    /// Random seed for reproducibility
68    seed: u64,
69}
70
71impl Fim {
72    /// Create a new FIM transform for the given column.
73    pub fn new(column: impl Into<String>) -> Self {
74        Self {
75            column: column.into(),
76            rate: 0.5,
77            format: FimFormat::PSM,
78            tokens: FimTokens::default(),
79            min_chars: 10,
80            seed: 42,
81        }
82    }
83
84    /// Set the FIM application rate (0.0-1.0).
85    #[must_use]
86    pub fn with_rate(mut self, rate: f64) -> Self {
87        self.rate = rate.clamp(0.0, 1.0);
88        self
89    }
90
91    /// Set the FIM format variant.
92    #[must_use]
93    pub fn with_format(mut self, format: FimFormat) -> Self {
94        self.format = format;
95        self
96    }
97
98    /// Set custom sentinel tokens.
99    #[must_use]
100    pub fn with_tokens(mut self, tokens: FimTokens) -> Self {
101        self.tokens = tokens;
102        self
103    }
104
105    /// Set minimum character count for FIM to apply.
106    #[must_use]
107    pub fn with_min_chars(mut self, min_chars: usize) -> Self {
108        self.min_chars = min_chars;
109        self
110    }
111
112    /// Set random seed for reproducibility.
113    #[must_use]
114    pub fn with_seed(mut self, seed: u64) -> Self {
115        self.seed = seed;
116        self
117    }
118}
119
120/// Apply FIM transformation to a single text string.
121fn apply_fim_to_text(
122    text: &str,
123    format: FimFormat,
124    tokens: &FimTokens,
125    rng: &mut impl Rng,
126) -> String {
127    let len = text.len();
128    if len < 2 {
129        return text.to_string();
130    }
131
132    // Pick two random split points to create (prefix, middle, suffix)
133    let mut a = rng.gen_range(0..len);
134    let mut b = rng.gen_range(0..len);
135    if a > b {
136        std::mem::swap(&mut a, &mut b);
137    }
138
139    // Align to char boundaries
140    let a = find_char_boundary(text, a);
141    let b = find_char_boundary(text, b);
142
143    let prefix = &text[..a];
144    let middle = &text[a..b];
145    let suffix = &text[b..];
146
147    match format {
148        FimFormat::PSM => {
149            format!(
150                "{}{}{}{}{}{}",
151                tokens.prefix, prefix, tokens.suffix, suffix, tokens.middle, middle
152            )
153        }
154        FimFormat::SPM => {
155            format!(
156                "{}{}{}{}{}{}",
157                tokens.suffix, suffix, tokens.prefix, prefix, tokens.middle, middle
158            )
159        }
160    }
161}
162
163/// Find the nearest char boundary at or after the given byte offset.
164fn find_char_boundary(s: &str, byte_offset: usize) -> usize {
165    let mut offset = byte_offset.min(s.len());
166    while offset < s.len() && !s.is_char_boundary(offset) {
167        offset += 1;
168    }
169    offset.min(s.len())
170}
171
172impl Transform for Fim {
173    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
174        let schema = batch.schema();
175        let col_idx = schema
176            .index_of(&self.column)
177            .map_err(|_| Error::column_not_found(&self.column))?;
178
179        let col = batch
180            .column(col_idx)
181            .as_any()
182            .downcast_ref::<StringArray>()
183            .ok_or_else(|| {
184                Error::transform(format!(
185                    "Column '{}' must be Utf8 type for FIM transform",
186                    self.column
187                ))
188            })?;
189
190        let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
191        let transformed: Vec<Option<String>> = (0..col.len())
192            .map(|i| {
193                if col.is_null(i) {
194                    return None;
195                }
196                let text = col.value(i);
197                if text.len() < self.min_chars {
198                    return Some(text.to_string());
199                }
200                let apply_fim: bool = rng.gen_bool(self.rate);
201                if apply_fim {
202                    Some(apply_fim_to_text(text, self.format, &self.tokens, &mut rng))
203                } else {
204                    Some(text.to_string())
205                }
206            })
207            .collect();
208
209        let new_col = StringArray::from(transformed);
210        let mut columns: Vec<Arc<dyn arrow::array::Array>> = batch.columns().to_vec();
211        columns[col_idx] = Arc::new(new_col);
212        RecordBatch::try_new(schema, columns).map_err(Error::Arrow)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use arrow::datatypes::{DataType, Field, Schema};
219
220    use super::*;
221
222    fn create_code_batch() -> RecordBatch {
223        let schema = Arc::new(Schema::new(vec![Field::new("code", DataType::Utf8, false)]));
224        let code = StringArray::from(vec![
225            "def hello():\n    print('hello world')\n",
226            "class Foo:\n    def bar(self):\n        return 42\n",
227            "x = 1",
228        ]);
229        RecordBatch::try_new(schema, vec![Arc::new(code)]).expect("batch creation should succeed")
230    }
231
232    #[test]
233    fn test_fim_psm_format() {
234        let text = "def hello():\n    print('hello')";
235        let tokens = FimTokens::default();
236        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
237        let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
238        assert!(result.contains("<|fim_prefix|>"));
239        assert!(result.contains("<|fim_suffix|>"));
240        assert!(result.contains("<|fim_middle|>"));
241    }
242
243    #[test]
244    fn test_fim_spm_format() {
245        let text = "def hello():\n    print('hello')";
246        let tokens = FimTokens::default();
247        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
248        let result = apply_fim_to_text(text, FimFormat::SPM, &tokens, &mut rng);
249        // SPM starts with suffix token
250        assert!(result.starts_with("<|fim_suffix|>"));
251    }
252
253    #[test]
254    fn test_fim_transform_applies_to_batch() {
255        let batch = create_code_batch();
256        let fim = Fim::new("code").with_rate(1.0).with_seed(42);
257        let result = fim.apply(batch);
258        assert!(result.is_ok());
259        let result = result.expect("should succeed");
260        assert_eq!(result.num_rows(), 3);
261
262        let col = result
263            .column(0)
264            .as_any()
265            .downcast_ref::<StringArray>()
266            .expect("should be string");
267        // First two rows should have FIM tokens (long enough)
268        assert!(col.value(0).contains("<|fim_prefix|>"));
269        assert!(col.value(1).contains("<|fim_prefix|>"));
270        // Third row too short (5 chars < 10 min_chars default)
271        assert_eq!(col.value(2), "x = 1");
272    }
273
274    #[test]
275    fn test_fim_rate_zero_leaves_unchanged() {
276        let batch = create_code_batch();
277        let fim = Fim::new("code").with_rate(0.0).with_seed(42);
278        let result = fim.apply(batch.clone()).expect("should succeed");
279        let original = batch
280            .column(0)
281            .as_any()
282            .downcast_ref::<StringArray>()
283            .expect("string");
284        let transformed = result
285            .column(0)
286            .as_any()
287            .downcast_ref::<StringArray>()
288            .expect("string");
289        for i in 0..original.len() {
290            assert_eq!(original.value(i), transformed.value(i));
291        }
292    }
293
294    #[test]
295    fn test_fim_deterministic_with_seed() {
296        let batch = create_code_batch();
297        let fim1 = Fim::new("code").with_rate(1.0).with_seed(123);
298        let fim2 = Fim::new("code").with_rate(1.0).with_seed(123);
299        let r1 = fim1.apply(batch.clone()).expect("should succeed");
300        let r2 = fim2.apply(batch).expect("should succeed");
301        let c1 = r1
302            .column(0)
303            .as_any()
304            .downcast_ref::<StringArray>()
305            .expect("s");
306        let c2 = r2
307            .column(0)
308            .as_any()
309            .downcast_ref::<StringArray>()
310            .expect("s");
311        for i in 0..c1.len() {
312            assert_eq!(c1.value(i), c2.value(i));
313        }
314    }
315
316    #[test]
317    fn test_fim_wrong_column_errors() {
318        let batch = create_code_batch();
319        let fim = Fim::new("nonexistent");
320        let result = fim.apply(batch);
321        assert!(result.is_err());
322    }
323
324    #[test]
325    fn test_fim_custom_tokens() {
326        let text = "def foo(): pass";
327        let tokens = FimTokens {
328            prefix: "<PRE>".to_string(),
329            suffix: "<SUF>".to_string(),
330            middle: "<MID>".to_string(),
331        };
332        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
333        let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
334        assert!(result.contains("<PRE>"));
335        assert!(result.contains("<SUF>"));
336        assert!(result.contains("<MID>"));
337    }
338
339    #[test]
340    fn test_fim_preserves_content() {
341        let text = "def hello():\n    print('hello')";
342        let tokens = FimTokens::default();
343        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
344        let result = apply_fim_to_text(text, FimFormat::PSM, &tokens, &mut rng);
345        // Remove sentinel tokens and verify all original content is present
346        let stripped = result
347            .replace("<|fim_prefix|>", "")
348            .replace("<|fim_suffix|>", "")
349            .replace("<|fim_middle|>", "");
350        // All chars from original should be in stripped (just reordered)
351        for ch in text.chars() {
352            assert!(stripped.contains(ch), "Missing char: {ch}");
353        }
354    }
355
356    #[test]
357    fn test_find_char_boundary() {
358        let s = "hello";
359        assert_eq!(find_char_boundary(s, 0), 0);
360        assert_eq!(find_char_boundary(s, 3), 3);
361        assert_eq!(find_char_boundary(s, 10), 5);
362    }
363
364    #[test]
365    fn test_find_char_boundary_multibyte() {
366        let s = "héllo"; // é is 2 bytes
367        let boundary = find_char_boundary(s, 2);
368        assert!(s.is_char_boundary(boundary));
369    }
370}