atomr_agents_parser/
basic.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use atomr_agents_core::{AgentError, Result, Value};
7use serde::de::DeserializeOwned;
8
9use crate::Parser;
10
11#[derive(Default)]
16pub struct JsonParser;
17
18#[async_trait]
19impl Parser<Value> for JsonParser {
20 async fn parse(&self, raw: &str) -> Result<Value> {
21 let raw = strip_code_fence(raw);
22 serde_json::from_str(&raw).map_err(|e| AgentError::Tool(format!("json parse: {e}")))
23 }
24 fn format_instructions(&self) -> String {
25 "Respond with a single valid JSON value.".into()
26 }
27}
28
29pub struct JsonSchemaParser {
37 pub schema: Value,
38}
39
40impl JsonSchemaParser {
41 pub fn new(schema: Value) -> Self {
42 Self { schema }
43 }
44}
45
46#[async_trait]
47impl Parser<Value> for JsonSchemaParser {
48 async fn parse(&self, raw: &str) -> Result<Value> {
49 let v: Value = JsonParser.parse(raw).await?;
50 validate(&self.schema, &v)?;
51 Ok(v)
52 }
53 fn format_instructions(&self) -> String {
54 format!(
55 "Respond with JSON matching this schema:\n```\n{}\n```",
56 serde_json::to_string_pretty(&self.schema).unwrap_or_default()
57 )
58 }
59}
60
61fn validate(schema: &Value, v: &Value) -> Result<()> {
62 let want_type = schema.get("type").and_then(|t| t.as_str()).unwrap_or("");
63 if want_type == "object" {
64 if !v.is_object() {
65 return Err(AgentError::Tool("expected object".into()));
66 }
67 if let Some(req) = schema.get("required").and_then(|r| r.as_array()) {
68 for r in req {
69 let key = r.as_str().unwrap_or("");
70 if v.get(key).is_none() {
71 return Err(AgentError::Tool(format!("missing required field '{key}'")));
72 }
73 }
74 }
75 } else if want_type == "array" && !v.is_array() {
76 return Err(AgentError::Tool("expected array".into()));
77 } else if want_type == "string" && !v.is_string() {
78 return Err(AgentError::Tool("expected string".into()));
79 } else if want_type == "integer" && !v.is_i64() {
80 return Err(AgentError::Tool("expected integer".into()));
81 }
82 Ok(())
83}
84
85pub struct SchemaParser<T> {
91 pub instructions: String,
92 _marker: PhantomData<fn() -> T>,
93}
94
95impl<T> SchemaParser<T> {
96 pub fn new(instructions: impl Into<String>) -> Self {
97 Self {
98 instructions: instructions.into(),
99 _marker: PhantomData,
100 }
101 }
102}
103
104#[async_trait]
105impl<T: DeserializeOwned + Send + Sync + 'static> Parser<T> for SchemaParser<T> {
106 async fn parse(&self, raw: &str) -> Result<T> {
107 let raw = strip_code_fence(raw);
108 serde_json::from_str(&raw).map_err(|e| AgentError::Tool(format!("schema parse: {e}")))
109 }
110 fn format_instructions(&self) -> String {
111 self.instructions.clone()
112 }
113}
114
115pub struct EnumParser {
120 pub variants: Vec<String>,
121}
122
123impl EnumParser {
124 pub fn new<I: IntoIterator<Item = impl Into<String>>>(variants: I) -> Self {
125 Self {
126 variants: variants.into_iter().map(Into::into).collect(),
127 }
128 }
129}
130
131#[async_trait]
132impl Parser<String> for EnumParser {
133 async fn parse(&self, raw: &str) -> Result<String> {
134 let raw = raw.trim();
135 for v in &self.variants {
136 if v.eq_ignore_ascii_case(raw) {
137 return Ok(v.clone());
138 }
139 }
140 Err(AgentError::Tool(format!(
141 "{raw:?} not one of {:?}",
142 self.variants
143 )))
144 }
145 fn format_instructions(&self) -> String {
146 format!("Reply with exactly one of: {}", self.variants.join(", "))
147 }
148}
149
150pub struct CommaListParser;
155
156#[async_trait]
157impl Parser<Vec<String>> for CommaListParser {
158 async fn parse(&self, raw: &str) -> Result<Vec<String>> {
159 Ok(raw
160 .split(',')
161 .map(|s| s.trim().to_string())
162 .filter(|s| !s.is_empty())
163 .collect())
164 }
165 fn format_instructions(&self) -> String {
166 "Reply with a comma-separated list of values.".into()
167 }
168}
169
170pub struct XmlParser;
176
177#[async_trait]
178impl Parser<Value> for XmlParser {
179 async fn parse(&self, raw: &str) -> Result<Value> {
180 let mut out = serde_json::Map::new();
181 let mut idx = 0;
182 let bytes = raw.as_bytes();
183 while idx < bytes.len() {
184 while idx < bytes.len() && bytes[idx] != b'<' {
186 idx += 1;
187 }
188 if idx >= bytes.len() {
189 break;
190 }
191 let tag_start = idx + 1;
192 let mut tag_end = tag_start;
194 while tag_end < bytes.len() && bytes[tag_end] != b'>' {
195 tag_end += 1;
196 }
197 if tag_end >= bytes.len() {
198 break;
199 }
200 let tag = &raw[tag_start..tag_end];
201 if tag.starts_with('/') {
202 idx = tag_end + 1;
203 continue;
204 }
205 let close = format!("</{tag}>");
206 if let Some(close_pos) = raw[tag_end..].find(&close) {
207 let body_start = tag_end + 1;
208 let body_end = tag_end + close_pos;
209 let body = &raw[body_start..body_end];
210 out.insert(tag.to_string(), Value::String(body.trim().to_string()));
211 idx = body_end + close.len();
212 } else {
213 idx = tag_end + 1;
214 }
215 }
216 if out.is_empty() {
217 return Err(AgentError::Tool("xml parse: no tags found".into()));
218 }
219 Ok(Value::Object(out))
220 }
221 fn format_instructions(&self) -> String {
222 "Wrap each field in matching XML tags, e.g. <name>Alice</name>.".into()
223 }
224}
225
226pub struct YamlParser;
233
234#[async_trait]
235impl Parser<Value> for YamlParser {
236 async fn parse(&self, raw: &str) -> Result<Value> {
237 let mut out = serde_json::Map::new();
238 for line in raw.lines() {
239 let l = line.trim();
240 if l.is_empty() || l.starts_with('#') {
241 continue;
242 }
243 if let Some((k, v)) = l.split_once(':') {
244 let k = k.trim();
245 let v = v.trim();
246 if k.is_empty() {
247 continue;
248 }
249 out.insert(k.to_string(), Value::String(v.to_string()));
250 }
251 }
252 if out.is_empty() {
253 return Err(AgentError::Tool("yaml parse: no key/value pairs".into()));
254 }
255 Ok(Value::Object(out))
256 }
257 fn format_instructions(&self) -> String {
258 "Reply with one key: value pair per line.".into()
259 }
260}
261
262fn strip_code_fence(s: &str) -> String {
263 let s = s.trim();
264 if s.starts_with("```") {
265 let mut lines: Vec<&str> = s.lines().collect();
266 if lines.first().map(|l| l.starts_with("```")).unwrap_or(false) {
267 lines.remove(0);
268 }
269 if lines.last().map(|l| l.trim() == "```").unwrap_or(false) {
270 lines.pop();
271 }
272 return lines.join("\n");
273 }
274 s.to_string()
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use serde::Deserialize;
281
282 #[derive(Debug, Deserialize, PartialEq)]
283 struct Plan {
284 title: String,
285 steps: Vec<String>,
286 }
287
288 #[tokio::test]
289 async fn json_strips_fence() {
290 let p = JsonParser;
291 let v = p.parse("```json\n{\"a\":1}\n```").await.unwrap();
292 assert_eq!(v, serde_json::json!({"a": 1}));
293 }
294
295 #[tokio::test]
296 async fn schema_parser_round_trips_typed_struct() {
297 let p: SchemaParser<Plan> = SchemaParser::new("...");
298 let v = p.parse(r#"{"title":"x","steps":["a","b"]}"#).await.unwrap();
299 assert_eq!(v.title, "x");
300 assert_eq!(v.steps.len(), 2);
301 }
302
303 #[tokio::test]
304 async fn schema_validation_catches_missing_field() {
305 let p = JsonSchemaParser::new(serde_json::json!({
306 "type": "object",
307 "required": ["a", "b"]
308 }));
309 let r = p.parse(r#"{"a":1}"#).await;
310 assert!(r.is_err());
311 }
312
313 #[tokio::test]
314 async fn enum_parser_normalizes_case() {
315 let p = EnumParser::new(["yes", "no"]);
316 assert_eq!(p.parse("YES").await.unwrap(), "yes");
317 assert!(p.parse("maybe").await.is_err());
318 }
319
320 #[tokio::test]
321 async fn comma_list_parses_with_trim() {
322 let p = CommaListParser;
323 assert_eq!(p.parse("a, b,c , ").await.unwrap(), vec!["a", "b", "c"]);
324 }
325
326 #[tokio::test]
327 async fn xml_parser_extracts_top_level_tags() {
328 let p = XmlParser;
329 let v = p.parse("<name>Alice</name><city>NYC</city>").await.unwrap();
330 assert_eq!(v["name"], "Alice");
331 assert_eq!(v["city"], "NYC");
332 }
333
334 #[tokio::test]
335 async fn yaml_parser_simple_dialect() {
336 let p = YamlParser;
337 let v = p.parse("name: Alice\nrole: admin\n").await.unwrap();
338 assert_eq!(v["name"], "Alice");
339 }
340}