Skip to main content

sgr_agent/agents/
hybrid.rs

1//! HybridAgent — 2-phase agent (reasoning + action).
2//!
3//! Phase 1: Send a minimal ReasoningTool-only FC call to get the agent's
4//! reasoning about what to do next.
5//! Phase 2: Send the full toolkit FC call with the reasoning as context,
6//! getting back concrete tool calls.
7//!
8//! Inspired by Python SGRToolCallingAgent — separates "thinking" from "acting"
9//! so the model doesn't get overwhelmed by a large tool set during reasoning.
10
11use crate::agent::{Agent, AgentError, Decision};
12use crate::client::LlmClient;
13use crate::registry::ToolRegistry;
14use crate::types::Message;
15
16/// 2-phase hybrid agent.
17pub struct HybridAgent<C: LlmClient> {
18    client: C,
19    system_prompt: String,
20}
21
22impl<C: LlmClient> HybridAgent<C> {
23    pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
24        Self {
25            client,
26            system_prompt: system_prompt.into(),
27        }
28    }
29}
30
31/// Internal reasoning tool definition for phase 1.
32fn reasoning_tool_def() -> crate::tool::ToolDef {
33    crate::tool::ToolDef {
34        name: "reasoning".to_string(),
35        description: "Analyze the situation and decide what tools to use next. Describe your reasoning, the current situation, and which tools you plan to call.".to_string(),
36        parameters: serde_json::json!({
37            "type": "object",
38            "properties": {
39                "situation": {
40                    "type": "string",
41                    "description": "Your assessment of the current situation"
42                },
43                "plan": {
44                    "type": "array",
45                    "items": { "type": "string" },
46                    "description": "Step-by-step plan of what to do next"
47                },
48                "done": {
49                    "type": "boolean",
50                    "description": "Set to true if the task is fully complete"
51                }
52            },
53            "required": ["situation", "plan", "done"]
54        }),
55    }
56}
57
58#[async_trait::async_trait]
59impl<C: LlmClient> Agent for HybridAgent<C> {
60    async fn decide(
61        &self,
62        messages: &[Message],
63        tools: &ToolRegistry,
64    ) -> Result<Decision, AgentError> {
65        self.decide_stateful(messages, tools, None)
66            .await
67            .map(|(d, _)| d)
68    }
69
70    async fn decide_stateful(
71        &self,
72        messages: &[Message],
73        tools: &ToolRegistry,
74        previous_response_id: Option<&str>,
75    ) -> Result<(Decision, Option<String>), AgentError> {
76        // Prepare messages with system prompt
77        let mut msgs = Vec::with_capacity(messages.len() + 1);
78        let has_system = messages
79            .iter()
80            .any(|m| m.role == crate::types::Role::System);
81        if !has_system && !self.system_prompt.is_empty() {
82            msgs.push(Message::system(&self.system_prompt));
83        }
84        msgs.extend_from_slice(messages);
85
86        // Phase 1: Reasoning — stateless (fresh context each time)
87        let reasoning_defs = vec![reasoning_tool_def()];
88        let reasoning_calls = self.client.tools_call(&msgs, &reasoning_defs).await?;
89
90        // Extract reasoning from phase 1
91        let (situation, plan, done) = if let Some(rc) = reasoning_calls.first() {
92            let sit = rc
93                .arguments
94                .get("situation")
95                .and_then(|s| s.as_str())
96                .unwrap_or("")
97                .to_string();
98            let plan: Vec<String> = rc
99                .arguments
100                .get("plan")
101                .and_then(|p| p.as_array())
102                .map(|arr| {
103                    arr.iter()
104                        .filter_map(|v| v.as_str().map(String::from))
105                        .collect()
106                })
107                .unwrap_or_default();
108            let done = rc
109                .arguments
110                .get("done")
111                .and_then(|d| d.as_bool())
112                .unwrap_or(false);
113            (sit, plan, done)
114        } else {
115            return Ok((
116                Decision {
117                    situation: String::new(),
118                    task: vec![],
119                    tool_calls: vec![],
120                    completed: true,
121                },
122                None,
123            ));
124        };
125
126        // Phase 2: Action — STATEFUL (chain from previous step for token caching)
127        let mut action_msgs = msgs.clone();
128        let reasoning_context = if done {
129            format!(
130                "Reasoning: {}\nStatus: Task appears complete. Call the answer/finish tool with the final result.",
131                situation
132            )
133        } else {
134            format!("Reasoning: {}\nPlan: {}", situation, plan.join(", "))
135        };
136        action_msgs.push(Message::assistant(&reasoning_context));
137        action_msgs.push(Message::user(
138            "Now execute the next step from your plan using the available tools.",
139        ));
140
141        // Progressive tool discovery: filter tools by reasoning context.
142        // Send only tools mentioned in situation/plan + answer/finish (always needed).
143        let context_lower = format!("{} {}", situation, plan.join(" ")).to_lowercase();
144        let filtered: Vec<_> = tools
145            .to_defs()
146            .into_iter()
147            .filter(|t| {
148                // Always include answer/finish tools
149                t.name == "answer"
150                    || t.name == "finish_task"
151                    || t.name.contains("answer")
152                    // Include if tool name appears in reasoning context
153                    || context_lower.contains(&t.name.to_lowercase())
154                    // Include read/write/search as core tools (almost always needed)
155                    || matches!(t.name.as_str(), "read" | "write" | "search")
156            })
157            .collect();
158        let defs = if filtered.is_empty() {
159            tools.to_defs()
160        } else {
161            filtered
162        };
163
164        let (tool_calls, new_response_id) = self
165            .client
166            .tools_call_stateful(&action_msgs, &defs, previous_response_id)
167            .await?;
168
169        let completed =
170            tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
171
172        Ok((
173            Decision {
174                situation,
175                task: plan,
176                tool_calls,
177                completed,
178            },
179            new_response_id,
180        ))
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::agent_tool::{Tool, ToolError, ToolOutput};
188    use crate::context::AgentContext;
189    use crate::tool::ToolDef;
190    use crate::types::{SgrError, ToolCall};
191    use serde_json::Value;
192    use std::sync::Arc;
193    use std::sync::atomic::{AtomicUsize, Ordering};
194
195    /// Mock client that returns reasoning in phase 1, tool call in phase 2.
196    struct MockHybridClient {
197        call_count: Arc<AtomicUsize>,
198    }
199
200    #[async_trait::async_trait]
201    impl LlmClient for MockHybridClient {
202        async fn structured_call(
203            &self,
204            _: &[Message],
205            _: &Value,
206        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
207            Ok((None, vec![], String::new()))
208        }
209        async fn tools_call(
210            &self,
211            _: &[Message],
212            _tools: &[ToolDef],
213        ) -> Result<Vec<ToolCall>, SgrError> {
214            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
215            if n == 0 {
216                // Phase 1: reasoning
217                Ok(vec![ToolCall {
218                    id: "r1".into(),
219                    name: "reasoning".into(),
220                    arguments: serde_json::json!({
221                        "situation": "Need to read a file",
222                        "plan": ["read main.rs", "analyze contents"],
223                        "done": false
224                    }),
225                }])
226            } else {
227                // Phase 2: action
228                Ok(vec![ToolCall {
229                    id: "a1".into(),
230                    name: "read_file".into(),
231                    arguments: serde_json::json!({"path": "main.rs"}),
232                }])
233            }
234        }
235        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
236            Ok(String::new())
237        }
238    }
239
240    struct DummyTool;
241    #[async_trait::async_trait]
242    impl Tool for DummyTool {
243        fn name(&self) -> &str {
244            "read_file"
245        }
246        fn description(&self) -> &str {
247            "read a file"
248        }
249        fn parameters_schema(&self) -> Value {
250            serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}})
251        }
252        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
253            Ok(ToolOutput::text("file contents"))
254        }
255    }
256
257    #[tokio::test]
258    async fn hybrid_two_phases() {
259        let client = MockHybridClient {
260            call_count: Arc::new(AtomicUsize::new(0)),
261        };
262        let agent = HybridAgent::new(client, "test agent");
263        let tools = ToolRegistry::new().register(DummyTool);
264        let msgs = vec![Message::user("read main.rs")];
265
266        let decision = agent.decide(&msgs, &tools).await.unwrap();
267        assert_eq!(decision.situation, "Need to read a file");
268        assert_eq!(decision.task.len(), 2);
269        assert_eq!(decision.tool_calls.len(), 1);
270        assert_eq!(decision.tool_calls[0].name, "read_file");
271        assert!(!decision.completed);
272    }
273
274    #[tokio::test]
275    async fn hybrid_done_still_runs_phase2() {
276        // Even when reasoning says done, phase 2 runs to let the model call answer/finish
277        struct DoneClient {
278            call_count: Arc<AtomicUsize>,
279        }
280        #[async_trait::async_trait]
281        impl LlmClient for DoneClient {
282            async fn structured_call(
283                &self,
284                _: &[Message],
285                _: &Value,
286            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
287                Ok((None, vec![], String::new()))
288            }
289            async fn tools_call(
290                &self,
291                _: &[Message],
292                _: &[ToolDef],
293            ) -> Result<Vec<ToolCall>, SgrError> {
294                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
295                if n == 0 {
296                    Ok(vec![ToolCall {
297                        id: "r1".into(),
298                        name: "reasoning".into(),
299                        arguments: serde_json::json!({
300                            "situation": "Task is already complete",
301                            "plan": [],
302                            "done": true
303                        }),
304                    }])
305                } else {
306                    // Phase 2 — model calls finish
307                    Ok(vec![ToolCall {
308                        id: "a1".into(),
309                        name: "finish_task".into(),
310                        arguments: serde_json::json!({"summary": "done"}),
311                    }])
312                }
313            }
314            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
315                Ok(String::new())
316            }
317        }
318
319        let agent = HybridAgent::new(
320            DoneClient {
321                call_count: Arc::new(AtomicUsize::new(0)),
322            },
323            "test",
324        );
325        let tools = ToolRegistry::new().register(DummyTool);
326        let msgs = vec![Message::user("done")];
327
328        let decision = agent.decide(&msgs, &tools).await.unwrap();
329        // Phase 2 ran and returned finish_task
330        assert!(decision.completed);
331        assert_eq!(decision.tool_calls.len(), 1);
332        assert_eq!(decision.tool_calls[0].name, "finish_task");
333    }
334
335    #[tokio::test]
336    async fn hybrid_no_reasoning_completes() {
337        struct EmptyClient;
338        #[async_trait::async_trait]
339        impl LlmClient for EmptyClient {
340            async fn structured_call(
341                &self,
342                _: &[Message],
343                _: &Value,
344            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
345                Ok((None, vec![], String::new()))
346            }
347            async fn tools_call(
348                &self,
349                _: &[Message],
350                _: &[ToolDef],
351            ) -> Result<Vec<ToolCall>, SgrError> {
352                Ok(vec![])
353            }
354            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
355                Ok(String::new())
356            }
357        }
358
359        let agent = HybridAgent::new(EmptyClient, "test");
360        let tools = ToolRegistry::new().register(DummyTool);
361        let msgs = vec![Message::user("hello")];
362
363        let decision = agent.decide(&msgs, &tools).await.unwrap();
364        assert!(decision.completed);
365    }
366
367    #[tokio::test]
368    async fn hybrid_two_phases_independent() {
369        // Verify that phase 1 and phase 2 don't share state:
370        // Both calls use tools_call (stateless), so they are independent.
371        // The mock tracks call order and verifies each phase gets separate invocations.
372        struct PhaseTrackingClient {
373            call_count: Arc<AtomicUsize>,
374        }
375
376        #[async_trait::async_trait]
377        impl LlmClient for PhaseTrackingClient {
378            async fn structured_call(
379                &self,
380                _: &[Message],
381                _: &Value,
382            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
383                Ok((None, vec![], String::new()))
384            }
385            async fn tools_call(
386                &self,
387                msgs: &[Message],
388                tools: &[ToolDef],
389            ) -> Result<Vec<ToolCall>, SgrError> {
390                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
391                if n == 0 {
392                    // Phase 1: reasoning — only gets reasoning tool
393                    assert_eq!(tools.len(), 1, "Phase 1 should only have reasoning tool");
394                    assert_eq!(tools[0].name, "reasoning");
395                    Ok(vec![ToolCall {
396                        id: "r1".into(),
397                        name: "reasoning".into(),
398                        arguments: serde_json::json!({
399                            "situation": "Testing phase independence",
400                            "plan": ["call read_file"],
401                            "done": false
402                        }),
403                    }])
404                } else {
405                    // Phase 2: action — gets full tool registry
406                    assert!(
407                        tools.len() > 1 || tools[0].name != "reasoning",
408                        "Phase 2 should have the real tools, not just reasoning"
409                    );
410                    // Verify that messages don't contain any implicit state from phase 1
411                    // (they will have reasoning context added explicitly as assistant message)
412                    let last_msg = msgs.last().unwrap();
413                    assert_eq!(
414                        last_msg.role,
415                        crate::types::Role::User,
416                        "Last message in phase 2 should be the action prompt"
417                    );
418                    Ok(vec![ToolCall {
419                        id: "a1".into(),
420                        name: "read_file".into(),
421                        arguments: serde_json::json!({"path": "test.rs"}),
422                    }])
423                }
424            }
425            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
426                Ok(String::new())
427            }
428        }
429
430        let call_count = Arc::new(AtomicUsize::new(0));
431        let agent = HybridAgent::new(
432            PhaseTrackingClient {
433                call_count: call_count.clone(),
434            },
435            "test agent",
436        );
437        let tools = ToolRegistry::new().register(DummyTool);
438        let msgs = vec![Message::user("read test.rs")];
439
440        let decision = agent.decide(&msgs, &tools).await.unwrap();
441
442        // Both phases ran
443        assert_eq!(call_count.load(Ordering::SeqCst), 2);
444        // Phase 2 returned the action
445        assert_eq!(decision.tool_calls.len(), 1);
446        assert_eq!(decision.tool_calls[0].name, "read_file");
447        assert!(!decision.completed);
448    }
449}