Skip to main content

llama_cpp_bindings/
sampled_token_classifier.rs

1use std::collections::VecDeque;
2
3use llama_cpp_bindings_sys::llama_pos;
4use llama_cpp_bindings_sys::llama_seq_id;
5
6use llama_cpp_bindings_types::TokenUsage;
7use llama_cpp_bindings_types::TokenUsageError;
8
9use crate::batch_add_error::BatchAddError;
10use crate::context::LlamaContext;
11use crate::error::EvalMultimodalChunksError;
12use crate::error::SampleError;
13use crate::llama_batch::LlamaBatch;
14use crate::model::LlamaModel;
15use crate::mtmd::MtmdContext;
16use crate::mtmd::MtmdInputChunks;
17use crate::sampled_token::SampledToken;
18use crate::sampling::LlamaSampler;
19use crate::streaming_json_probe::JsonProbeOutcome;
20use crate::streaming_markers::{MarkerKind, StreamingMarkers};
21use crate::token::LlamaToken;
22
23pub use crate::ingest_outcome::IngestOutcome;
24pub use crate::sampled_token_section::SampledTokenSection;
25
26#[derive(Clone, Debug)]
27struct PendingToken {
28    token: LlamaToken,
29    decoded: String,
30    section: SampledTokenSection,
31    is_boundary: bool,
32    is_from_prompt: bool,
33    is_held_for_probe: bool,
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37struct JsonProbeState {
38    held_text: String,
39}
40
41#[derive(Clone, Debug, Eq, PartialEq)]
42enum ProbeMode {
43    Idle,
44    Active(JsonProbeState),
45}
46
47pub struct SampledTokenClassifier<'model> {
48    model: &'model LlamaModel,
49    markers: StreamingMarkers,
50    decoder: encoding_rs::Decoder,
51    pending: VecDeque<PendingToken>,
52    section: SampledTokenSection,
53    pending_prompt_tokens: u64,
54    usage: TokenUsage,
55    probe_mode: ProbeMode,
56}
57
58impl<'model> SampledTokenClassifier<'model> {
59    #[must_use]
60    pub fn new(model: &'model LlamaModel, markers: StreamingMarkers) -> Self {
61        Self {
62            model,
63            markers,
64            decoder: encoding_rs::UTF_8.new_decoder(),
65            pending: VecDeque::new(),
66            section: SampledTokenSection::Pending,
67            pending_prompt_tokens: 0,
68            usage: TokenUsage::new(),
69            probe_mode: ProbeMode::Idle,
70        }
71    }
72
73    pub fn ingest(&mut self, token: LlamaToken) -> Vec<IngestOutcome> {
74        if !self.markers.has_any() {
75            self.usage.record_undeterminable_token();
76            let piece = self.decode(token);
77            return vec![IngestOutcome {
78                sampled_token: SampledToken::Undeterminable(token),
79                visible_piece: piece.clone(),
80                raw_piece: piece,
81            }];
82        }
83
84        let decoded = self.decode(token);
85        self.pending.push_back(PendingToken {
86            token,
87            decoded: decoded.clone(),
88            section: self.section,
89            is_boundary: false,
90            is_from_prompt: false,
91            is_held_for_probe: false,
92        });
93
94        self.try_consume_marker_at_tail();
95
96        let probe_was_active = matches!(self.probe_mode, ProbeMode::Active(_));
97        let mut outcomes = if probe_was_active && self.section_disengages_probe() {
98            self.abandon_probe()
99        } else {
100            self.update_probe(&decoded)
101        };
102
103        outcomes.extend(self.drain_overflow());
104        outcomes
105    }
106
107    const fn section_disengages_probe(&self) -> bool {
108        matches!(
109            self.section,
110            SampledTokenSection::ToolCall | SampledTokenSection::Reasoning
111        )
112    }
113
114    pub fn ingest_prompt_token(&mut self, token: LlamaToken) {
115        if !self.markers.has_any() {
116            return;
117        }
118
119        self.pending.push_back(PendingToken {
120            token,
121            decoded: String::new(),
122            section: self.section,
123            is_boundary: false,
124            is_from_prompt: true,
125            is_held_for_probe: false,
126        });
127
128        self.try_consume_marker_at_tail();
129        self.drain_overflow();
130    }
131
132    pub fn ingest_prompt_tokens(&mut self, tokens: &[LlamaToken]) {
133        if !self.markers.has_any() {
134            return;
135        }
136        for &token in tokens {
137            self.ingest_prompt_token(token);
138        }
139    }
140
141    pub fn flush(&mut self) -> Vec<IngestOutcome> {
142        self.probe_mode = ProbeMode::Idle;
143        let mut outcomes = Vec::with_capacity(self.pending.len());
144        while let Some(entry) = self.pending.pop_front() {
145            if entry.is_from_prompt {
146                continue;
147            }
148            outcomes.push(self.finalize_entry(entry));
149        }
150        outcomes
151    }
152
153    fn decode(&mut self, token: LlamaToken) -> String {
154        match self.model.token_to_piece(
155            &SampledToken::Content(token),
156            &mut self.decoder,
157            true,
158            None,
159        ) {
160            Ok(piece) => piece,
161            Err(detokenize_error) => {
162                log::debug!(
163                    "token_to_piece failed during classification, dropping piece: {detokenize_error}",
164                );
165                String::new()
166            }
167        }
168    }
169
170    fn try_consume_marker_at_tail(&mut self) {
171        const PROBE_KINDS: &[MarkerKind] = &[
172            MarkerKind::ReasoningOpen,
173            MarkerKind::ReasoningClose,
174            MarkerKind::ToolCallOpen,
175            MarkerKind::ToolCallClose,
176        ];
177
178        for &kind in PROBE_KINDS {
179            let Some(marker) = self.markers.lookup(kind) else {
180                continue;
181            };
182            if marker.is_empty() || self.pending.len() < marker.len() {
183                continue;
184            }
185            let span_start = self.pending.len() - marker.len();
186            let matches = self
187                .pending
188                .iter()
189                .skip(span_start)
190                .zip(marker)
191                .all(|(entry, marker_token)| entry.token == *marker_token);
192            if matches {
193                self.mark_marker_span(span_start, kind);
194                return;
195            }
196        }
197    }
198
199    fn mark_marker_span(&mut self, span_start: usize, kind: MarkerKind) {
200        let next_section = match kind {
201            MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
202            MarkerKind::ReasoningClose | MarkerKind::ToolCallClose => SampledTokenSection::Content,
203            MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
204        };
205        let span_section = match kind {
206            MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
207            MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
208            MarkerKind::ReasoningClose => {
209                if self.section == SampledTokenSection::Reasoning {
210                    SampledTokenSection::Reasoning
211                } else {
212                    SampledTokenSection::Content
213                }
214            }
215            MarkerKind::ToolCallClose => {
216                if self.section == SampledTokenSection::ToolCall {
217                    SampledTokenSection::ToolCall
218                } else {
219                    SampledTokenSection::Content
220                }
221            }
222        };
223
224        for entry in self.pending.iter_mut().skip(span_start) {
225            entry.is_boundary = true;
226            entry.section = span_section;
227        }
228
229        self.section = next_section;
230    }
231
232    fn drain_overflow(&mut self) -> Vec<IngestOutcome> {
233        let lookback = self.markers.max_token_len().saturating_sub(1);
234        let mut outcomes = Vec::new();
235
236        while let Some(front) = self.pending.front() {
237            if front.is_held_for_probe {
238                break;
239            }
240            let probe_held = self
241                .pending
242                .iter()
243                .filter(|entry| entry.is_held_for_probe)
244                .count();
245            let drainable = self.pending.len().saturating_sub(probe_held);
246            let beyond_lookback = drainable > lookback;
247            if !front.is_boundary && !beyond_lookback {
248                break;
249            }
250            let Some(entry) = self.pending.pop_front() else {
251                break;
252            };
253            if entry.is_from_prompt {
254                continue;
255            }
256            outcomes.push(self.finalize_entry(entry));
257        }
258
259        outcomes
260    }
261
262    fn update_probe(&mut self, piece: &str) -> Vec<IngestOutcome> {
263        let probe_active = matches!(self.probe_mode, ProbeMode::Active(_));
264        if !probe_active {
265            if !self.section_allows_probe_engagement() {
266                return Vec::new();
267            }
268            if !piece.trim_start().starts_with('{') {
269                return Vec::new();
270            }
271            if let Some(entry) = self.pending.back_mut() {
272                entry.is_held_for_probe = true;
273            }
274            self.probe_mode = ProbeMode::Active(JsonProbeState {
275                held_text: piece.to_owned(),
276            });
277            return self.evaluate_probe();
278        }
279
280        if let Some(entry) = self.pending.back_mut() {
281            entry.is_held_for_probe = true;
282        }
283        if let ProbeMode::Active(state) = &mut self.probe_mode {
284            state.held_text.push_str(piece);
285        }
286        self.evaluate_probe()
287    }
288
289    const fn section_allows_probe_engagement(&self) -> bool {
290        matches!(
291            self.section,
292            SampledTokenSection::Content | SampledTokenSection::Pending
293        )
294    }
295
296    fn evaluate_probe(&mut self) -> Vec<IngestOutcome> {
297        let outcome = match &self.probe_mode {
298            ProbeMode::Active(state) => JsonProbeOutcome::validate_prefix(&state.held_text),
299            ProbeMode::Idle => return Vec::new(),
300        };
301        match outcome {
302            JsonProbeOutcome::StillPossiblyValid => Vec::new(),
303            JsonProbeOutcome::CompletedValid => self.commit_probe_as_tool_call(),
304            JsonProbeOutcome::Failed => self.abandon_probe(),
305        }
306    }
307
308    fn commit_probe_as_tool_call(&mut self) -> Vec<IngestOutcome> {
309        if !matches!(self.probe_mode, ProbeMode::Active(_)) {
310            return Vec::new();
311        }
312        self.probe_mode = ProbeMode::Idle;
313        self.section = SampledTokenSection::Content;
314
315        let drained: Vec<_> = self.pending.drain(..).collect();
316        let mut outcomes = Vec::new();
317        for mut entry in drained {
318            if entry.is_held_for_probe {
319                entry.section = SampledTokenSection::ToolCall;
320                entry.is_held_for_probe = false;
321                if !entry.is_from_prompt {
322                    outcomes.push(self.finalize_entry(entry));
323                }
324            } else {
325                self.pending.push_back(entry);
326            }
327        }
328        outcomes
329    }
330
331    fn abandon_probe(&mut self) -> Vec<IngestOutcome> {
332        if !matches!(self.probe_mode, ProbeMode::Active(_)) {
333            return Vec::new();
334        }
335        self.probe_mode = ProbeMode::Idle;
336
337        let drained: Vec<_> = self.pending.drain(..).collect();
338        let mut outcomes = Vec::new();
339        for mut entry in drained {
340            if entry.is_held_for_probe {
341                entry.is_held_for_probe = false;
342                if !entry.is_from_prompt {
343                    outcomes.push(self.finalize_entry(entry));
344                }
345            } else {
346                self.pending.push_back(entry);
347            }
348        }
349        outcomes
350    }
351
352    fn finalize_entry(&mut self, entry: PendingToken) -> IngestOutcome {
353        let section = entry.section;
354        match section {
355            SampledTokenSection::Reasoning => self.usage.record_reasoning_token(),
356            SampledTokenSection::Content => self.usage.record_content_token(),
357            SampledTokenSection::ToolCall => self.usage.record_tool_call_token(),
358            SampledTokenSection::Pending => self.usage.record_undeterminable_token(),
359        }
360
361        let sampled_token = match section {
362            SampledTokenSection::Reasoning => SampledToken::Reasoning(entry.token),
363            SampledTokenSection::Content => SampledToken::Content(entry.token),
364            SampledTokenSection::ToolCall => SampledToken::ToolCall(entry.token),
365            SampledTokenSection::Pending => SampledToken::Undeterminable(entry.token),
366        };
367
368        let visible_piece = if entry.is_boundary {
369            String::new()
370        } else {
371            entry.decoded.clone()
372        };
373
374        IngestOutcome {
375            sampled_token,
376            visible_piece,
377            raw_piece: entry.decoded,
378        }
379    }
380
381    /// # Errors
382    /// Forwards [`LlamaSampler::sample`] errors verbatim. Nothing is recorded on failure.
383    ///
384    /// Returns the raw sampled token (for downstream `batch.add` / `is_eog_token`
385    /// calls) alongside the outcomes that finalised this turn — see
386    /// [`Self::ingest`] for buffering semantics.
387    pub fn sample(
388        &mut self,
389        sampler: &mut LlamaSampler,
390        context: &LlamaContext,
391        idx: i32,
392    ) -> Result<(LlamaToken, Vec<IngestOutcome>), SampleError> {
393        let raw = sampler.sample(context, idx)?;
394        let outcomes = self.ingest(raw);
395
396        Ok((raw, outcomes))
397    }
398
399    /// # Errors
400    /// Forwards [`LlamaBatch::add`] errors verbatim. Nothing is staged on failure.
401    pub fn feed_prompt_to_batch(
402        &mut self,
403        batch: &mut LlamaBatch,
404        token: LlamaToken,
405        position: llama_pos,
406        seq_ids: &[llama_seq_id],
407        logits: bool,
408    ) -> Result<(), BatchAddError> {
409        batch.add(&SampledToken::Content(token), position, seq_ids, logits)?;
410        self.ingest_prompt_token(token);
411        self.pending_prompt_tokens = self.pending_prompt_tokens.saturating_add(1);
412
413        Ok(())
414    }
415
416    /// # Errors
417    /// Forwards [`LlamaBatch::add_sequence`] errors verbatim. Nothing is staged on failure.
418    pub fn feed_prompt_sequence_to_batch(
419        &mut self,
420        batch: &mut LlamaBatch,
421        tokens: &[LlamaToken],
422        seq_id: llama_seq_id,
423        logits_all: bool,
424    ) -> Result<(), BatchAddError> {
425        batch.add_sequence(tokens, seq_id, logits_all)?;
426        self.ingest_prompt_tokens(tokens);
427        self.pending_prompt_tokens = self
428            .pending_prompt_tokens
429            .saturating_add(tokens.len() as u64);
430
431        Ok(())
432    }
433
434    pub const fn commit_prompt_tokens(&mut self) -> u64 {
435        let promoted = self.pending_prompt_tokens;
436        self.usage.record_prompt_tokens(promoted);
437        self.pending_prompt_tokens = 0;
438
439        promoted
440    }
441
442    pub const fn discard_pending_prompt_tokens(&mut self) -> u64 {
443        let discarded = self.pending_prompt_tokens;
444        self.pending_prompt_tokens = 0;
445
446        discarded
447    }
448
449    #[must_use]
450    pub const fn pending_prompt_tokens(&self) -> u64 {
451        self.pending_prompt_tokens
452    }
453
454    /// # Errors
455    /// Returns [`EvalMultimodalChunksError::EvalFailed`] when the underlying
456    /// `eval_chunks` call fails (no counters move),
457    /// [`EvalMultimodalChunksError::UnknownChunkType`] when a chunk reports a
458    /// type unknown to this binding, or
459    /// [`EvalMultimodalChunksError::ChunkOutOfBounds`] when a valid index returns
460    /// `None` from `chunks.get`.
461    #[expect(
462        clippy::too_many_arguments,
463        reason = "thin wrapper over MtmdInputChunks::eval_chunks; parameter shape mirrors the underlying API"
464    )]
465    pub fn eval_multimodal_chunks(
466        &mut self,
467        chunks: &MtmdInputChunks,
468        mtmd_ctx: &MtmdContext,
469        llama_ctx: &LlamaContext,
470        start_position: llama_pos,
471        seq_id: llama_seq_id,
472        n_batch: i32,
473        logits_last: bool,
474    ) -> Result<llama_pos, EvalMultimodalChunksError> {
475        let chunk_count = chunks.len();
476        let mut next_position = start_position;
477
478        for index in 0..chunk_count {
479            let chunk = chunks
480                .get(index)
481                .ok_or(EvalMultimodalChunksError::ChunkOutOfBounds(index))?;
482            let logits_for_this_chunk = logits_last && index + 1 == chunk_count;
483
484            next_position = chunk.eval_single(
485                mtmd_ctx,
486                llama_ctx,
487                next_position,
488                seq_id,
489                n_batch,
490                logits_for_this_chunk,
491            )?;
492            crate::ingest_prompt_chunk::ingest_prompt_chunk(self, &chunk)?;
493        }
494
495        Ok(next_position)
496    }
497
498    pub const fn record_prompt_tokens(&mut self, count: u64) {
499        self.usage.record_prompt_tokens(count);
500    }
501
502    pub const fn record_input_image_tokens(&mut self, count: u64) {
503        self.usage.record_input_image_tokens(count);
504    }
505
506    pub const fn record_input_audio_tokens(&mut self, count: u64) {
507        self.usage.record_input_audio_tokens(count);
508    }
509
510    /// # Errors
511    /// Forwards [`TokenUsageError::CachedExceedsPrompt`] when the running cached total would
512    /// exceed the prompt total.
513    pub const fn record_cached_prompt_tokens(&mut self, count: u64) -> Result<(), TokenUsageError> {
514        self.usage.record_cached_prompt_tokens(count)
515    }
516
517    #[must_use]
518    pub const fn usage(&self) -> &TokenUsage {
519        &self.usage
520    }
521
522    #[must_use]
523    pub fn into_usage(self) -> TokenUsage {
524        self.usage
525    }
526
527    #[must_use]
528    pub const fn current_section(&self) -> SampledTokenSection {
529        self.section
530    }
531
532    #[must_use]
533    pub const fn markers(&self) -> &StreamingMarkers {
534        &self.markers
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::PendingToken;
541    use super::ProbeMode;
542    use super::SampledTokenClassifier;
543    use crate::ingest_outcome::IngestOutcome;
544    use crate::sampled_token::SampledToken;
545    use crate::sampled_token_section::SampledTokenSection;
546    use crate::streaming_markers::StreamingMarkers;
547    use crate::token::LlamaToken;
548
549    fn token(id: i32) -> LlamaToken {
550        LlamaToken::new(id)
551    }
552
553    fn markers_with(
554        reasoning_open: Option<Vec<LlamaToken>>,
555        reasoning_close: Option<Vec<LlamaToken>>,
556    ) -> StreamingMarkers {
557        StreamingMarkers {
558            reasoning_open,
559            reasoning_close,
560            tool_call_open: None,
561            tool_call_close: None,
562        }
563    }
564
565    fn synthetic_classifier(markers: StreamingMarkers) -> SampledTokenClassifier<'static> {
566        SampledTokenClassifier {
567            model: unsafe { &*std::ptr::NonNull::<crate::model::LlamaModel>::dangling().as_ptr() },
568            markers,
569            decoder: encoding_rs::UTF_8.new_decoder(),
570            pending: std::collections::VecDeque::new(),
571            section: SampledTokenSection::Pending,
572            pending_prompt_tokens: 0,
573            usage: llama_cpp_bindings_types::TokenUsage::new(),
574            probe_mode: ProbeMode::Idle,
575        }
576    }
577
578    fn push_pending(classifier: &mut SampledTokenClassifier<'_>, token_id: i32, decoded: &str) {
579        classifier.pending.push_back(PendingToken {
580            token: token(token_id),
581            decoded: decoded.to_owned(),
582            section: classifier.section,
583            is_boundary: false,
584            is_from_prompt: false,
585            is_held_for_probe: false,
586        });
587    }
588
589    fn push_pending_from_prompt(classifier: &mut SampledTokenClassifier<'_>, token_id: i32) {
590        classifier.pending.push_back(PendingToken {
591            token: token(token_id),
592            decoded: String::new(),
593            section: classifier.section,
594            is_boundary: false,
595            is_from_prompt: true,
596            is_held_for_probe: false,
597        });
598    }
599
600    fn push_and_probe(
601        classifier: &mut SampledTokenClassifier<'_>,
602        token_id: i32,
603        decoded: &str,
604    ) -> Vec<IngestOutcome> {
605        push_pending(classifier, token_id, decoded);
606        classifier.try_consume_marker_at_tail();
607        let probe_was_active = matches!(classifier.probe_mode, ProbeMode::Active(_));
608        let mut outcomes = if probe_was_active && classifier.section_disengages_probe() {
609            classifier.abandon_probe()
610        } else {
611            classifier.update_probe(decoded)
612        };
613        outcomes.extend(classifier.drain_overflow());
614        outcomes
615    }
616
617    fn outcome_pieces(outcomes: &[IngestOutcome]) -> Vec<&str> {
618        outcomes
619            .iter()
620            .map(|outcome| outcome.visible_piece.as_str())
621            .collect()
622    }
623
624    fn outcome_sections(outcomes: &[IngestOutcome]) -> Vec<SampledTokenSection> {
625        outcomes
626            .iter()
627            .map(|outcome| match outcome.sampled_token {
628                SampledToken::Reasoning(_) => SampledTokenSection::Reasoning,
629                SampledToken::Content(_) => SampledTokenSection::Content,
630                SampledToken::ToolCall(_) => SampledTokenSection::ToolCall,
631                SampledToken::Undeterminable(_) => SampledTokenSection::Pending,
632            })
633            .collect()
634    }
635
636    #[test]
637    fn single_token_close_marker_when_already_in_reasoning_emits_empty_piece_for_marker() {
638        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
639        let mut classifier = synthetic_classifier(markers);
640        classifier.section = SampledTokenSection::Reasoning;
641
642        push_pending(&mut classifier, 7, "step");
643        classifier.try_consume_marker_at_tail();
644        let mut outcomes = classifier.drain_overflow();
645
646        push_pending(&mut classifier, 200, "</think>");
647        classifier.try_consume_marker_at_tail();
648        outcomes.extend(classifier.drain_overflow());
649
650        push_pending(&mut classifier, 9, "Hi");
651        classifier.try_consume_marker_at_tail();
652        outcomes.extend(classifier.drain_overflow());
653
654        outcomes.extend(classifier.flush());
655
656        assert_eq!(
657            outcome_sections(&outcomes),
658            vec![
659                SampledTokenSection::Reasoning,
660                SampledTokenSection::Reasoning,
661                SampledTokenSection::Content,
662            ],
663        );
664        assert_eq!(outcome_pieces(&outcomes), vec!["step", "", "Hi"]);
665        assert_eq!(classifier.section, SampledTokenSection::Content);
666    }
667
668    #[test]
669    fn multi_token_close_marker_suppresses_every_marker_token() {
670        let markers = markers_with(
671            Some(vec![token(100)]),
672            Some(vec![token(200), token(201), token(202)]),
673        );
674        let mut classifier = synthetic_classifier(markers);
675        classifier.section = SampledTokenSection::Reasoning;
676
677        let mut outcomes = Vec::new();
678        for (id, decoded) in [(7, "r"), (200, "</"), (201, "thi"), (202, "nk>"), (9, "OK")] {
679            push_pending(&mut classifier, id, decoded);
680            classifier.try_consume_marker_at_tail();
681            outcomes.extend(classifier.drain_overflow());
682        }
683        outcomes.extend(classifier.flush());
684
685        assert_eq!(outcome_pieces(&outcomes), vec!["r", "", "", "", "OK"]);
686        assert_eq!(classifier.section, SampledTokenSection::Content);
687    }
688
689    #[test]
690    fn marker_prefix_that_diverges_does_not_suppress_buffered_tokens() {
691        let markers = markers_with(
692            Some(vec![token(100)]),
693            Some(vec![token(200), token(201), token(202)]),
694        );
695        let mut classifier = synthetic_classifier(markers);
696        classifier.section = SampledTokenSection::Reasoning;
697
698        let mut outcomes = Vec::new();
699        for (id, decoded) in [(7, "r"), (200, "a"), (201, "b"), (300, "x")] {
700            push_pending(&mut classifier, id, decoded);
701            classifier.try_consume_marker_at_tail();
702            outcomes.extend(classifier.drain_overflow());
703        }
704        outcomes.extend(classifier.flush());
705
706        assert_eq!(outcome_pieces(&outcomes), vec!["r", "a", "b", "x"]);
707        assert!(
708            outcomes
709                .iter()
710                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Reasoning(_)))
711        );
712        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
713    }
714
715    #[test]
716    fn open_then_close_back_to_back_emits_two_empty_pieces_around_zero_content() {
717        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
718        let mut classifier = synthetic_classifier(markers);
719        classifier.section = SampledTokenSection::Content;
720
721        let mut outcomes = Vec::new();
722        for (id, decoded) in [(100, "<think>"), (200, "</think>"), (9, "Hi")] {
723            push_pending(&mut classifier, id, decoded);
724            classifier.try_consume_marker_at_tail();
725            outcomes.extend(classifier.drain_overflow());
726        }
727        outcomes.extend(classifier.flush());
728
729        assert_eq!(
730            outcome_sections(&outcomes),
731            vec![
732                SampledTokenSection::Reasoning,
733                SampledTokenSection::Reasoning,
734                SampledTokenSection::Content,
735            ],
736        );
737        assert_eq!(outcome_pieces(&outcomes), vec!["", "", "Hi"]);
738        assert_eq!(classifier.section, SampledTokenSection::Content);
739    }
740
741    #[test]
742    fn spurious_reasoning_close_in_content_section_classifies_as_content() {
743        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
744        let mut classifier = synthetic_classifier(markers);
745        classifier.section = SampledTokenSection::Content;
746
747        push_pending(&mut classifier, 200, "</think>");
748        classifier.try_consume_marker_at_tail();
749        let outcomes = classifier.drain_overflow();
750
751        assert_eq!(
752            outcome_sections(&outcomes),
753            vec![SampledTokenSection::Content],
754        );
755        assert_eq!(classifier.section, SampledTokenSection::Content);
756    }
757
758    #[test]
759    fn spurious_tool_call_close_in_reasoning_section_classifies_as_tool_call() {
760        let markers = StreamingMarkers {
761            reasoning_open: Some(vec![token(100)]),
762            reasoning_close: Some(vec![token(200)]),
763            tool_call_open: Some(vec![token(300)]),
764            tool_call_close: Some(vec![token(400)]),
765        };
766        let mut classifier = synthetic_classifier(markers);
767        classifier.section = SampledTokenSection::ToolCall;
768
769        push_pending(&mut classifier, 400, "</tool_call>");
770        classifier.try_consume_marker_at_tail();
771        let outcomes = classifier.drain_overflow();
772
773        assert_eq!(
774            outcome_sections(&outcomes),
775            vec![SampledTokenSection::ToolCall],
776        );
777        assert_eq!(classifier.section, SampledTokenSection::Content);
778    }
779
780    #[test]
781    fn flush_drains_remaining_pending_at_eog() {
782        let markers = markers_with(
783            Some(vec![token(100)]),
784            Some(vec![token(200), token(201), token(202)]),
785        );
786        let mut classifier = synthetic_classifier(markers);
787        classifier.section = SampledTokenSection::Reasoning;
788
789        push_pending(&mut classifier, 7, "abc");
790        push_pending(&mut classifier, 200, "</");
791        push_pending(&mut classifier, 201, "th");
792
793        let outcomes = classifier.flush();
794
795        assert_eq!(outcome_pieces(&outcomes), vec!["abc", "</", "th"]);
796        assert!(classifier.pending.is_empty());
797    }
798
799    #[test]
800    fn no_markers_marks_each_token_undeterminable_with_visible_piece() {
801        let markers = StreamingMarkers::default();
802        let mut classifier = synthetic_classifier(markers);
803
804        push_pending(&mut classifier, 1, "h");
805        push_pending(&mut classifier, 2, "i");
806        let outcomes = classifier.flush();
807
808        assert_eq!(outcome_pieces(&outcomes), vec!["h", "i"]);
809        assert_eq!(
810            outcome_sections(&outcomes),
811            vec![SampledTokenSection::Pending, SampledTokenSection::Pending],
812        );
813    }
814
815    #[test]
816    fn ingest_prompt_tokens_without_markers_is_noop() {
817        let markers = StreamingMarkers::default();
818        let mut classifier = synthetic_classifier(markers);
819
820        push_pending_from_prompt(&mut classifier, 7);
821        push_pending_from_prompt(&mut classifier, 8);
822
823        assert_eq!(classifier.section, SampledTokenSection::Pending);
824        assert_eq!(classifier.usage().reasoning_tokens, 0);
825        assert_eq!(classifier.usage().content_tokens, 0);
826        assert_eq!(classifier.usage().tool_call_tokens, 0);
827        assert_eq!(classifier.usage().undeterminable_tokens, 0);
828    }
829
830    #[test]
831    fn ingest_prompt_tokens_through_open_close_pair_ends_in_content() {
832        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
833        let mut classifier = synthetic_classifier(markers);
834
835        for token_id in [100, 7, 200] {
836            push_pending_from_prompt(&mut classifier, token_id);
837            classifier.try_consume_marker_at_tail();
838            classifier.drain_overflow();
839        }
840
841        assert_eq!(classifier.section, SampledTokenSection::Content);
842        assert_eq!(classifier.usage().reasoning_tokens, 0);
843        assert_eq!(classifier.usage().content_tokens, 0);
844        assert_eq!(classifier.usage().tool_call_tokens, 0);
845        assert_eq!(classifier.usage().undeterminable_tokens, 0);
846    }
847
848    #[test]
849    fn ingest_prompt_tokens_through_open_only_ends_in_reasoning() {
850        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
851        let mut classifier = synthetic_classifier(markers);
852
853        for token_id in [100, 7] {
854            push_pending_from_prompt(&mut classifier, token_id);
855            classifier.try_consume_marker_at_tail();
856            classifier.drain_overflow();
857        }
858
859        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
860        assert_eq!(classifier.usage().reasoning_tokens, 0);
861        assert_eq!(classifier.usage().content_tokens, 0);
862    }
863
864    #[test]
865    fn ingest_prompt_tokens_does_not_record_usage() {
866        let markers = markers_with(
867            Some(vec![token(100)]),
868            Some(vec![token(200), token(201), token(202)]),
869        );
870        let mut classifier = synthetic_classifier(markers);
871
872        for token_id in [100, 7, 8, 9, 200, 201, 202, 11] {
873            push_pending_from_prompt(&mut classifier, token_id);
874            classifier.try_consume_marker_at_tail();
875            classifier.drain_overflow();
876        }
877        let drained = classifier.flush();
878        assert!(drained.is_empty());
879
880        assert_eq!(classifier.usage().reasoning_tokens, 0);
881        assert_eq!(classifier.usage().content_tokens, 0);
882        assert_eq!(classifier.usage().tool_call_tokens, 0);
883        assert_eq!(classifier.usage().undeterminable_tokens, 0);
884    }
885
886    #[test]
887    fn prompt_token_completing_marker_with_generated_token_is_suppressed_correctly() {
888        let markers = markers_with(
889            Some(vec![token(100)]),
890            Some(vec![token(200), token(201), token(202)]),
891        );
892        let mut classifier = synthetic_classifier(markers);
893        classifier.section = SampledTokenSection::Reasoning;
894
895        for token_id in [200, 201] {
896            push_pending_from_prompt(&mut classifier, token_id);
897            classifier.try_consume_marker_at_tail();
898            classifier.drain_overflow();
899        }
900
901        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
902        assert_eq!(classifier.pending.len(), 2);
903
904        classifier.pending.push_back(PendingToken {
905            token: token(202),
906            decoded: "k>".to_owned(),
907            section: classifier.section,
908            is_boundary: false,
909            is_from_prompt: false,
910            is_held_for_probe: false,
911        });
912        classifier.try_consume_marker_at_tail();
913        let outcomes = classifier.drain_overflow();
914
915        assert_eq!(outcomes.len(), 1);
916        assert_eq!(
917            std::mem::discriminant(&outcomes[0].sampled_token),
918            std::mem::discriminant(&SampledToken::Reasoning(LlamaToken::new(0)))
919        );
920        assert_eq!(outcomes[0].visible_piece, "");
921        assert_eq!(outcomes[0].raw_piece, "k>");
922
923        assert_eq!(classifier.section, SampledTokenSection::Content);
924        assert_eq!(classifier.usage().reasoning_tokens, 1);
925        assert_eq!(classifier.usage().content_tokens, 0);
926    }
927
928    #[test]
929    fn ingest_prompt_tokens_with_multiple_round_trips_ends_in_content() {
930        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
931        let mut classifier = synthetic_classifier(markers);
932
933        for token_id in [100, 7, 200, 100, 8, 200] {
934            push_pending_from_prompt(&mut classifier, token_id);
935            classifier.try_consume_marker_at_tail();
936            classifier.drain_overflow();
937        }
938
939        assert_eq!(classifier.section, SampledTokenSection::Content);
940        assert_eq!(classifier.usage().reasoning_tokens, 0);
941        assert_eq!(classifier.usage().content_tokens, 0);
942        assert_eq!(classifier.usage().tool_call_tokens, 0);
943        assert_eq!(classifier.usage().undeterminable_tokens, 0);
944    }
945
946    #[test]
947    fn ingest_prompt_tokens_initial_section_is_always_pending() {
948        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
949        let classifier = synthetic_classifier(markers);
950
951        assert_eq!(classifier.section, SampledTokenSection::Pending);
952    }
953
954    #[test]
955    fn ingest_prompt_tokens_then_drain_for_generated_token_classifies_correctly() {
956        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
957        let mut classifier = synthetic_classifier(markers);
958
959        for token_id in [100, 7, 200] {
960            push_pending_from_prompt(&mut classifier, token_id);
961            classifier.try_consume_marker_at_tail();
962            classifier.drain_overflow();
963        }
964
965        assert_eq!(classifier.section, SampledTokenSection::Content);
966        assert_eq!(classifier.usage().reasoning_tokens, 0);
967        assert_eq!(classifier.usage().content_tokens, 0);
968
969        classifier.pending.push_back(PendingToken {
970            token: token(50),
971            decoded: "hi".to_owned(),
972            section: classifier.section,
973            is_boundary: false,
974            is_from_prompt: false,
975            is_held_for_probe: false,
976        });
977        classifier.try_consume_marker_at_tail();
978        let outcomes = classifier.drain_overflow();
979
980        assert_eq!(outcomes.len(), 1);
981        assert_eq!(
982            std::mem::discriminant(&outcomes[0].sampled_token),
983            std::mem::discriminant(&SampledToken::Content(LlamaToken::new(0)))
984        );
985        assert_eq!(outcomes[0].visible_piece, "hi");
986        assert_eq!(classifier.usage().content_tokens, 1);
987        assert_eq!(classifier.usage().reasoning_tokens, 0);
988        assert_eq!(classifier.usage().undeterminable_tokens, 0);
989    }
990
991    #[test]
992    fn close_marker_in_content_section_is_suppressed_as_boundary() {
993        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
994        let mut classifier = synthetic_classifier(markers);
995        classifier.section = SampledTokenSection::Content;
996
997        let mut outcomes = Vec::new();
998        for (id, decoded) in [(7, "hi"), (200, "</think>"), (8, "ok")] {
999            push_pending(&mut classifier, id, decoded);
1000            classifier.try_consume_marker_at_tail();
1001            outcomes.extend(classifier.drain_overflow());
1002        }
1003        outcomes.extend(classifier.flush());
1004
1005        assert_eq!(
1006            outcome_sections(&outcomes),
1007            vec![
1008                SampledTokenSection::Content,
1009                SampledTokenSection::Content,
1010                SampledTokenSection::Content,
1011            ],
1012        );
1013        assert_eq!(outcome_pieces(&outcomes), vec!["hi", "", "ok"]);
1014        assert_eq!(classifier.section, SampledTokenSection::Content);
1015    }
1016
1017    #[test]
1018    fn open_marker_in_reasoning_section_is_suppressed_as_boundary() {
1019        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1020        let mut classifier = synthetic_classifier(markers);
1021        classifier.section = SampledTokenSection::Reasoning;
1022
1023        let mut outcomes = Vec::new();
1024        for (id, decoded) in [(7, "step1"), (100, "<think>"), (8, "step2")] {
1025            push_pending(&mut classifier, id, decoded);
1026            classifier.try_consume_marker_at_tail();
1027            outcomes.extend(classifier.drain_overflow());
1028        }
1029        outcomes.extend(classifier.flush());
1030
1031        assert_eq!(outcome_pieces(&outcomes), vec!["step1", "", "step2"]);
1032        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
1033    }
1034
1035    #[test]
1036    fn record_prompt_tokens_updates_usage() {
1037        let markers = markers_with(None, None);
1038        let mut classifier = synthetic_classifier(markers);
1039
1040        classifier.record_prompt_tokens(7);
1041
1042        assert_eq!(classifier.usage().prompt_tokens, 7);
1043    }
1044
1045    #[test]
1046    fn record_cached_prompt_tokens_updates_usage_when_under_limit() {
1047        let markers = markers_with(None, None);
1048        let mut classifier = synthetic_classifier(markers);
1049        classifier.record_prompt_tokens(10);
1050
1051        classifier.record_cached_prompt_tokens(3).unwrap();
1052
1053        assert_eq!(classifier.usage().cached_prompt_tokens, 3);
1054    }
1055
1056    #[test]
1057    fn record_cached_prompt_tokens_returns_error_when_over_prompt_total() {
1058        let markers = markers_with(None, None);
1059        let mut classifier = synthetic_classifier(markers);
1060        classifier.record_prompt_tokens(2);
1061
1062        let result = classifier.record_cached_prompt_tokens(5);
1063
1064        assert!(result.is_err());
1065    }
1066
1067    #[test]
1068    fn markers_accessor_returns_configured_markers() {
1069        let configured = markers_with(Some(vec![token(1)]), Some(vec![token(2)]));
1070        let classifier = synthetic_classifier(configured);
1071
1072        let returned = classifier.markers();
1073
1074        assert_eq!(returned.reasoning_open.as_deref(), Some(&[token(1)][..]));
1075        assert_eq!(returned.reasoning_close.as_deref(), Some(&[token(2)][..]));
1076    }
1077
1078    #[test]
1079    fn into_usage_consumes_classifier_and_yields_usage_snapshot() {
1080        let markers = markers_with(None, None);
1081        let mut classifier = synthetic_classifier(markers);
1082        classifier.record_prompt_tokens(11);
1083
1084        let usage = classifier.into_usage();
1085
1086        assert_eq!(usage.prompt_tokens, 11);
1087    }
1088
1089    #[test]
1090    fn spurious_tool_call_close_in_content_section_classifies_as_content() {
1091        let mut markers = markers_with(None, None);
1092        markers.tool_call_close = Some(vec![token(300)]);
1093        let mut classifier = synthetic_classifier(markers);
1094        classifier.section = SampledTokenSection::Content;
1095
1096        push_pending(&mut classifier, 300, "</tool_call>");
1097        classifier.try_consume_marker_at_tail();
1098        let outcomes = classifier.drain_overflow();
1099
1100        assert_eq!(
1101            outcome_sections(&outcomes),
1102            vec![SampledTokenSection::Content],
1103        );
1104        assert_eq!(classifier.section, SampledTokenSection::Content);
1105    }
1106
1107    fn markers_with_tool_call_open(tool_call_open: Vec<LlamaToken>) -> StreamingMarkers {
1108        StreamingMarkers {
1109            reasoning_open: None,
1110            reasoning_close: None,
1111            tool_call_open: Some(tool_call_open),
1112            tool_call_close: None,
1113        }
1114    }
1115
1116    fn feed_json_string(
1117        classifier: &mut SampledTokenClassifier<'_>,
1118        text: &str,
1119        starting_token_id: i32,
1120    ) -> Vec<IngestOutcome> {
1121        let mut outcomes = Vec::new();
1122        for (offset, ch) in text.char_indices() {
1123            let token_id = starting_token_id + i32::try_from(offset).unwrap_or(i32::MAX);
1124            let mut buffer = [0_u8; 4];
1125            let chunk = ch.encode_utf8(&mut buffer);
1126            outcomes.extend(push_and_probe(classifier, token_id, chunk));
1127        }
1128        outcomes
1129    }
1130
1131    #[test]
1132    fn json_probe_engages_when_first_non_whitespace_is_open_brace_in_content() {
1133        let markers = markers_with_tool_call_open(vec![token(900)]);
1134        let mut classifier = synthetic_classifier(markers);
1135        classifier.section = SampledTokenSection::Content;
1136
1137        push_and_probe(&mut classifier, 1, "{");
1138
1139        assert_ne!(classifier.probe_mode, ProbeMode::Idle);
1140    }
1141
1142    #[test]
1143    fn json_probe_releases_tokens_as_tool_call_when_signature_matches() {
1144        let markers = markers_with_tool_call_open(vec![token(900)]);
1145        let mut classifier = synthetic_classifier(markers);
1146        classifier.section = SampledTokenSection::Content;
1147
1148        let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":{}}"#, 100);
1149
1150        assert!(!outcomes.is_empty());
1151        assert!(
1152            outcomes
1153                .iter()
1154                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1155            "every emitted outcome should be ToolCall, got {:?}",
1156            outcome_sections(&outcomes),
1157        );
1158        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1159    }
1160
1161    #[test]
1162    fn json_probe_releases_tokens_as_content_when_signature_does_not_match() {
1163        let markers = markers_with_tool_call_open(vec![token(900)]);
1164        let mut classifier = synthetic_classifier(markers);
1165        classifier.section = SampledTokenSection::Content;
1166
1167        let outcomes = feed_json_string(&mut classifier, r#"{"foo":"bar"}"#, 100);
1168
1169        assert!(
1170            outcomes
1171                .iter()
1172                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1173            "every emitted outcome should be Content, got {:?}",
1174            outcome_sections(&outcomes),
1175        );
1176        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1177    }
1178
1179    #[test]
1180    fn json_probe_releases_tokens_as_content_when_extra_top_level_key() {
1181        let markers = markers_with_tool_call_open(vec![token(900)]);
1182        let mut classifier = synthetic_classifier(markers);
1183        classifier.section = SampledTokenSection::Content;
1184
1185        let outcomes = feed_json_string(
1186            &mut classifier,
1187            r#"{"name":"f","arguments":{},"extra":1}"#,
1188            100,
1189        );
1190
1191        assert!(
1192            outcomes
1193                .iter()
1194                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1195        );
1196    }
1197
1198    #[test]
1199    fn json_probe_releases_tokens_as_content_when_arguments_is_not_object() {
1200        let markers = markers_with_tool_call_open(vec![token(900)]);
1201        let mut classifier = synthetic_classifier(markers);
1202        classifier.section = SampledTokenSection::Content;
1203
1204        let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":"hi"}"#, 100);
1205
1206        assert!(
1207            outcomes
1208                .iter()
1209                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1210        );
1211    }
1212
1213    #[test]
1214    fn json_probe_handles_strings_with_quoted_braces_in_arguments() {
1215        let markers = markers_with_tool_call_open(vec![token(900)]);
1216        let mut classifier = synthetic_classifier(markers);
1217        classifier.section = SampledTokenSection::Content;
1218
1219        let outcomes = feed_json_string(
1220            &mut classifier,
1221            r#"{"name":"f","arguments":{"q":"a } b"}}"#,
1222            100,
1223        );
1224
1225        assert!(
1226            outcomes
1227                .iter()
1228                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1229        );
1230    }
1231
1232    #[test]
1233    fn json_probe_handles_escaped_quotes_in_string_values() {
1234        let markers = markers_with_tool_call_open(vec![token(900)]);
1235        let mut classifier = synthetic_classifier(markers);
1236        classifier.section = SampledTokenSection::Content;
1237
1238        let outcomes = feed_json_string(
1239            &mut classifier,
1240            r#"{"name":"f","arguments":{"q":"he said \"hi\""}}"#,
1241            100,
1242        );
1243
1244        assert!(
1245            outcomes
1246                .iter()
1247                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1248        );
1249    }
1250
1251    #[test]
1252    fn json_probe_handles_unicode_letters_in_strings() {
1253        let markers = markers_with_tool_call_open(vec![token(900)]);
1254        let mut classifier = synthetic_classifier(markers);
1255        classifier.section = SampledTokenSection::Content;
1256
1257        let outcomes = feed_json_string(
1258            &mut classifier,
1259            r#"{"name":"日本語","arguments":{"city":"パリ"}}"#,
1260            100,
1261        );
1262
1263        assert!(
1264            outcomes
1265                .iter()
1266                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1267        );
1268    }
1269
1270    #[test]
1271    fn json_probe_handles_nested_objects() {
1272        let markers = markers_with_tool_call_open(vec![token(900)]);
1273        let mut classifier = synthetic_classifier(markers);
1274        classifier.section = SampledTokenSection::Content;
1275
1276        let outcomes = feed_json_string(
1277            &mut classifier,
1278            r#"{"name":"f","arguments":{"a":{"b":{"c":1}}}}"#,
1279            100,
1280        );
1281
1282        assert!(
1283            outcomes
1284                .iter()
1285                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1286        );
1287    }
1288
1289    #[test]
1290    fn json_probe_handles_arrays_inside_arguments() {
1291        let markers = markers_with_tool_call_open(vec![token(900)]);
1292        let mut classifier = synthetic_classifier(markers);
1293        classifier.section = SampledTokenSection::Content;
1294
1295        let outcomes = feed_json_string(
1296            &mut classifier,
1297            r#"{"name":"f","arguments":{"items":[1,2,3]}}"#,
1298            100,
1299        );
1300
1301        assert!(
1302            outcomes
1303                .iter()
1304                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1305        );
1306    }
1307
1308    #[test]
1309    fn json_probe_does_not_engage_when_first_byte_is_close_brace() {
1310        let markers = markers_with_tool_call_open(vec![token(900)]);
1311        let mut classifier = synthetic_classifier(markers);
1312        classifier.section = SampledTokenSection::Content;
1313
1314        let outcomes = feed_json_string(&mut classifier, "}}", 100);
1315
1316        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1317        assert!(
1318            outcomes
1319                .iter()
1320                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1321        );
1322    }
1323
1324    #[test]
1325    fn json_probe_does_not_engage_in_reasoning_section() {
1326        let markers = StreamingMarkers {
1327            reasoning_open: Some(vec![token(800)]),
1328            reasoning_close: Some(vec![token(801)]),
1329            tool_call_open: Some(vec![token(900)]),
1330            tool_call_close: None,
1331        };
1332        let mut classifier = synthetic_classifier(markers);
1333        classifier.section = SampledTokenSection::Reasoning;
1334
1335        push_and_probe(&mut classifier, 1, "{");
1336
1337        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1338    }
1339
1340    #[test]
1341    fn json_probe_does_not_engage_in_tool_call_section() {
1342        let markers = markers_with_tool_call_open(vec![token(900)]);
1343        let mut classifier = synthetic_classifier(markers);
1344        classifier.section = SampledTokenSection::ToolCall;
1345
1346        push_and_probe(&mut classifier, 1, "{");
1347
1348        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1349    }
1350
1351    #[test]
1352    fn marker_probe_takes_precedence_when_both_could_match() {
1353        let markers = markers_with_tool_call_open(vec![token(900)]);
1354        let mut classifier = synthetic_classifier(markers);
1355        classifier.section = SampledTokenSection::Content;
1356
1357        let mut outcomes = Vec::new();
1358        outcomes.extend(push_and_probe(&mut classifier, 1, "{"));
1359        outcomes.extend(push_and_probe(&mut classifier, 900, r#"""#));
1360
1361        assert_eq!(classifier.section, SampledTokenSection::ToolCall);
1362        assert_eq!(outcome_pieces(&outcomes), vec!["{", ""]);
1363        assert_eq!(
1364            outcome_sections(&outcomes),
1365            vec![SampledTokenSection::Content, SampledTokenSection::ToolCall],
1366        );
1367    }
1368
1369    #[test]
1370    fn json_probe_consumes_two_consecutive_objects_separately() {
1371        let markers = markers_with_tool_call_open(vec![token(900)]);
1372        let mut classifier = synthetic_classifier(markers);
1373        classifier.section = SampledTokenSection::Content;
1374
1375        let mut outcomes = Vec::new();
1376        outcomes.extend(feed_json_string(
1377            &mut classifier,
1378            r#"{"name":"a","arguments":{}}"#,
1379            100,
1380        ));
1381        outcomes.extend(feed_json_string(
1382            &mut classifier,
1383            r#"{"name":"b","arguments":{"x":1}}"#,
1384            200,
1385        ));
1386
1387        assert!(
1388            outcomes
1389                .iter()
1390                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1391            "two consecutive markerless tool calls must both classify as ToolCall, got {:?}",
1392            outcome_sections(&outcomes),
1393        );
1394    }
1395
1396    #[test]
1397    fn json_probe_with_leading_whitespace_then_open_brace_classifies_whitespace_as_content_and_json_as_tool_call()
1398     {
1399        let markers = markers_with_tool_call_open(vec![token(900)]);
1400        let mut classifier = synthetic_classifier(markers);
1401        classifier.section = SampledTokenSection::Content;
1402
1403        let outcomes = feed_json_string(
1404            &mut classifier,
1405            "\n  {\"name\":\"f\",\"arguments\":{}}",
1406            100,
1407        );
1408
1409        let tool_call_count = outcomes
1410            .iter()
1411            .filter(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_)))
1412            .count();
1413        let content_count = outcomes
1414            .iter()
1415            .filter(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_)))
1416            .count();
1417        assert_eq!(
1418            content_count, 3,
1419            "leading `\\n  ` should classify as content"
1420        );
1421        assert!(
1422            tool_call_count > 0,
1423            "the JSON object should classify as ToolCall",
1424        );
1425        assert_eq!(content_count + tool_call_count, outcomes.len());
1426    }
1427
1428    #[test]
1429    fn json_probe_records_tool_call_token_usage_on_commit() {
1430        let markers = markers_with_tool_call_open(vec![token(900)]);
1431        let mut classifier = synthetic_classifier(markers);
1432        classifier.section = SampledTokenSection::Content;
1433
1434        let json = r#"{"name":"f","arguments":{}}"#;
1435        let outcomes = feed_json_string(&mut classifier, json, 100);
1436
1437        let emitted = outcomes.len();
1438        let usage = classifier.usage();
1439        assert_eq!(usage.tool_call_tokens, emitted as u64);
1440        assert_eq!(usage.content_tokens, 0);
1441    }
1442
1443    #[test]
1444    fn json_probe_records_content_token_usage_on_abandon() {
1445        let markers = markers_with_tool_call_open(vec![token(900)]);
1446        let mut classifier = synthetic_classifier(markers);
1447        classifier.section = SampledTokenSection::Content;
1448
1449        let json = r#"{"foo":"bar"}"#;
1450        let outcomes = feed_json_string(&mut classifier, json, 100);
1451
1452        let emitted = outcomes.len();
1453        let usage = classifier.usage();
1454        assert_eq!(usage.content_tokens, emitted as u64);
1455        assert_eq!(usage.tool_call_tokens, 0);
1456    }
1457
1458    #[test]
1459    fn flush_during_active_json_probe_releases_held_tokens_as_content() {
1460        let markers = markers_with_tool_call_open(vec![token(900)]);
1461        let mut classifier = synthetic_classifier(markers);
1462        classifier.section = SampledTokenSection::Content;
1463
1464        push_and_probe(&mut classifier, 1, "{");
1465        push_and_probe(&mut classifier, 2, r#""name""#);
1466        assert_ne!(classifier.probe_mode, ProbeMode::Idle);
1467
1468        let outcomes = classifier.flush();
1469
1470        assert!(
1471            outcomes
1472                .iter()
1473                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1474            "mid-probe flush must release held tokens as Content, got {:?}",
1475            outcome_sections(&outcomes),
1476        );
1477        assert_eq!(classifier.probe_mode, ProbeMode::Idle);
1478    }
1479}