use crate::native_bridge::NativeRuntimeHandle;
use crate::runtime::llama_token;
use crate::runtime::scheduler::{SlotPhase, SlotState};
#[cfg(test)]
#[path = "../../../tests/runtime/inference_runtime/text_tests.rs"]
mod text_tests;
#[inline]
pub(super) fn append_token_piece_to_slot(
native_runtime: &NativeRuntimeHandle,
token: llama_token,
slot: &mut SlotState,
piece_scratch: &mut Vec<u8>,
) {
if !token_to_piece_into(native_runtime, token, false, piece_scratch) {
slot.fail("Failed to convert sampled token to text piece.");
return;
}
slot.pending_utf8_bytes.extend_from_slice(piece_scratch);
let tail_len = incomplete_utf8_tail_length(&slot.pending_utf8_bytes);
let complete_len = slot.pending_utf8_bytes.len().saturating_sub(tail_len);
if complete_len > 0 {
let complete = String::from_utf8_lossy(&slot.pending_utf8_bytes[..complete_len]);
slot.pending_emission_text.push_str(&complete);
slot.pending_utf8_bytes.drain(..complete_len);
}
if !slot.pending_emission_text.is_empty() {
slot.buffered_output_text
.push_str(&std::mem::take(&mut slot.pending_emission_text));
}
}
pub(super) fn apply_stop_sequences_to_slot(slot: &mut SlotState) -> bool {
let (stop_index, max_hold) = if let Some(request) = slot.request() {
if request.stop.is_empty() {
return false;
}
let stops = &request.stop;
let stop_index =
earliest_stop_index_split(&slot.output_text, &slot.buffered_output_text, stops);
let max_hold = if stop_index.is_none() {
stops
.iter()
.map(|stop| stop.len().saturating_sub(1))
.max()
.unwrap_or(0)
} else {
0
};
(stop_index, max_hold)
} else {
return false;
};
if let Some(stop_index) = stop_index {
let output_len = slot.output_text.len();
if stop_index <= output_len {
slot.output_text.truncate(stop_index);
slot.buffered_output_text.clear();
} else {
truncate_to_char_boundary(&mut slot.buffered_output_text, stop_index - output_len);
}
slot.pending_emission_text.clear();
slot.pending_utf8_bytes.clear();
slot.phase = SlotPhase::Completed;
true
} else {
if max_hold > 0 && slot.buffered_output_text.len() > max_hold {
let raw_split = slot.buffered_output_text.len() - max_hold;
let split = floor_char_boundary(&slot.buffered_output_text, raw_split);
if split > 0 && split < slot.buffered_output_text.len() {
let held = slot.buffered_output_text.split_off(split);
slot.pending_emission_text.insert_str(0, &held);
}
}
false
}
}
pub(super) fn earliest_stop_index_split(
output: &str,
buffered: &str,
stops: &[String],
) -> Option<usize> {
stops
.iter()
.filter_map(|stop| find_stop_index_split(output, buffered, stop))
.min()
}
fn find_stop_index_split(output: &str, buffered: &str, stop: &str) -> Option<usize> {
if stop.is_empty() {
return None;
}
let output_len = output.len();
let suffix_len = stop.len().saturating_sub(1);
let suffix_start = floor_char_boundary(output, output_len.saturating_sub(suffix_len));
for (relative_start, _) in output[suffix_start..].char_indices() {
let start = suffix_start + relative_start;
if stop_matches_at_split(&output[start..], buffered, stop) {
return Some(start);
}
}
buffered.find(stop).map(|idx| output_len + idx)
}
fn stop_matches_at_split(output_tail: &str, buffered: &str, stop: &str) -> bool {
let stop_bytes = stop.as_bytes();
let output_bytes = output_tail.as_bytes();
let output_match_len = output_bytes.len().min(stop_bytes.len());
if output_bytes[..output_match_len] != stop_bytes[..output_match_len] {
return false;
}
output_match_len == stop_bytes.len()
|| buffered
.as_bytes()
.starts_with(&stop_bytes[output_match_len..])
}
pub(super) fn truncate_to_char_boundary(value: &mut String, max_len: usize) {
let boundary = floor_char_boundary(value, max_len.min(value.len()));
value.truncate(boundary);
}
pub(super) fn floor_char_boundary(value: &str, mut index: usize) -> usize {
index = index.min(value.len());
while index > 0 && !value.is_char_boundary(index) {
index -= 1;
}
index
}
pub(super) fn flush_pending_utf8(slot: &mut SlotState) {
if !slot.pending_emission_text.is_empty() {
slot.buffered_output_text
.push_str(&std::mem::take(&mut slot.pending_emission_text));
}
if slot.pending_utf8_bytes.is_empty() {
return;
}
let pending = String::from_utf8_lossy(&slot.pending_utf8_bytes);
slot.buffered_output_text.push_str(&pending);
slot.pending_utf8_bytes.clear();
}
#[inline]
pub(super) fn token_to_piece_into(
native_runtime: &NativeRuntimeHandle,
token: llama_token,
special: bool,
scratch: &mut Vec<u8>,
) -> bool {
if token < 0 {
scratch.clear();
return false;
}
native_runtime
.token_to_piece_bytes_into(token, special, scratch)
.is_ok()
}
pub(super) fn incomplete_utf8_tail_length(data: &[u8]) -> usize {
if data.is_empty() {
return 0;
}
let max_lookback = data.len().min(4);
for offset in 1..=max_lookback {
let byte = data[data.len() - offset];
if (byte & 0xC0) == 0x80 {
continue;
}
let expected = if (byte & 0x80) == 0 {
1
} else if (byte & 0xE0) == 0xC0 {
2
} else if (byte & 0xF0) == 0xE0 {
3
} else if (byte & 0xF8) == 0xF0 {
4
} else {
return 0;
};
if offset >= expected {
return 0;
}
return offset;
}
max_lookback
}