#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FimTokens {
pub prefix_id: u32,
pub middle_id: u32,
pub suffix_id: u32,
pub pad_id: u32,
}
impl Default for FimTokens {
fn default() -> Self {
Self {
prefix_id: 1,
middle_id: 2,
suffix_id: 3,
pad_id: 4,
}
}
}
pub fn format_fim_prompt(prefix: &str, suffix: &str) -> String {
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
}
pub fn parse_fim_output(output: &str) -> Option<String> {
let marker = "<fim_middle>";
let start = output.find(marker)?;
let after_marker = &output[start + marker.len()..];
let end_of_text = "<|endoftext|>";
let middle = if let Some(eot_pos) = after_marker.find(end_of_text) {
&after_marker[..eot_pos]
} else {
after_marker
};
Some(middle.to_string())
}
#[cfg(test)]
mod fim_tests {
use super::*;
#[test]
fn test_fim_tokens_default() {
let tokens = FimTokens::default();
assert_eq!(tokens.prefix_id, 1);
assert_eq!(tokens.middle_id, 2);
assert_eq!(tokens.suffix_id, 3);
assert_eq!(tokens.pad_id, 4);
}
#[test]
fn test_format_fim_prompt_basic() {
let prompt = format_fim_prompt("def foo():", " return 42");
assert!(prompt.starts_with("<fim_prefix>def foo():"));
assert!(prompt.contains("<fim_suffix> return 42"));
assert!(prompt.ends_with("<fim_middle>"));
}
#[test]
fn test_format_fim_prompt_empty_suffix() {
let prompt = format_fim_prompt("hello", "");
assert_eq!(prompt, "<fim_prefix>hello<fim_suffix><fim_middle>");
}
#[test]
fn test_parse_fim_output_with_eot() {
let raw = "<fim_prefix>x<fim_suffix>z<fim_middle>y<|endoftext|>";
let middle = parse_fim_output(raw);
assert_eq!(middle, Some("y".to_string()));
}
#[test]
fn test_parse_fim_output_without_eot() {
let raw = "<fim_prefix>x<fim_suffix>z<fim_middle>some middle text";
let middle = parse_fim_output(raw);
assert_eq!(middle, Some("some middle text".to_string()));
}
#[test]
fn test_parse_fim_output_no_marker() {
let raw = "just plain text with no fim tokens";
assert!(parse_fim_output(raw).is_none());
}
}