Skip to main content

codec_rs/
tool_watcher.rs

1// SPDX-License-Identifier: MIT
2//! Tool-call / region watcher.
3//!
4//! Mirrors `libcodec`'s `codec_tool_watcher`, the .NET `ToolWatcher`,
5//! and `@codecai/web`'s `ToolWatcher` — same state-machine semantics.
6//! Detects delimited regions (tool calls, reasoning blocks, vision
7//! spans, sandbox regions, channel headers) in a token-ID stream
8//! without ever decoding. The hot loop is a `u32` compare against two
9//! cached IDs; no vocab read, no detokenize call, no string allocation.
10//!
11//! State survives across [`ToolWatcher::feed`] calls — a region split
12//! between network frames buffers internally until the end marker
13//! arrives.
14
15use crate::map::TokenizerMap;
16
17/// Kind of event emitted by [`ToolWatcher::feed`].
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum WatcherEventKind {
20    /// Token IDs outside any watched region. Forward as-is.
21    Passthrough,
22    /// A complete start..end region with markers excluded.
23    Region,
24}
25
26/// One event from [`ToolWatcher::feed`].
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct WatcherEvent {
29    pub kind: WatcherEventKind,
30    pub ids: Vec<u32>,
31}
32
33/// Errors raised when constructing a [`ToolWatcher`].
34#[derive(Debug, thiserror::Error)]
35pub enum ToolWatcherError {
36    #[error("special token \"{0}\" not in map.special_tokens")]
37    MissingSpecial(String),
38}
39
40/// Stateful watcher for delimited regions in a token-ID stream.
41///
42/// Construct with a map and the names of the start/end specials. The
43/// watcher resolves them to IDs once and caches them — no further map
44/// access happens during [`ToolWatcher::feed`].
45pub struct ToolWatcher {
46    pub start_id: u32,
47    pub end_id: u32,
48    pub start_name: String,
49    pub end_name: String,
50    inside: bool,
51    region: Vec<u32>,
52}
53
54impl ToolWatcher {
55    pub fn new(
56        map: &TokenizerMap,
57        start_name: &str,
58        end_name: &str,
59    ) -> Result<Self, ToolWatcherError> {
60        let specials = map.special_tokens.as_ref().ok_or_else(|| {
61            ToolWatcherError::MissingSpecial(start_name.to_string())
62        })?;
63        let start_id = specials
64            .get(start_name)
65            .copied()
66            .ok_or_else(|| ToolWatcherError::MissingSpecial(start_name.to_string()))?;
67        let end_id = specials
68            .get(end_name)
69            .copied()
70            .ok_or_else(|| ToolWatcherError::MissingSpecial(end_name.to_string()))?;
71
72        Ok(Self {
73            start_id,
74            end_id,
75            start_name: start_name.to_string(),
76            end_name: end_name.to_string(),
77            inside: false,
78            region: Vec::new(),
79        })
80    }
81
82    /// True while a region is open (start seen, end not yet).
83    pub fn inside(&self) -> bool {
84        self.inside
85    }
86
87    /// Drop any in-flight region buffer. Call between conversations so
88    /// a leftover unclosed region from session N doesn't spill into N+1.
89    pub fn reset(&mut self) {
90        self.inside = false;
91        self.region.clear();
92    }
93
94    /// Feed a chunk of token IDs and receive a flat list of events.
95    pub fn feed(&mut self, ids: &[u32]) -> Vec<WatcherEvent> {
96        let mut events: Vec<WatcherEvent> = Vec::new();
97        let n = ids.len();
98        let mut pt_start = 0usize;
99
100        // Single-pass scan. Identical state machine to the C / .NET / TS
101        // implementations — keep them in sync if you change one.
102        for i in 0..n {
103            let id = ids[i];
104            if !self.inside {
105                if id == self.start_id {
106                    if i > pt_start {
107                        events.push(WatcherEvent {
108                            kind: WatcherEventKind::Passthrough,
109                            ids: ids[pt_start..i].to_vec(),
110                        });
111                    }
112                    self.inside = true;
113                    self.region.clear();
114                    // pt_start re-anchors when the region closes.
115                }
116                // else: token continues passthrough run; no action.
117            } else if id == self.end_id {
118                events.push(WatcherEvent {
119                    kind: WatcherEventKind::Region,
120                    ids: std::mem::take(&mut self.region),
121                });
122                self.inside = false;
123                pt_start = i + 1;
124            } else if id == self.start_id {
125                // Nested start — ignore. Most models don't nest these markers.
126            } else {
127                self.region.push(id);
128            }
129        }
130
131        if !self.inside && pt_start < n {
132            events.push(WatcherEvent {
133                kind: WatcherEventKind::Passthrough,
134                ids: ids[pt_start..n].to_vec(),
135            });
136        }
137
138        events
139    }
140
141    /// Convenience: feed an `i32` slice (the wire frame's natural type
142    /// from .NET surface familiarity). Internally upcast to `u32`.
143    pub fn feed_i32(&mut self, ids: &[i32]) -> Vec<WatcherEvent> {
144        let copy: Vec<u32> = ids.iter().map(|&v| v as u32).collect();
145        self.feed(&copy)
146    }
147}