use std::collections::{BTreeMap, BTreeSet};
use thiserror::Error;
pub type Attributes = BTreeMap<String, String>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParserEvent {
Text(String),
StartTag(String, Attributes),
Chunk(Vec<u8>),
EndTag(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Instruction {
Text(String),
StartTag {
name: String,
attributes: Attributes,
},
EndTag(String),
WriteChunk(Vec<u8>),
RawChunk {
tag: String,
bytes: Vec<u8>,
},
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ParserError {
#[error("invalid utf-8 in text")]
InvalidUtf8,
#[error("malformed tag: {0}")]
MalformedTag(String),
#[error("unexpected closing tag: </{found}>")]
UnexpectedClosingTag { found: String },
#[error("mismatched closing tag: expected </{expected}> but found </{found}>")]
MismatchedClosingTag { expected: String, found: String },
#[error("unterminated tag")]
UnterminatedTag,
#[error("unterminated raw section for <{0}>")]
UnterminatedRawSection(String),
#[error("unclosed tag(s): {0}")]
UnclosedTags(String),
#[error("chunk emitted with no active raw tag")]
UnexpectedChunk,
}
#[derive(Debug)]
pub struct StreamingParser {
buffer: Vec<u8>,
open_tags: Vec<String>,
raw_tag: Option<String>,
raw_tags: BTreeSet<String>,
}
impl Default for StreamingParser {
fn default() -> Self {
Self::new()
}
}
impl StreamingParser {
pub fn new() -> Self {
Self::with_raw_tags(default_raw_tags())
}
pub fn with_raw_tags<I, S>(raw_tags: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let raw_tags = raw_tags.into_iter().map(Into::into).collect();
Self {
buffer: Vec::new(),
open_tags: Vec::new(),
raw_tag: None,
raw_tags,
}
}
pub fn feed(&mut self, input: &[u8]) -> Result<Vec<ParserEvent>, ParserError> {
if !input.is_empty() {
self.buffer.extend_from_slice(input);
}
let mut events = Vec::new();
loop {
if let Some(active_raw) = self.raw_tag.clone() {
let needle = format!("</{active_raw}>").into_bytes();
if let Some(idx) = find_subsequence(&self.buffer, &needle) {
if idx > 0 {
events.push(ParserEvent::Chunk(self.buffer[..idx].to_vec()));
}
self.buffer.drain(..idx + needle.len());
self.raw_tag = None;
self.pop_expected_tag(&active_raw)?;
events.push(ParserEvent::EndTag(active_raw));
continue;
}
let keep_tail = needle.len().saturating_sub(1);
if self.buffer.len() > keep_tail {
let emit_len = self.buffer.len() - keep_tail;
events.push(ParserEvent::Chunk(self.buffer[..emit_len].to_vec()));
self.buffer.drain(..emit_len);
}
break;
}
let Some(tag_start) = self.buffer.iter().position(|b| *b == b'<') else {
if !self.buffer.is_empty() {
events.push(ParserEvent::Text(to_utf8_lossless(&self.buffer)?));
self.buffer.clear();
}
break;
};
if tag_start > 0 {
events.push(ParserEvent::Text(to_utf8_lossless(
&self.buffer[..tag_start],
)?));
self.buffer.drain(..tag_start);
continue;
}
let Some(tag_end) = find_tag_end(&self.buffer) else {
break;
};
let tag_bytes: Vec<u8> = self.buffer.drain(..=tag_end).collect();
let parsed = parse_tag(&tag_bytes)?;
match parsed {
ParsedTag::Start {
name,
attributes,
self_closing,
} => {
events.push(ParserEvent::StartTag(name.clone(), attributes));
if self_closing {
events.push(ParserEvent::EndTag(name));
} else {
self.open_tags.push(name.clone());
if self.raw_tags.contains(&name) {
self.raw_tag = Some(name);
}
}
}
ParsedTag::End { name } => {
self.pop_expected_tag(&name)?;
events.push(ParserEvent::EndTag(name));
}
}
}
Ok(events)
}
pub fn finish(&mut self) -> Result<Vec<ParserEvent>, ParserError> {
let mut events = self.feed(&[])?;
if let Some(raw_tag) = &self.raw_tag {
return Err(ParserError::UnterminatedRawSection(raw_tag.to_string()));
}
if !self.buffer.is_empty() {
if self.buffer[0] == b'<' {
return Err(ParserError::UnterminatedTag);
}
events.push(ParserEvent::Text(to_utf8_lossless(&self.buffer)?));
self.buffer.clear();
}
if !self.open_tags.is_empty() {
return Err(ParserError::UnclosedTags(self.open_tags.join(", ")));
}
Ok(events)
}
fn pop_expected_tag(&mut self, closing_tag: &str) -> Result<(), ParserError> {
let Some(last) = self.open_tags.pop() else {
return Err(ParserError::UnexpectedClosingTag {
found: closing_tag.to_string(),
});
};
if last != closing_tag {
return Err(ParserError::MismatchedClosingTag {
expected: last,
found: closing_tag.to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct InstructionParser {
parser: StreamingParser,
raw_tag_stack: Vec<String>,
}
impl InstructionParser {
pub fn new() -> Self {
Self::default()
}
pub fn feed(&mut self, input: &[u8]) -> Result<Vec<Instruction>, ParserError> {
let events = self.parser.feed(input)?;
self.map_events(events)
}
pub fn finish(&mut self) -> Result<Vec<Instruction>, ParserError> {
let events = self.parser.finish()?;
self.map_events(events)
}
fn map_events(&mut self, events: Vec<ParserEvent>) -> Result<Vec<Instruction>, ParserError> {
let mut instructions = Vec::with_capacity(events.len());
for event in events {
match event {
ParserEvent::Text(text) => instructions.push(Instruction::Text(text)),
ParserEvent::StartTag(name, attributes) => {
if is_raw_tag(&name) {
self.raw_tag_stack.push(name.clone());
}
instructions.push(Instruction::StartTag { name, attributes });
}
ParserEvent::EndTag(name) => {
if self.raw_tag_stack.last().is_some_and(|v| v == &name) {
self.raw_tag_stack.pop();
}
instructions.push(Instruction::EndTag(name));
}
ParserEvent::Chunk(bytes) => {
let Some(active_tag) = self.raw_tag_stack.last() else {
return Err(ParserError::UnexpectedChunk);
};
if active_tag == "write_file" {
instructions.push(Instruction::WriteChunk(bytes));
} else {
instructions.push(Instruction::RawChunk {
tag: active_tag.clone(),
bytes,
});
}
}
}
}
Ok(instructions)
}
}
fn to_utf8_lossless(bytes: &[u8]) -> Result<String, ParserError> {
String::from_utf8(bytes.to_vec()).map_err(|_| ParserError::InvalidUtf8)
}
fn default_raw_tags() -> [&'static str; 5] {
["write_file", "apply_edit", "search", "replace", "terminal"]
}
fn is_raw_tag(tag: &str) -> bool {
default_raw_tags().contains(&tag)
}
#[derive(Debug)]
enum ParsedTag {
Start {
name: String,
attributes: Attributes,
self_closing: bool,
},
End {
name: String,
},
}
fn parse_tag(tag_bytes: &[u8]) -> Result<ParsedTag, ParserError> {
let tag = std::str::from_utf8(tag_bytes).map_err(|_| ParserError::InvalidUtf8)?;
if !tag.starts_with('<') || !tag.ends_with('>') {
return Err(ParserError::MalformedTag(tag.to_string()));
}
let mut inner = tag[1..tag.len() - 1].trim();
if inner.is_empty() {
return Err(ParserError::MalformedTag(tag.to_string()));
}
if let Some(stripped) = inner.strip_prefix('/') {
let name = stripped.trim();
if !is_valid_name(name) {
return Err(ParserError::MalformedTag(tag.to_string()));
}
return Ok(ParsedTag::End {
name: name.to_string(),
});
}
let self_closing = inner.ends_with('/');
if self_closing {
inner = inner[..inner.len() - 1].trim_end();
}
let name_end = inner.find(char::is_whitespace).unwrap_or(inner.len());
let name = &inner[..name_end];
if !is_valid_name(name) {
return Err(ParserError::MalformedTag(tag.to_string()));
}
let mut attributes = Attributes::new();
let mut cursor = inner[name_end..].trim_start();
while !cursor.is_empty() {
let eq_idx = cursor
.find('=')
.ok_or_else(|| ParserError::MalformedTag(tag.to_string()))?;
let key = cursor[..eq_idx].trim();
if !is_valid_name(key) {
return Err(ParserError::MalformedTag(tag.to_string()));
}
cursor = cursor[eq_idx + 1..].trim_start();
let Some(quote) = cursor.chars().next() else {
return Err(ParserError::MalformedTag(tag.to_string()));
};
if quote != '"' && quote != '\'' {
return Err(ParserError::MalformedTag(tag.to_string()));
}
cursor = &cursor[quote.len_utf8()..];
let mut value_end = None;
for (idx, ch) in cursor.char_indices() {
if ch == quote {
value_end = Some(idx);
break;
}
}
let Some(value_end) = value_end else {
return Err(ParserError::MalformedTag(tag.to_string()));
};
let value = &cursor[..value_end];
attributes.insert(key.to_string(), value.to_string());
cursor = cursor[value_end + quote.len_utf8()..].trim_start();
}
Ok(ParsedTag::Start {
name: name.to_string(),
attributes,
self_closing,
})
}
fn is_valid_name(name: &str) -> bool {
!name.is_empty()
&& !name.contains(char::is_whitespace)
&& !name.contains('/')
&& !name.contains('>')
&& !name.contains('<')
}
fn find_tag_end(bytes: &[u8]) -> Option<usize> {
let mut in_quote = None;
for (idx, byte) in bytes.iter().enumerate().skip(1) {
match (*byte, in_quote) {
(b'\'' | b'"', None) => in_quote = Some(*byte),
(b'\'' | b'"', Some(q)) if q == *byte => in_quote = None,
(b'>', None) => return Some(idx),
_ => {}
}
}
None
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
haystack
.windows(needle.len())
.position(|window| window == needle)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_self_closing_tag() {
let mut parser = StreamingParser::new();
let events = parser
.feed(br#"<read_file path="src/lib.rs" start_line="1" end_line="5" />"#)
.expect("parse should succeed");
let finished = parser.finish().expect("finish should succeed");
assert_eq!(
events,
vec![
ParserEvent::StartTag(
"read_file".to_string(),
BTreeMap::from([
("end_line".to_string(), "5".to_string()),
("path".to_string(), "src/lib.rs".to_string()),
("start_line".to_string(), "1".to_string()),
])
),
ParserEvent::EndTag("read_file".to_string())
]
);
assert!(finished.is_empty());
}
#[test]
fn parses_raw_write_file_body_even_with_pseudo_tags_inside() {
let mut parser = StreamingParser::new();
let first = parser
.feed(b"<write_file path=\"src/main.rs\">fn main() {<not_a_tag>")
.expect("first parse should succeed");
assert!(matches!(
first.first(),
Some(ParserEvent::StartTag(name, _)) if name == "write_file"
));
let second = parser
.feed(b" println!(\"ok\"); }</write_file>")
.expect("second parse should succeed");
assert!(matches!(
second.last(),
Some(ParserEvent::EndTag(name)) if name == "write_file"
));
let full_body = collect_chunks(&[first.clone(), second.clone()]);
assert_eq!(full_body, b"fn main() {<not_a_tag> println!(\"ok\"); }");
assert!(parser.finish().expect("finish should succeed").is_empty());
}
#[test]
fn treats_apply_edit_body_as_raw_payload() {
let mut parser = StreamingParser::new();
let events = parser
.feed(
b"<apply_edit path=\"src/lib.rs\"><search>old</search><replace>new</replace></apply_edit>",
)
.expect("parse should succeed");
let finished = parser.finish().expect("finish should succeed");
assert!(finished.is_empty());
assert!(matches!(
events.first(),
Some(ParserEvent::StartTag(name, _)) if name == "apply_edit"
));
assert!(matches!(
events.last(),
Some(ParserEvent::EndTag(name)) if name == "apply_edit"
));
let body = collect_chunks(&[events]);
assert_eq!(body, b"<search>old</search><replace>new</replace>");
}
#[test]
fn instruction_parser_emits_raw_chunk_for_apply_edit_body() {
let mut parser = InstructionParser::new();
let instructions = parser
.feed(b"<apply_edit path=\"src/lib.rs\">@@ -1 +1 @@\n-old\n+new\n</apply_edit>")
.expect("instruction parse should succeed");
let final_batch = parser.finish().expect("finish should succeed");
assert!(final_batch.is_empty());
assert!(matches!(
instructions.first(),
Some(Instruction::StartTag { name, .. }) if name == "apply_edit"
));
assert!(matches!(
instructions.last(),
Some(Instruction::EndTag(name)) if name == "apply_edit"
));
assert!(instructions.iter().any(|instruction| matches!(
instruction,
Instruction::RawChunk { tag, .. } if tag == "apply_edit"
)));
}
#[test]
fn handles_raw_tag_close_across_chunk_boundaries() {
let mut parser = StreamingParser::new();
let events_1 = parser
.feed(b"<terminal>cargo t")
.expect("feed should succeed");
assert_eq!(
events_1,
vec![ParserEvent::StartTag(
"terminal".to_string(),
BTreeMap::new()
)]
);
let events_2 = parser.feed(b"est</ter").expect("feed should succeed");
assert!(!events_2.is_empty());
let events_3 = parser.feed(b"minal>").expect("feed should succeed");
assert!(matches!(
events_3.last(),
Some(ParserEvent::EndTag(name)) if name == "terminal"
));
let full_body = collect_chunks(&[events_1, events_2, events_3]);
assert_eq!(full_body, b"cargo test");
assert!(parser.finish().expect("finish should succeed").is_empty());
}
#[test]
fn returns_mismatched_closing_tag_error() {
let mut parser = StreamingParser::new();
let _ = parser
.feed(b"<read_file path=\"src/lib.rs\">")
.expect("opening tag should parse");
let err = parser.feed(b"</write_file>").expect_err("should fail");
assert_eq!(
err,
ParserError::MismatchedClosingTag {
expected: "read_file".to_string(),
found: "write_file".to_string(),
}
);
}
#[test]
fn returns_unterminated_tag_error_on_finish() {
let mut parser = StreamingParser::new();
let _ = parser
.feed(b"<read_file path=\"a")
.expect("feed should work");
let err = parser.finish().expect_err("finish should fail");
assert_eq!(err, ParserError::UnterminatedTag);
}
#[test]
fn returns_unterminated_raw_section_error_on_finish() {
let mut parser = StreamingParser::new();
let _ = parser
.feed(b"<write_file path=\"x\">partial")
.expect("feed should work");
let err = parser.finish().expect_err("finish should fail");
assert_eq!(
err,
ParserError::UnterminatedRawSection("write_file".to_string())
);
}
#[test]
fn instruction_parser_emits_write_chunk_for_write_file_content() {
let mut parser = InstructionParser::new();
let batch_1 = parser
.feed(b"<write_file path=\"src/main.rs\">hello")
.expect("instruction batch 1 should parse");
let batch_2 = parser
.feed(b" world</write_file>")
.expect("instruction batch 2 should parse");
let final_batch = parser.finish().expect("finish should parse");
assert!(matches!(
batch_1.first(),
Some(Instruction::StartTag { name, .. }) if name == "write_file"
));
let mut body = Vec::new();
for batch in [&batch_1, &batch_2, &final_batch] {
for instruction in batch.iter() {
if let Instruction::WriteChunk(bytes) = instruction {
body.extend_from_slice(bytes);
}
}
}
assert_eq!(body, b"hello world");
assert!(matches!(
batch_2.last(),
Some(Instruction::EndTag(name)) if name == "write_file"
));
}
fn collect_chunks(batches: &[Vec<ParserEvent>]) -> Vec<u8> {
let mut out = Vec::new();
for batch in batches {
for event in batch {
if let ParserEvent::Chunk(bytes) = event {
out.extend_from_slice(bytes);
}
}
}
out
}
}