agents_runtime/
planner.rs

1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13    model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17    pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18        Self { model }
19    }
20
21    /// Get the underlying language model for direct access (e.g., streaming)
22    pub fn model(&self) -> &Arc<dyn LanguageModel> {
23        &self.model
24    }
25}
26
27#[derive(Debug, Deserialize)]
28struct ToolCall {
29    name: String,
30    #[serde(default)]
31    args: Value,
32}
33
34#[derive(Debug, Deserialize)]
35struct PlannerOutput {
36    #[serde(default)]
37    tool_calls: Vec<ToolCall>,
38    #[serde(default)]
39    response: Option<String>,
40}
41
42#[async_trait]
43impl PlannerHandle for LlmBackedPlanner {
44    async fn plan(
45        &self,
46        context: PlannerContext,
47        _state: Arc<AgentStateSnapshot>,
48    ) -> anyhow::Result<PlannerDecision> {
49        let request = LlmRequest::new(context.system_prompt.clone(), context.history.clone());
50        let response = self.model.generate(request).await?;
51        let message = response.message;
52
53        match parse_planner_output(&message)? {
54            PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
55                next_action: PlannerAction::CallTool {
56                    tool_name: name,
57                    payload: args,
58                },
59            }),
60            PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
61                next_action: PlannerAction::Respond {
62                    message: AgentMessage {
63                        role: MessageRole::Agent,
64                        content: MessageContent::Text(text),
65                        metadata: message.metadata,
66                    },
67                },
68            }),
69        }
70    }
71
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75}
76
77enum PlannerOutputVariant {
78    ToolCall { name: String, args: Value },
79    Respond(String),
80}
81
82fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
83    match &message.content {
84        MessageContent::Json(value) => parse_from_value(value.clone()),
85        MessageContent::Text(text) => {
86            // Try to parse JSON even when returned as text, optionally in code fences.
87            if let Some(parsed) = parse_from_text(text) {
88                if let Some(tc) = parsed.tool_calls.first() {
89                    return Ok(PlannerOutputVariant::ToolCall {
90                        name: tc.name.clone(),
91                        args: tc.args.clone(),
92                    });
93                }
94                if let Some(resp) = parsed.response {
95                    return Ok(PlannerOutputVariant::Respond(resp));
96                }
97            }
98            Ok(PlannerOutputVariant::Respond(text.clone()))
99        }
100    }
101}
102
103fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
104    let parsed: PlannerOutput = serde_json::from_value(value)?;
105    if let Some(tool_call) = parsed.tool_calls.first() {
106        Ok(PlannerOutputVariant::ToolCall {
107            name: tool_call.name.clone(),
108            args: tool_call.args.clone(),
109        })
110    } else if let Some(response) = parsed.response {
111        Ok(PlannerOutputVariant::Respond(response))
112    } else {
113        anyhow::bail!("LLM response missing tool call and response fields")
114    }
115}
116
117fn parse_from_text(text: &str) -> Option<PlannerOutput> {
118    // 1) Raw JSON
119    if let Some(parsed) = decode_output_from_str(text) {
120        return Some(parsed);
121    }
122    // 2) Remove common code fences ```json ... ``` or ``` ... ```
123    let trimmed = text.trim();
124    if trimmed.starts_with("```") {
125        let without_ticks = trimmed.trim_start_matches("```");
126        // optional language tag (e.g., json)
127        let without_lang = without_ticks
128            .trim_start_matches(|c: char| c.is_alphabetic())
129            .trim_start();
130        let inner = if let Some(end) = without_lang.rfind("```") {
131            &without_lang[..end]
132        } else {
133            without_lang
134        };
135        if let Some(parsed) = decode_output_from_str(inner) {
136            return Some(parsed);
137        }
138    }
139    None
140}
141
142/// Attempt to decode PlannerOutput from a JSON string; returns None on failure.
143fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
144    serde_json::from_str::<Value>(s)
145        .ok()
146        .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use agents_core::llm::{LanguageModel, LlmResponse};
153    use agents_core::messaging::MessageMetadata;
154    use async_trait::async_trait;
155
156    struct EchoModel;
157
158    #[async_trait]
159    impl LanguageModel for EchoModel {
160        async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
161            Ok(LlmResponse {
162                message: request.messages.last().cloned().unwrap_or(AgentMessage {
163                    role: MessageRole::Agent,
164                    content: MessageContent::Text("".into()),
165                    metadata: None,
166                }),
167            })
168        }
169    }
170
171    #[tokio::test]
172    async fn planner_falls_back_to_text_response() {
173        let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
174        let context = PlannerContext {
175            history: vec![AgentMessage {
176                role: MessageRole::User,
177                content: MessageContent::Text("Hi".into()),
178                metadata: None,
179            }],
180            system_prompt: "Be helpful".into(),
181        };
182
183        let decision = planner
184            .plan(context, Arc::new(AgentStateSnapshot::default()))
185            .await
186            .unwrap();
187
188        match decision.next_action {
189            PlannerAction::Respond { message } => match message.content {
190                MessageContent::Text(text) => assert_eq!(text, "Hi"),
191                other => panic!("expected text, got {other:?}"),
192            },
193            _ => panic!("expected respond"),
194        }
195    }
196
197    struct ToolCallModel;
198
199    #[async_trait]
200    impl LanguageModel for ToolCallModel {
201        async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
202            Ok(LlmResponse {
203                message: AgentMessage {
204                    role: MessageRole::Agent,
205                    content: MessageContent::Json(serde_json::json!({
206                        "tool_calls": [
207                            {
208                                "name": "write_file",
209                                "args": { "path": "notes.txt" }
210                            }
211                        ]
212                    })),
213                    metadata: Some(MessageMetadata {
214                        tool_call_id: Some("call-1".into()),
215                        cache_control: None,
216                    }),
217                },
218            })
219        }
220    }
221
222    #[tokio::test]
223    async fn planner_parses_tool_call() {
224        let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
225        let decision = planner
226            .plan(
227                PlannerContext {
228                    history: vec![],
229                    system_prompt: "System".into(),
230                },
231                Arc::new(AgentStateSnapshot::default()),
232            )
233            .await
234            .unwrap();
235
236        match decision.next_action {
237            PlannerAction::CallTool { tool_name, payload } => {
238                assert_eq!(tool_name, "write_file");
239                assert_eq!(payload["path"], "notes.txt");
240            }
241            _ => panic!("expected tool call"),
242        }
243    }
244}