agents_runtime/
planner.rs

1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13    model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17    pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18        Self { model }
19    }
20}
21
22#[derive(Debug, Deserialize)]
23struct ToolCall {
24    name: String,
25    #[serde(default)]
26    args: Value,
27}
28
29#[derive(Debug, Deserialize)]
30struct PlannerOutput {
31    #[serde(default)]
32    tool_calls: Vec<ToolCall>,
33    #[serde(default)]
34    response: Option<String>,
35}
36
37#[async_trait]
38impl PlannerHandle for LlmBackedPlanner {
39    async fn plan(
40        &self,
41        context: PlannerContext,
42        _state: Arc<AgentStateSnapshot>,
43    ) -> anyhow::Result<PlannerDecision> {
44        let request = LlmRequest {
45            system_prompt: context.system_prompt.clone(),
46            messages: context.history.clone(),
47        };
48        let response = self.model.generate(request).await?;
49        let message = response.message;
50
51        match parse_planner_output(&message)? {
52            PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
53                next_action: PlannerAction::CallTool {
54                    tool_name: name,
55                    payload: args,
56                },
57            }),
58            PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
59                next_action: PlannerAction::Respond {
60                    message: AgentMessage {
61                        role: MessageRole::Agent,
62                        content: MessageContent::Text(text),
63                        metadata: message.metadata,
64                    },
65                },
66            }),
67        }
68    }
69}
70
71enum PlannerOutputVariant {
72    ToolCall { name: String, args: Value },
73    Respond(String),
74}
75
76fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
77    match &message.content {
78        MessageContent::Json(value) => parse_from_value(value.clone()),
79        MessageContent::Text(text) => {
80            // Try to parse JSON even when returned as text, optionally in code fences.
81            if let Some(parsed) = parse_from_text(text) {
82                if let Some(tc) = parsed.tool_calls.first() {
83                    return Ok(PlannerOutputVariant::ToolCall {
84                        name: tc.name.clone(),
85                        args: tc.args.clone(),
86                    });
87                }
88                if let Some(resp) = parsed.response {
89                    return Ok(PlannerOutputVariant::Respond(resp));
90                }
91            }
92            Ok(PlannerOutputVariant::Respond(text.clone()))
93        }
94    }
95}
96
97fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
98    let parsed: PlannerOutput = serde_json::from_value(value)?;
99    if let Some(tool_call) = parsed.tool_calls.first() {
100        Ok(PlannerOutputVariant::ToolCall {
101            name: tool_call.name.clone(),
102            args: tool_call.args.clone(),
103        })
104    } else if let Some(response) = parsed.response {
105        Ok(PlannerOutputVariant::Respond(response))
106    } else {
107        anyhow::bail!("LLM response missing tool call and response fields")
108    }
109}
110
111fn parse_from_text(text: &str) -> Option<PlannerOutput> {
112    // 1) Raw JSON
113    if let Some(parsed) = decode_output_from_str(text) {
114        return Some(parsed);
115    }
116    // 2) Remove common code fences ```json ... ``` or ``` ... ```
117    let trimmed = text.trim();
118    if trimmed.starts_with("```") {
119        let without_ticks = trimmed.trim_start_matches("```");
120        // optional language tag (e.g., json)
121        let without_lang = without_ticks
122            .trim_start_matches(|c: char| c.is_alphabetic())
123            .trim_start();
124        let inner = if let Some(end) = without_lang.rfind("```") {
125            &without_lang[..end]
126        } else {
127            without_lang
128        };
129        if let Some(parsed) = decode_output_from_str(inner) {
130            return Some(parsed);
131        }
132    }
133    None
134}
135
136/// Attempt to decode PlannerOutput from a JSON string; returns None on failure.
137fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
138    serde_json::from_str::<Value>(s)
139        .ok()
140        .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use agents_core::llm::{LanguageModel, LlmResponse};
147    use agents_core::messaging::MessageMetadata;
148    use async_trait::async_trait;
149
150    struct EchoModel;
151
152    #[async_trait]
153    impl LanguageModel for EchoModel {
154        async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
155            Ok(LlmResponse {
156                message: request.messages.last().cloned().unwrap_or(AgentMessage {
157                    role: MessageRole::Agent,
158                    content: MessageContent::Text("".into()),
159                    metadata: None,
160                }),
161            })
162        }
163    }
164
165    #[tokio::test]
166    async fn planner_falls_back_to_text_response() {
167        let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
168        let context = PlannerContext {
169            history: vec![AgentMessage {
170                role: MessageRole::User,
171                content: MessageContent::Text("Hi".into()),
172                metadata: None,
173            }],
174            system_prompt: "Be helpful".into(),
175        };
176
177        let decision = planner
178            .plan(context, Arc::new(AgentStateSnapshot::default()))
179            .await
180            .unwrap();
181
182        match decision.next_action {
183            PlannerAction::Respond { message } => match message.content {
184                MessageContent::Text(text) => assert_eq!(text, "Hi"),
185                other => panic!("expected text, got {other:?}"),
186            },
187            _ => panic!("expected respond"),
188        }
189    }
190
191    struct ToolCallModel;
192
193    #[async_trait]
194    impl LanguageModel for ToolCallModel {
195        async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
196            Ok(LlmResponse {
197                message: AgentMessage {
198                    role: MessageRole::Agent,
199                    content: MessageContent::Json(serde_json::json!({
200                        "tool_calls": [
201                            {
202                                "name": "write_file",
203                                "args": { "path": "notes.txt" }
204                            }
205                        ]
206                    })),
207                    metadata: Some(MessageMetadata {
208                        tool_call_id: Some("call-1".into()),
209                        cache_control: None,
210                    }),
211                },
212            })
213        }
214    }
215
216    #[tokio::test]
217    async fn planner_parses_tool_call() {
218        let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
219        let decision = planner
220            .plan(
221                PlannerContext {
222                    history: vec![],
223                    system_prompt: "System".into(),
224                },
225                Arc::new(AgentStateSnapshot::default()),
226            )
227            .await
228            .unwrap();
229
230        match decision.next_action {
231            PlannerAction::CallTool { tool_name, payload } => {
232                assert_eq!(tool_name, "write_file");
233                assert_eq!(payload["path"], "notes.txt");
234            }
235            _ => panic!("expected tool call"),
236        }
237    }
238}