use crate::agent::Agent;
use crate::error::{ReactError, Result};
use super::ReactAgent;
pub struct StructuredAgent<T> {
inner: ReactAgent,
_phantom: std::marker::PhantomData<T>,
}
impl<T> StructuredAgent<T>
where
T: serde::de::DeserializeOwned + Send + 'static,
{
pub fn new(agent: ReactAgent) -> Self {
Self {
inner: agent,
_phantom: std::marker::PhantomData,
}
}
pub fn inner(&self) -> &ReactAgent {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut ReactAgent {
&mut self.inner
}
pub async fn execute(&mut self, task: &str) -> Result<T> {
let text = self.inner.execute(task).await?;
parse_json_output(&text)
}
pub async fn chat(&mut self, message: &str) -> Result<T> {
let text = self.inner.chat(message).await?;
parse_json_output(&text)
}
}
fn parse_json_output<T: serde::de::DeserializeOwned>(text: &str) -> Result<T> {
let trimmed = text.trim();
if let Ok(v) = serde_json::from_str::<T>(trimmed) {
return Ok(v);
}
if let Some(json_str) = extract_json_from_markdown(trimmed)
&& let Ok(v) = serde_json::from_str::<T>(json_str)
{
return Ok(v);
}
Err(ReactError::Other(format!(
"Failed to parse LLM output as target type. Raw output:\n{text}"
)))
}
fn extract_json_from_markdown(text: &str) -> Option<&str> {
let start = if let Some(pos) = text.find("```json") {
pos + 7
} else if let Some(pos) = text.find("```") {
pos + 3
} else {
return None;
};
let remaining = &text[start..];
let content_start = remaining.find('\n').map(|p| p + 1).unwrap_or(0);
let content = &remaining[content_start..];
content.find("```").map(|end| content[..end].trim())
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq)]
struct Person {
name: String,
age: u32,
}
#[test]
fn test_parse_json_direct() {
let result: Person = parse_json_output(r#"{"name": "Alice", "age": 30}"#).unwrap();
assert_eq!(
result,
Person {
name: "Alice".to_string(),
age: 30
}
);
}
#[test]
fn test_parse_json_with_whitespace() {
let result: Person = parse_json_output(" \n{\"name\": \"Bob\", \"age\": 25}\n ").unwrap();
assert_eq!(result.name, "Bob");
}
#[test]
fn test_parse_json_from_markdown() {
let text = r#"Here is the result:
```json
{"name": "Charlie", "age": 35}
```
"#;
let result: Person = parse_json_output(text).unwrap();
assert_eq!(result.name, "Charlie");
assert_eq!(result.age, 35);
}
#[test]
fn test_parse_json_failure() {
let result = parse_json_output::<Person>("not json at all");
assert!(result.is_err());
}
#[test]
fn test_extract_json_from_markdown() {
let text = "```json\n{\"a\": 1}\n```";
assert_eq!(extract_json_from_markdown(text), Some("{\"a\": 1}"));
}
#[test]
fn test_extract_json_no_markdown() {
assert_eq!(extract_json_from_markdown("{\"a\": 1}"), None);
}
}