use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Emission {
WholeFiles(BTreeMap<String, String>),
UnifiedDiff(String),
Prose(String),
}
impl Emission {
pub fn shape_label(&self) -> &'static str {
match self {
Self::WholeFiles(_) => plugins_protocol::emission_shape::WHOLE_FILES,
Self::UnifiedDiff(_) => plugins_protocol::emission_shape::UNIFIED_DIFF,
Self::Prose(_) => plugins_protocol::emission_shape::PROSE,
}
}
}
pub fn normalize_emission(raw: &str) -> Result<Emission> {
let stripped = strip_outer_fences(raw);
if let Some(files) = try_parse_whole_files(&stripped) {
if !files.is_empty() {
return Ok(Emission::WholeFiles(files));
}
}
if let Some(diff) = try_parse_unified_diff(&stripped) {
return Ok(Emission::UnifiedDiff(diff));
}
Ok(Emission::Prose(stripped))
}
fn strip_outer_fences(raw: &str) -> String {
let trimmed = raw.trim();
if let Some(rest) = trimmed.strip_prefix("```") {
let after_tag = match rest.find('\n') {
Some(nl) => &rest[nl + 1..],
None => rest,
};
let body = after_tag
.strip_suffix("```")
.or_else(|| after_tag.strip_suffix("```\n"))
.unwrap_or(after_tag);
return body.trim_end_matches('\n').to_string();
}
trimmed.to_string()
}
fn try_parse_whole_files(body: &str) -> Option<BTreeMap<String, String>> {
let mut files = BTreeMap::new();
let mut cur_path: Option<String> = None;
let mut cur_buf = String::new();
let mut saw_header = false;
let mut block_body_empty = true;
for line in body.lines() {
if let Some(rest) = line.strip_prefix("FILE: ") {
saw_header = true;
if cur_path.is_some() && block_body_empty {
cur_buf.clear();
cur_path = Some(rest.trim().to_string());
block_body_empty = true;
continue;
}
if let Some(path) = cur_path.take() {
files.insert(path, cur_buf.trim_end_matches('\n').to_string());
cur_buf.clear();
}
cur_path = Some(rest.trim().to_string());
block_body_empty = true;
continue;
}
if line.trim() == "END-FILE" {
if let Some(path) = cur_path.take() {
files.insert(path, cur_buf.trim_end_matches('\n').to_string());
cur_buf.clear();
}
block_body_empty = true;
continue;
}
if cur_path.is_some() {
if !(block_body_empty && line.trim().is_empty()) {
block_body_empty = false;
}
cur_buf.push_str(line);
cur_buf.push('\n');
}
}
if let Some(path) = cur_path {
files.insert(path, cur_buf.trim_end_matches('\n').to_string());
}
if !saw_header {
return None;
}
Some(files)
}
fn try_parse_unified_diff(body: &str) -> Option<String> {
let has_minus = body.starts_with("--- ") || body.contains("\n--- ");
let has_plus = body.contains("\n+++ ");
let has_hunk = body.contains("\n@@ ") || body.contains("@@ -");
if has_minus && has_plus && has_hunk {
Some(body.to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_single_whole_file_block() {
let raw = "FILE: src/lib.rs\npub fn hello() {}\nEND-FILE\n";
let em = normalize_emission(raw).unwrap();
match em {
Emission::WholeFiles(files) => {
assert_eq!(files.len(), 1);
assert_eq!(files.get("src/lib.rs").unwrap(), "pub fn hello() {}");
}
other => panic!("expected WholeFiles, got {other:?}"),
}
}
#[test]
fn parses_multi_file_whole_file_block() {
let raw = "\
FILE: a.rs
pub fn a() {}
END-FILE
FILE: b.rs
pub fn b() {}
END-FILE
";
let em = normalize_emission(raw).unwrap();
match em {
Emission::WholeFiles(files) => {
assert_eq!(files.len(), 2);
assert_eq!(files.get("a.rs").unwrap(), "pub fn a() {}");
assert_eq!(files.get("b.rs").unwrap(), "pub fn b() {}");
}
other => panic!("expected WholeFiles, got {other:?}"),
}
}
#[test]
fn handles_outer_code_fence_around_whole_files() {
let raw = "```\nFILE: a.rs\npub fn x() {}\nEND-FILE\n```";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
assert_eq!(files.get("a.rs").unwrap(), "pub fn x() {}");
} else {
panic!("expected whole files");
}
}
#[test]
fn handles_outer_code_fence_with_language_tag() {
let raw = "```rust\nFILE: a.rs\npub fn x() {}\nEND-FILE\n```";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
assert_eq!(files.get("a.rs").unwrap(), "pub fn x() {}");
} else {
panic!("expected whole files");
}
}
#[test]
fn tolerates_missing_trailing_end_file() {
let raw = "FILE: src/lib.rs\npub fn hello() {}\n";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
assert_eq!(files.get("src/lib.rs").unwrap(), "pub fn hello() {}");
} else {
panic!("expected whole files");
}
}
#[test]
fn parses_unified_diff_when_no_whole_files() {
let raw = "\
--- a/foo.rs
+++ b/foo.rs
@@ -1 +1 @@
-old
+new
";
let em = normalize_emission(raw).unwrap();
assert!(matches!(em, Emission::UnifiedDiff(_)));
}
#[test]
fn falls_back_to_prose_on_plain_text() {
let raw = "I've updated the file successfully.";
let em = normalize_emission(raw).unwrap();
assert!(matches!(em, Emission::Prose(_)));
}
#[test]
fn shape_labels_match_wire_constants() {
let whole = Emission::WholeFiles(BTreeMap::new());
assert_eq!(whole.shape_label(), "whole_files");
let diff = Emission::UnifiedDiff(String::new());
assert_eq!(diff.shape_label(), "unified_diff");
let prose = Emission::Prose(String::new());
assert_eq!(prose.shape_label(), "prose");
}
#[test]
fn empty_input_is_prose() {
let em = normalize_emission("").unwrap();
match em {
Emission::Prose(s) => assert!(s.is_empty()),
other => panic!("expected empty prose, got {other:?}"),
}
}
#[test]
fn whole_files_preferred_over_diff_when_both_present() {
let raw = "\
FILE: src/lib.rs
pub fn hello() {}
END-FILE
--- a/foo
+++ b/foo
@@ -1 +1 @@
-x
+y
";
let em = normalize_emission(raw).unwrap();
assert!(matches!(em, Emission::WholeFiles(_)));
}
#[test]
fn strips_leaked_file_marker_restated_in_body() {
let raw = "FILE: src/lib.rs\nFILE: src/lib.rs\npub fn add(a: i32, b: i32) -> i32 { a + b }\nEND-FILE\n";
let em = normalize_emission(raw).unwrap();
match em {
Emission::WholeFiles(files) => {
assert_eq!(files.len(), 1);
assert_eq!(
files.get("src/lib.rs").unwrap(),
"pub fn add(a: i32, b: i32) -> i32 { a + b }"
);
}
other => panic!("expected WholeFiles, got {other:?}"),
}
}
#[test]
fn strips_leaked_marker_inside_peeled_fence() {
let raw = "```rust\nFILE: src/lib.rs\nFILE: src/lib.rs\npub fn a() {}\n```";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
assert_eq!(files.get("src/lib.rs").unwrap(), "pub fn a() {}");
} else {
panic!("expected whole files");
}
}
#[test]
fn strips_leaked_marker_after_leading_blank() {
let raw = "FILE: src/lib.rs\n\nFILE: src/lib.rs\npub fn a() {}\nEND-FILE\n";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
assert_eq!(files.get("src/lib.rs").unwrap(), "pub fn a() {}");
} else {
panic!("expected whole files");
}
}
#[test]
fn does_not_strip_second_file_block_as_leaked_marker() {
let raw = "\
FILE: a.rs
pub fn a() {}
FILE: b.rs
pub fn b() {}
";
let em = normalize_emission(raw).unwrap();
match em {
Emission::WholeFiles(files) => {
assert_eq!(files.len(), 2);
assert_eq!(files.get("a.rs").unwrap(), "pub fn a() {}");
assert_eq!(files.get("b.rs").unwrap(), "pub fn b() {}");
}
other => panic!("expected two WholeFiles, got {other:?}"),
}
}
#[test]
fn parsed_leaked_marker_body_is_applyable() {
let raw = "FILE: src/lib.rs\nFILE: src/lib.rs\npub fn add() {}\nEND-FILE\n";
let em = normalize_emission(raw).unwrap();
if let Emission::WholeFiles(files) = em {
let contents = files.get("src/lib.rs").unwrap();
let first = contents.lines().find(|l| !l.trim().is_empty()).unwrap();
assert!(
!first.trim_start().starts_with("FILE:"),
"marker leaked: {first}"
);
} else {
panic!("expected whole files");
}
}
}