agents_runtime/
planner.rs1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13 model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17 pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18 Self { model }
19 }
20
21 pub fn model(&self) -> &Arc<dyn LanguageModel> {
23 &self.model
24 }
25}
26
27#[derive(Debug, Deserialize)]
28struct ToolCall {
29 name: String,
30 #[serde(default)]
31 args: Value,
32}
33
34#[derive(Debug, Deserialize)]
35struct PlannerOutput {
36 #[serde(default)]
37 tool_calls: Vec<ToolCall>,
38 #[serde(default)]
39 response: Option<String>,
40}
41
42#[async_trait]
43impl PlannerHandle for LlmBackedPlanner {
44 async fn plan(
45 &self,
46 context: PlannerContext,
47 _state: Arc<AgentStateSnapshot>,
48 ) -> anyhow::Result<PlannerDecision> {
49 let request = LlmRequest::new(context.system_prompt.clone(), context.history.clone());
50 let response = self.model.generate(request).await?;
51 let message = response.message;
52
53 match parse_planner_output(&message)? {
54 PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
55 next_action: PlannerAction::CallTool {
56 tool_name: name,
57 payload: args,
58 },
59 }),
60 PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
61 next_action: PlannerAction::Respond {
62 message: AgentMessage {
63 role: MessageRole::Agent,
64 content: MessageContent::Text(text),
65 metadata: message.metadata,
66 },
67 },
68 }),
69 }
70 }
71
72 fn as_any(&self) -> &dyn std::any::Any {
73 self
74 }
75}
76
77enum PlannerOutputVariant {
78 ToolCall { name: String, args: Value },
79 Respond(String),
80}
81
82fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
83 match &message.content {
84 MessageContent::Json(value) => parse_from_value(value.clone()),
85 MessageContent::Text(text) => {
86 if let Some(parsed) = parse_from_text(text) {
88 if let Some(tc) = parsed.tool_calls.first() {
89 return Ok(PlannerOutputVariant::ToolCall {
90 name: tc.name.clone(),
91 args: tc.args.clone(),
92 });
93 }
94 if let Some(resp) = parsed.response {
95 return Ok(PlannerOutputVariant::Respond(resp));
96 }
97 }
98 Ok(PlannerOutputVariant::Respond(text.clone()))
99 }
100 }
101}
102
103fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
104 let parsed: PlannerOutput = serde_json::from_value(value)?;
105 if let Some(tool_call) = parsed.tool_calls.first() {
106 Ok(PlannerOutputVariant::ToolCall {
107 name: tool_call.name.clone(),
108 args: tool_call.args.clone(),
109 })
110 } else if let Some(response) = parsed.response {
111 Ok(PlannerOutputVariant::Respond(response))
112 } else {
113 anyhow::bail!("LLM response missing tool call and response fields")
114 }
115}
116
117fn parse_from_text(text: &str) -> Option<PlannerOutput> {
118 if let Some(parsed) = decode_output_from_str(text) {
120 return Some(parsed);
121 }
122 let trimmed = text.trim();
124 if trimmed.starts_with("```") {
125 let without_ticks = trimmed.trim_start_matches("```");
126 let without_lang = without_ticks
128 .trim_start_matches(|c: char| c.is_alphabetic())
129 .trim_start();
130 let inner = if let Some(end) = without_lang.rfind("```") {
131 &without_lang[..end]
132 } else {
133 without_lang
134 };
135 if let Some(parsed) = decode_output_from_str(inner) {
136 return Some(parsed);
137 }
138 }
139 None
140}
141
142fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
144 serde_json::from_str::<Value>(s)
145 .ok()
146 .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use agents_core::llm::{LanguageModel, LlmResponse};
153 use agents_core::messaging::MessageMetadata;
154 use async_trait::async_trait;
155
156 struct EchoModel;
157
158 #[async_trait]
159 impl LanguageModel for EchoModel {
160 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
161 Ok(LlmResponse {
162 message: request.messages.last().cloned().unwrap_or(AgentMessage {
163 role: MessageRole::Agent,
164 content: MessageContent::Text("".into()),
165 metadata: None,
166 }),
167 })
168 }
169 }
170
171 #[tokio::test]
172 async fn planner_falls_back_to_text_response() {
173 let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
174 let context = PlannerContext {
175 history: vec![AgentMessage {
176 role: MessageRole::User,
177 content: MessageContent::Text("Hi".into()),
178 metadata: None,
179 }],
180 system_prompt: "Be helpful".into(),
181 };
182
183 let decision = planner
184 .plan(context, Arc::new(AgentStateSnapshot::default()))
185 .await
186 .unwrap();
187
188 match decision.next_action {
189 PlannerAction::Respond { message } => match message.content {
190 MessageContent::Text(text) => assert_eq!(text, "Hi"),
191 other => panic!("expected text, got {other:?}"),
192 },
193 _ => panic!("expected respond"),
194 }
195 }
196
197 struct ToolCallModel;
198
199 #[async_trait]
200 impl LanguageModel for ToolCallModel {
201 async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
202 Ok(LlmResponse {
203 message: AgentMessage {
204 role: MessageRole::Agent,
205 content: MessageContent::Json(serde_json::json!({
206 "tool_calls": [
207 {
208 "name": "write_file",
209 "args": { "path": "notes.txt" }
210 }
211 ]
212 })),
213 metadata: Some(MessageMetadata {
214 tool_call_id: Some("call-1".into()),
215 cache_control: None,
216 }),
217 },
218 })
219 }
220 }
221
222 #[tokio::test]
223 async fn planner_parses_tool_call() {
224 let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
225 let decision = planner
226 .plan(
227 PlannerContext {
228 history: vec![],
229 system_prompt: "System".into(),
230 },
231 Arc::new(AgentStateSnapshot::default()),
232 )
233 .await
234 .unwrap();
235
236 match decision.next_action {
237 PlannerAction::CallTool { tool_name, payload } => {
238 assert_eq!(tool_name, "write_file");
239 assert_eq!(payload["path"], "notes.txt");
240 }
241 _ => panic!("expected tool call"),
242 }
243 }
244}