1use crate::agent::{Agent, AgentError, Decision};
12use crate::client::LlmClient;
13use crate::registry::ToolRegistry;
14use crate::types::Message;
15
16pub struct HybridAgent<C: LlmClient> {
18 client: C,
19 system_prompt: String,
20}
21
22impl<C: LlmClient> HybridAgent<C> {
23 pub fn new(client: C, system_prompt: impl Into<String>) -> Self {
24 Self {
25 client,
26 system_prompt: system_prompt.into(),
27 }
28 }
29}
30
31fn reasoning_tool_def() -> crate::tool::ToolDef {
33 crate::tool::ToolDef {
34 name: "reasoning".to_string(),
35 description: "Analyze the situation and decide what tools to use next. Describe your reasoning, the current situation, and which tools you plan to call.".to_string(),
36 parameters: serde_json::json!({
37 "type": "object",
38 "properties": {
39 "situation": {
40 "type": "string",
41 "description": "Your assessment of the current situation"
42 },
43 "plan": {
44 "type": "array",
45 "items": { "type": "string" },
46 "description": "Step-by-step plan of what to do next"
47 },
48 "done": {
49 "type": "boolean",
50 "description": "Set to true if the task is fully complete"
51 }
52 },
53 "required": ["situation", "plan", "done"]
54 }),
55 }
56}
57
58#[async_trait::async_trait]
59impl<C: LlmClient> Agent for HybridAgent<C> {
60 async fn decide(
61 &self,
62 messages: &[Message],
63 tools: &ToolRegistry,
64 ) -> Result<Decision, AgentError> {
65 self.decide_stateful(messages, tools, None)
66 .await
67 .map(|(d, _)| d)
68 }
69
70 async fn decide_stateful(
71 &self,
72 messages: &[Message],
73 tools: &ToolRegistry,
74 previous_response_id: Option<&str>,
75 ) -> Result<(Decision, Option<String>), AgentError> {
76 let mut msgs = Vec::with_capacity(messages.len() + 1);
78 let has_system = messages
79 .iter()
80 .any(|m| m.role == crate::types::Role::System);
81 if !has_system && !self.system_prompt.is_empty() {
82 msgs.push(Message::system(&self.system_prompt));
83 }
84 msgs.extend_from_slice(messages);
85
86 let reasoning_defs = vec![reasoning_tool_def()];
88 let reasoning_calls = self.client.tools_call(&msgs, &reasoning_defs).await?;
89
90 let (situation, plan, done) = if let Some(rc) = reasoning_calls.first() {
92 let sit = rc
93 .arguments
94 .get("situation")
95 .and_then(|s| s.as_str())
96 .unwrap_or("")
97 .to_string();
98 let plan: Vec<String> = rc
99 .arguments
100 .get("plan")
101 .and_then(|p| p.as_array())
102 .map(|arr| {
103 arr.iter()
104 .filter_map(|v| v.as_str().map(String::from))
105 .collect()
106 })
107 .unwrap_or_default();
108 let done = rc
109 .arguments
110 .get("done")
111 .and_then(|d| d.as_bool())
112 .unwrap_or(false);
113 (sit, plan, done)
114 } else {
115 return Ok((
116 Decision {
117 situation: String::new(),
118 task: vec![],
119 tool_calls: vec![],
120 completed: true,
121 },
122 None,
123 ));
124 };
125
126 let mut action_msgs = msgs.clone();
128 let reasoning_context = if done {
129 format!(
130 "Reasoning: {}\nStatus: Task appears complete. Call the answer/finish tool with the final result.",
131 situation
132 )
133 } else {
134 format!("Reasoning: {}\nPlan: {}", situation, plan.join(", "))
135 };
136 action_msgs.push(Message::assistant(&reasoning_context));
137 action_msgs.push(Message::user(
138 "Now execute the next step from your plan using the available tools.",
139 ));
140
141 let defs = tools.to_defs();
142 let (tool_calls, new_response_id) = self
143 .client
144 .tools_call_stateful(&action_msgs, &defs, previous_response_id)
145 .await?;
146
147 let completed =
148 tool_calls.is_empty() || tool_calls.iter().any(|tc| tc.name == "finish_task");
149
150 Ok((
151 Decision {
152 situation,
153 task: plan,
154 tool_calls,
155 completed,
156 },
157 new_response_id,
158 ))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::agent_tool::{Tool, ToolError, ToolOutput};
166 use crate::context::AgentContext;
167 use crate::tool::ToolDef;
168 use crate::types::{SgrError, ToolCall};
169 use serde_json::Value;
170 use std::sync::Arc;
171 use std::sync::atomic::{AtomicUsize, Ordering};
172
173 struct MockHybridClient {
175 call_count: Arc<AtomicUsize>,
176 }
177
178 #[async_trait::async_trait]
179 impl LlmClient for MockHybridClient {
180 async fn structured_call(
181 &self,
182 _: &[Message],
183 _: &Value,
184 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
185 Ok((None, vec![], String::new()))
186 }
187 async fn tools_call(
188 &self,
189 _: &[Message],
190 _tools: &[ToolDef],
191 ) -> Result<Vec<ToolCall>, SgrError> {
192 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
193 if n == 0 {
194 Ok(vec![ToolCall {
196 id: "r1".into(),
197 name: "reasoning".into(),
198 arguments: serde_json::json!({
199 "situation": "Need to read a file",
200 "plan": ["read main.rs", "analyze contents"],
201 "done": false
202 }),
203 }])
204 } else {
205 Ok(vec![ToolCall {
207 id: "a1".into(),
208 name: "read_file".into(),
209 arguments: serde_json::json!({"path": "main.rs"}),
210 }])
211 }
212 }
213 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
214 Ok(String::new())
215 }
216 }
217
218 struct DummyTool;
219 #[async_trait::async_trait]
220 impl Tool for DummyTool {
221 fn name(&self) -> &str {
222 "read_file"
223 }
224 fn description(&self) -> &str {
225 "read a file"
226 }
227 fn parameters_schema(&self) -> Value {
228 serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}})
229 }
230 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
231 Ok(ToolOutput::text("file contents"))
232 }
233 }
234
235 #[tokio::test]
236 async fn hybrid_two_phases() {
237 let client = MockHybridClient {
238 call_count: Arc::new(AtomicUsize::new(0)),
239 };
240 let agent = HybridAgent::new(client, "test agent");
241 let tools = ToolRegistry::new().register(DummyTool);
242 let msgs = vec![Message::user("read main.rs")];
243
244 let decision = agent.decide(&msgs, &tools).await.unwrap();
245 assert_eq!(decision.situation, "Need to read a file");
246 assert_eq!(decision.task.len(), 2);
247 assert_eq!(decision.tool_calls.len(), 1);
248 assert_eq!(decision.tool_calls[0].name, "read_file");
249 assert!(!decision.completed);
250 }
251
252 #[tokio::test]
253 async fn hybrid_done_still_runs_phase2() {
254 struct DoneClient {
256 call_count: Arc<AtomicUsize>,
257 }
258 #[async_trait::async_trait]
259 impl LlmClient for DoneClient {
260 async fn structured_call(
261 &self,
262 _: &[Message],
263 _: &Value,
264 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
265 Ok((None, vec![], String::new()))
266 }
267 async fn tools_call(
268 &self,
269 _: &[Message],
270 _: &[ToolDef],
271 ) -> Result<Vec<ToolCall>, SgrError> {
272 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
273 if n == 0 {
274 Ok(vec![ToolCall {
275 id: "r1".into(),
276 name: "reasoning".into(),
277 arguments: serde_json::json!({
278 "situation": "Task is already complete",
279 "plan": [],
280 "done": true
281 }),
282 }])
283 } else {
284 Ok(vec![ToolCall {
286 id: "a1".into(),
287 name: "finish_task".into(),
288 arguments: serde_json::json!({"summary": "done"}),
289 }])
290 }
291 }
292 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
293 Ok(String::new())
294 }
295 }
296
297 let agent = HybridAgent::new(
298 DoneClient {
299 call_count: Arc::new(AtomicUsize::new(0)),
300 },
301 "test",
302 );
303 let tools = ToolRegistry::new().register(DummyTool);
304 let msgs = vec![Message::user("done")];
305
306 let decision = agent.decide(&msgs, &tools).await.unwrap();
307 assert!(decision.completed);
309 assert_eq!(decision.tool_calls.len(), 1);
310 assert_eq!(decision.tool_calls[0].name, "finish_task");
311 }
312
313 #[tokio::test]
314 async fn hybrid_no_reasoning_completes() {
315 struct EmptyClient;
316 #[async_trait::async_trait]
317 impl LlmClient for EmptyClient {
318 async fn structured_call(
319 &self,
320 _: &[Message],
321 _: &Value,
322 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
323 Ok((None, vec![], String::new()))
324 }
325 async fn tools_call(
326 &self,
327 _: &[Message],
328 _: &[ToolDef],
329 ) -> Result<Vec<ToolCall>, SgrError> {
330 Ok(vec![])
331 }
332 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
333 Ok(String::new())
334 }
335 }
336
337 let agent = HybridAgent::new(EmptyClient, "test");
338 let tools = ToolRegistry::new().register(DummyTool);
339 let msgs = vec![Message::user("hello")];
340
341 let decision = agent.decide(&msgs, &tools).await.unwrap();
342 assert!(decision.completed);
343 }
344
345 #[tokio::test]
346 async fn hybrid_two_phases_independent() {
347 struct PhaseTrackingClient {
351 call_count: Arc<AtomicUsize>,
352 }
353
354 #[async_trait::async_trait]
355 impl LlmClient for PhaseTrackingClient {
356 async fn structured_call(
357 &self,
358 _: &[Message],
359 _: &Value,
360 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
361 Ok((None, vec![], String::new()))
362 }
363 async fn tools_call(
364 &self,
365 msgs: &[Message],
366 tools: &[ToolDef],
367 ) -> Result<Vec<ToolCall>, SgrError> {
368 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
369 if n == 0 {
370 assert_eq!(tools.len(), 1, "Phase 1 should only have reasoning tool");
372 assert_eq!(tools[0].name, "reasoning");
373 Ok(vec![ToolCall {
374 id: "r1".into(),
375 name: "reasoning".into(),
376 arguments: serde_json::json!({
377 "situation": "Testing phase independence",
378 "plan": ["call read_file"],
379 "done": false
380 }),
381 }])
382 } else {
383 assert!(
385 tools.len() > 1 || tools[0].name != "reasoning",
386 "Phase 2 should have the real tools, not just reasoning"
387 );
388 let last_msg = msgs.last().unwrap();
391 assert_eq!(
392 last_msg.role,
393 crate::types::Role::User,
394 "Last message in phase 2 should be the action prompt"
395 );
396 Ok(vec![ToolCall {
397 id: "a1".into(),
398 name: "read_file".into(),
399 arguments: serde_json::json!({"path": "test.rs"}),
400 }])
401 }
402 }
403 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
404 Ok(String::new())
405 }
406 }
407
408 let call_count = Arc::new(AtomicUsize::new(0));
409 let agent = HybridAgent::new(
410 PhaseTrackingClient {
411 call_count: call_count.clone(),
412 },
413 "test agent",
414 );
415 let tools = ToolRegistry::new().register(DummyTool);
416 let msgs = vec![Message::user("read test.rs")];
417
418 let decision = agent.decide(&msgs, &tools).await.unwrap();
419
420 assert_eq!(call_count.load(Ordering::SeqCst), 2);
422 assert_eq!(decision.tool_calls.len(), 1);
424 assert_eq!(decision.tool_calls[0].name, "read_file");
425 assert!(!decision.completed);
426 }
427}