Skip to main content

sgr_agent/agents/
hybrid.rs

1//! HybridAgent — 2-phase agent (reasoning + action).
2//!
3//! Phase 1: Send a minimal ReasoningTool-only FC call to get the agent's
4//! reasoning about what to do next.
5//! Phase 2: Send the full toolkit FC call with the reasoning as context,
6//! getting back concrete tool calls.
7//!
8//! Inspired by Python SGRToolCallingAgent — separates "thinking" from "acting"
9//! so the model doesn't get overwhelmed by a large tool set during reasoning.
10
11use crate::agent::{Agent, AgentError, Decision};
12use crate::client::LlmClient;
13use crate::registry::ToolRegistry;
14use crate::types::Message;
15
16/// 2-phase hybrid agent.
17pub struct HybridAgent<C: LlmClient> {
18    client: C,
19    system_prompt: String,
20}
21
22impl<C: LlmClient> HybridAgent<C> {
23    pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
24        Self {
25            client,
26            system_prompt: system_prompt.into(),
27        }
28    }
29}
30
31/// Internal reasoning tool definition for phase 1.
32fn reasoning_tool_def() -> crate::tool::ToolDef {
33    crate::tool::ToolDef {
34        name: "reasoning".to_string(),
35        description: "Analyze the situation and decide what tools to use next. Describe your reasoning, the current situation, and which tools you plan to call.".to_string(),
36        parameters: serde_json::json!({
37            "type": "object",
38            "properties": {
39                "situation": {
40                    "type": "string",
41                    "description": "Your assessment of the current situation"
42                },
43                "plan": {
44                    "type": "array",
45                    "items": { "type": "string" },
46                    "description": "Step-by-step plan of what to do next"
47                },
48                "done": {
49                    "type": "boolean",
50                    "description": "Set to true if the task is fully complete"
51                }
52            },
53            "required": ["situation", "plan", "done"]
54        }),
55    }
56}
57
58#[async_trait::async_trait]
59impl<C: LlmClient> Agent for HybridAgent<C> {
60    async fn decide(
61        &self,
62        messages: &[Message],
63        tools: &ToolRegistry,
64    ) -> Result<Decision, AgentError> {
65        // Prepare messages with system prompt
66        let mut msgs = Vec::with_capacity(messages.len() + 1);
67        let has_system = messages
68            .iter()
69            .any(|m| m.role == crate::types::Role::System);
70        if !has_system && !self.system_prompt.is_empty() {
71            msgs.push(Message::system(&self.system_prompt));
72        }
73        msgs.extend_from_slice(messages);
74
75        // Phase 1: Reasoning — FC call with only the reasoning tool
76        let reasoning_defs = vec![reasoning_tool_def()];
77        let reasoning_calls = self.client.tools_call(&msgs, &reasoning_defs).await?;
78
79        // Extract reasoning from phase 1
80        let (situation, plan, done) = if let Some(rc) = reasoning_calls.first() {
81            let sit = rc
82                .arguments
83                .get("situation")
84                .and_then(|s| s.as_str())
85                .unwrap_or("")
86                .to_string();
87            let plan: Vec<String> = rc
88                .arguments
89                .get("plan")
90                .and_then(|p| p.as_array())
91                .map(|arr| {
92                    arr.iter()
93                        .filter_map(|v| v.as_str().map(String::from))
94                        .collect()
95                })
96                .unwrap_or_default();
97            let done = rc
98                .arguments
99                .get("done")
100                .and_then(|d| d.as_bool())
101                .unwrap_or(false);
102            (sit, plan, done)
103        } else {
104            // No reasoning call — treat as completed
105            return Ok(Decision {
106                situation: String::new(),
107                task: vec![],
108                tool_calls: vec![],
109                completed: true,
110            });
111        };
112
113        // If reasoning says done, complete without phase 2
114        if done {
115            return Ok(Decision {
116                situation,
117                task: plan,
118                tool_calls: vec![],
119                completed: true,
120            });
121        }
122
123        // Phase 2: Action — FC call with full toolkit + reasoning context
124        let mut action_msgs = msgs.clone();
125        // Add reasoning as assistant context
126        let reasoning_context = format!("Reasoning: {}\nPlan: {}", situation, plan.join(", "));
127        action_msgs.push(Message::assistant(&reasoning_context));
128        // Prompt to execute
129        action_msgs.push(Message::user(
130            "Now execute the next step from your plan using the available tools.",
131        ));
132
133        let defs = tools.to_defs();
134        let tool_calls = self.client.tools_call(&action_msgs, &defs).await?;
135
136        let completed =
137            tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
138
139        Ok(Decision {
140            situation,
141            task: plan,
142            tool_calls,
143            completed,
144        })
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::agent_tool::{Tool, ToolError, ToolOutput};
152    use crate::context::AgentContext;
153    use crate::tool::ToolDef;
154    use crate::types::{SgrError, ToolCall};
155    use serde_json::Value;
156    use std::sync::Arc;
157    use std::sync::atomic::{AtomicUsize, Ordering};
158
159    /// Mock client that returns reasoning in phase 1, tool call in phase 2.
160    struct MockHybridClient {
161        call_count: Arc<AtomicUsize>,
162    }
163
164    #[async_trait::async_trait]
165    impl LlmClient for MockHybridClient {
166        async fn structured_call(
167            &self,
168            _: &[Message],
169            _: &Value,
170        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
171            Ok((None, vec![], String::new()))
172        }
173        async fn tools_call(
174            &self,
175            _: &[Message],
176            _tools: &[ToolDef],
177        ) -> Result<Vec<ToolCall>, SgrError> {
178            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
179            if n == 0 {
180                // Phase 1: reasoning
181                Ok(vec![ToolCall {
182                    id: "r1".into(),
183                    name: "reasoning".into(),
184                    arguments: serde_json::json!({
185                        "situation": "Need to read a file",
186                        "plan": ["read main.rs", "analyze contents"],
187                        "done": false
188                    }),
189                }])
190            } else {
191                // Phase 2: action
192                Ok(vec![ToolCall {
193                    id: "a1".into(),
194                    name: "read_file".into(),
195                    arguments: serde_json::json!({"path": "main.rs"}),
196                }])
197            }
198        }
199        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
200            Ok(String::new())
201        }
202    }
203
204    struct DummyTool;
205    #[async_trait::async_trait]
206    impl Tool for DummyTool {
207        fn name(&self) -> &str {
208            "read_file"
209        }
210        fn description(&self) -> &str {
211            "read a file"
212        }
213        fn parameters_schema(&self) -> Value {
214            serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}})
215        }
216        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
217            Ok(ToolOutput::text("file contents"))
218        }
219    }
220
221    #[tokio::test]
222    async fn hybrid_two_phases() {
223        let client = MockHybridClient {
224            call_count: Arc::new(AtomicUsize::new(0)),
225        };
226        let agent = HybridAgent::new(client, "test agent");
227        let tools = ToolRegistry::new().register(DummyTool);
228        let msgs = vec![Message::user("read main.rs")];
229
230        let decision = agent.decide(&msgs, &tools).await.unwrap();
231        assert_eq!(decision.situation, "Need to read a file");
232        assert_eq!(decision.task.len(), 2);
233        assert_eq!(decision.tool_calls.len(), 1);
234        assert_eq!(decision.tool_calls[0].name, "read_file");
235        assert!(!decision.completed);
236    }
237
238    #[tokio::test]
239    async fn hybrid_done_in_reasoning() {
240        struct DoneClient;
241        #[async_trait::async_trait]
242        impl LlmClient for DoneClient {
243            async fn structured_call(
244                &self,
245                _: &[Message],
246                _: &Value,
247            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
248                Ok((None, vec![], String::new()))
249            }
250            async fn tools_call(
251                &self,
252                _: &[Message],
253                _: &[ToolDef],
254            ) -> Result<Vec<ToolCall>, SgrError> {
255                Ok(vec![ToolCall {
256                    id: "r1".into(),
257                    name: "reasoning".into(),
258                    arguments: serde_json::json!({
259                        "situation": "Task is already complete",
260                        "plan": [],
261                        "done": true
262                    }),
263                }])
264            }
265            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
266                Ok(String::new())
267            }
268        }
269
270        let agent = HybridAgent::new(DoneClient, "test");
271        let tools = ToolRegistry::new().register(DummyTool);
272        let msgs = vec![Message::user("done")];
273
274        let decision = agent.decide(&msgs, &tools).await.unwrap();
275        assert!(decision.completed);
276        assert!(decision.tool_calls.is_empty());
277    }
278
279    #[tokio::test]
280    async fn hybrid_no_reasoning_completes() {
281        struct EmptyClient;
282        #[async_trait::async_trait]
283        impl LlmClient for EmptyClient {
284            async fn structured_call(
285                &self,
286                _: &[Message],
287                _: &Value,
288            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
289                Ok((None, vec![], String::new()))
290            }
291            async fn tools_call(
292                &self,
293                _: &[Message],
294                _: &[ToolDef],
295            ) -> Result<Vec<ToolCall>, SgrError> {
296                Ok(vec![])
297            }
298            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
299                Ok(String::new())
300            }
301        }
302
303        let agent = HybridAgent::new(EmptyClient, "test");
304        let tools = ToolRegistry::new().register(DummyTool);
305        let msgs = vec![Message::user("hello")];
306
307        let decision = agent.decide(&msgs, &tools).await.unwrap();
308        assert!(decision.completed);
309    }
310}