use crate::map::TokenizerMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WatcherEventKind {
Passthrough,
Region,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WatcherEvent {
pub kind: WatcherEventKind,
pub ids: Vec<u32>,
}
#[derive(Debug, thiserror::Error)]
pub enum ToolWatcherError {
#[error("special token \"{0}\" not in map.special_tokens")]
MissingSpecial(String),
}
pub struct ToolWatcher {
pub start_id: u32,
pub end_id: u32,
pub start_name: String,
pub end_name: String,
inside: bool,
region: Vec<u32>,
}
impl ToolWatcher {
pub fn new(
map: &TokenizerMap,
start_name: &str,
end_name: &str,
) -> Result<Self, ToolWatcherError> {
let specials = map.special_tokens.as_ref().ok_or_else(|| {
ToolWatcherError::MissingSpecial(start_name.to_string())
})?;
let start_id = specials
.get(start_name)
.copied()
.ok_or_else(|| ToolWatcherError::MissingSpecial(start_name.to_string()))?;
let end_id = specials
.get(end_name)
.copied()
.ok_or_else(|| ToolWatcherError::MissingSpecial(end_name.to_string()))?;
Ok(Self {
start_id,
end_id,
start_name: start_name.to_string(),
end_name: end_name.to_string(),
inside: false,
region: Vec::new(),
})
}
pub fn inside(&self) -> bool {
self.inside
}
pub fn reset(&mut self) {
self.inside = false;
self.region.clear();
}
pub fn feed(&mut self, ids: &[u32]) -> Vec<WatcherEvent> {
let mut events: Vec<WatcherEvent> = Vec::new();
let n = ids.len();
let mut pt_start = 0usize;
for i in 0..n {
let id = ids[i];
if !self.inside {
if id == self.start_id {
if i > pt_start {
events.push(WatcherEvent {
kind: WatcherEventKind::Passthrough,
ids: ids[pt_start..i].to_vec(),
});
}
self.inside = true;
self.region.clear();
}
} else if id == self.end_id {
events.push(WatcherEvent {
kind: WatcherEventKind::Region,
ids: std::mem::take(&mut self.region),
});
self.inside = false;
pt_start = i + 1;
} else if id == self.start_id {
} else {
self.region.push(id);
}
}
if !self.inside && pt_start < n {
events.push(WatcherEvent {
kind: WatcherEventKind::Passthrough,
ids: ids[pt_start..n].to_vec(),
});
}
events
}
pub fn feed_i32(&mut self, ids: &[i32]) -> Vec<WatcherEvent> {
let copy: Vec<u32> = ids.iter().map(|&v| v as u32).collect();
self.feed(©)
}
}