Skip to main content

sgr_agent/agents/
flexible.rs

1//! FlexibleAgent — text-based agent for weak models without structured output.
2//!
3//! Puts tool descriptions in the system prompt, sends plain completion,
4//! then uses flexible_parser + coerce to extract tool calls from text.
5//!
6//! Supports retry with error feedback: if parsing fails, accumulates errors
7//! and feeds them back to the model in the next attempt (up to max_retries).
8
9use crate::agent::{Agent, AgentError, Decision};
10use crate::client::LlmClient;
11use crate::registry::ToolRegistry;
12use crate::schema_simplifier;
13use crate::types::Message;
14use crate::union_schema;
15
16/// Agent for models without native structured output or function calling.
17pub struct FlexibleAgent<C: LlmClient> {
18    client: C,
19    system_prompt: String,
20    /// Maximum parse retry attempts (1 = no retry, 5 = up to 5 attempts).
21    max_retries: usize,
22}
23
24impl<C: LlmClient> FlexibleAgent<C> {
25    pub fn new(client: C, system_prompt: impl Into<String>, max_retries: usize) -> Self {
26        Self {
27            client,
28            system_prompt: system_prompt.into(),
29            max_retries: max_retries.max(1),
30        }
31    }
32}
33
34/// Build tool descriptions for system prompt using SchemaSimplifier.
35fn tools_prompt(tools: &ToolRegistry) -> String {
36    let mut s = String::from(
37        "## Available Tools\n\nRespond with JSON: {\"situation\": \"...\", \"task\": [...], \"actions\": [{\"tool_name\": \"...\", ...args}]}\n\n",
38    );
39    for t in tools.list() {
40        s.push_str(&schema_simplifier::simplify_tool(
41            t.name(),
42            t.description(),
43            &t.parameters_schema(),
44        ));
45        s.push_str("\n\n");
46    }
47    s
48}
49
50/// Generate a format error correction prompt with accumulated errors.
51fn format_error_prompt(errors: &[String]) -> String {
52    let mut prompt = String::from(
53        "Your previous response(s) could not be parsed as valid JSON. Please fix and try again.\n\nErrors:\n",
54    );
55    for (i, err) in errors.iter().enumerate() {
56        prompt.push_str(&format!("{}. {}\n", i + 1, err));
57    }
58    prompt.push_str(
59        "\nRespond with ONLY valid JSON matching the schema. No markdown, no explanations.",
60    );
61    prompt
62}
63
64#[async_trait::async_trait]
65impl<C: LlmClient> Agent for FlexibleAgent<C> {
66    async fn decide(
67        &self,
68        messages: &[Message],
69        tools: &ToolRegistry,
70    ) -> Result<Decision, AgentError> {
71        let defs = tools.to_defs();
72
73        // Build system prompt with tool descriptions
74        let full_system = format!("{}\n\n{}", self.system_prompt, tools_prompt(tools));
75        let mut msgs = Vec::with_capacity(messages.len() + 1);
76        let has_system = messages
77            .iter()
78            .any(|m| m.role == crate::types::Role::System);
79        if !has_system {
80            msgs.push(Message::system(&full_system));
81        }
82        msgs.extend_from_slice(messages);
83
84        let mut errors: Vec<String> = Vec::new();
85
86        for attempt in 0..self.max_retries {
87            // On retry, add error feedback
88            if attempt > 0 && !errors.is_empty() {
89                msgs.push(Message::user(format_error_prompt(&errors)));
90            }
91
92            let raw = self.client.complete(&msgs).await?;
93
94            match union_schema::parse_action(&raw, &defs) {
95                Ok((situation, tool_calls)) => {
96                    let completed = tool_calls.is_empty()
97                        || tool_calls.iter().any(|tc| tc.name == "finish_task");
98                    return Ok(Decision {
99                        situation,
100                        task: vec![],
101                        tool_calls,
102                        completed,
103                    });
104                }
105                Err(e) => {
106                    errors.push(e.to_string());
107                    // Add the raw response as assistant message for context
108                    msgs.push(Message::assistant(&raw));
109                }
110            }
111        }
112
113        // All retries exhausted — treat last raw response as completed
114        Ok(Decision {
115            situation: format!(
116                "Failed to parse after {} attempts. Errors: {}",
117                self.max_retries,
118                errors.join("; ")
119            ),
120            task: vec![],
121            tool_calls: vec![],
122            completed: true,
123        })
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::agent_tool::{ToolError, ToolOutput};
131    use crate::client::LlmClient;
132    use crate::context::AgentContext;
133    use crate::tool::ToolDef;
134    use crate::types::{SgrError, ToolCall};
135    use serde_json::Value;
136    use std::sync::Arc;
137    use std::sync::atomic::{AtomicUsize, Ordering};
138
139    struct MockTextClient {
140        response: String,
141    }
142
143    #[async_trait::async_trait]
144    impl LlmClient for MockTextClient {
145        async fn structured_call(
146            &self,
147            _: &[Message],
148            _: &Value,
149        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
150            Ok((None, vec![], String::new()))
151        }
152        async fn tools_call(
153            &self,
154            _: &[Message],
155            _: &[ToolDef],
156        ) -> Result<Vec<ToolCall>, SgrError> {
157            Ok(vec![])
158        }
159        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
160            Ok(self.response.clone())
161        }
162    }
163
164    struct DummyTool;
165
166    #[async_trait::async_trait]
167    impl crate::agent_tool::Tool for DummyTool {
168        fn name(&self) -> &str {
169            "search"
170        }
171        fn description(&self) -> &str {
172            "search files"
173        }
174        fn parameters_schema(&self) -> Value {
175            serde_json::json!({"type": "object", "properties": {"query": {"type": "string"}}})
176        }
177        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
178            Ok(ToolOutput::text("ok"))
179        }
180    }
181
182    #[tokio::test]
183    async fn flexible_agent_parses_json_from_text() {
184        let client = MockTextClient {
185            response: r#"Sure, let me search for that.
186```json
187{"situation": "searching", "task": ["find files"], "actions": [{"tool_name": "search", "query": "main.rs"}]}
188```"#
189            .into(),
190        };
191        let agent = FlexibleAgent::new(client, "You are a test agent", 1);
192        let tools = ToolRegistry::new().register(DummyTool);
193        let msgs = vec![Message::user("find main.rs")];
194
195        let decision = agent.decide(&msgs, &tools).await.unwrap();
196        assert_eq!(decision.tool_calls.len(), 1);
197        assert_eq!(decision.tool_calls[0].name, "search");
198    }
199
200    #[tokio::test]
201    async fn flexible_agent_plain_text_completes() {
202        let client = MockTextClient {
203            response: "I can't find any tools to use here.".into(),
204        };
205        let agent = FlexibleAgent::new(client, "test", 1);
206        let tools = ToolRegistry::new().register(DummyTool);
207        let msgs = vec![Message::user("hello")];
208
209        let decision = agent.decide(&msgs, &tools).await.unwrap();
210        assert!(decision.completed);
211        assert!(decision.tool_calls.is_empty());
212    }
213
214    #[tokio::test]
215    async fn flexible_agent_retry_succeeds() {
216        /// Client that fails first, succeeds second
217        struct RetryClient {
218            call_count: Arc<AtomicUsize>,
219        }
220        #[async_trait::async_trait]
221        impl LlmClient for RetryClient {
222            async fn structured_call(
223                &self,
224                _: &[Message],
225                _: &Value,
226            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
227                Ok((None, vec![], String::new()))
228            }
229            async fn tools_call(
230                &self,
231                _: &[Message],
232                _: &[ToolDef],
233            ) -> Result<Vec<ToolCall>, SgrError> {
234                Ok(vec![])
235            }
236            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
237                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
238                if n == 0 {
239                    Ok("not valid json at all".into())
240                } else {
241                    Ok(
242                        r#"{"situation": "found it", "task": [], "actions": [{"tool_name": "search", "query": "test"}]}"#
243                            .into(),
244                    )
245                }
246            }
247        }
248
249        let client = RetryClient {
250            call_count: Arc::new(AtomicUsize::new(0)),
251        };
252        let agent = FlexibleAgent::new(client, "test", 3);
253        let tools = ToolRegistry::new().register(DummyTool);
254        let msgs = vec![Message::user("search")];
255
256        let decision = agent.decide(&msgs, &tools).await.unwrap();
257        assert_eq!(decision.tool_calls.len(), 1);
258        assert_eq!(decision.situation, "found it");
259    }
260
261    #[tokio::test]
262    async fn flexible_agent_retry_exhausted() {
263        let client = MockTextClient {
264            response: "garbage output always".into(),
265        };
266        let agent = FlexibleAgent::new(client, "test", 3);
267        let tools = ToolRegistry::new().register(DummyTool);
268        let msgs = vec![Message::user("do something")];
269
270        let decision = agent.decide(&msgs, &tools).await.unwrap();
271        assert!(decision.completed);
272        assert!(decision.tool_calls.is_empty());
273        assert!(
274            decision
275                .situation
276                .contains("Failed to parse after 3 attempts")
277        );
278    }
279
280    #[test]
281    fn format_error_prompt_content() {
282        let errors = vec!["bad json".to_string(), "missing field".to_string()];
283        let prompt = format_error_prompt(&errors);
284        assert!(prompt.contains("1. bad json"));
285        assert!(prompt.contains("2. missing field"));
286        assert!(prompt.contains("valid JSON"));
287    }
288
289    #[test]
290    fn tools_prompt_uses_simplifier() {
291        let tools = ToolRegistry::new().register(DummyTool);
292        let prompt = tools_prompt(&tools);
293        assert!(prompt.contains("### search"));
294        assert!(prompt.contains("search files"));
295    }
296}