1use std::collections::VecDeque;
2
3use llama_cpp_bindings_sys::llama_pos;
4use llama_cpp_bindings_sys::llama_seq_id;
5
6use llama_cpp_bindings_types::TokenUsage;
7use llama_cpp_bindings_types::TokenUsageError;
8
9use crate::batch_add_error::BatchAddError;
10use crate::context::LlamaContext;
11use crate::error::EvalMultimodalChunksError;
12use crate::error::SampleError;
13use crate::llama_batch::LlamaBatch;
14use crate::model::LlamaModel;
15use crate::mtmd::MtmdContext;
16use crate::mtmd::MtmdInputChunks;
17use crate::sampled_token::SampledToken;
18use crate::sampling::LlamaSampler;
19use crate::streaming_json_probe::JsonProbeOutcome;
20use crate::token::LlamaToken;
21
22#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum SampledTokenSection {
24 Pending,
25 Content,
26 Reasoning,
27 ToolCall,
28}
29
30#[derive(Copy, Clone, Debug, Eq, PartialEq)]
31enum MarkerKind {
32 ReasoningOpen,
33 ReasoningClose,
34 ToolCallOpen,
35 ToolCallClose,
36}
37
38#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct StreamingMarkers {
45 pub reasoning_open: Option<Vec<LlamaToken>>,
46 pub reasoning_close: Option<Vec<LlamaToken>>,
47 pub tool_call_open: Option<Vec<LlamaToken>>,
48 pub tool_call_close: Option<Vec<LlamaToken>>,
49}
50
51impl StreamingMarkers {
52 const fn has_any(&self) -> bool {
53 self.reasoning_open.is_some()
54 || self.reasoning_close.is_some()
55 || self.tool_call_open.is_some()
56 || self.tool_call_close.is_some()
57 }
58
59 fn max_token_len(&self) -> usize {
60 [
61 self.reasoning_open.as_deref(),
62 self.reasoning_close.as_deref(),
63 self.tool_call_open.as_deref(),
64 self.tool_call_close.as_deref(),
65 ]
66 .into_iter()
67 .flatten()
68 .map(<[LlamaToken]>::len)
69 .max()
70 .unwrap_or(0)
71 }
72
73 fn lookup(&self, kind: MarkerKind) -> Option<&[LlamaToken]> {
74 match kind {
75 MarkerKind::ReasoningOpen => self.reasoning_open.as_deref(),
76 MarkerKind::ReasoningClose => self.reasoning_close.as_deref(),
77 MarkerKind::ToolCallOpen => self.tool_call_open.as_deref(),
78 MarkerKind::ToolCallClose => self.tool_call_close.as_deref(),
79 }
80 }
81}
82
83#[derive(Clone, Debug)]
84pub struct IngestOutcome {
85 pub sampled_token: SampledToken,
86 pub visible_piece: String,
90 pub raw_piece: String,
94}
95
96#[derive(Clone, Debug)]
97struct PendingToken {
98 token: LlamaToken,
99 decoded: String,
100 section: SampledTokenSection,
101 is_boundary: bool,
102 is_from_prompt: bool,
103 is_held_for_probe: bool,
104}
105
106#[derive(Clone, Debug)]
107struct JsonProbeState {
108 held_text: String,
109}
110
111#[derive(Clone, Debug)]
112enum ProbeMode {
113 Idle,
114 Active(JsonProbeState),
115}
116
117pub struct SampledTokenClassifier<'model> {
118 model: &'model LlamaModel,
119 markers: StreamingMarkers,
120 decoder: encoding_rs::Decoder,
121 pending: VecDeque<PendingToken>,
122 section: SampledTokenSection,
123 pending_prompt_tokens: u64,
124 usage: TokenUsage,
125 probe_mode: ProbeMode,
126}
127
128impl<'model> SampledTokenClassifier<'model> {
129 #[must_use]
130 pub fn new(model: &'model LlamaModel, markers: StreamingMarkers) -> Self {
131 Self {
132 model,
133 markers,
134 decoder: encoding_rs::UTF_8.new_decoder(),
135 pending: VecDeque::new(),
136 section: SampledTokenSection::Pending,
137 pending_prompt_tokens: 0,
138 usage: TokenUsage::new(),
139 probe_mode: ProbeMode::Idle,
140 }
141 }
142
143 pub fn ingest(&mut self, token: LlamaToken) -> Vec<IngestOutcome> {
153 if !self.markers.has_any() {
154 self.usage.record_undeterminable_token();
155 let piece = self.decode(token);
156 return vec![IngestOutcome {
157 sampled_token: SampledToken::Undeterminable(token),
158 visible_piece: piece.clone(),
159 raw_piece: piece,
160 }];
161 }
162
163 let decoded = self.decode(token);
164 self.pending.push_back(PendingToken {
165 token,
166 decoded: decoded.clone(),
167 section: self.section,
168 is_boundary: false,
169 is_from_prompt: false,
170 is_held_for_probe: false,
171 });
172
173 self.try_consume_marker_at_tail();
174
175 let probe_was_active = matches!(self.probe_mode, ProbeMode::Active(_));
176 let mut outcomes = if probe_was_active && self.section_disengages_probe() {
177 self.abandon_probe()
178 } else {
179 self.update_probe(&decoded)
180 };
181
182 outcomes.extend(self.drain_overflow());
183 outcomes
184 }
185
186 const fn section_disengages_probe(&self) -> bool {
187 matches!(
188 self.section,
189 SampledTokenSection::ToolCall | SampledTokenSection::Reasoning
190 )
191 }
192
193 pub fn ingest_prompt_token(&mut self, token: LlamaToken) {
203 if !self.markers.has_any() {
204 return;
205 }
206
207 self.pending.push_back(PendingToken {
208 token,
209 decoded: String::new(),
210 section: self.section,
211 is_boundary: false,
212 is_from_prompt: true,
213 is_held_for_probe: false,
214 });
215
216 self.try_consume_marker_at_tail();
217 self.drain_overflow();
218 }
219
220 pub fn ingest_prompt_tokens(&mut self, tokens: &[LlamaToken]) {
221 if !self.markers.has_any() {
222 return;
223 }
224 for &token in tokens {
225 self.ingest_prompt_token(token);
226 }
227 }
228
229 pub fn flush(&mut self) -> Vec<IngestOutcome> {
233 self.probe_mode = ProbeMode::Idle;
234 let mut outcomes = Vec::with_capacity(self.pending.len());
235 while let Some(entry) = self.pending.pop_front() {
236 if entry.is_from_prompt {
237 continue;
238 }
239 outcomes.push(self.finalize_entry(entry));
240 }
241 outcomes
242 }
243
244 fn decode(&mut self, token: LlamaToken) -> String {
245 match self.model.token_to_piece(
246 &SampledToken::Content(token),
247 &mut self.decoder,
248 true,
249 None,
250 ) {
251 Ok(piece) => piece,
252 Err(detokenize_error) => {
253 tracing::debug!(
254 "token_to_piece failed during classification, dropping piece: {detokenize_error}"
255 );
256 String::new()
257 }
258 }
259 }
260
261 fn try_consume_marker_at_tail(&mut self) {
262 const PROBE_KINDS: &[MarkerKind] = &[
271 MarkerKind::ReasoningOpen,
272 MarkerKind::ReasoningClose,
273 MarkerKind::ToolCallOpen,
274 MarkerKind::ToolCallClose,
275 ];
276
277 for &kind in PROBE_KINDS {
278 let Some(marker) = self.markers.lookup(kind) else {
279 continue;
280 };
281 if marker.is_empty() || self.pending.len() < marker.len() {
282 continue;
283 }
284 let span_start = self.pending.len() - marker.len();
285 let matches = self
286 .pending
287 .iter()
288 .skip(span_start)
289 .zip(marker)
290 .all(|(entry, marker_token)| entry.token == *marker_token);
291 if matches {
292 self.mark_marker_span(span_start, kind);
293 return;
294 }
295 }
296 }
297
298 fn mark_marker_span(&mut self, span_start: usize, kind: MarkerKind) {
299 let next_section = match kind {
300 MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
301 MarkerKind::ReasoningClose | MarkerKind::ToolCallClose => SampledTokenSection::Content,
302 MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
303 };
304 let span_section = match kind {
314 MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
315 MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
316 MarkerKind::ReasoningClose => {
317 if self.section == SampledTokenSection::Reasoning {
318 SampledTokenSection::Reasoning
319 } else {
320 SampledTokenSection::Content
321 }
322 }
323 MarkerKind::ToolCallClose => {
324 if self.section == SampledTokenSection::ToolCall {
325 SampledTokenSection::ToolCall
326 } else {
327 SampledTokenSection::Content
328 }
329 }
330 };
331
332 for entry in self.pending.iter_mut().skip(span_start) {
333 entry.is_boundary = true;
334 entry.section = span_section;
335 }
336
337 self.section = next_section;
338 }
339
340 fn drain_overflow(&mut self) -> Vec<IngestOutcome> {
341 let lookback = self.markers.max_token_len().saturating_sub(1);
342 let mut outcomes = Vec::new();
343
344 loop {
345 let Some(front) = self.pending.front() else {
346 break;
347 };
348 if front.is_held_for_probe {
349 break;
350 }
351 let probe_held = self
352 .pending
353 .iter()
354 .filter(|entry| entry.is_held_for_probe)
355 .count();
356 let drainable = self.pending.len().saturating_sub(probe_held);
357 let beyond_lookback = drainable > lookback;
358 if !front.is_boundary && !beyond_lookback {
359 break;
360 }
361 let Some(entry) = self.pending.pop_front() else {
362 break;
363 };
364 if entry.is_from_prompt {
365 continue;
366 }
367 outcomes.push(self.finalize_entry(entry));
368 }
369
370 outcomes
371 }
372
373 fn update_probe(&mut self, piece: &str) -> Vec<IngestOutcome> {
374 let probe_active = matches!(self.probe_mode, ProbeMode::Active(_));
375 if !probe_active {
376 if !self.section_allows_probe_engagement() {
377 return Vec::new();
378 }
379 if !piece.trim_start().starts_with('{') {
380 return Vec::new();
381 }
382 if let Some(entry) = self.pending.back_mut() {
383 entry.is_held_for_probe = true;
384 }
385 self.probe_mode = ProbeMode::Active(JsonProbeState {
386 held_text: piece.to_owned(),
387 });
388 return self.evaluate_probe();
389 }
390
391 if let Some(entry) = self.pending.back_mut() {
392 entry.is_held_for_probe = true;
393 }
394 if let ProbeMode::Active(state) = &mut self.probe_mode {
395 state.held_text.push_str(piece);
396 }
397 self.evaluate_probe()
398 }
399
400 const fn section_allows_probe_engagement(&self) -> bool {
401 matches!(
402 self.section,
403 SampledTokenSection::Content | SampledTokenSection::Pending
404 )
405 }
406
407 fn evaluate_probe(&mut self) -> Vec<IngestOutcome> {
408 let outcome = match &self.probe_mode {
409 ProbeMode::Active(state) => JsonProbeOutcome::validate_prefix(&state.held_text),
410 ProbeMode::Idle => return Vec::new(),
411 };
412 match outcome {
413 JsonProbeOutcome::StillPossiblyValid => Vec::new(),
414 JsonProbeOutcome::CompletedValid => self.commit_probe_as_tool_call(),
415 JsonProbeOutcome::Failed => self.abandon_probe(),
416 }
417 }
418
419 fn commit_probe_as_tool_call(&mut self) -> Vec<IngestOutcome> {
420 if !matches!(self.probe_mode, ProbeMode::Active(_)) {
421 return Vec::new();
422 }
423 self.probe_mode = ProbeMode::Idle;
424 self.section = SampledTokenSection::Content;
425
426 let drained: Vec<_> = self.pending.drain(..).collect();
427 let mut outcomes = Vec::new();
428 for mut entry in drained {
429 if entry.is_held_for_probe {
430 entry.section = SampledTokenSection::ToolCall;
431 entry.is_held_for_probe = false;
432 if !entry.is_from_prompt {
433 outcomes.push(self.finalize_entry(entry));
434 }
435 } else {
436 self.pending.push_back(entry);
437 }
438 }
439 outcomes
440 }
441
442 fn abandon_probe(&mut self) -> Vec<IngestOutcome> {
443 if !matches!(self.probe_mode, ProbeMode::Active(_)) {
444 return Vec::new();
445 }
446 self.probe_mode = ProbeMode::Idle;
447
448 let drained: Vec<_> = self.pending.drain(..).collect();
449 let mut outcomes = Vec::new();
450 for mut entry in drained {
451 if entry.is_held_for_probe {
452 entry.is_held_for_probe = false;
453 if !entry.is_from_prompt {
454 outcomes.push(self.finalize_entry(entry));
455 }
456 } else {
457 self.pending.push_back(entry);
458 }
459 }
460 outcomes
461 }
462
463 fn finalize_entry(&mut self, entry: PendingToken) -> IngestOutcome {
464 let section = entry.section;
465 match section {
466 SampledTokenSection::Reasoning => self.usage.record_reasoning_token(),
467 SampledTokenSection::Content => self.usage.record_content_token(),
468 SampledTokenSection::ToolCall => self.usage.record_tool_call_token(),
469 SampledTokenSection::Pending => self.usage.record_undeterminable_token(),
470 }
471
472 let sampled_token = match section {
473 SampledTokenSection::Reasoning => SampledToken::Reasoning(entry.token),
474 SampledTokenSection::Content => SampledToken::Content(entry.token),
475 SampledTokenSection::ToolCall => SampledToken::ToolCall(entry.token),
476 SampledTokenSection::Pending => SampledToken::Undeterminable(entry.token),
477 };
478
479 let visible_piece = if entry.is_boundary {
480 String::new()
481 } else {
482 entry.decoded.clone()
483 };
484
485 IngestOutcome {
486 sampled_token,
487 visible_piece,
488 raw_piece: entry.decoded,
489 }
490 }
491
492 pub fn sample(
499 &mut self,
500 sampler: &mut LlamaSampler,
501 context: &LlamaContext,
502 idx: i32,
503 ) -> Result<(LlamaToken, Vec<IngestOutcome>), SampleError> {
504 let raw = sampler.sample(context, idx)?;
505 let outcomes = self.ingest(raw);
506
507 Ok((raw, outcomes))
508 }
509
510 pub fn feed_prompt_to_batch(
513 &mut self,
514 batch: &mut LlamaBatch,
515 token: LlamaToken,
516 position: llama_pos,
517 seq_ids: &[llama_seq_id],
518 logits: bool,
519 ) -> Result<(), BatchAddError> {
520 batch.add(&SampledToken::Content(token), position, seq_ids, logits)?;
521 self.ingest_prompt_token(token);
522 self.pending_prompt_tokens = self.pending_prompt_tokens.saturating_add(1);
523
524 Ok(())
525 }
526
527 pub fn feed_prompt_sequence_to_batch(
530 &mut self,
531 batch: &mut LlamaBatch,
532 tokens: &[LlamaToken],
533 seq_id: llama_seq_id,
534 logits_all: bool,
535 ) -> Result<(), BatchAddError> {
536 batch.add_sequence(tokens, seq_id, logits_all)?;
537 self.ingest_prompt_tokens(tokens);
538 self.pending_prompt_tokens = self
539 .pending_prompt_tokens
540 .saturating_add(tokens.len() as u64);
541
542 Ok(())
543 }
544
545 pub const fn commit_prompt_tokens(&mut self) -> u64 {
546 let promoted = self.pending_prompt_tokens;
547 self.usage.record_prompt_tokens(promoted);
548 self.pending_prompt_tokens = 0;
549
550 promoted
551 }
552
553 pub const fn discard_pending_prompt_tokens(&mut self) -> u64 {
554 let discarded = self.pending_prompt_tokens;
555 self.pending_prompt_tokens = 0;
556
557 discarded
558 }
559
560 #[must_use]
561 pub const fn pending_prompt_tokens(&self) -> u64 {
562 self.pending_prompt_tokens
563 }
564
565 #[expect(
573 clippy::too_many_arguments,
574 reason = "thin wrapper over MtmdInputChunks::eval_chunks; parameter shape mirrors the underlying API"
575 )]
576 pub fn eval_multimodal_chunks(
577 &mut self,
578 chunks: &MtmdInputChunks,
579 mtmd_ctx: &MtmdContext,
580 llama_ctx: &LlamaContext,
581 start_position: llama_pos,
582 seq_id: llama_seq_id,
583 n_batch: i32,
584 logits_last: bool,
585 ) -> Result<llama_pos, EvalMultimodalChunksError> {
586 let chunk_count = chunks.len();
587 let mut next_position = start_position;
591
592 for index in 0..chunk_count {
593 let chunk = chunks
594 .get(index)
595 .ok_or(EvalMultimodalChunksError::ChunkOutOfBounds(index))?;
596 let logits_for_this_chunk = logits_last && index + 1 == chunk_count;
597
598 next_position = chunk.eval_single(
599 mtmd_ctx,
600 llama_ctx,
601 next_position,
602 seq_id,
603 n_batch,
604 logits_for_this_chunk,
605 )?;
606 crate::ingest_prompt_chunk::ingest_prompt_chunk(self, &chunk)?;
607 }
608
609 Ok(next_position)
610 }
611
612 pub const fn record_prompt_tokens(&mut self, count: u64) {
613 self.usage.record_prompt_tokens(count);
614 }
615
616 pub const fn record_input_image_tokens(&mut self, count: u64) {
617 self.usage.record_input_image_tokens(count);
618 }
619
620 pub const fn record_input_audio_tokens(&mut self, count: u64) {
621 self.usage.record_input_audio_tokens(count);
622 }
623
624 pub const fn record_cached_prompt_tokens(&mut self, count: u64) -> Result<(), TokenUsageError> {
628 self.usage.record_cached_prompt_tokens(count)
629 }
630
631 #[must_use]
632 pub const fn usage(&self) -> &TokenUsage {
633 &self.usage
634 }
635
636 #[must_use]
637 pub fn into_usage(self) -> TokenUsage {
638 self.usage
639 }
640
641 #[must_use]
642 pub const fn current_section(&self) -> SampledTokenSection {
643 self.section
644 }
645
646 #[must_use]
647 pub const fn markers(&self) -> &StreamingMarkers {
648 &self.markers
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::IngestOutcome;
655 use super::PendingToken;
656 use super::ProbeMode;
657 use super::SampledTokenClassifier;
658 use super::SampledTokenSection;
659 use super::StreamingMarkers;
660 use crate::sampled_token::SampledToken;
661 use crate::token::LlamaToken;
662
663 fn token(id: i32) -> LlamaToken {
664 LlamaToken::new(id)
665 }
666
667 fn markers_with(
668 reasoning_open: Option<Vec<LlamaToken>>,
669 reasoning_close: Option<Vec<LlamaToken>>,
670 ) -> StreamingMarkers {
671 StreamingMarkers {
672 reasoning_open,
673 reasoning_close,
674 tool_call_open: None,
675 tool_call_close: None,
676 }
677 }
678
679 fn synthetic_classifier(markers: StreamingMarkers) -> SampledTokenClassifier<'static> {
683 SampledTokenClassifier {
684 model: unsafe { &*std::ptr::NonNull::<crate::model::LlamaModel>::dangling().as_ptr() },
685 markers,
686 decoder: encoding_rs::UTF_8.new_decoder(),
687 pending: std::collections::VecDeque::new(),
688 section: SampledTokenSection::Pending,
689 pending_prompt_tokens: 0,
690 usage: llama_cpp_bindings_types::TokenUsage::new(),
691 probe_mode: ProbeMode::Idle,
692 }
693 }
694
695 fn push_pending(classifier: &mut SampledTokenClassifier<'_>, token_id: i32, decoded: &str) {
696 classifier.pending.push_back(PendingToken {
697 token: token(token_id),
698 decoded: decoded.to_owned(),
699 section: classifier.section,
700 is_boundary: false,
701 is_from_prompt: false,
702 is_held_for_probe: false,
703 });
704 }
705
706 fn push_pending_from_prompt(classifier: &mut SampledTokenClassifier<'_>, token_id: i32) {
707 classifier.pending.push_back(PendingToken {
708 token: token(token_id),
709 decoded: String::new(),
710 section: classifier.section,
711 is_boundary: false,
712 is_from_prompt: true,
713 is_held_for_probe: false,
714 });
715 }
716
717 fn push_and_probe(
718 classifier: &mut SampledTokenClassifier<'_>,
719 token_id: i32,
720 decoded: &str,
721 ) -> Vec<IngestOutcome> {
722 push_pending(classifier, token_id, decoded);
723 classifier.try_consume_marker_at_tail();
724 let probe_was_active = matches!(classifier.probe_mode, ProbeMode::Active(_));
725 let mut outcomes = if probe_was_active && classifier.section_disengages_probe() {
726 classifier.abandon_probe()
727 } else {
728 classifier.update_probe(decoded)
729 };
730 outcomes.extend(classifier.drain_overflow());
731 outcomes
732 }
733
734 fn outcome_pieces(outcomes: &[IngestOutcome]) -> Vec<&str> {
735 outcomes
736 .iter()
737 .map(|outcome| outcome.visible_piece.as_str())
738 .collect()
739 }
740
741 fn outcome_sections(outcomes: &[IngestOutcome]) -> Vec<SampledTokenSection> {
742 outcomes
743 .iter()
744 .map(|outcome| match outcome.sampled_token {
745 SampledToken::Reasoning(_) => SampledTokenSection::Reasoning,
746 SampledToken::Content(_) => SampledTokenSection::Content,
747 SampledToken::ToolCall(_) => SampledTokenSection::ToolCall,
748 SampledToken::Undeterminable(_) => SampledTokenSection::Pending,
749 })
750 .collect()
751 }
752
753 #[test]
754 fn streaming_markers_with_no_markers_reports_none() {
755 let markers = StreamingMarkers::default();
756 assert!(!markers.has_any());
757 assert_eq!(markers.max_token_len(), 0);
758 }
759
760 #[test]
761 fn streaming_markers_max_token_len_takes_longest() {
762 let markers = StreamingMarkers {
763 reasoning_open: Some(vec![token(1)]),
764 reasoning_close: Some(vec![token(2), token(3), token(4)]),
765 tool_call_open: Some(vec![token(5), token(6)]),
766 tool_call_close: None,
767 };
768 assert_eq!(markers.max_token_len(), 3);
769 }
770
771 #[test]
772 fn single_token_close_marker_when_already_in_reasoning_emits_empty_piece_for_marker() {
773 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
774 let mut classifier = synthetic_classifier(markers);
775 classifier.section = SampledTokenSection::Reasoning;
776
777 push_pending(&mut classifier, 7, "step");
778 classifier.try_consume_marker_at_tail();
779 let mut outcomes = classifier.drain_overflow();
780
781 push_pending(&mut classifier, 200, "</think>");
782 classifier.try_consume_marker_at_tail();
783 outcomes.extend(classifier.drain_overflow());
784
785 push_pending(&mut classifier, 9, "Hi");
786 classifier.try_consume_marker_at_tail();
787 outcomes.extend(classifier.drain_overflow());
788
789 outcomes.extend(classifier.flush());
790
791 assert_eq!(
792 outcome_sections(&outcomes),
793 vec![
794 SampledTokenSection::Reasoning,
795 SampledTokenSection::Reasoning,
796 SampledTokenSection::Content,
797 ],
798 );
799 assert_eq!(outcome_pieces(&outcomes), vec!["step", "", "Hi"]);
800 assert_eq!(classifier.section, SampledTokenSection::Content);
801 }
802
803 #[test]
804 fn multi_token_close_marker_suppresses_every_marker_token() {
805 let markers = markers_with(
806 Some(vec![token(100)]),
807 Some(vec![token(200), token(201), token(202)]),
808 );
809 let mut classifier = synthetic_classifier(markers);
810 classifier.section = SampledTokenSection::Reasoning;
811
812 let mut outcomes = Vec::new();
813 for (id, decoded) in [(7, "r"), (200, "</"), (201, "thi"), (202, "nk>"), (9, "OK")] {
814 push_pending(&mut classifier, id, decoded);
815 classifier.try_consume_marker_at_tail();
816 outcomes.extend(classifier.drain_overflow());
817 }
818 outcomes.extend(classifier.flush());
819
820 assert_eq!(outcome_pieces(&outcomes), vec!["r", "", "", "", "OK"]);
821 assert_eq!(classifier.section, SampledTokenSection::Content);
822 }
823
824 #[test]
825 fn marker_prefix_that_diverges_does_not_suppress_buffered_tokens() {
826 let markers = markers_with(
827 Some(vec![token(100)]),
828 Some(vec![token(200), token(201), token(202)]),
829 );
830 let mut classifier = synthetic_classifier(markers);
831 classifier.section = SampledTokenSection::Reasoning;
832
833 let mut outcomes = Vec::new();
834 for (id, decoded) in [(7, "r"), (200, "a"), (201, "b"), (300, "x")] {
835 push_pending(&mut classifier, id, decoded);
836 classifier.try_consume_marker_at_tail();
837 outcomes.extend(classifier.drain_overflow());
838 }
839 outcomes.extend(classifier.flush());
840
841 assert_eq!(outcome_pieces(&outcomes), vec!["r", "a", "b", "x"]);
842 assert!(
843 outcomes
844 .iter()
845 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Reasoning(_)))
846 );
847 assert_eq!(classifier.section, SampledTokenSection::Reasoning);
848 }
849
850 #[test]
851 fn open_then_close_back_to_back_emits_two_empty_pieces_around_zero_content() {
852 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
853 let mut classifier = synthetic_classifier(markers);
854 classifier.section = SampledTokenSection::Content;
855
856 let mut outcomes = Vec::new();
857 for (id, decoded) in [(100, "<think>"), (200, "</think>"), (9, "Hi")] {
858 push_pending(&mut classifier, id, decoded);
859 classifier.try_consume_marker_at_tail();
860 outcomes.extend(classifier.drain_overflow());
861 }
862 outcomes.extend(classifier.flush());
863
864 assert_eq!(
865 outcome_sections(&outcomes),
866 vec![
867 SampledTokenSection::Reasoning,
868 SampledTokenSection::Reasoning,
869 SampledTokenSection::Content,
870 ],
871 );
872 assert_eq!(outcome_pieces(&outcomes), vec!["", "", "Hi"]);
873 assert_eq!(classifier.section, SampledTokenSection::Content);
874 }
875
876 #[test]
877 fn spurious_reasoning_close_in_content_section_classifies_as_content() {
878 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
879 let mut classifier = synthetic_classifier(markers);
880 classifier.section = SampledTokenSection::Content;
881
882 push_pending(&mut classifier, 200, "</think>");
883 classifier.try_consume_marker_at_tail();
884 let outcomes = classifier.drain_overflow();
885
886 assert_eq!(
887 outcome_sections(&outcomes),
888 vec![SampledTokenSection::Content],
889 );
890 assert_eq!(classifier.section, SampledTokenSection::Content);
891 }
892
893 #[test]
894 fn spurious_tool_call_close_in_reasoning_section_classifies_as_tool_call() {
895 let markers = StreamingMarkers {
896 reasoning_open: Some(vec![token(100)]),
897 reasoning_close: Some(vec![token(200)]),
898 tool_call_open: Some(vec![token(300)]),
899 tool_call_close: Some(vec![token(400)]),
900 };
901 let mut classifier = synthetic_classifier(markers);
902 classifier.section = SampledTokenSection::ToolCall;
903
904 push_pending(&mut classifier, 400, "</tool_call>");
905 classifier.try_consume_marker_at_tail();
906 let outcomes = classifier.drain_overflow();
907
908 assert_eq!(
909 outcome_sections(&outcomes),
910 vec![SampledTokenSection::ToolCall],
911 );
912 assert_eq!(classifier.section, SampledTokenSection::Content);
913 }
914
915 #[test]
916 fn flush_drains_remaining_pending_at_eog() {
917 let markers = markers_with(
918 Some(vec![token(100)]),
919 Some(vec![token(200), token(201), token(202)]),
920 );
921 let mut classifier = synthetic_classifier(markers);
922 classifier.section = SampledTokenSection::Reasoning;
923
924 push_pending(&mut classifier, 7, "abc");
925 push_pending(&mut classifier, 200, "</");
926 push_pending(&mut classifier, 201, "th");
927
928 let outcomes = classifier.flush();
929
930 assert_eq!(outcome_pieces(&outcomes), vec!["abc", "</", "th"]);
931 assert!(classifier.pending.is_empty());
932 }
933
934 #[test]
935 fn no_markers_marks_each_token_undeterminable_with_visible_piece() {
936 let markers = StreamingMarkers::default();
937 let mut classifier = synthetic_classifier(markers);
938
939 push_pending(&mut classifier, 1, "h");
940 push_pending(&mut classifier, 2, "i");
941 let outcomes = classifier.flush();
942
943 assert_eq!(outcome_pieces(&outcomes), vec!["h", "i"]);
944 assert_eq!(
945 outcome_sections(&outcomes),
946 vec![SampledTokenSection::Pending, SampledTokenSection::Pending],
947 );
948 }
949
950 #[test]
951 fn ingest_prompt_tokens_without_markers_is_noop() {
952 let markers = StreamingMarkers::default();
953 let mut classifier = synthetic_classifier(markers);
954
955 push_pending_from_prompt(&mut classifier, 7);
956 push_pending_from_prompt(&mut classifier, 8);
957
958 assert_eq!(classifier.section, SampledTokenSection::Pending);
959 assert_eq!(classifier.usage().reasoning_tokens, 0);
960 assert_eq!(classifier.usage().content_tokens, 0);
961 assert_eq!(classifier.usage().tool_call_tokens, 0);
962 assert_eq!(classifier.usage().undeterminable_tokens, 0);
963 }
964
965 #[test]
966 fn ingest_prompt_tokens_through_open_close_pair_ends_in_content() {
967 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
968 let mut classifier = synthetic_classifier(markers);
969
970 for token_id in [100, 7, 200] {
971 push_pending_from_prompt(&mut classifier, token_id);
972 classifier.try_consume_marker_at_tail();
973 classifier.drain_overflow();
974 }
975
976 assert_eq!(classifier.section, SampledTokenSection::Content);
977 assert_eq!(classifier.usage().reasoning_tokens, 0);
978 assert_eq!(classifier.usage().content_tokens, 0);
979 assert_eq!(classifier.usage().tool_call_tokens, 0);
980 assert_eq!(classifier.usage().undeterminable_tokens, 0);
981 }
982
983 #[test]
984 fn ingest_prompt_tokens_through_open_only_ends_in_reasoning() {
985 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
986 let mut classifier = synthetic_classifier(markers);
987
988 for token_id in [100, 7] {
989 push_pending_from_prompt(&mut classifier, token_id);
990 classifier.try_consume_marker_at_tail();
991 classifier.drain_overflow();
992 }
993
994 assert_eq!(classifier.section, SampledTokenSection::Reasoning);
995 assert_eq!(classifier.usage().reasoning_tokens, 0);
996 assert_eq!(classifier.usage().content_tokens, 0);
997 }
998
999 #[test]
1000 fn ingest_prompt_tokens_does_not_record_usage() {
1001 let markers = markers_with(
1002 Some(vec![token(100)]),
1003 Some(vec![token(200), token(201), token(202)]),
1004 );
1005 let mut classifier = synthetic_classifier(markers);
1006
1007 for token_id in [100, 7, 8, 9, 200, 201, 202, 11] {
1008 push_pending_from_prompt(&mut classifier, token_id);
1009 classifier.try_consume_marker_at_tail();
1010 classifier.drain_overflow();
1011 }
1012 let drained = classifier.flush();
1013 assert!(drained.is_empty());
1014
1015 assert_eq!(classifier.usage().reasoning_tokens, 0);
1016 assert_eq!(classifier.usage().content_tokens, 0);
1017 assert_eq!(classifier.usage().tool_call_tokens, 0);
1018 assert_eq!(classifier.usage().undeterminable_tokens, 0);
1019 }
1020
1021 #[test]
1022 fn prompt_token_completing_marker_with_generated_token_is_suppressed_correctly() {
1023 let markers = markers_with(
1024 Some(vec![token(100)]),
1025 Some(vec![token(200), token(201), token(202)]),
1026 );
1027 let mut classifier = synthetic_classifier(markers);
1028 classifier.section = SampledTokenSection::Reasoning;
1029
1030 for token_id in [200, 201] {
1031 push_pending_from_prompt(&mut classifier, token_id);
1032 classifier.try_consume_marker_at_tail();
1033 classifier.drain_overflow();
1034 }
1035
1036 assert_eq!(classifier.section, SampledTokenSection::Reasoning);
1037 assert_eq!(classifier.pending.len(), 2);
1038
1039 classifier.pending.push_back(PendingToken {
1040 token: token(202),
1041 decoded: "k>".to_owned(),
1042 section: classifier.section,
1043 is_boundary: false,
1044 is_from_prompt: false,
1045 is_held_for_probe: false,
1046 });
1047 classifier.try_consume_marker_at_tail();
1048 let outcomes = classifier.drain_overflow();
1049
1050 assert_eq!(outcomes.len(), 1);
1051 assert!(matches!(
1052 outcomes[0].sampled_token,
1053 SampledToken::Reasoning(_)
1054 ));
1055 assert_eq!(outcomes[0].visible_piece, "");
1056 assert_eq!(outcomes[0].raw_piece, "k>");
1057
1058 assert_eq!(classifier.section, SampledTokenSection::Content);
1059 assert_eq!(classifier.usage().reasoning_tokens, 1);
1060 assert_eq!(classifier.usage().content_tokens, 0);
1061 }
1062
1063 #[test]
1064 fn ingest_prompt_tokens_with_multiple_round_trips_ends_in_content() {
1065 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1066 let mut classifier = synthetic_classifier(markers);
1067
1068 for token_id in [100, 7, 200, 100, 8, 200] {
1070 push_pending_from_prompt(&mut classifier, token_id);
1071 classifier.try_consume_marker_at_tail();
1072 classifier.drain_overflow();
1073 }
1074
1075 assert_eq!(classifier.section, SampledTokenSection::Content);
1076 assert_eq!(classifier.usage().reasoning_tokens, 0);
1077 assert_eq!(classifier.usage().content_tokens, 0);
1078 assert_eq!(classifier.usage().tool_call_tokens, 0);
1079 assert_eq!(classifier.usage().undeterminable_tokens, 0);
1080 }
1081
1082 #[test]
1083 fn ingest_prompt_tokens_initial_section_is_always_pending() {
1084 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1085 let classifier = synthetic_classifier(markers);
1086
1087 assert_eq!(classifier.section, SampledTokenSection::Pending);
1088 }
1089
1090 #[test]
1091 fn ingest_prompt_tokens_then_drain_for_generated_token_classifies_correctly() {
1092 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1093 let mut classifier = synthetic_classifier(markers);
1094
1095 for token_id in [100, 7, 200] {
1097 push_pending_from_prompt(&mut classifier, token_id);
1098 classifier.try_consume_marker_at_tail();
1099 classifier.drain_overflow();
1100 }
1101
1102 assert_eq!(classifier.section, SampledTokenSection::Content);
1103 assert_eq!(classifier.usage().reasoning_tokens, 0);
1104 assert_eq!(classifier.usage().content_tokens, 0);
1105
1106 classifier.pending.push_back(PendingToken {
1110 token: token(50),
1111 decoded: "hi".to_owned(),
1112 section: classifier.section,
1113 is_boundary: false,
1114 is_from_prompt: false,
1115 is_held_for_probe: false,
1116 });
1117 classifier.try_consume_marker_at_tail();
1118 let outcomes = classifier.drain_overflow();
1119
1120 assert_eq!(outcomes.len(), 1);
1121 assert!(matches!(
1122 outcomes[0].sampled_token,
1123 SampledToken::Content(_)
1124 ));
1125 assert_eq!(outcomes[0].visible_piece, "hi");
1126 assert_eq!(classifier.usage().content_tokens, 1);
1127 assert_eq!(classifier.usage().reasoning_tokens, 0);
1128 assert_eq!(classifier.usage().undeterminable_tokens, 0);
1129 }
1130
1131 #[test]
1132 fn close_marker_in_content_section_is_suppressed_as_boundary() {
1133 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1141 let mut classifier = synthetic_classifier(markers);
1142 classifier.section = SampledTokenSection::Content;
1143
1144 let mut outcomes = Vec::new();
1145 for (id, decoded) in [(7, "hi"), (200, "</think>"), (8, "ok")] {
1146 push_pending(&mut classifier, id, decoded);
1147 classifier.try_consume_marker_at_tail();
1148 outcomes.extend(classifier.drain_overflow());
1149 }
1150 outcomes.extend(classifier.flush());
1151
1152 assert_eq!(
1153 outcome_sections(&outcomes),
1154 vec![
1155 SampledTokenSection::Content,
1156 SampledTokenSection::Content,
1157 SampledTokenSection::Content,
1158 ],
1159 );
1160 assert_eq!(outcome_pieces(&outcomes), vec!["hi", "", "ok"]);
1163 assert_eq!(classifier.section, SampledTokenSection::Content);
1164 }
1165
1166 #[test]
1167 fn open_marker_in_reasoning_section_is_suppressed_as_boundary() {
1168 let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1172 let mut classifier = synthetic_classifier(markers);
1173 classifier.section = SampledTokenSection::Reasoning;
1174
1175 let mut outcomes = Vec::new();
1176 for (id, decoded) in [(7, "step1"), (100, "<think>"), (8, "step2")] {
1177 push_pending(&mut classifier, id, decoded);
1178 classifier.try_consume_marker_at_tail();
1179 outcomes.extend(classifier.drain_overflow());
1180 }
1181 outcomes.extend(classifier.flush());
1182
1183 assert_eq!(outcome_pieces(&outcomes), vec!["step1", "", "step2"]);
1184 assert_eq!(classifier.section, SampledTokenSection::Reasoning);
1185 }
1186
1187 #[test]
1188 fn record_prompt_tokens_updates_usage() {
1189 let markers = markers_with(None, None);
1190 let mut classifier = synthetic_classifier(markers);
1191
1192 classifier.record_prompt_tokens(7);
1193
1194 assert_eq!(classifier.usage().prompt_tokens, 7);
1195 }
1196
1197 #[test]
1198 fn record_cached_prompt_tokens_updates_usage_when_under_limit() {
1199 let markers = markers_with(None, None);
1200 let mut classifier = synthetic_classifier(markers);
1201 classifier.record_prompt_tokens(10);
1202
1203 classifier.record_cached_prompt_tokens(3).unwrap();
1204
1205 assert_eq!(classifier.usage().cached_prompt_tokens, 3);
1206 }
1207
1208 #[test]
1209 fn record_cached_prompt_tokens_returns_error_when_over_prompt_total() {
1210 let markers = markers_with(None, None);
1211 let mut classifier = synthetic_classifier(markers);
1212 classifier.record_prompt_tokens(2);
1213
1214 let result = classifier.record_cached_prompt_tokens(5);
1215
1216 assert!(result.is_err());
1217 }
1218
1219 #[test]
1220 fn markers_accessor_returns_configured_markers() {
1221 let configured = markers_with(Some(vec![token(1)]), Some(vec![token(2)]));
1222 let classifier = synthetic_classifier(configured);
1223
1224 let returned = classifier.markers();
1225
1226 assert_eq!(returned.reasoning_open.as_deref(), Some(&[token(1)][..]));
1227 assert_eq!(returned.reasoning_close.as_deref(), Some(&[token(2)][..]));
1228 }
1229
1230 #[test]
1231 fn into_usage_consumes_classifier_and_yields_usage_snapshot() {
1232 let markers = markers_with(None, None);
1233 let mut classifier = synthetic_classifier(markers);
1234 classifier.record_prompt_tokens(11);
1235
1236 let usage = classifier.into_usage();
1237
1238 assert_eq!(usage.prompt_tokens, 11);
1239 }
1240
1241 #[test]
1242 fn spurious_tool_call_close_in_content_section_classifies_as_content() {
1243 let mut markers = markers_with(None, None);
1246 markers.tool_call_close = Some(vec![token(300)]);
1247 let mut classifier = synthetic_classifier(markers);
1248 classifier.section = SampledTokenSection::Content;
1249
1250 push_pending(&mut classifier, 300, "</tool_call>");
1251 classifier.try_consume_marker_at_tail();
1252 let outcomes = classifier.drain_overflow();
1253
1254 assert_eq!(
1255 outcome_sections(&outcomes),
1256 vec![SampledTokenSection::Content],
1257 );
1258 assert_eq!(classifier.section, SampledTokenSection::Content);
1259 }
1260
1261 fn markers_with_tool_call_open(tool_call_open: Vec<LlamaToken>) -> StreamingMarkers {
1262 StreamingMarkers {
1263 reasoning_open: None,
1264 reasoning_close: None,
1265 tool_call_open: Some(tool_call_open),
1266 tool_call_close: None,
1267 }
1268 }
1269
1270 fn feed_json_string(
1271 classifier: &mut SampledTokenClassifier<'_>,
1272 text: &str,
1273 starting_token_id: i32,
1274 ) -> Vec<IngestOutcome> {
1275 let mut outcomes = Vec::new();
1276 for (offset, ch) in text.char_indices() {
1277 let token_id = starting_token_id + i32::try_from(offset).unwrap_or(i32::MAX);
1278 let mut buffer = [0_u8; 4];
1279 let chunk = ch.encode_utf8(&mut buffer);
1280 outcomes.extend(push_and_probe(classifier, token_id, chunk));
1281 }
1282 outcomes
1283 }
1284
1285 #[test]
1286 fn json_probe_engages_when_first_non_whitespace_is_open_brace_in_content() {
1287 let markers = markers_with_tool_call_open(vec![token(900)]);
1288 let mut classifier = synthetic_classifier(markers);
1289 classifier.section = SampledTokenSection::Content;
1290
1291 push_and_probe(&mut classifier, 1, "{");
1292
1293 assert!(matches!(classifier.probe_mode, ProbeMode::Active(_)));
1294 }
1295
1296 #[test]
1297 fn json_probe_releases_tokens_as_tool_call_when_signature_matches() {
1298 let markers = markers_with_tool_call_open(vec![token(900)]);
1299 let mut classifier = synthetic_classifier(markers);
1300 classifier.section = SampledTokenSection::Content;
1301
1302 let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":{}}"#, 100);
1303
1304 assert!(!outcomes.is_empty());
1305 assert!(
1306 outcomes
1307 .iter()
1308 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1309 "every emitted outcome should be ToolCall, got {:?}",
1310 outcome_sections(&outcomes),
1311 );
1312 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1313 }
1314
1315 #[test]
1316 fn json_probe_releases_tokens_as_content_when_signature_does_not_match() {
1317 let markers = markers_with_tool_call_open(vec![token(900)]);
1318 let mut classifier = synthetic_classifier(markers);
1319 classifier.section = SampledTokenSection::Content;
1320
1321 let outcomes = feed_json_string(&mut classifier, r#"{"foo":"bar"}"#, 100);
1322
1323 assert!(
1324 outcomes
1325 .iter()
1326 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1327 "every emitted outcome should be Content, got {:?}",
1328 outcome_sections(&outcomes),
1329 );
1330 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1331 }
1332
1333 #[test]
1334 fn json_probe_releases_tokens_as_content_when_extra_top_level_key() {
1335 let markers = markers_with_tool_call_open(vec![token(900)]);
1336 let mut classifier = synthetic_classifier(markers);
1337 classifier.section = SampledTokenSection::Content;
1338
1339 let outcomes = feed_json_string(
1340 &mut classifier,
1341 r#"{"name":"f","arguments":{},"extra":1}"#,
1342 100,
1343 );
1344
1345 assert!(
1346 outcomes
1347 .iter()
1348 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1349 );
1350 }
1351
1352 #[test]
1353 fn json_probe_releases_tokens_as_content_when_arguments_is_not_object() {
1354 let markers = markers_with_tool_call_open(vec![token(900)]);
1355 let mut classifier = synthetic_classifier(markers);
1356 classifier.section = SampledTokenSection::Content;
1357
1358 let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":"hi"}"#, 100);
1359
1360 assert!(
1361 outcomes
1362 .iter()
1363 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1364 );
1365 }
1366
1367 #[test]
1368 fn json_probe_handles_strings_with_quoted_braces_in_arguments() {
1369 let markers = markers_with_tool_call_open(vec![token(900)]);
1370 let mut classifier = synthetic_classifier(markers);
1371 classifier.section = SampledTokenSection::Content;
1372
1373 let outcomes = feed_json_string(
1374 &mut classifier,
1375 r#"{"name":"f","arguments":{"q":"a } b"}}"#,
1376 100,
1377 );
1378
1379 assert!(
1380 outcomes
1381 .iter()
1382 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1383 );
1384 }
1385
1386 #[test]
1387 fn json_probe_handles_escaped_quotes_in_string_values() {
1388 let markers = markers_with_tool_call_open(vec![token(900)]);
1389 let mut classifier = synthetic_classifier(markers);
1390 classifier.section = SampledTokenSection::Content;
1391
1392 let outcomes = feed_json_string(
1393 &mut classifier,
1394 r#"{"name":"f","arguments":{"q":"he said \"hi\""}}"#,
1395 100,
1396 );
1397
1398 assert!(
1399 outcomes
1400 .iter()
1401 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1402 );
1403 }
1404
1405 #[test]
1406 fn json_probe_handles_unicode_letters_in_strings() {
1407 let markers = markers_with_tool_call_open(vec![token(900)]);
1408 let mut classifier = synthetic_classifier(markers);
1409 classifier.section = SampledTokenSection::Content;
1410
1411 let outcomes = feed_json_string(
1412 &mut classifier,
1413 r#"{"name":"日本語","arguments":{"city":"パリ"}}"#,
1414 100,
1415 );
1416
1417 assert!(
1418 outcomes
1419 .iter()
1420 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1421 );
1422 }
1423
1424 #[test]
1425 fn json_probe_handles_nested_objects() {
1426 let markers = markers_with_tool_call_open(vec![token(900)]);
1427 let mut classifier = synthetic_classifier(markers);
1428 classifier.section = SampledTokenSection::Content;
1429
1430 let outcomes = feed_json_string(
1431 &mut classifier,
1432 r#"{"name":"f","arguments":{"a":{"b":{"c":1}}}}"#,
1433 100,
1434 );
1435
1436 assert!(
1437 outcomes
1438 .iter()
1439 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1440 );
1441 }
1442
1443 #[test]
1444 fn json_probe_handles_arrays_inside_arguments() {
1445 let markers = markers_with_tool_call_open(vec![token(900)]);
1446 let mut classifier = synthetic_classifier(markers);
1447 classifier.section = SampledTokenSection::Content;
1448
1449 let outcomes = feed_json_string(
1450 &mut classifier,
1451 r#"{"name":"f","arguments":{"items":[1,2,3]}}"#,
1452 100,
1453 );
1454
1455 assert!(
1456 outcomes
1457 .iter()
1458 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1459 );
1460 }
1461
1462 #[test]
1463 fn json_probe_does_not_engage_when_first_byte_is_close_brace() {
1464 let markers = markers_with_tool_call_open(vec![token(900)]);
1465 let mut classifier = synthetic_classifier(markers);
1466 classifier.section = SampledTokenSection::Content;
1467
1468 let outcomes = feed_json_string(&mut classifier, "}}", 100);
1469
1470 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1471 assert!(
1472 outcomes
1473 .iter()
1474 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1475 );
1476 }
1477
1478 #[test]
1479 fn json_probe_does_not_engage_in_reasoning_section() {
1480 let markers = StreamingMarkers {
1481 reasoning_open: Some(vec![token(800)]),
1482 reasoning_close: Some(vec![token(801)]),
1483 tool_call_open: Some(vec![token(900)]),
1484 tool_call_close: None,
1485 };
1486 let mut classifier = synthetic_classifier(markers);
1487 classifier.section = SampledTokenSection::Reasoning;
1488
1489 push_and_probe(&mut classifier, 1, "{");
1490
1491 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1492 }
1493
1494 #[test]
1495 fn json_probe_does_not_engage_in_tool_call_section() {
1496 let markers = markers_with_tool_call_open(vec![token(900)]);
1497 let mut classifier = synthetic_classifier(markers);
1498 classifier.section = SampledTokenSection::ToolCall;
1499
1500 push_and_probe(&mut classifier, 1, "{");
1501
1502 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1503 }
1504
1505 #[test]
1506 fn marker_probe_takes_precedence_when_both_could_match() {
1507 let markers = markers_with_tool_call_open(vec![token(900)]);
1513 let mut classifier = synthetic_classifier(markers);
1514 classifier.section = SampledTokenSection::Content;
1515
1516 let mut outcomes = Vec::new();
1517 outcomes.extend(push_and_probe(&mut classifier, 1, "{"));
1518 outcomes.extend(push_and_probe(&mut classifier, 900, r#"""#));
1519
1520 assert_eq!(classifier.section, SampledTokenSection::ToolCall);
1521 assert_eq!(outcome_pieces(&outcomes), vec!["{", ""]);
1522 assert_eq!(
1523 outcome_sections(&outcomes),
1524 vec![SampledTokenSection::Content, SampledTokenSection::ToolCall],
1525 );
1526 }
1527
1528 #[test]
1529 fn json_probe_consumes_two_consecutive_objects_separately() {
1530 let markers = markers_with_tool_call_open(vec![token(900)]);
1531 let mut classifier = synthetic_classifier(markers);
1532 classifier.section = SampledTokenSection::Content;
1533
1534 let mut outcomes = Vec::new();
1535 outcomes.extend(feed_json_string(
1536 &mut classifier,
1537 r#"{"name":"a","arguments":{}}"#,
1538 100,
1539 ));
1540 outcomes.extend(feed_json_string(
1541 &mut classifier,
1542 r#"{"name":"b","arguments":{"x":1}}"#,
1543 200,
1544 ));
1545
1546 assert!(
1547 outcomes
1548 .iter()
1549 .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1550 "two consecutive markerless tool calls must both classify as ToolCall, got {:?}",
1551 outcome_sections(&outcomes),
1552 );
1553 }
1554
1555 #[test]
1556 fn json_probe_with_leading_whitespace_then_open_brace_classifies_whitespace_as_content_and_json_as_tool_call()
1557 {
1558 let markers = markers_with_tool_call_open(vec![token(900)]);
1559 let mut classifier = synthetic_classifier(markers);
1560 classifier.section = SampledTokenSection::Content;
1561
1562 let outcomes = feed_json_string(
1563 &mut classifier,
1564 "\n {\"name\":\"f\",\"arguments\":{}}",
1565 100,
1566 );
1567
1568 let tool_call_count = outcomes
1569 .iter()
1570 .filter(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_)))
1571 .count();
1572 let content_count = outcomes
1573 .iter()
1574 .filter(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_)))
1575 .count();
1576 assert_eq!(
1577 content_count, 3,
1578 "leading `\\n ` should classify as content"
1579 );
1580 assert!(
1581 tool_call_count > 0,
1582 "the JSON object should classify as ToolCall",
1583 );
1584 assert_eq!(content_count + tool_call_count, outcomes.len());
1585 }
1586
1587 #[test]
1588 fn json_probe_records_tool_call_token_usage_on_commit() {
1589 let markers = markers_with_tool_call_open(vec![token(900)]);
1590 let mut classifier = synthetic_classifier(markers);
1591 classifier.section = SampledTokenSection::Content;
1592
1593 let json = r#"{"name":"f","arguments":{}}"#;
1594 let outcomes = feed_json_string(&mut classifier, json, 100);
1595
1596 let emitted = outcomes.len();
1597 let usage = classifier.usage();
1598 assert_eq!(usage.tool_call_tokens, emitted as u64);
1599 assert_eq!(usage.content_tokens, 0);
1600 }
1601
1602 #[test]
1603 fn json_probe_records_content_token_usage_on_abandon() {
1604 let markers = markers_with_tool_call_open(vec![token(900)]);
1605 let mut classifier = synthetic_classifier(markers);
1606 classifier.section = SampledTokenSection::Content;
1607
1608 let json = r#"{"foo":"bar"}"#;
1609 let outcomes = feed_json_string(&mut classifier, json, 100);
1610
1611 let emitted = outcomes.len();
1612 let usage = classifier.usage();
1613 assert_eq!(usage.content_tokens, emitted as u64);
1614 assert_eq!(usage.tool_call_tokens, 0);
1615 }
1616
1617 #[test]
1618 fn flush_during_active_json_probe_releases_held_tokens_as_content() {
1619 let markers = markers_with_tool_call_open(vec![token(900)]);
1620 let mut classifier = synthetic_classifier(markers);
1621 classifier.section = SampledTokenSection::Content;
1622
1623 push_and_probe(&mut classifier, 1, "{");
1624 push_and_probe(&mut classifier, 2, r#""name""#);
1625 assert!(matches!(classifier.probe_mode, ProbeMode::Active(_)));
1626
1627 let outcomes = classifier.flush();
1628
1629 assert!(
1630 outcomes
1631 .iter()
1632 .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1633 "mid-probe flush must release held tokens as Content, got {:?}",
1634 outcome_sections(&outcomes),
1635 );
1636 assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1637 }
1638}