agents_runtime/
planner.rs1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13 model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17 pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18 Self { model }
19 }
20}
21
22#[derive(Debug, Deserialize)]
23struct ToolCall {
24 name: String,
25 #[serde(default)]
26 args: Value,
27}
28
29#[derive(Debug, Deserialize)]
30struct PlannerOutput {
31 #[serde(default)]
32 tool_calls: Vec<ToolCall>,
33 #[serde(default)]
34 response: Option<String>,
35}
36
37#[async_trait]
38impl PlannerHandle for LlmBackedPlanner {
39 async fn plan(
40 &self,
41 context: PlannerContext,
42 _state: Arc<AgentStateSnapshot>,
43 ) -> anyhow::Result<PlannerDecision> {
44 let request = LlmRequest {
45 system_prompt: context.system_prompt.clone(),
46 messages: context.history.clone(),
47 };
48 let response = self.model.generate(request).await?;
49 let message = response.message;
50
51 match parse_planner_output(&message)? {
52 PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
53 next_action: PlannerAction::CallTool {
54 tool_name: name,
55 payload: args,
56 },
57 }),
58 PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
59 next_action: PlannerAction::Respond {
60 message: AgentMessage {
61 role: MessageRole::Agent,
62 content: MessageContent::Text(text),
63 metadata: message.metadata,
64 },
65 },
66 }),
67 }
68 }
69}
70
71enum PlannerOutputVariant {
72 ToolCall { name: String, args: Value },
73 Respond(String),
74}
75
76fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
77 match &message.content {
78 MessageContent::Json(value) => parse_from_value(value.clone()),
79 MessageContent::Text(text) => {
80 if let Some(parsed) = parse_from_text(text) {
82 if let Some(tc) = parsed.tool_calls.first() {
83 return Ok(PlannerOutputVariant::ToolCall {
84 name: tc.name.clone(),
85 args: tc.args.clone(),
86 });
87 }
88 if let Some(resp) = parsed.response {
89 return Ok(PlannerOutputVariant::Respond(resp));
90 }
91 }
92 Ok(PlannerOutputVariant::Respond(text.clone()))
93 }
94 }
95}
96
97fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
98 let parsed: PlannerOutput = serde_json::from_value(value)?;
99 if let Some(tool_call) = parsed.tool_calls.first() {
100 Ok(PlannerOutputVariant::ToolCall {
101 name: tool_call.name.clone(),
102 args: tool_call.args.clone(),
103 })
104 } else if let Some(response) = parsed.response {
105 Ok(PlannerOutputVariant::Respond(response))
106 } else {
107 anyhow::bail!("LLM response missing tool call and response fields")
108 }
109}
110
111fn parse_from_text(text: &str) -> Option<PlannerOutput> {
112 if let Some(parsed) = decode_output_from_str(text) {
114 return Some(parsed);
115 }
116 let trimmed = text.trim();
118 if trimmed.starts_with("```") {
119 let without_ticks = trimmed.trim_start_matches("```");
120 let without_lang = without_ticks
122 .trim_start_matches(|c: char| c.is_alphabetic())
123 .trim_start();
124 let inner = if let Some(end) = without_lang.rfind("```") {
125 &without_lang[..end]
126 } else {
127 without_lang
128 };
129 if let Some(parsed) = decode_output_from_str(inner) {
130 return Some(parsed);
131 }
132 }
133 None
134}
135
136fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
138 serde_json::from_str::<Value>(s)
139 .ok()
140 .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use agents_core::llm::{LanguageModel, LlmResponse};
147 use agents_core::messaging::MessageMetadata;
148 use async_trait::async_trait;
149
150 struct EchoModel;
151
152 #[async_trait]
153 impl LanguageModel for EchoModel {
154 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
155 Ok(LlmResponse {
156 message: request.messages.last().cloned().unwrap_or(AgentMessage {
157 role: MessageRole::Agent,
158 content: MessageContent::Text("".into()),
159 metadata: None,
160 }),
161 })
162 }
163 }
164
165 #[tokio::test]
166 async fn planner_falls_back_to_text_response() {
167 let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
168 let context = PlannerContext {
169 history: vec![AgentMessage {
170 role: MessageRole::User,
171 content: MessageContent::Text("Hi".into()),
172 metadata: None,
173 }],
174 system_prompt: "Be helpful".into(),
175 };
176
177 let decision = planner
178 .plan(context, Arc::new(AgentStateSnapshot::default()))
179 .await
180 .unwrap();
181
182 match decision.next_action {
183 PlannerAction::Respond { message } => match message.content {
184 MessageContent::Text(text) => assert_eq!(text, "Hi"),
185 other => panic!("expected text, got {other:?}"),
186 },
187 _ => panic!("expected respond"),
188 }
189 }
190
191 struct ToolCallModel;
192
193 #[async_trait]
194 impl LanguageModel for ToolCallModel {
195 async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
196 Ok(LlmResponse {
197 message: AgentMessage {
198 role: MessageRole::Agent,
199 content: MessageContent::Json(serde_json::json!({
200 "tool_calls": [
201 {
202 "name": "write_file",
203 "args": { "path": "notes.txt" }
204 }
205 ]
206 })),
207 metadata: Some(MessageMetadata {
208 tool_call_id: Some("call-1".into()),
209 cache_control: None,
210 }),
211 },
212 })
213 }
214 }
215
216 #[tokio::test]
217 async fn planner_parses_tool_call() {
218 let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
219 let decision = planner
220 .plan(
221 PlannerContext {
222 history: vec![],
223 system_prompt: "System".into(),
224 },
225 Arc::new(AgentStateSnapshot::default()),
226 )
227 .await
228 .unwrap();
229
230 match decision.next_action {
231 PlannerAction::CallTool { tool_name, payload } => {
232 assert_eq!(tool_name, "write_file");
233 assert_eq!(payload["path"], "notes.txt");
234 }
235 _ => panic!("expected tool call"),
236 }
237 }
238}