codec_rs/tool_watcher.rs
1// SPDX-License-Identifier: MIT
2//! Tool-call / region watcher.
3//!
4//! Mirrors `libcodec`'s `codec_tool_watcher`, the .NET `ToolWatcher`,
5//! and `@codecai/web`'s `ToolWatcher` — same state-machine semantics.
6//! Detects delimited regions (tool calls, reasoning blocks, vision
7//! spans, sandbox regions, channel headers) in a token-ID stream
8//! without ever decoding. The hot loop is a `u32` compare against two
9//! cached IDs; no vocab read, no detokenize call, no string allocation.
10//!
11//! State survives across [`ToolWatcher::feed`] calls — a region split
12//! between network frames buffers internally until the end marker
13//! arrives.
14
15use crate::map::TokenizerMap;
16
17/// Kind of event emitted by [`ToolWatcher::feed`].
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum WatcherEventKind {
20 /// Token IDs outside any watched region. Forward as-is.
21 Passthrough,
22 /// A complete start..end region with markers excluded.
23 Region,
24}
25
26/// One event from [`ToolWatcher::feed`].
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct WatcherEvent {
29 pub kind: WatcherEventKind,
30 pub ids: Vec<u32>,
31}
32
33/// Errors raised when constructing a [`ToolWatcher`].
34#[derive(Debug, thiserror::Error)]
35pub enum ToolWatcherError {
36 #[error("special token \"{0}\" not in map.special_tokens")]
37 MissingSpecial(String),
38}
39
40/// Stateful watcher for delimited regions in a token-ID stream.
41///
42/// Construct with a map and the names of the start/end specials. The
43/// watcher resolves them to IDs once and caches them — no further map
44/// access happens during [`ToolWatcher::feed`].
45pub struct ToolWatcher {
46 pub start_id: u32,
47 pub end_id: u32,
48 pub start_name: String,
49 pub end_name: String,
50 inside: bool,
51 region: Vec<u32>,
52}
53
54impl ToolWatcher {
55 pub fn new(
56 map: &TokenizerMap,
57 start_name: &str,
58 end_name: &str,
59 ) -> Result<Self, ToolWatcherError> {
60 let specials = map.special_tokens.as_ref().ok_or_else(|| {
61 ToolWatcherError::MissingSpecial(start_name.to_string())
62 })?;
63 let start_id = specials
64 .get(start_name)
65 .copied()
66 .ok_or_else(|| ToolWatcherError::MissingSpecial(start_name.to_string()))?;
67 let end_id = specials
68 .get(end_name)
69 .copied()
70 .ok_or_else(|| ToolWatcherError::MissingSpecial(end_name.to_string()))?;
71
72 Ok(Self {
73 start_id,
74 end_id,
75 start_name: start_name.to_string(),
76 end_name: end_name.to_string(),
77 inside: false,
78 region: Vec::new(),
79 })
80 }
81
82 /// True while a region is open (start seen, end not yet).
83 pub fn inside(&self) -> bool {
84 self.inside
85 }
86
87 /// Drop any in-flight region buffer. Call between conversations so
88 /// a leftover unclosed region from session N doesn't spill into N+1.
89 pub fn reset(&mut self) {
90 self.inside = false;
91 self.region.clear();
92 }
93
94 /// Feed a chunk of token IDs and receive a flat list of events.
95 pub fn feed(&mut self, ids: &[u32]) -> Vec<WatcherEvent> {
96 let mut events: Vec<WatcherEvent> = Vec::new();
97 let n = ids.len();
98 let mut pt_start = 0usize;
99
100 // Single-pass scan. Identical state machine to the C / .NET / TS
101 // implementations — keep them in sync if you change one.
102 for i in 0..n {
103 let id = ids[i];
104 if !self.inside {
105 if id == self.start_id {
106 if i > pt_start {
107 events.push(WatcherEvent {
108 kind: WatcherEventKind::Passthrough,
109 ids: ids[pt_start..i].to_vec(),
110 });
111 }
112 self.inside = true;
113 self.region.clear();
114 // pt_start re-anchors when the region closes.
115 }
116 // else: token continues passthrough run; no action.
117 } else if id == self.end_id {
118 events.push(WatcherEvent {
119 kind: WatcherEventKind::Region,
120 ids: std::mem::take(&mut self.region),
121 });
122 self.inside = false;
123 pt_start = i + 1;
124 } else if id == self.start_id {
125 // Nested start — ignore. Most models don't nest these markers.
126 } else {
127 self.region.push(id);
128 }
129 }
130
131 if !self.inside && pt_start < n {
132 events.push(WatcherEvent {
133 kind: WatcherEventKind::Passthrough,
134 ids: ids[pt_start..n].to_vec(),
135 });
136 }
137
138 events
139 }
140
141 /// Convenience: feed an `i32` slice (the wire frame's natural type
142 /// from .NET surface familiarity). Internally upcast to `u32`.
143 pub fn feed_i32(&mut self, ids: &[i32]) -> Vec<WatcherEvent> {
144 let copy: Vec<u32> = ids.iter().map(|&v| v as u32).collect();
145 self.feed(©)
146 }
147}