use crate::adapter::{Adapter, DemoData};
use crate::error::Result;
use crate::intern::{sym, Sym};
use crate::predict::FieldRange;
use crate::signature::Signature;
use crate::str_view::StrView;
use crate::types::Inputs;
use smallvec::SmallVec;
#[derive(Clone, Copy)]
pub struct XMLConfig {
pub pretty: bool,
pub root_element: &'static str,
pub include_declaration: bool,
}
impl Default for XMLConfig {
fn default() -> Self {
Self {
pretty: true,
root_element: "response",
include_declaration: false,
}
}
}
impl XMLConfig {
pub const fn new() -> Self {
Self {
pretty: true,
root_element: "response",
include_declaration: false,
}
}
pub const fn with_pretty(mut self, pretty: bool) -> Self {
self.pretty = pretty;
self
}
pub const fn with_root(mut self, root: &'static str) -> Self {
self.root_element = root;
self
}
pub const fn with_declaration(mut self, include: bool) -> Self {
self.include_declaration = include;
self
}
}
#[derive(Clone, Copy)]
pub struct XMLAdapter {
config: XMLConfig,
}
impl XMLAdapter {
pub const fn new(config: XMLConfig) -> Self {
Self { config }
}
pub const fn default() -> Self {
Self::new(XMLConfig::new())
}
pub const fn config(&self) -> &XMLConfig {
&self.config
}
fn write_schema(&self, buffer: &mut Vec<u8>, signature: &Signature<'_>) {
buffer.push(b'<');
buffer.extend_from_slice(self.config.root_element.as_bytes());
buffer.extend_from_slice(b">\n");
for field in &signature.output_fields {
let name = &field.name;
if self.config.pretty {
buffer.extend_from_slice(b" ");
}
buffer.push(b'<');
buffer.extend_from_slice(name.as_bytes());
buffer.push(b'>');
buffer.extend_from_slice(b"...");
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(name.as_bytes());
buffer.extend_from_slice(b">\n");
}
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(self.config.root_element.as_bytes());
buffer.push(b'>');
}
fn write_input_xml(
&self,
buffer: &mut Vec<u8>,
inputs: &Inputs<'_>,
signature: &Signature<'_>,
) {
buffer.extend_from_slice(b"<input>\n");
for field in &signature.input_fields {
let name = &field.name;
if let Some(value) = inputs.get(name.as_ref()) {
if self.config.pretty {
buffer.extend_from_slice(b" ");
}
buffer.push(b'<');
buffer.extend_from_slice(name.as_bytes());
buffer.push(b'>');
write_escaped_xml(buffer, value);
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(name.as_bytes());
buffer.extend_from_slice(b">\n");
}
}
buffer.extend_from_slice(b"</input>");
}
fn write_demo_xml(&self, buffer: &mut Vec<u8>, demo: &DemoData<'_>, signature: &Signature<'_>) {
use crate::intern::sym;
buffer.extend_from_slice(b"Input:\n<input>\n");
for field in &signature.input_fields {
let field_sym = sym(&field.name);
if let Some(value) = demo.get_input(field_sym) {
let name = &field.name;
if self.config.pretty {
buffer.extend_from_slice(b" ");
}
buffer.push(b'<');
buffer.extend_from_slice(name.as_bytes());
buffer.push(b'>');
write_escaped_xml(buffer, value.as_str());
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(name.as_bytes());
buffer.extend_from_slice(b">\n");
}
}
buffer.extend_from_slice(b"</input>\n\n");
buffer.extend_from_slice(b"Output:\n<");
buffer.extend_from_slice(self.config.root_element.as_bytes());
buffer.extend_from_slice(b">\n");
for field in &signature.output_fields {
let field_sym = sym(&field.name);
if let Some(value) = demo.get_output(field_sym) {
let name = &field.name;
if self.config.pretty {
buffer.extend_from_slice(b" ");
}
buffer.push(b'<');
buffer.extend_from_slice(name.as_bytes());
buffer.push(b'>');
write_escaped_xml(buffer, value.as_str());
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(name.as_bytes());
buffer.extend_from_slice(b">\n");
}
}
buffer.extend_from_slice(b"</");
buffer.extend_from_slice(self.config.root_element.as_bytes());
buffer.extend_from_slice(b">\n");
}
}
fn write_escaped_xml(buffer: &mut Vec<u8>, s: &str) {
for c in s.chars() {
match c {
'<' => buffer.extend_from_slice(b"<"),
'>' => buffer.extend_from_slice(b">"),
'&' => buffer.extend_from_slice(b"&"),
'"' => buffer.extend_from_slice(b"""),
'\'' => buffer.extend_from_slice(b"'"),
c => {
let mut buf = [0u8; 4];
buffer.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
}
}
}
}
#[allow(dead_code)] fn unescape_xml(s: &str) -> String {
s.replace("<", "<")
.replace(">", ">")
.replace("&", "&")
.replace(""", "\"")
.replace("'", "'")
}
fn find_xml_element(text: &str, element_name: &str) -> Option<std::ops::Range<usize>> {
let open_tag = format!("<{}>", element_name);
let close_tag = format!("</{}>", element_name);
let start = text.find(&open_tag)?;
let content_start = start + open_tag.len();
let rest = &text[content_start..];
let end = rest.find(&close_tag)?;
Some(content_start..content_start + end)
}
fn find_xml_element_with_attrs(text: &str, element_name: &str) -> Option<std::ops::Range<usize>> {
if let Some(range) = find_xml_element(text, element_name) {
return Some(range);
}
let pattern = format!("<{}", element_name);
let start = text.find(&pattern)?;
let rest = &text[start..];
let tag_end = rest.find('>')?;
let content_start = start + tag_end + 1;
let close_tag = format!("</{}>", element_name);
let rest = &text[content_start..];
let end = rest.find(&close_tag)?;
Some(content_start..content_start + end)
}
impl Adapter for XMLAdapter {
fn format<'a>(
&self,
buffer: &'a mut Vec<u8>,
signature: &Signature<'_>,
inputs: &Inputs<'_>,
demos: &[DemoData<'_>],
) -> StrView<'a> {
buffer.clear();
if self.config.include_declaration {
buffer.extend_from_slice(b"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\n");
}
if !signature.instructions.is_empty() {
buffer.extend_from_slice(signature.instructions.as_bytes());
buffer.extend_from_slice(b"\n\n");
}
buffer.extend_from_slice(b"Respond with XML in this format:\n");
self.write_schema(buffer, signature);
buffer.extend_from_slice(b"\n\n");
for (i, demo) in demos.iter().enumerate() {
if i > 0 {
buffer.extend_from_slice(b"\n---\n\n");
}
self.write_demo_xml(buffer, demo, signature);
}
if !demos.is_empty() {
buffer.extend_from_slice(b"\n---\n\n");
}
buffer.extend_from_slice(b"Input:\n");
self.write_input_xml(buffer, inputs, signature);
buffer.extend_from_slice(b"\n\nOutput:\n");
unsafe { StrView::from_raw_parts(buffer.as_ptr(), buffer.len()) }
}
fn parse<'a>(
&self,
response: StrView<'a>,
signature: &Signature<'_>,
) -> Result<SmallVec<[(Sym, FieldRange); 4]>> {
let text = response.as_str();
let mut fields = SmallVec::new();
for field in &signature.output_fields {
let field_name = &field.name;
if let Some(range) = find_xml_element_with_attrs(text, field_name) {
fields.push((
sym(field_name),
FieldRange::new(range.start as u32, range.end as u32),
));
}
}
Ok(fields)
}
fn name(&self) -> &'static str {
"XML"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::intern::sym;
#[test]
fn test_xml_adapter_creation() {
let adapter = XMLAdapter::default();
assert_eq!(adapter.name(), "XML");
assert!(adapter.config().pretty);
}
#[test]
fn test_xml_config() {
let config = XMLConfig::new()
.with_pretty(false)
.with_root("output")
.with_declaration(true);
assert!(!config.pretty);
assert_eq!(config.root_element, "output");
assert!(config.include_declaration);
}
#[test]
fn test_escape_xml() {
let mut buffer = Vec::new();
write_escaped_xml(&mut buffer, "<script>alert('xss')</script>");
assert_eq!(
String::from_utf8(buffer).unwrap(),
"<script>alert('xss')</script>"
);
}
#[test]
fn test_unescape_xml() {
assert_eq!(
unescape_xml("<script>alert('xss')</script>"),
"<script>alert('xss')</script>"
);
}
#[test]
fn test_find_xml_element() {
let xml = "<response><answer>42</answer></response>";
let range = find_xml_element(xml, "answer");
assert!(range.is_some());
assert_eq!(&xml[range.unwrap()], "42");
}
#[test]
fn test_find_xml_element_with_attrs() {
let xml = r#"<response><answer type="number">42</answer></response>"#;
let range = find_xml_element_with_attrs(xml, "answer");
assert!(range.is_some());
assert_eq!(&xml[range.unwrap()], "42");
}
#[test]
fn test_format_basic() {
let adapter = XMLAdapter::default();
let sig = Signature::parse("question -> answer").unwrap();
let mut buffer = Vec::new();
let mut inputs = Inputs::new();
inputs.insert("question", "What is 2+2?");
let prompt = adapter.format(&mut buffer, &sig, &inputs, &[]);
assert!(prompt
.as_str()
.contains("<question>What is 2+2?</question>"));
assert!(prompt.as_str().contains("<answer>...</answer>"));
assert!(prompt.as_str().contains("Output:"));
}
#[test]
fn test_parse_response() {
let adapter = XMLAdapter::default();
let sig = Signature::parse("question -> answer").unwrap();
let response = StrView::new("<response><answer>4</answer></response>");
let fields = adapter.parse(response, &sig).unwrap();
assert_eq!(fields.len(), 1);
assert_eq!(fields[0].0, sym("answer"));
let range = fields[0].1.as_range();
assert_eq!(&response.as_str()[range], "4");
}
#[test]
fn test_format_with_demo() {
use crate::intern::sym;
let adapter = XMLAdapter::default();
let sig = Signature::parse("question -> answer").unwrap();
let q_sym = sym("question");
let a_sym = sym("answer");
let demo_inputs = [(q_sym, StrView::new("What is 1+1?"))];
let demo_outputs = [(a_sym, StrView::new("2"))];
let demo = DemoData::new(&demo_inputs, &demo_outputs);
let mut buffer = Vec::new();
let mut inputs = Inputs::new();
inputs.insert("question", "What is 2+2?");
let prompt = adapter.format(&mut buffer, &sig, &inputs, &[demo]);
assert!(prompt
.as_str()
.contains("<question>What is 1+1?</question>"));
assert!(prompt.as_str().contains("<answer>2</answer>"));
assert!(prompt
.as_str()
.contains("<question>What is 2+2?</question>"));
}
}