Skip to main content

aiproof_parse/
json_schema.rs

1use aiproof_core::document::{Document, Kind, PromptText, Role};
2use std::path::Path;
3
4pub fn parse(path: &Path, source: &str) -> anyhow::Result<Vec<Document>> {
5    let Ok(value) = serde_json::from_str::<serde_json::Value>(source) else {
6        // Invalid JSON — fall back to plain.
7        return crate::plain::parse(path, source);
8    };
9
10    let mut descriptions = Vec::new();
11    collect_descriptions(&value, &mut descriptions);
12
13    if descriptions.is_empty() {
14        return crate::plain::parse(path, source);
15    }
16
17    Ok(descriptions
18        .into_iter()
19        .map(|d| Document {
20            path: path.to_path_buf(),
21            role: Role::Tool,
22            source: source.to_string(),
23            prompt: PromptText {
24                text: d,
25                origin_span: None,
26            },
27            kind: Kind::JsonSchema,
28        })
29        .collect())
30}
31
32fn collect_descriptions(v: &serde_json::Value, out: &mut Vec<String>) {
33    match v {
34        serde_json::Value::Object(map) => {
35            if let Some(serde_json::Value::String(d)) = map.get("description") {
36                out.push(d.clone());
37            }
38            for (_, child) in map {
39                collect_descriptions(child, out);
40            }
41        }
42        serde_json::Value::Array(arr) => {
43            for child in arr {
44                collect_descriptions(child, out);
45            }
46        }
47        _ => {}
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54
55    #[test]
56    fn extracts_mcp_tool_descriptions() {
57        let src = r#"{
58            "name": "search",
59            "description": "Search the knowledge base for a term.",
60            "inputSchema": {
61                "type": "object",
62                "properties": {
63                    "query": { "type": "string", "description": "The search term." }
64                }
65            }
66        }"#;
67        let docs = parse(std::path::Path::new("t.json"), src).unwrap();
68        assert_eq!(docs.len(), 2);
69        let texts: Vec<_> = docs.iter().map(|d| d.prompt.text.as_str()).collect();
70        assert!(texts.contains(&"Search the knowledge base for a term."));
71        assert!(texts.contains(&"The search term."));
72    }
73
74    #[test]
75    fn non_mcp_json_falls_back_to_plain() {
76        let src = r#"{"foo": 1}"#;
77        let docs = parse(std::path::Path::new("d.json"), src).unwrap();
78        assert_eq!(docs.len(), 1);
79        assert_eq!(docs[0].prompt.text, src);
80    }
81
82    #[test]
83    fn invalid_json_falls_back_to_plain() {
84        let src = "not json at all";
85        let docs = parse(std::path::Path::new("x.json"), src).unwrap();
86        assert_eq!(docs[0].prompt.text, "not json at all");
87    }
88}