@@ -581,11 +581,21 @@
self.top_p.map(|v| v as f64),
);
- let mut index_pos = 0usize;
+ let mut match_len = 0;
+ for (c, t) in self.cached_tokens.iter().zip(tokens.iter()) {
+ if c == t {
+ match_len += 1;
+ } else {
+ break;
+ }
+ }
+ if match_len == tokens.len() && match_len > 0 {
+ match_len -= 1;
+ }
+ let mut index_pos = match_len;
let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
let mut finish_reason = "length".to_string();
for _ in 0..self.max_tokens {
- let ctxt: &[u32] = if index_pos == 0 {
- tokens.as_slice()
+ let ctxt: &[u32] = if index_pos == match_len && match_len < tokens.len() {
+ &tokens[match_len..]
} else {
&tokens[tokens.len() - 1..]
@@ -624,6 +634,7 @@
}
}
+ self.cached_tokens = tokens.clone();
let duration = started_at.elapsed();
let tokens_per_sec = if duration.as_secs_f64() > 0.0 {
(generated.len() as f64) / duration.as_secs_f64()