Skip to main content

atomr_agents_parser/
basic.rs

1//! Basic parsers.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use atomr_agents_core::{AgentError, Result, Value};
7use serde::de::DeserializeOwned;
8
9use crate::Parser;
10
11// --------------------------------------------------------------------
12// JsonParser — `Value`
13// --------------------------------------------------------------------
14
15#[derive(Default)]
16pub struct JsonParser;
17
18#[async_trait]
19impl Parser<Value> for JsonParser {
20    async fn parse(&self, raw: &str) -> Result<Value> {
21        let raw = strip_code_fence(raw);
22        serde_json::from_str(&raw).map_err(|e| AgentError::Tool(format!("json parse: {e}")))
23    }
24    fn format_instructions(&self) -> String {
25        "Respond with a single valid JSON value.".into()
26    }
27}
28
29// --------------------------------------------------------------------
30// JsonSchemaParser — `Value`, validated against a JSON-Schema-shaped
31// guard. (Lightweight: only checks `type`, `required`, top-level
32// property types — enough for tests; production users plug in a real
33// JSON-Schema validator.)
34// --------------------------------------------------------------------
35
36pub struct JsonSchemaParser {
37    pub schema: Value,
38}
39
40impl JsonSchemaParser {
41    pub fn new(schema: Value) -> Self {
42        Self { schema }
43    }
44}
45
46#[async_trait]
47impl Parser<Value> for JsonSchemaParser {
48    async fn parse(&self, raw: &str) -> Result<Value> {
49        let v: Value = JsonParser.parse(raw).await?;
50        validate(&self.schema, &v)?;
51        Ok(v)
52    }
53    fn format_instructions(&self) -> String {
54        format!(
55            "Respond with JSON matching this schema:\n```\n{}\n```",
56            serde_json::to_string_pretty(&self.schema).unwrap_or_default()
57        )
58    }
59}
60
61fn validate(schema: &Value, v: &Value) -> Result<()> {
62    let want_type = schema.get("type").and_then(|t| t.as_str()).unwrap_or("");
63    if want_type == "object" {
64        if !v.is_object() {
65            return Err(AgentError::Tool("expected object".into()));
66        }
67        if let Some(req) = schema.get("required").and_then(|r| r.as_array()) {
68            for r in req {
69                let key = r.as_str().unwrap_or("");
70                if v.get(key).is_none() {
71                    return Err(AgentError::Tool(format!("missing required field '{key}'")));
72                }
73            }
74        }
75    } else if want_type == "array" && !v.is_array() {
76        return Err(AgentError::Tool("expected array".into()));
77    } else if want_type == "string" && !v.is_string() {
78        return Err(AgentError::Tool("expected string".into()));
79    } else if want_type == "integer" && !v.is_i64() {
80        return Err(AgentError::Tool("expected integer".into()));
81    }
82    Ok(())
83}
84
85// --------------------------------------------------------------------
86// SchemaParser<T> — Pydantic-style: deserialize into a typed Rust
87// struct, with format-instructions surfacing the schema description.
88// --------------------------------------------------------------------
89
90pub struct SchemaParser<T> {
91    pub instructions: String,
92    _marker: PhantomData<fn() -> T>,
93}
94
95impl<T> SchemaParser<T> {
96    pub fn new(instructions: impl Into<String>) -> Self {
97        Self {
98            instructions: instructions.into(),
99            _marker: PhantomData,
100        }
101    }
102}
103
104#[async_trait]
105impl<T: DeserializeOwned + Send + Sync + 'static> Parser<T> for SchemaParser<T> {
106    async fn parse(&self, raw: &str) -> Result<T> {
107        let raw = strip_code_fence(raw);
108        serde_json::from_str(&raw).map_err(|e| AgentError::Tool(format!("schema parse: {e}")))
109    }
110    fn format_instructions(&self) -> String {
111        self.instructions.clone()
112    }
113}
114
115// --------------------------------------------------------------------
116// EnumParser
117// --------------------------------------------------------------------
118
119pub struct EnumParser {
120    pub variants: Vec<String>,
121}
122
123impl EnumParser {
124    pub fn new<I: IntoIterator<Item = impl Into<String>>>(variants: I) -> Self {
125        Self {
126            variants: variants.into_iter().map(Into::into).collect(),
127        }
128    }
129}
130
131#[async_trait]
132impl Parser<String> for EnumParser {
133    async fn parse(&self, raw: &str) -> Result<String> {
134        let raw = raw.trim();
135        for v in &self.variants {
136            if v.eq_ignore_ascii_case(raw) {
137                return Ok(v.clone());
138            }
139        }
140        Err(AgentError::Tool(format!(
141            "{raw:?} not one of {:?}",
142            self.variants
143        )))
144    }
145    fn format_instructions(&self) -> String {
146        format!("Reply with exactly one of: {}", self.variants.join(", "))
147    }
148}
149
150// --------------------------------------------------------------------
151// CommaListParser
152// --------------------------------------------------------------------
153
154pub struct CommaListParser;
155
156#[async_trait]
157impl Parser<Vec<String>> for CommaListParser {
158    async fn parse(&self, raw: &str) -> Result<Vec<String>> {
159        Ok(raw
160            .split(',')
161            .map(|s| s.trim().to_string())
162            .filter(|s| !s.is_empty())
163            .collect())
164    }
165    fn format_instructions(&self) -> String {
166        "Reply with a comma-separated list of values.".into()
167    }
168}
169
170// --------------------------------------------------------------------
171// XmlParser — naive: extracts top-level <tag>contents</tag>
172// pairs into a flat object.
173// --------------------------------------------------------------------
174
175pub struct XmlParser;
176
177#[async_trait]
178impl Parser<Value> for XmlParser {
179    async fn parse(&self, raw: &str) -> Result<Value> {
180        let mut out = serde_json::Map::new();
181        let mut idx = 0;
182        let bytes = raw.as_bytes();
183        while idx < bytes.len() {
184            // find '<'
185            while idx < bytes.len() && bytes[idx] != b'<' {
186                idx += 1;
187            }
188            if idx >= bytes.len() {
189                break;
190            }
191            let tag_start = idx + 1;
192            // find '>'
193            let mut tag_end = tag_start;
194            while tag_end < bytes.len() && bytes[tag_end] != b'>' {
195                tag_end += 1;
196            }
197            if tag_end >= bytes.len() {
198                break;
199            }
200            let tag = &raw[tag_start..tag_end];
201            if tag.starts_with('/') {
202                idx = tag_end + 1;
203                continue;
204            }
205            let close = format!("</{tag}>");
206            if let Some(close_pos) = raw[tag_end..].find(&close) {
207                let body_start = tag_end + 1;
208                let body_end = tag_end + close_pos;
209                let body = &raw[body_start..body_end];
210                out.insert(tag.to_string(), Value::String(body.trim().to_string()));
211                idx = body_end + close.len();
212            } else {
213                idx = tag_end + 1;
214            }
215        }
216        if out.is_empty() {
217            return Err(AgentError::Tool("xml parse: no tags found".into()));
218        }
219        Ok(Value::Object(out))
220    }
221    fn format_instructions(&self) -> String {
222        "Wrap each field in matching XML tags, e.g. <name>Alice</name>.".into()
223    }
224}
225
226// --------------------------------------------------------------------
227// YamlParser — accepts a small `key: value` dialect (one pair per
228// line, no nesting). Sufficient for unit tests; users can plug in a
229// full YAML crate behind a feature flag later.
230// --------------------------------------------------------------------
231
232pub struct YamlParser;
233
234#[async_trait]
235impl Parser<Value> for YamlParser {
236    async fn parse(&self, raw: &str) -> Result<Value> {
237        let mut out = serde_json::Map::new();
238        for line in raw.lines() {
239            let l = line.trim();
240            if l.is_empty() || l.starts_with('#') {
241                continue;
242            }
243            if let Some((k, v)) = l.split_once(':') {
244                let k = k.trim();
245                let v = v.trim();
246                if k.is_empty() {
247                    continue;
248                }
249                out.insert(k.to_string(), Value::String(v.to_string()));
250            }
251        }
252        if out.is_empty() {
253            return Err(AgentError::Tool("yaml parse: no key/value pairs".into()));
254        }
255        Ok(Value::Object(out))
256    }
257    fn format_instructions(&self) -> String {
258        "Reply with one key: value pair per line.".into()
259    }
260}
261
262fn strip_code_fence(s: &str) -> String {
263    let s = s.trim();
264    if s.starts_with("```") {
265        let mut lines: Vec<&str> = s.lines().collect();
266        if lines.first().map(|l| l.starts_with("```")).unwrap_or(false) {
267            lines.remove(0);
268        }
269        if lines.last().map(|l| l.trim() == "```").unwrap_or(false) {
270            lines.pop();
271        }
272        return lines.join("\n");
273    }
274    s.to_string()
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use serde::Deserialize;
281
282    #[derive(Debug, Deserialize, PartialEq)]
283    struct Plan {
284        title: String,
285        steps: Vec<String>,
286    }
287
288    #[tokio::test]
289    async fn json_strips_fence() {
290        let p = JsonParser;
291        let v = p.parse("```json\n{\"a\":1}\n```").await.unwrap();
292        assert_eq!(v, serde_json::json!({"a": 1}));
293    }
294
295    #[tokio::test]
296    async fn schema_parser_round_trips_typed_struct() {
297        let p: SchemaParser<Plan> = SchemaParser::new("...");
298        let v = p.parse(r#"{"title":"x","steps":["a","b"]}"#).await.unwrap();
299        assert_eq!(v.title, "x");
300        assert_eq!(v.steps.len(), 2);
301    }
302
303    #[tokio::test]
304    async fn schema_validation_catches_missing_field() {
305        let p = JsonSchemaParser::new(serde_json::json!({
306            "type": "object",
307            "required": ["a", "b"]
308        }));
309        let r = p.parse(r#"{"a":1}"#).await;
310        assert!(r.is_err());
311    }
312
313    #[tokio::test]
314    async fn enum_parser_normalizes_case() {
315        let p = EnumParser::new(["yes", "no"]);
316        assert_eq!(p.parse("YES").await.unwrap(), "yes");
317        assert!(p.parse("maybe").await.is_err());
318    }
319
320    #[tokio::test]
321    async fn comma_list_parses_with_trim() {
322        let p = CommaListParser;
323        assert_eq!(p.parse("a, b,c , ").await.unwrap(), vec!["a", "b", "c"]);
324    }
325
326    #[tokio::test]
327    async fn xml_parser_extracts_top_level_tags() {
328        let p = XmlParser;
329        let v = p.parse("<name>Alice</name><city>NYC</city>").await.unwrap();
330        assert_eq!(v["name"], "Alice");
331        assert_eq!(v["city"], "NYC");
332    }
333
334    #[tokio::test]
335    async fn yaml_parser_simple_dialect() {
336        let p = YamlParser;
337        let v = p.parse("name: Alice\nrole: admin\n").await.unwrap();
338        assert_eq!(v["name"], "Alice");
339    }
340}