Skip to main content

llama_cpp_bindings/
sampled_token_classifier.rs

1use std::collections::VecDeque;
2
3use llama_cpp_bindings_sys::llama_pos;
4use llama_cpp_bindings_sys::llama_seq_id;
5
6use llama_cpp_bindings_types::TokenUsage;
7use llama_cpp_bindings_types::TokenUsageError;
8
9use crate::batch_add_error::BatchAddError;
10use crate::context::LlamaContext;
11use crate::error::EvalMultimodalChunksError;
12use crate::error::SampleError;
13use crate::llama_batch::LlamaBatch;
14use crate::model::LlamaModel;
15use crate::mtmd::MtmdContext;
16use crate::mtmd::MtmdInputChunks;
17use crate::sampled_token::SampledToken;
18use crate::sampling::LlamaSampler;
19use crate::streaming_json_probe::JsonProbeOutcome;
20use crate::token::LlamaToken;
21
22#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum SampledTokenSection {
24    Pending,
25    Content,
26    Reasoning,
27    ToolCall,
28}
29
30#[derive(Copy, Clone, Debug, Eq, PartialEq)]
31enum MarkerKind {
32    ReasoningOpen,
33    ReasoningClose,
34    ToolCallOpen,
35    ToolCallClose,
36}
37
38/// Tokenized marker sequences (token IDs, not strings).
39///
40/// Each marker is a `Vec<LlamaToken>` of length `>= 1`; absent markers are
41/// `None`. Sequence matching at every `ingest()` is by token-ID equality,
42/// never by substring scanning of decoded text.
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct StreamingMarkers {
45    pub reasoning_open: Option<Vec<LlamaToken>>,
46    pub reasoning_close: Option<Vec<LlamaToken>>,
47    pub tool_call_open: Option<Vec<LlamaToken>>,
48    pub tool_call_close: Option<Vec<LlamaToken>>,
49}
50
51impl StreamingMarkers {
52    const fn has_any(&self) -> bool {
53        self.reasoning_open.is_some()
54            || self.reasoning_close.is_some()
55            || self.tool_call_open.is_some()
56            || self.tool_call_close.is_some()
57    }
58
59    fn max_token_len(&self) -> usize {
60        [
61            self.reasoning_open.as_deref(),
62            self.reasoning_close.as_deref(),
63            self.tool_call_open.as_deref(),
64            self.tool_call_close.as_deref(),
65        ]
66        .into_iter()
67        .flatten()
68        .map(<[LlamaToken]>::len)
69        .max()
70        .unwrap_or(0)
71    }
72
73    fn lookup(&self, kind: MarkerKind) -> Option<&[LlamaToken]> {
74        match kind {
75            MarkerKind::ReasoningOpen => self.reasoning_open.as_deref(),
76            MarkerKind::ReasoningClose => self.reasoning_close.as_deref(),
77            MarkerKind::ToolCallOpen => self.tool_call_open.as_deref(),
78            MarkerKind::ToolCallClose => self.tool_call_close.as_deref(),
79        }
80    }
81}
82
83#[derive(Clone, Debug)]
84pub struct IngestOutcome {
85    pub sampled_token: SampledToken,
86    /// Empty when the token is part of a recognised marker boundary; otherwise
87    /// the decoded UTF-8 piece. Callers should stream `visible_piece` and skip
88    /// emission when it is empty.
89    pub visible_piece: String,
90    /// Always the decoded UTF-8 piece, even for marker-boundary tokens. Useful
91    /// for accumulating the full raw model output (e.g. for downstream parser
92    /// cross-checks) without losing marker bytes.
93    pub raw_piece: String,
94}
95
96#[derive(Clone, Debug)]
97struct PendingToken {
98    token: LlamaToken,
99    decoded: String,
100    section: SampledTokenSection,
101    is_boundary: bool,
102    is_from_prompt: bool,
103    is_held_for_probe: bool,
104}
105
106#[derive(Clone, Debug)]
107struct JsonProbeState {
108    held_text: String,
109}
110
111#[derive(Clone, Debug)]
112enum ProbeMode {
113    Idle,
114    Active(JsonProbeState),
115}
116
117pub struct SampledTokenClassifier<'model> {
118    model: &'model LlamaModel,
119    markers: StreamingMarkers,
120    decoder: encoding_rs::Decoder,
121    pending: VecDeque<PendingToken>,
122    section: SampledTokenSection,
123    pending_prompt_tokens: u64,
124    usage: TokenUsage,
125    probe_mode: ProbeMode,
126}
127
128impl<'model> SampledTokenClassifier<'model> {
129    #[must_use]
130    pub fn new(model: &'model LlamaModel, markers: StreamingMarkers) -> Self {
131        Self {
132            model,
133            markers,
134            decoder: encoding_rs::UTF_8.new_decoder(),
135            pending: VecDeque::new(),
136            section: SampledTokenSection::Pending,
137            pending_prompt_tokens: 0,
138            usage: TokenUsage::new(),
139            probe_mode: ProbeMode::Idle,
140        }
141    }
142
143    /// Ingest one sampled token. Returns the outcomes that have finalised this
144    /// turn — typically a single outcome, occasionally zero (the classifier is
145    /// holding back tokens that may yet form a marker), or several when a
146    /// buffered marker prefix diverges and the held-back tokens flush.
147    ///
148    /// Each [`IngestOutcome`] carries both the [`SampledToken`] variant for
149    /// classification and the decoded `visible_piece` for streaming. Marker
150    /// boundaries get an empty `visible_piece` so their text never reaches
151    /// user-visible streams.
152    pub fn ingest(&mut self, token: LlamaToken) -> Vec<IngestOutcome> {
153        if !self.markers.has_any() {
154            self.usage.record_undeterminable_token();
155            let piece = self.decode(token);
156            return vec![IngestOutcome {
157                sampled_token: SampledToken::Undeterminable(token),
158                visible_piece: piece.clone(),
159                raw_piece: piece,
160            }];
161        }
162
163        let decoded = self.decode(token);
164        self.pending.push_back(PendingToken {
165            token,
166            decoded: decoded.clone(),
167            section: self.section,
168            is_boundary: false,
169            is_from_prompt: false,
170            is_held_for_probe: false,
171        });
172
173        self.try_consume_marker_at_tail();
174
175        let probe_was_active = matches!(self.probe_mode, ProbeMode::Active(_));
176        let mut outcomes = if probe_was_active && self.section_disengages_probe() {
177            self.abandon_probe()
178        } else {
179            self.update_probe(&decoded)
180        };
181
182        outcomes.extend(self.drain_overflow());
183        outcomes
184    }
185
186    const fn section_disengages_probe(&self) -> bool {
187        matches!(
188            self.section,
189            SampledTokenSection::ToolCall | SampledTokenSection::Reasoning
190        )
191    }
192
193    /// Replay one prompt token through the marker state machine so that the
194    /// section at end-of-prompt reflects the chat template's rendered tail
195    /// (e.g. for Qwen3.5/3.6 with `enable_thinking=false` the prompt ends with
196    /// a closed empty `<think>...</think>` block, leaving the section in
197    /// `Content`; with `enable_thinking=true` it ends inside an open `<think>`,
198    /// leaving the section in `Reasoning`).
199    ///
200    /// Prompt tokens never produce [`IngestOutcome`]s and never increment usage
201    /// counters — they are not generated content.
202    pub fn ingest_prompt_token(&mut self, token: LlamaToken) {
203        if !self.markers.has_any() {
204            return;
205        }
206
207        self.pending.push_back(PendingToken {
208            token,
209            decoded: String::new(),
210            section: self.section,
211            is_boundary: false,
212            is_from_prompt: true,
213            is_held_for_probe: false,
214        });
215
216        self.try_consume_marker_at_tail();
217        self.drain_overflow();
218    }
219
220    pub fn ingest_prompt_tokens(&mut self, tokens: &[LlamaToken]) {
221        if !self.markers.has_any() {
222            return;
223        }
224        for &token in tokens {
225            self.ingest_prompt_token(token);
226        }
227    }
228
229    /// Drain every still-buffered token. Call once at end of generation (EOG)
230    /// to make sure no decoded text is silently dropped. After `flush()` the
231    /// classifier behaves as if freshly constructed in terms of buffer state.
232    pub fn flush(&mut self) -> Vec<IngestOutcome> {
233        self.probe_mode = ProbeMode::Idle;
234        let mut outcomes = Vec::with_capacity(self.pending.len());
235        while let Some(entry) = self.pending.pop_front() {
236            if entry.is_from_prompt {
237                continue;
238            }
239            outcomes.push(self.finalize_entry(entry));
240        }
241        outcomes
242    }
243
244    fn decode(&mut self, token: LlamaToken) -> String {
245        match self.model.token_to_piece(
246            &SampledToken::Content(token),
247            &mut self.decoder,
248            true,
249            None,
250        ) {
251            Ok(piece) => piece,
252            Err(detokenize_error) => {
253                tracing::debug!(
254                    "token_to_piece failed during classification, dropping piece: {detokenize_error}"
255                );
256                String::new()
257            }
258        }
259    }
260
261    fn try_consume_marker_at_tail(&mut self) {
262        // Probe every marker in every section so the user-visible streams stay
263        // free of marker text even when the model misbehaves: a stray
264        // `</think>` / `<channel|>` / `[/THINK]` while in `Content` is
265        // suppressed (close markers transition to Content — a no-op when
266        // already there); a nested `<think>` while in `Reasoning` is also
267        // suppressed (open markers keep the section in Reasoning). Without
268        // this, models like Gemma 4 E4B that emit close markers without ever
269        // opening leak the literal marker text into `content_stream`.
270        const PROBE_KINDS: &[MarkerKind] = &[
271            MarkerKind::ReasoningOpen,
272            MarkerKind::ReasoningClose,
273            MarkerKind::ToolCallOpen,
274            MarkerKind::ToolCallClose,
275        ];
276
277        for &kind in PROBE_KINDS {
278            let Some(marker) = self.markers.lookup(kind) else {
279                continue;
280            };
281            if marker.is_empty() || self.pending.len() < marker.len() {
282                continue;
283            }
284            let span_start = self.pending.len() - marker.len();
285            let matches = self
286                .pending
287                .iter()
288                .skip(span_start)
289                .zip(marker)
290                .all(|(entry, marker_token)| entry.token == *marker_token);
291            if matches {
292                self.mark_marker_span(span_start, kind);
293                return;
294            }
295        }
296    }
297
298    fn mark_marker_span(&mut self, span_start: usize, kind: MarkerKind) {
299        let next_section = match kind {
300            MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
301            MarkerKind::ReasoningClose | MarkerKind::ToolCallClose => SampledTokenSection::Content,
302            MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
303        };
304        // For open markers, the boundary tokens are classified as the destination
305        // section — they are the marker itself (`<think>` is part of reasoning,
306        // `<tool_call>` is part of the tool-call protocol). For close markers,
307        // the boundary tokens are classified as the section the model was in:
308        // a normal `</think>` while in `Reasoning` is still reasoning, but a
309        // spurious `</think>` while in `Content` (e.g. some Gemma variants
310        // re-emit close markers without ever opening) is just noise in the
311        // content section — counting it as `Reasoning` would inflate
312        // `observed_reasoning` and falsely indicate the model thought.
313        let span_section = match kind {
314            MarkerKind::ReasoningOpen => SampledTokenSection::Reasoning,
315            MarkerKind::ToolCallOpen => SampledTokenSection::ToolCall,
316            MarkerKind::ReasoningClose => {
317                if self.section == SampledTokenSection::Reasoning {
318                    SampledTokenSection::Reasoning
319                } else {
320                    SampledTokenSection::Content
321                }
322            }
323            MarkerKind::ToolCallClose => {
324                if self.section == SampledTokenSection::ToolCall {
325                    SampledTokenSection::ToolCall
326                } else {
327                    SampledTokenSection::Content
328                }
329            }
330        };
331
332        for entry in self.pending.iter_mut().skip(span_start) {
333            entry.is_boundary = true;
334            entry.section = span_section;
335        }
336
337        self.section = next_section;
338    }
339
340    fn drain_overflow(&mut self) -> Vec<IngestOutcome> {
341        let lookback = self.markers.max_token_len().saturating_sub(1);
342        let mut outcomes = Vec::new();
343
344        loop {
345            let Some(front) = self.pending.front() else {
346                break;
347            };
348            if front.is_held_for_probe {
349                break;
350            }
351            let probe_held = self
352                .pending
353                .iter()
354                .filter(|entry| entry.is_held_for_probe)
355                .count();
356            let drainable = self.pending.len().saturating_sub(probe_held);
357            let beyond_lookback = drainable > lookback;
358            if !front.is_boundary && !beyond_lookback {
359                break;
360            }
361            let Some(entry) = self.pending.pop_front() else {
362                break;
363            };
364            if entry.is_from_prompt {
365                continue;
366            }
367            outcomes.push(self.finalize_entry(entry));
368        }
369
370        outcomes
371    }
372
373    fn update_probe(&mut self, piece: &str) -> Vec<IngestOutcome> {
374        let probe_active = matches!(self.probe_mode, ProbeMode::Active(_));
375        if !probe_active {
376            if !self.section_allows_probe_engagement() {
377                return Vec::new();
378            }
379            if !piece.trim_start().starts_with('{') {
380                return Vec::new();
381            }
382            if let Some(entry) = self.pending.back_mut() {
383                entry.is_held_for_probe = true;
384            }
385            self.probe_mode = ProbeMode::Active(JsonProbeState {
386                held_text: piece.to_owned(),
387            });
388            return self.evaluate_probe();
389        }
390
391        if let Some(entry) = self.pending.back_mut() {
392            entry.is_held_for_probe = true;
393        }
394        if let ProbeMode::Active(state) = &mut self.probe_mode {
395            state.held_text.push_str(piece);
396        }
397        self.evaluate_probe()
398    }
399
400    const fn section_allows_probe_engagement(&self) -> bool {
401        matches!(
402            self.section,
403            SampledTokenSection::Content | SampledTokenSection::Pending
404        )
405    }
406
407    fn evaluate_probe(&mut self) -> Vec<IngestOutcome> {
408        let outcome = match &self.probe_mode {
409            ProbeMode::Active(state) => JsonProbeOutcome::validate_prefix(&state.held_text),
410            ProbeMode::Idle => return Vec::new(),
411        };
412        match outcome {
413            JsonProbeOutcome::StillPossiblyValid => Vec::new(),
414            JsonProbeOutcome::CompletedValid => self.commit_probe_as_tool_call(),
415            JsonProbeOutcome::Failed => self.abandon_probe(),
416        }
417    }
418
419    fn commit_probe_as_tool_call(&mut self) -> Vec<IngestOutcome> {
420        if !matches!(self.probe_mode, ProbeMode::Active(_)) {
421            return Vec::new();
422        }
423        self.probe_mode = ProbeMode::Idle;
424        self.section = SampledTokenSection::Content;
425
426        let drained: Vec<_> = self.pending.drain(..).collect();
427        let mut outcomes = Vec::new();
428        for mut entry in drained {
429            if entry.is_held_for_probe {
430                entry.section = SampledTokenSection::ToolCall;
431                entry.is_held_for_probe = false;
432                if !entry.is_from_prompt {
433                    outcomes.push(self.finalize_entry(entry));
434                }
435            } else {
436                self.pending.push_back(entry);
437            }
438        }
439        outcomes
440    }
441
442    fn abandon_probe(&mut self) -> Vec<IngestOutcome> {
443        if !matches!(self.probe_mode, ProbeMode::Active(_)) {
444            return Vec::new();
445        }
446        self.probe_mode = ProbeMode::Idle;
447
448        let drained: Vec<_> = self.pending.drain(..).collect();
449        let mut outcomes = Vec::new();
450        for mut entry in drained {
451            if entry.is_held_for_probe {
452                entry.is_held_for_probe = false;
453                if !entry.is_from_prompt {
454                    outcomes.push(self.finalize_entry(entry));
455                }
456            } else {
457                self.pending.push_back(entry);
458            }
459        }
460        outcomes
461    }
462
463    fn finalize_entry(&mut self, entry: PendingToken) -> IngestOutcome {
464        let section = entry.section;
465        match section {
466            SampledTokenSection::Reasoning => self.usage.record_reasoning_token(),
467            SampledTokenSection::Content => self.usage.record_content_token(),
468            SampledTokenSection::ToolCall => self.usage.record_tool_call_token(),
469            SampledTokenSection::Pending => self.usage.record_undeterminable_token(),
470        }
471
472        let sampled_token = match section {
473            SampledTokenSection::Reasoning => SampledToken::Reasoning(entry.token),
474            SampledTokenSection::Content => SampledToken::Content(entry.token),
475            SampledTokenSection::ToolCall => SampledToken::ToolCall(entry.token),
476            SampledTokenSection::Pending => SampledToken::Undeterminable(entry.token),
477        };
478
479        let visible_piece = if entry.is_boundary {
480            String::new()
481        } else {
482            entry.decoded.clone()
483        };
484
485        IngestOutcome {
486            sampled_token,
487            visible_piece,
488            raw_piece: entry.decoded,
489        }
490    }
491
492    /// # Errors
493    /// Forwards [`LlamaSampler::sample`] errors verbatim. Nothing is recorded on failure.
494    ///
495    /// Returns the raw sampled token (for downstream `batch.add` / `is_eog_token`
496    /// calls) alongside the outcomes that finalised this turn — see
497    /// [`Self::ingest`] for buffering semantics.
498    pub fn sample(
499        &mut self,
500        sampler: &mut LlamaSampler,
501        context: &LlamaContext,
502        idx: i32,
503    ) -> Result<(LlamaToken, Vec<IngestOutcome>), SampleError> {
504        let raw = sampler.sample(context, idx)?;
505        let outcomes = self.ingest(raw);
506
507        Ok((raw, outcomes))
508    }
509
510    /// # Errors
511    /// Forwards [`LlamaBatch::add`] errors verbatim. Nothing is staged on failure.
512    pub fn feed_prompt_to_batch(
513        &mut self,
514        batch: &mut LlamaBatch,
515        token: LlamaToken,
516        position: llama_pos,
517        seq_ids: &[llama_seq_id],
518        logits: bool,
519    ) -> Result<(), BatchAddError> {
520        batch.add(&SampledToken::Content(token), position, seq_ids, logits)?;
521        self.ingest_prompt_token(token);
522        self.pending_prompt_tokens = self.pending_prompt_tokens.saturating_add(1);
523
524        Ok(())
525    }
526
527    /// # Errors
528    /// Forwards [`LlamaBatch::add_sequence`] errors verbatim. Nothing is staged on failure.
529    pub fn feed_prompt_sequence_to_batch(
530        &mut self,
531        batch: &mut LlamaBatch,
532        tokens: &[LlamaToken],
533        seq_id: llama_seq_id,
534        logits_all: bool,
535    ) -> Result<(), BatchAddError> {
536        batch.add_sequence(tokens, seq_id, logits_all)?;
537        self.ingest_prompt_tokens(tokens);
538        self.pending_prompt_tokens = self
539            .pending_prompt_tokens
540            .saturating_add(tokens.len() as u64);
541
542        Ok(())
543    }
544
545    pub const fn commit_prompt_tokens(&mut self) -> u64 {
546        let promoted = self.pending_prompt_tokens;
547        self.usage.record_prompt_tokens(promoted);
548        self.pending_prompt_tokens = 0;
549
550        promoted
551    }
552
553    pub const fn discard_pending_prompt_tokens(&mut self) -> u64 {
554        let discarded = self.pending_prompt_tokens;
555        self.pending_prompt_tokens = 0;
556
557        discarded
558    }
559
560    #[must_use]
561    pub const fn pending_prompt_tokens(&self) -> u64 {
562        self.pending_prompt_tokens
563    }
564
565    /// # Errors
566    /// Returns [`EvalMultimodalChunksError::EvalFailed`] when the underlying
567    /// `eval_chunks` call fails (no counters move),
568    /// [`EvalMultimodalChunksError::UnknownChunkType`] when a chunk reports a
569    /// type unknown to this binding, or
570    /// [`EvalMultimodalChunksError::ChunkOutOfBounds`] when a valid index returns
571    /// `None` from `chunks.get`.
572    #[expect(
573        clippy::too_many_arguments,
574        reason = "thin wrapper over MtmdInputChunks::eval_chunks; parameter shape mirrors the underlying API"
575    )]
576    pub fn eval_multimodal_chunks(
577        &mut self,
578        chunks: &MtmdInputChunks,
579        mtmd_ctx: &MtmdContext,
580        llama_ctx: &LlamaContext,
581        start_position: llama_pos,
582        seq_id: llama_seq_id,
583        n_batch: i32,
584        logits_last: bool,
585    ) -> Result<llama_pos, EvalMultimodalChunksError> {
586        let chunk_count = chunks.len();
587        // `start_position` stays read-only; `next_position` is the loop
588        // accumulator that walks forward chunk-by-chunk and is the function's
589        // return value. Two locals, single responsibility each.
590        let mut next_position = start_position;
591
592        for index in 0..chunk_count {
593            let chunk = chunks
594                .get(index)
595                .ok_or(EvalMultimodalChunksError::ChunkOutOfBounds(index))?;
596            let logits_for_this_chunk = logits_last && index + 1 == chunk_count;
597
598            next_position = chunk.eval_single(
599                mtmd_ctx,
600                llama_ctx,
601                next_position,
602                seq_id,
603                n_batch,
604                logits_for_this_chunk,
605            )?;
606            crate::ingest_prompt_chunk::ingest_prompt_chunk(self, &chunk)?;
607        }
608
609        Ok(next_position)
610    }
611
612    pub const fn record_prompt_tokens(&mut self, count: u64) {
613        self.usage.record_prompt_tokens(count);
614    }
615
616    pub const fn record_input_image_tokens(&mut self, count: u64) {
617        self.usage.record_input_image_tokens(count);
618    }
619
620    pub const fn record_input_audio_tokens(&mut self, count: u64) {
621        self.usage.record_input_audio_tokens(count);
622    }
623
624    /// # Errors
625    /// Forwards [`TokenUsageError::CachedExceedsPrompt`] when the running cached total would
626    /// exceed the prompt total.
627    pub const fn record_cached_prompt_tokens(&mut self, count: u64) -> Result<(), TokenUsageError> {
628        self.usage.record_cached_prompt_tokens(count)
629    }
630
631    #[must_use]
632    pub const fn usage(&self) -> &TokenUsage {
633        &self.usage
634    }
635
636    #[must_use]
637    pub fn into_usage(self) -> TokenUsage {
638        self.usage
639    }
640
641    #[must_use]
642    pub const fn current_section(&self) -> SampledTokenSection {
643        self.section
644    }
645
646    #[must_use]
647    pub const fn markers(&self) -> &StreamingMarkers {
648        &self.markers
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::IngestOutcome;
655    use super::PendingToken;
656    use super::ProbeMode;
657    use super::SampledTokenClassifier;
658    use super::SampledTokenSection;
659    use super::StreamingMarkers;
660    use crate::sampled_token::SampledToken;
661    use crate::token::LlamaToken;
662
663    fn token(id: i32) -> LlamaToken {
664        LlamaToken::new(id)
665    }
666
667    fn markers_with(
668        reasoning_open: Option<Vec<LlamaToken>>,
669        reasoning_close: Option<Vec<LlamaToken>>,
670    ) -> StreamingMarkers {
671        StreamingMarkers {
672            reasoning_open,
673            reasoning_close,
674            tool_call_open: None,
675            tool_call_close: None,
676        }
677    }
678
679    /// Builds a classifier without a real model — only safe for tests that go
680    /// through `try_consume_marker_at_tail` / `drain_overflow` directly, never
681    /// through `ingest()` (which calls `model.token_to_piece`).
682    fn synthetic_classifier(markers: StreamingMarkers) -> SampledTokenClassifier<'static> {
683        SampledTokenClassifier {
684            model: unsafe { &*std::ptr::NonNull::<crate::model::LlamaModel>::dangling().as_ptr() },
685            markers,
686            decoder: encoding_rs::UTF_8.new_decoder(),
687            pending: std::collections::VecDeque::new(),
688            section: SampledTokenSection::Pending,
689            pending_prompt_tokens: 0,
690            usage: llama_cpp_bindings_types::TokenUsage::new(),
691            probe_mode: ProbeMode::Idle,
692        }
693    }
694
695    fn push_pending(classifier: &mut SampledTokenClassifier<'_>, token_id: i32, decoded: &str) {
696        classifier.pending.push_back(PendingToken {
697            token: token(token_id),
698            decoded: decoded.to_owned(),
699            section: classifier.section,
700            is_boundary: false,
701            is_from_prompt: false,
702            is_held_for_probe: false,
703        });
704    }
705
706    fn push_pending_from_prompt(classifier: &mut SampledTokenClassifier<'_>, token_id: i32) {
707        classifier.pending.push_back(PendingToken {
708            token: token(token_id),
709            decoded: String::new(),
710            section: classifier.section,
711            is_boundary: false,
712            is_from_prompt: true,
713            is_held_for_probe: false,
714        });
715    }
716
717    fn push_and_probe(
718        classifier: &mut SampledTokenClassifier<'_>,
719        token_id: i32,
720        decoded: &str,
721    ) -> Vec<IngestOutcome> {
722        push_pending(classifier, token_id, decoded);
723        classifier.try_consume_marker_at_tail();
724        let probe_was_active = matches!(classifier.probe_mode, ProbeMode::Active(_));
725        let mut outcomes = if probe_was_active && classifier.section_disengages_probe() {
726            classifier.abandon_probe()
727        } else {
728            classifier.update_probe(decoded)
729        };
730        outcomes.extend(classifier.drain_overflow());
731        outcomes
732    }
733
734    fn outcome_pieces(outcomes: &[IngestOutcome]) -> Vec<&str> {
735        outcomes
736            .iter()
737            .map(|outcome| outcome.visible_piece.as_str())
738            .collect()
739    }
740
741    fn outcome_sections(outcomes: &[IngestOutcome]) -> Vec<SampledTokenSection> {
742        outcomes
743            .iter()
744            .map(|outcome| match outcome.sampled_token {
745                SampledToken::Reasoning(_) => SampledTokenSection::Reasoning,
746                SampledToken::Content(_) => SampledTokenSection::Content,
747                SampledToken::ToolCall(_) => SampledTokenSection::ToolCall,
748                SampledToken::Undeterminable(_) => SampledTokenSection::Pending,
749            })
750            .collect()
751    }
752
753    #[test]
754    fn streaming_markers_with_no_markers_reports_none() {
755        let markers = StreamingMarkers::default();
756        assert!(!markers.has_any());
757        assert_eq!(markers.max_token_len(), 0);
758    }
759
760    #[test]
761    fn streaming_markers_max_token_len_takes_longest() {
762        let markers = StreamingMarkers {
763            reasoning_open: Some(vec![token(1)]),
764            reasoning_close: Some(vec![token(2), token(3), token(4)]),
765            tool_call_open: Some(vec![token(5), token(6)]),
766            tool_call_close: None,
767        };
768        assert_eq!(markers.max_token_len(), 3);
769    }
770
771    #[test]
772    fn single_token_close_marker_when_already_in_reasoning_emits_empty_piece_for_marker() {
773        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
774        let mut classifier = synthetic_classifier(markers);
775        classifier.section = SampledTokenSection::Reasoning;
776
777        push_pending(&mut classifier, 7, "step");
778        classifier.try_consume_marker_at_tail();
779        let mut outcomes = classifier.drain_overflow();
780
781        push_pending(&mut classifier, 200, "</think>");
782        classifier.try_consume_marker_at_tail();
783        outcomes.extend(classifier.drain_overflow());
784
785        push_pending(&mut classifier, 9, "Hi");
786        classifier.try_consume_marker_at_tail();
787        outcomes.extend(classifier.drain_overflow());
788
789        outcomes.extend(classifier.flush());
790
791        assert_eq!(
792            outcome_sections(&outcomes),
793            vec![
794                SampledTokenSection::Reasoning,
795                SampledTokenSection::Reasoning,
796                SampledTokenSection::Content,
797            ],
798        );
799        assert_eq!(outcome_pieces(&outcomes), vec!["step", "", "Hi"]);
800        assert_eq!(classifier.section, SampledTokenSection::Content);
801    }
802
803    #[test]
804    fn multi_token_close_marker_suppresses_every_marker_token() {
805        let markers = markers_with(
806            Some(vec![token(100)]),
807            Some(vec![token(200), token(201), token(202)]),
808        );
809        let mut classifier = synthetic_classifier(markers);
810        classifier.section = SampledTokenSection::Reasoning;
811
812        let mut outcomes = Vec::new();
813        for (id, decoded) in [(7, "r"), (200, "</"), (201, "thi"), (202, "nk>"), (9, "OK")] {
814            push_pending(&mut classifier, id, decoded);
815            classifier.try_consume_marker_at_tail();
816            outcomes.extend(classifier.drain_overflow());
817        }
818        outcomes.extend(classifier.flush());
819
820        assert_eq!(outcome_pieces(&outcomes), vec!["r", "", "", "", "OK"]);
821        assert_eq!(classifier.section, SampledTokenSection::Content);
822    }
823
824    #[test]
825    fn marker_prefix_that_diverges_does_not_suppress_buffered_tokens() {
826        let markers = markers_with(
827            Some(vec![token(100)]),
828            Some(vec![token(200), token(201), token(202)]),
829        );
830        let mut classifier = synthetic_classifier(markers);
831        classifier.section = SampledTokenSection::Reasoning;
832
833        let mut outcomes = Vec::new();
834        for (id, decoded) in [(7, "r"), (200, "a"), (201, "b"), (300, "x")] {
835            push_pending(&mut classifier, id, decoded);
836            classifier.try_consume_marker_at_tail();
837            outcomes.extend(classifier.drain_overflow());
838        }
839        outcomes.extend(classifier.flush());
840
841        assert_eq!(outcome_pieces(&outcomes), vec!["r", "a", "b", "x"]);
842        assert!(
843            outcomes
844                .iter()
845                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Reasoning(_)))
846        );
847        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
848    }
849
850    #[test]
851    fn open_then_close_back_to_back_emits_two_empty_pieces_around_zero_content() {
852        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
853        let mut classifier = synthetic_classifier(markers);
854        classifier.section = SampledTokenSection::Content;
855
856        let mut outcomes = Vec::new();
857        for (id, decoded) in [(100, "<think>"), (200, "</think>"), (9, "Hi")] {
858            push_pending(&mut classifier, id, decoded);
859            classifier.try_consume_marker_at_tail();
860            outcomes.extend(classifier.drain_overflow());
861        }
862        outcomes.extend(classifier.flush());
863
864        assert_eq!(
865            outcome_sections(&outcomes),
866            vec![
867                SampledTokenSection::Reasoning,
868                SampledTokenSection::Reasoning,
869                SampledTokenSection::Content,
870            ],
871        );
872        assert_eq!(outcome_pieces(&outcomes), vec!["", "", "Hi"]);
873        assert_eq!(classifier.section, SampledTokenSection::Content);
874    }
875
876    #[test]
877    fn spurious_reasoning_close_in_content_section_classifies_as_content() {
878        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
879        let mut classifier = synthetic_classifier(markers);
880        classifier.section = SampledTokenSection::Content;
881
882        push_pending(&mut classifier, 200, "</think>");
883        classifier.try_consume_marker_at_tail();
884        let outcomes = classifier.drain_overflow();
885
886        assert_eq!(
887            outcome_sections(&outcomes),
888            vec![SampledTokenSection::Content],
889        );
890        assert_eq!(classifier.section, SampledTokenSection::Content);
891    }
892
893    #[test]
894    fn spurious_tool_call_close_in_reasoning_section_classifies_as_tool_call() {
895        let markers = StreamingMarkers {
896            reasoning_open: Some(vec![token(100)]),
897            reasoning_close: Some(vec![token(200)]),
898            tool_call_open: Some(vec![token(300)]),
899            tool_call_close: Some(vec![token(400)]),
900        };
901        let mut classifier = synthetic_classifier(markers);
902        classifier.section = SampledTokenSection::ToolCall;
903
904        push_pending(&mut classifier, 400, "</tool_call>");
905        classifier.try_consume_marker_at_tail();
906        let outcomes = classifier.drain_overflow();
907
908        assert_eq!(
909            outcome_sections(&outcomes),
910            vec![SampledTokenSection::ToolCall],
911        );
912        assert_eq!(classifier.section, SampledTokenSection::Content);
913    }
914
915    #[test]
916    fn flush_drains_remaining_pending_at_eog() {
917        let markers = markers_with(
918            Some(vec![token(100)]),
919            Some(vec![token(200), token(201), token(202)]),
920        );
921        let mut classifier = synthetic_classifier(markers);
922        classifier.section = SampledTokenSection::Reasoning;
923
924        push_pending(&mut classifier, 7, "abc");
925        push_pending(&mut classifier, 200, "</");
926        push_pending(&mut classifier, 201, "th");
927
928        let outcomes = classifier.flush();
929
930        assert_eq!(outcome_pieces(&outcomes), vec!["abc", "</", "th"]);
931        assert!(classifier.pending.is_empty());
932    }
933
934    #[test]
935    fn no_markers_marks_each_token_undeterminable_with_visible_piece() {
936        let markers = StreamingMarkers::default();
937        let mut classifier = synthetic_classifier(markers);
938
939        push_pending(&mut classifier, 1, "h");
940        push_pending(&mut classifier, 2, "i");
941        let outcomes = classifier.flush();
942
943        assert_eq!(outcome_pieces(&outcomes), vec!["h", "i"]);
944        assert_eq!(
945            outcome_sections(&outcomes),
946            vec![SampledTokenSection::Pending, SampledTokenSection::Pending],
947        );
948    }
949
950    #[test]
951    fn ingest_prompt_tokens_without_markers_is_noop() {
952        let markers = StreamingMarkers::default();
953        let mut classifier = synthetic_classifier(markers);
954
955        push_pending_from_prompt(&mut classifier, 7);
956        push_pending_from_prompt(&mut classifier, 8);
957
958        assert_eq!(classifier.section, SampledTokenSection::Pending);
959        assert_eq!(classifier.usage().reasoning_tokens, 0);
960        assert_eq!(classifier.usage().content_tokens, 0);
961        assert_eq!(classifier.usage().tool_call_tokens, 0);
962        assert_eq!(classifier.usage().undeterminable_tokens, 0);
963    }
964
965    #[test]
966    fn ingest_prompt_tokens_through_open_close_pair_ends_in_content() {
967        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
968        let mut classifier = synthetic_classifier(markers);
969
970        for token_id in [100, 7, 200] {
971            push_pending_from_prompt(&mut classifier, token_id);
972            classifier.try_consume_marker_at_tail();
973            classifier.drain_overflow();
974        }
975
976        assert_eq!(classifier.section, SampledTokenSection::Content);
977        assert_eq!(classifier.usage().reasoning_tokens, 0);
978        assert_eq!(classifier.usage().content_tokens, 0);
979        assert_eq!(classifier.usage().tool_call_tokens, 0);
980        assert_eq!(classifier.usage().undeterminable_tokens, 0);
981    }
982
983    #[test]
984    fn ingest_prompt_tokens_through_open_only_ends_in_reasoning() {
985        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
986        let mut classifier = synthetic_classifier(markers);
987
988        for token_id in [100, 7] {
989            push_pending_from_prompt(&mut classifier, token_id);
990            classifier.try_consume_marker_at_tail();
991            classifier.drain_overflow();
992        }
993
994        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
995        assert_eq!(classifier.usage().reasoning_tokens, 0);
996        assert_eq!(classifier.usage().content_tokens, 0);
997    }
998
999    #[test]
1000    fn ingest_prompt_tokens_does_not_record_usage() {
1001        let markers = markers_with(
1002            Some(vec![token(100)]),
1003            Some(vec![token(200), token(201), token(202)]),
1004        );
1005        let mut classifier = synthetic_classifier(markers);
1006
1007        for token_id in [100, 7, 8, 9, 200, 201, 202, 11] {
1008            push_pending_from_prompt(&mut classifier, token_id);
1009            classifier.try_consume_marker_at_tail();
1010            classifier.drain_overflow();
1011        }
1012        let drained = classifier.flush();
1013        assert!(drained.is_empty());
1014
1015        assert_eq!(classifier.usage().reasoning_tokens, 0);
1016        assert_eq!(classifier.usage().content_tokens, 0);
1017        assert_eq!(classifier.usage().tool_call_tokens, 0);
1018        assert_eq!(classifier.usage().undeterminable_tokens, 0);
1019    }
1020
1021    #[test]
1022    fn prompt_token_completing_marker_with_generated_token_is_suppressed_correctly() {
1023        let markers = markers_with(
1024            Some(vec![token(100)]),
1025            Some(vec![token(200), token(201), token(202)]),
1026        );
1027        let mut classifier = synthetic_classifier(markers);
1028        classifier.section = SampledTokenSection::Reasoning;
1029
1030        for token_id in [200, 201] {
1031            push_pending_from_prompt(&mut classifier, token_id);
1032            classifier.try_consume_marker_at_tail();
1033            classifier.drain_overflow();
1034        }
1035
1036        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
1037        assert_eq!(classifier.pending.len(), 2);
1038
1039        classifier.pending.push_back(PendingToken {
1040            token: token(202),
1041            decoded: "k>".to_owned(),
1042            section: classifier.section,
1043            is_boundary: false,
1044            is_from_prompt: false,
1045            is_held_for_probe: false,
1046        });
1047        classifier.try_consume_marker_at_tail();
1048        let outcomes = classifier.drain_overflow();
1049
1050        assert_eq!(outcomes.len(), 1);
1051        assert!(matches!(
1052            outcomes[0].sampled_token,
1053            SampledToken::Reasoning(_)
1054        ));
1055        assert_eq!(outcomes[0].visible_piece, "");
1056        assert_eq!(outcomes[0].raw_piece, "k>");
1057
1058        assert_eq!(classifier.section, SampledTokenSection::Content);
1059        assert_eq!(classifier.usage().reasoning_tokens, 1);
1060        assert_eq!(classifier.usage().content_tokens, 0);
1061    }
1062
1063    #[test]
1064    fn ingest_prompt_tokens_with_multiple_round_trips_ends_in_content() {
1065        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1066        let mut classifier = synthetic_classifier(markers);
1067
1068        // <think> body </think> <think> body </think>
1069        for token_id in [100, 7, 200, 100, 8, 200] {
1070            push_pending_from_prompt(&mut classifier, token_id);
1071            classifier.try_consume_marker_at_tail();
1072            classifier.drain_overflow();
1073        }
1074
1075        assert_eq!(classifier.section, SampledTokenSection::Content);
1076        assert_eq!(classifier.usage().reasoning_tokens, 0);
1077        assert_eq!(classifier.usage().content_tokens, 0);
1078        assert_eq!(classifier.usage().tool_call_tokens, 0);
1079        assert_eq!(classifier.usage().undeterminable_tokens, 0);
1080    }
1081
1082    #[test]
1083    fn ingest_prompt_tokens_initial_section_is_always_pending() {
1084        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1085        let classifier = synthetic_classifier(markers);
1086
1087        assert_eq!(classifier.section, SampledTokenSection::Pending);
1088    }
1089
1090    #[test]
1091    fn ingest_prompt_tokens_then_drain_for_generated_token_classifies_correctly() {
1092        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1093        let mut classifier = synthetic_classifier(markers);
1094
1095        // Closed-think prompt: <think> body </think>
1096        for token_id in [100, 7, 200] {
1097            push_pending_from_prompt(&mut classifier, token_id);
1098            classifier.try_consume_marker_at_tail();
1099            classifier.drain_overflow();
1100        }
1101
1102        assert_eq!(classifier.section, SampledTokenSection::Content);
1103        assert_eq!(classifier.usage().reasoning_tokens, 0);
1104        assert_eq!(classifier.usage().content_tokens, 0);
1105
1106        // Generated content token (not from prompt): pushed with section=Content,
1107        // is_from_prompt=false. drain_overflow finalises it as SampledToken::Content
1108        // and increments usage.content_tokens.
1109        classifier.pending.push_back(PendingToken {
1110            token: token(50),
1111            decoded: "hi".to_owned(),
1112            section: classifier.section,
1113            is_boundary: false,
1114            is_from_prompt: false,
1115            is_held_for_probe: false,
1116        });
1117        classifier.try_consume_marker_at_tail();
1118        let outcomes = classifier.drain_overflow();
1119
1120        assert_eq!(outcomes.len(), 1);
1121        assert!(matches!(
1122            outcomes[0].sampled_token,
1123            SampledToken::Content(_)
1124        ));
1125        assert_eq!(outcomes[0].visible_piece, "hi");
1126        assert_eq!(classifier.usage().content_tokens, 1);
1127        assert_eq!(classifier.usage().reasoning_tokens, 0);
1128        assert_eq!(classifier.usage().undeterminable_tokens, 0);
1129    }
1130
1131    #[test]
1132    fn close_marker_in_content_section_is_suppressed_as_boundary() {
1133        // When a misbehaving model emits a close marker (e.g. `</think>`) while
1134        // already in the Content section, the classifier must treat it as a
1135        // boundary so the marker text never reaches the user-visible content
1136        // stream. The boundary token is classified as Content (not Reasoning):
1137        // there is no reasoning to close, the close marker is just noise in
1138        // the content section. This is the architectural backstop against
1139        // models that re-emit close markers without a preceding open.
1140        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1141        let mut classifier = synthetic_classifier(markers);
1142        classifier.section = SampledTokenSection::Content;
1143
1144        let mut outcomes = Vec::new();
1145        for (id, decoded) in [(7, "hi"), (200, "</think>"), (8, "ok")] {
1146            push_pending(&mut classifier, id, decoded);
1147            classifier.try_consume_marker_at_tail();
1148            outcomes.extend(classifier.drain_overflow());
1149        }
1150        outcomes.extend(classifier.flush());
1151
1152        assert_eq!(
1153            outcome_sections(&outcomes),
1154            vec![
1155                SampledTokenSection::Content,
1156                SampledTokenSection::Content,
1157                SampledTokenSection::Content,
1158            ],
1159        );
1160        // The close marker's `visible_piece` is empty (boundary), so the
1161        // user-visible content stream is "hi" + "" + "ok" = "hiok".
1162        assert_eq!(outcome_pieces(&outcomes), vec!["hi", "", "ok"]);
1163        assert_eq!(classifier.section, SampledTokenSection::Content);
1164    }
1165
1166    #[test]
1167    fn open_marker_in_reasoning_section_is_suppressed_as_boundary() {
1168        // A nested `<think>` while already in Reasoning is suppressed (so the
1169        // user never sees the marker text in the reasoning stream) and the
1170        // section stays Reasoning.
1171        let markers = markers_with(Some(vec![token(100)]), Some(vec![token(200)]));
1172        let mut classifier = synthetic_classifier(markers);
1173        classifier.section = SampledTokenSection::Reasoning;
1174
1175        let mut outcomes = Vec::new();
1176        for (id, decoded) in [(7, "step1"), (100, "<think>"), (8, "step2")] {
1177            push_pending(&mut classifier, id, decoded);
1178            classifier.try_consume_marker_at_tail();
1179            outcomes.extend(classifier.drain_overflow());
1180        }
1181        outcomes.extend(classifier.flush());
1182
1183        assert_eq!(outcome_pieces(&outcomes), vec!["step1", "", "step2"]);
1184        assert_eq!(classifier.section, SampledTokenSection::Reasoning);
1185    }
1186
1187    #[test]
1188    fn record_prompt_tokens_updates_usage() {
1189        let markers = markers_with(None, None);
1190        let mut classifier = synthetic_classifier(markers);
1191
1192        classifier.record_prompt_tokens(7);
1193
1194        assert_eq!(classifier.usage().prompt_tokens, 7);
1195    }
1196
1197    #[test]
1198    fn record_cached_prompt_tokens_updates_usage_when_under_limit() {
1199        let markers = markers_with(None, None);
1200        let mut classifier = synthetic_classifier(markers);
1201        classifier.record_prompt_tokens(10);
1202
1203        classifier.record_cached_prompt_tokens(3).unwrap();
1204
1205        assert_eq!(classifier.usage().cached_prompt_tokens, 3);
1206    }
1207
1208    #[test]
1209    fn record_cached_prompt_tokens_returns_error_when_over_prompt_total() {
1210        let markers = markers_with(None, None);
1211        let mut classifier = synthetic_classifier(markers);
1212        classifier.record_prompt_tokens(2);
1213
1214        let result = classifier.record_cached_prompt_tokens(5);
1215
1216        assert!(result.is_err());
1217    }
1218
1219    #[test]
1220    fn markers_accessor_returns_configured_markers() {
1221        let configured = markers_with(Some(vec![token(1)]), Some(vec![token(2)]));
1222        let classifier = synthetic_classifier(configured);
1223
1224        let returned = classifier.markers();
1225
1226        assert_eq!(returned.reasoning_open.as_deref(), Some(&[token(1)][..]));
1227        assert_eq!(returned.reasoning_close.as_deref(), Some(&[token(2)][..]));
1228    }
1229
1230    #[test]
1231    fn into_usage_consumes_classifier_and_yields_usage_snapshot() {
1232        let markers = markers_with(None, None);
1233        let mut classifier = synthetic_classifier(markers);
1234        classifier.record_prompt_tokens(11);
1235
1236        let usage = classifier.into_usage();
1237
1238        assert_eq!(usage.prompt_tokens, 11);
1239    }
1240
1241    #[test]
1242    fn spurious_tool_call_close_in_content_section_classifies_as_content() {
1243        // A `</tool_call>` while in Content (model misbehaves) is classified as
1244        // Content (not ToolCall) so observed_tool_calls isn't inflated.
1245        let mut markers = markers_with(None, None);
1246        markers.tool_call_close = Some(vec![token(300)]);
1247        let mut classifier = synthetic_classifier(markers);
1248        classifier.section = SampledTokenSection::Content;
1249
1250        push_pending(&mut classifier, 300, "</tool_call>");
1251        classifier.try_consume_marker_at_tail();
1252        let outcomes = classifier.drain_overflow();
1253
1254        assert_eq!(
1255            outcome_sections(&outcomes),
1256            vec![SampledTokenSection::Content],
1257        );
1258        assert_eq!(classifier.section, SampledTokenSection::Content);
1259    }
1260
1261    fn markers_with_tool_call_open(tool_call_open: Vec<LlamaToken>) -> StreamingMarkers {
1262        StreamingMarkers {
1263            reasoning_open: None,
1264            reasoning_close: None,
1265            tool_call_open: Some(tool_call_open),
1266            tool_call_close: None,
1267        }
1268    }
1269
1270    fn feed_json_string(
1271        classifier: &mut SampledTokenClassifier<'_>,
1272        text: &str,
1273        starting_token_id: i32,
1274    ) -> Vec<IngestOutcome> {
1275        let mut outcomes = Vec::new();
1276        for (offset, ch) in text.char_indices() {
1277            let token_id = starting_token_id + i32::try_from(offset).unwrap_or(i32::MAX);
1278            let mut buffer = [0_u8; 4];
1279            let chunk = ch.encode_utf8(&mut buffer);
1280            outcomes.extend(push_and_probe(classifier, token_id, chunk));
1281        }
1282        outcomes
1283    }
1284
1285    #[test]
1286    fn json_probe_engages_when_first_non_whitespace_is_open_brace_in_content() {
1287        let markers = markers_with_tool_call_open(vec![token(900)]);
1288        let mut classifier = synthetic_classifier(markers);
1289        classifier.section = SampledTokenSection::Content;
1290
1291        push_and_probe(&mut classifier, 1, "{");
1292
1293        assert!(matches!(classifier.probe_mode, ProbeMode::Active(_)));
1294    }
1295
1296    #[test]
1297    fn json_probe_releases_tokens_as_tool_call_when_signature_matches() {
1298        let markers = markers_with_tool_call_open(vec![token(900)]);
1299        let mut classifier = synthetic_classifier(markers);
1300        classifier.section = SampledTokenSection::Content;
1301
1302        let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":{}}"#, 100);
1303
1304        assert!(!outcomes.is_empty());
1305        assert!(
1306            outcomes
1307                .iter()
1308                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1309            "every emitted outcome should be ToolCall, got {:?}",
1310            outcome_sections(&outcomes),
1311        );
1312        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1313    }
1314
1315    #[test]
1316    fn json_probe_releases_tokens_as_content_when_signature_does_not_match() {
1317        let markers = markers_with_tool_call_open(vec![token(900)]);
1318        let mut classifier = synthetic_classifier(markers);
1319        classifier.section = SampledTokenSection::Content;
1320
1321        let outcomes = feed_json_string(&mut classifier, r#"{"foo":"bar"}"#, 100);
1322
1323        assert!(
1324            outcomes
1325                .iter()
1326                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1327            "every emitted outcome should be Content, got {:?}",
1328            outcome_sections(&outcomes),
1329        );
1330        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1331    }
1332
1333    #[test]
1334    fn json_probe_releases_tokens_as_content_when_extra_top_level_key() {
1335        let markers = markers_with_tool_call_open(vec![token(900)]);
1336        let mut classifier = synthetic_classifier(markers);
1337        classifier.section = SampledTokenSection::Content;
1338
1339        let outcomes = feed_json_string(
1340            &mut classifier,
1341            r#"{"name":"f","arguments":{},"extra":1}"#,
1342            100,
1343        );
1344
1345        assert!(
1346            outcomes
1347                .iter()
1348                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1349        );
1350    }
1351
1352    #[test]
1353    fn json_probe_releases_tokens_as_content_when_arguments_is_not_object() {
1354        let markers = markers_with_tool_call_open(vec![token(900)]);
1355        let mut classifier = synthetic_classifier(markers);
1356        classifier.section = SampledTokenSection::Content;
1357
1358        let outcomes = feed_json_string(&mut classifier, r#"{"name":"f","arguments":"hi"}"#, 100);
1359
1360        assert!(
1361            outcomes
1362                .iter()
1363                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1364        );
1365    }
1366
1367    #[test]
1368    fn json_probe_handles_strings_with_quoted_braces_in_arguments() {
1369        let markers = markers_with_tool_call_open(vec![token(900)]);
1370        let mut classifier = synthetic_classifier(markers);
1371        classifier.section = SampledTokenSection::Content;
1372
1373        let outcomes = feed_json_string(
1374            &mut classifier,
1375            r#"{"name":"f","arguments":{"q":"a } b"}}"#,
1376            100,
1377        );
1378
1379        assert!(
1380            outcomes
1381                .iter()
1382                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1383        );
1384    }
1385
1386    #[test]
1387    fn json_probe_handles_escaped_quotes_in_string_values() {
1388        let markers = markers_with_tool_call_open(vec![token(900)]);
1389        let mut classifier = synthetic_classifier(markers);
1390        classifier.section = SampledTokenSection::Content;
1391
1392        let outcomes = feed_json_string(
1393            &mut classifier,
1394            r#"{"name":"f","arguments":{"q":"he said \"hi\""}}"#,
1395            100,
1396        );
1397
1398        assert!(
1399            outcomes
1400                .iter()
1401                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1402        );
1403    }
1404
1405    #[test]
1406    fn json_probe_handles_unicode_letters_in_strings() {
1407        let markers = markers_with_tool_call_open(vec![token(900)]);
1408        let mut classifier = synthetic_classifier(markers);
1409        classifier.section = SampledTokenSection::Content;
1410
1411        let outcomes = feed_json_string(
1412            &mut classifier,
1413            r#"{"name":"日本語","arguments":{"city":"パリ"}}"#,
1414            100,
1415        );
1416
1417        assert!(
1418            outcomes
1419                .iter()
1420                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1421        );
1422    }
1423
1424    #[test]
1425    fn json_probe_handles_nested_objects() {
1426        let markers = markers_with_tool_call_open(vec![token(900)]);
1427        let mut classifier = synthetic_classifier(markers);
1428        classifier.section = SampledTokenSection::Content;
1429
1430        let outcomes = feed_json_string(
1431            &mut classifier,
1432            r#"{"name":"f","arguments":{"a":{"b":{"c":1}}}}"#,
1433            100,
1434        );
1435
1436        assert!(
1437            outcomes
1438                .iter()
1439                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1440        );
1441    }
1442
1443    #[test]
1444    fn json_probe_handles_arrays_inside_arguments() {
1445        let markers = markers_with_tool_call_open(vec![token(900)]);
1446        let mut classifier = synthetic_classifier(markers);
1447        classifier.section = SampledTokenSection::Content;
1448
1449        let outcomes = feed_json_string(
1450            &mut classifier,
1451            r#"{"name":"f","arguments":{"items":[1,2,3]}}"#,
1452            100,
1453        );
1454
1455        assert!(
1456            outcomes
1457                .iter()
1458                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1459        );
1460    }
1461
1462    #[test]
1463    fn json_probe_does_not_engage_when_first_byte_is_close_brace() {
1464        let markers = markers_with_tool_call_open(vec![token(900)]);
1465        let mut classifier = synthetic_classifier(markers);
1466        classifier.section = SampledTokenSection::Content;
1467
1468        let outcomes = feed_json_string(&mut classifier, "}}", 100);
1469
1470        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1471        assert!(
1472            outcomes
1473                .iter()
1474                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1475        );
1476    }
1477
1478    #[test]
1479    fn json_probe_does_not_engage_in_reasoning_section() {
1480        let markers = StreamingMarkers {
1481            reasoning_open: Some(vec![token(800)]),
1482            reasoning_close: Some(vec![token(801)]),
1483            tool_call_open: Some(vec![token(900)]),
1484            tool_call_close: None,
1485        };
1486        let mut classifier = synthetic_classifier(markers);
1487        classifier.section = SampledTokenSection::Reasoning;
1488
1489        push_and_probe(&mut classifier, 1, "{");
1490
1491        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1492    }
1493
1494    #[test]
1495    fn json_probe_does_not_engage_in_tool_call_section() {
1496        let markers = markers_with_tool_call_open(vec![token(900)]);
1497        let mut classifier = synthetic_classifier(markers);
1498        classifier.section = SampledTokenSection::ToolCall;
1499
1500        push_and_probe(&mut classifier, 1, "{");
1501
1502        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1503    }
1504
1505    #[test]
1506    fn marker_probe_takes_precedence_when_both_could_match() {
1507        // Marker is a single token whose decoded text starts with `"` (a JSON
1508        // signature-valid byte). The JSON probe holds the leading `{`, the
1509        // marker matches at the next token, the section transitions to ToolCall,
1510        // the JSON probe abandons. The leading `{` releases as Content; the
1511        // marker token releases as a ToolCall boundary (suppressed).
1512        let markers = markers_with_tool_call_open(vec![token(900)]);
1513        let mut classifier = synthetic_classifier(markers);
1514        classifier.section = SampledTokenSection::Content;
1515
1516        let mut outcomes = Vec::new();
1517        outcomes.extend(push_and_probe(&mut classifier, 1, "{"));
1518        outcomes.extend(push_and_probe(&mut classifier, 900, r#"""#));
1519
1520        assert_eq!(classifier.section, SampledTokenSection::ToolCall);
1521        assert_eq!(outcome_pieces(&outcomes), vec!["{", ""]);
1522        assert_eq!(
1523            outcome_sections(&outcomes),
1524            vec![SampledTokenSection::Content, SampledTokenSection::ToolCall],
1525        );
1526    }
1527
1528    #[test]
1529    fn json_probe_consumes_two_consecutive_objects_separately() {
1530        let markers = markers_with_tool_call_open(vec![token(900)]);
1531        let mut classifier = synthetic_classifier(markers);
1532        classifier.section = SampledTokenSection::Content;
1533
1534        let mut outcomes = Vec::new();
1535        outcomes.extend(feed_json_string(
1536            &mut classifier,
1537            r#"{"name":"a","arguments":{}}"#,
1538            100,
1539        ));
1540        outcomes.extend(feed_json_string(
1541            &mut classifier,
1542            r#"{"name":"b","arguments":{"x":1}}"#,
1543            200,
1544        ));
1545
1546        assert!(
1547            outcomes
1548                .iter()
1549                .all(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_))),
1550            "two consecutive markerless tool calls must both classify as ToolCall, got {:?}",
1551            outcome_sections(&outcomes),
1552        );
1553    }
1554
1555    #[test]
1556    fn json_probe_with_leading_whitespace_then_open_brace_classifies_whitespace_as_content_and_json_as_tool_call()
1557     {
1558        let markers = markers_with_tool_call_open(vec![token(900)]);
1559        let mut classifier = synthetic_classifier(markers);
1560        classifier.section = SampledTokenSection::Content;
1561
1562        let outcomes = feed_json_string(
1563            &mut classifier,
1564            "\n  {\"name\":\"f\",\"arguments\":{}}",
1565            100,
1566        );
1567
1568        let tool_call_count = outcomes
1569            .iter()
1570            .filter(|outcome| matches!(outcome.sampled_token, SampledToken::ToolCall(_)))
1571            .count();
1572        let content_count = outcomes
1573            .iter()
1574            .filter(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_)))
1575            .count();
1576        assert_eq!(
1577            content_count, 3,
1578            "leading `\\n  ` should classify as content"
1579        );
1580        assert!(
1581            tool_call_count > 0,
1582            "the JSON object should classify as ToolCall",
1583        );
1584        assert_eq!(content_count + tool_call_count, outcomes.len());
1585    }
1586
1587    #[test]
1588    fn json_probe_records_tool_call_token_usage_on_commit() {
1589        let markers = markers_with_tool_call_open(vec![token(900)]);
1590        let mut classifier = synthetic_classifier(markers);
1591        classifier.section = SampledTokenSection::Content;
1592
1593        let json = r#"{"name":"f","arguments":{}}"#;
1594        let outcomes = feed_json_string(&mut classifier, json, 100);
1595
1596        let emitted = outcomes.len();
1597        let usage = classifier.usage();
1598        assert_eq!(usage.tool_call_tokens, emitted as u64);
1599        assert_eq!(usage.content_tokens, 0);
1600    }
1601
1602    #[test]
1603    fn json_probe_records_content_token_usage_on_abandon() {
1604        let markers = markers_with_tool_call_open(vec![token(900)]);
1605        let mut classifier = synthetic_classifier(markers);
1606        classifier.section = SampledTokenSection::Content;
1607
1608        let json = r#"{"foo":"bar"}"#;
1609        let outcomes = feed_json_string(&mut classifier, json, 100);
1610
1611        let emitted = outcomes.len();
1612        let usage = classifier.usage();
1613        assert_eq!(usage.content_tokens, emitted as u64);
1614        assert_eq!(usage.tool_call_tokens, 0);
1615    }
1616
1617    #[test]
1618    fn flush_during_active_json_probe_releases_held_tokens_as_content() {
1619        let markers = markers_with_tool_call_open(vec![token(900)]);
1620        let mut classifier = synthetic_classifier(markers);
1621        classifier.section = SampledTokenSection::Content;
1622
1623        push_and_probe(&mut classifier, 1, "{");
1624        push_and_probe(&mut classifier, 2, r#""name""#);
1625        assert!(matches!(classifier.probe_mode, ProbeMode::Active(_)));
1626
1627        let outcomes = classifier.flush();
1628
1629        assert!(
1630            outcomes
1631                .iter()
1632                .all(|outcome| matches!(outcome.sampled_token, SampledToken::Content(_))),
1633            "mid-probe flush must release held tokens as Content, got {:?}",
1634            outcome_sections(&outcomes),
1635        );
1636        assert!(matches!(classifier.probe_mode, ProbeMode::Idle));
1637    }
1638}