1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role};
5
6pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
8pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
9pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
10pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ModelRole {
16 Main,
18 Plan,
20 Compress,
22 Fast,
24}
25
26impl ModelRole {
27 pub fn default_model(&self) -> &'static str {
28 match self {
29 ModelRole::Main => DEFAULT_MAIN_MODEL,
30 ModelRole::Plan => DEFAULT_PLAN_MODEL,
31 ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
32 ModelRole::Fast => DEFAULT_FAST_MODEL,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ModelConfig {
40 pub name: String,
42 pub max_tokens: u32,
44 pub think: bool,
46 pub context_size: Option<u32>,
48}
49
50impl ModelConfig {
51 pub fn new(name: String) -> Self {
52 Self {
53 name: name.clone(),
54 max_tokens: 16384,
55 think: true,
56 context_size: infer_context_size(&name),
57 }
58 }
59
60 pub fn for_role(role: ModelRole) -> Self {
62 let name = role.default_model().to_string();
63 match role {
64 ModelRole::Main => Self::new(name),
65 ModelRole::Plan => Self::new(name),
66 ModelRole::Compress => Self {
67 name,
68 max_tokens: 1024,
69 think: false,
70 context_size: Some(200_000),
71 },
72 ModelRole::Fast => Self {
73 name,
74 max_tokens: 2048,
75 think: false,
76 context_size: Some(200_000),
77 },
78 }
79 }
80
81 pub fn display_name(&self) -> &str {
82 &self.name
83 }
84}
85
86fn infer_context_size(model: &str) -> Option<u32> {
88 let m = model.to_ascii_lowercase();
89
90 if m.contains("opus-4-7") || m.contains("opus-4.7") || m.contains("[1m]") {
92 return Some(1_000_000);
93 }
94 if m.contains("claude") {
95 return Some(200_000);
96 }
97
98 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
100 return Some(128_000);
101 }
102 if m.contains("o1") {
103 return Some(200_000);
104 }
105 if m.contains("gpt-4-32k") {
106 return Some(32_768);
107 }
108 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
109 return Some(8_192);
110 }
111 if m.contains("gpt-3.5") {
112 return Some(16_384);
113 }
114
115 if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
117 return Some(128_000);
118 }
119 if m.contains("deepseek") {
120 return Some(64_000);
121 }
122
123 None
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct MultiModelConfig {
129 pub main: ModelConfig,
131 pub plan: ModelConfig,
133 pub compress: ModelConfig,
135 pub fast: ModelConfig,
137}
138
139impl Default for MultiModelConfig {
140 fn default() -> Self {
141 Self {
142 main: ModelConfig::for_role(ModelRole::Main),
143 plan: ModelConfig::for_role(ModelRole::Plan),
144 compress: ModelConfig::for_role(ModelRole::Compress),
145 fast: ModelConfig::for_role(ModelRole::Fast),
146 }
147 }
148}
149
150impl MultiModelConfig {
151 pub fn with_main(main_model: String) -> Self {
154 let main_config = ModelConfig::new(main_model);
155 Self {
156 main: main_config.clone(),
157 plan: main_config.clone(),
158 compress: main_config.clone(),
159 fast: main_config,
160 }
161 }
162
163 pub fn unified(model: String) -> Self {
165 let config = ModelConfig::new(model);
166 Self {
167 main: config.clone(),
168 plan: config.clone(),
169 compress: config.clone(),
170 fast: config,
171 }
172 }
173
174 pub fn get(&self, role: ModelRole) -> &ModelConfig {
176 match role {
177 ModelRole::Main => &self.main,
178 ModelRole::Plan => &self.plan,
179 ModelRole::Compress => &self.compress,
180 ModelRole::Fast => &self.fast,
181 }
182 }
183
184 pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
186 match role {
187 ModelRole::Main => self.main = config,
188 ModelRole::Plan => self.plan = config,
189 ModelRole::Compress => self.compress = config,
190 ModelRole::Fast => self.fast = config,
191 }
192 }
193
194 pub fn format_summary(&self) -> String {
196 format!(
197 "main: {}, plan: {}, compress: {}, fast: {}",
198 self.main.display_name(),
199 self.plan.display_name(),
200 self.compress.display_name(),
201 self.fast.display_name()
202 )
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum TaskComplexity {
210 Simple,
211 Moderate,
212 Complex,
213}
214
215impl TaskComplexity {
216 pub fn display(&self) -> &'static str {
217 match self {
218 TaskComplexity::Simple => "简单",
219 TaskComplexity::Moderate => "中等",
220 TaskComplexity::Complex => "复杂",
221 }
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
227#[serde(rename_all = "snake_case")]
228pub enum StepDifficulty {
229 Easy,
230 Medium,
231 Hard,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct PlanStep {
237 pub description: String,
239 pub tools: Vec<String>,
241 pub optional: bool,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TaskPlan {
248 pub request: String,
250 pub steps: Vec<PlanStep>,
252 pub complexity: TaskComplexity,
254 pub approach: String,
256 pub considerations: Vec<String>,
258}
259
260impl TaskPlan {
261 pub fn format(&self) -> String {
263 let mut output = String::new();
264
265 output.push_str(&format!("任务分析: {}\n", self.request));
266 output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
267 output.push_str(&format!("建议方案: {}\n\n", self.approach));
268
269 output.push_str("执行步骤:\n");
270 for (i, step) in self.steps.iter().enumerate() {
271 let marker = if step.optional { "[可选]" } else { "" };
272 output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
273 if !step.tools.is_empty() {
274 output.push_str(&format!(" 工具: {}\n", step.tools.join(", ")));
275 }
276 }
277
278 if !self.considerations.is_empty() {
279 output.push_str("\n注意事项:\n");
280 for c in &self.considerations {
281 output.push_str(&format!("• {}\n", c));
282 }
283 }
284
285 output
286 }
287
288 pub fn to_todo_items(&self) -> Vec<TodoItem> {
290 self.steps
291 .iter()
292 .enumerate()
293 .map(|(i, step)| TodoItem {
294 content: step.description.clone(),
295 active_form: format!("执行步骤 {}: {}", i + 1, step.description),
296 status: if i == 0 { "in_progress".to_string() } else { "pending".to_string() },
297 })
298 .collect()
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct TodoItem {
305 pub content: String,
306 pub active_form: String,
307 pub status: String,
308}
309
310pub struct Planner {
312 provider: Box<dyn Provider>,
313 config: ModelConfig,
314}
315
316impl Planner {
317 pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
319 Self { provider, config }
320 }
321
322 pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
324 let prompt = build_plan_prompt(request, available_tools);
325
326 let chat_request = ChatRequest {
327 messages: vec![Message {
328 role: Role::User,
329 content: MessageContent::Text(prompt),
330 }],
331 tools: vec![],
332 system: Some(PLAN_SYSTEM_PROMPT.to_string()),
333 think: false,
334 max_tokens: self.config.max_tokens,
335 server_tools: vec![],
336 enable_caching: false,
337 };
338
339 let response = self.provider.chat(chat_request).await?;
340 let text = extract_text(&response);
341
342 parse_plan_response(request, &text)
343 }
344
345 pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
347 let prompt = format!(
348 "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
349 request
350 );
351
352 let chat_request = ChatRequest {
353 messages: vec![Message {
354 role: Role::User,
355 content: MessageContent::Text(prompt),
356 }],
357 tools: vec![],
358 system: None,
359 think: false,
360 max_tokens: 50,
361 server_tools: vec![],
362 enable_caching: false,
363 };
364
365 let response = self.provider.chat(chat_request).await?;
366 let text = extract_text(&response).to_lowercase();
367
368 if text.contains("简单") || text.contains("simple") {
369 Ok(TaskComplexity::Simple)
370 } else if text.contains("复杂") || text.contains("complex") {
371 Ok(TaskComplexity::Complex)
372 } else {
373 Ok(TaskComplexity::Moderate)
374 }
375 }
376}
377
378const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
380
381输出要求(JSON格式):
382```json
383{
384 "complexity": "simple|moderate|complex",
385 "approach": "建议的方案(一句话)",
386 "steps": [
387 {
388 "description": "步骤描述",
389 "tools": ["需要的工具"],
390 "optional": false
391 }
392 ],
393 "considerations": ["注意事项"]
394}
395```
396
397规划原则:
3981. 简单任务(如读取文件、简单查询)只需1-2步
3992. 中等任务(如修改代码、添加功能)需要3-5步
4003. 复杂任务(如重构、跨模块修改)需要详细规划
4014. 每个步骤要具体、可执行
4025. 标记可选步骤和潜在风险"#;
403
404fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
406 format!(
407 r#"用户请求:
408{}
409
410可用工具:
411{}
412
413请分析任务并生成执行计划(JSON格式)。"#,
414 request,
415 available_tools.join(", ")
416 )
417}
418
419fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
421 if let Some(json_start) = text.find('{')
423 && let Some(json_end) = text.rfind('}') {
424 let json_str = &text[json_start..=json_end];
425 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
426 return Ok(TaskPlan {
427 request: request.to_string(),
428 steps: parse_steps(&parsed["steps"]),
429 complexity: parse_complexity(&parsed["complexity"]),
430 approach: parsed["approach"].as_str().unwrap_or("直接执行").to_string(),
431 considerations: parsed["considerations"]
432 .as_array()
433 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
434 .unwrap_or_default(),
435 });
436 }
437 }
438
439 Ok(TaskPlan {
441 request: request.to_string(),
442 steps: parse_steps_from_text(text),
443 complexity: TaskComplexity::Moderate,
444 approach: "按步骤执行".to_string(),
445 considerations: vec!["请检查执行结果".to_string()],
446 })
447}
448
449fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
450 value.as_array()
451 .map(|arr| arr.iter().filter_map(|v| {
452 Some(PlanStep {
453 description: v["description"].as_str()?.to_string(),
454 tools: v["tools"].as_array()
455 .map(|t| t.iter().filter_map(|x| x.as_str().map(String::from)).collect())
456 .unwrap_or_default(),
457 optional: v["optional"].as_bool().unwrap_or(false),
458 })
459 }).collect())
460 .unwrap_or_default()
461}
462
463fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
464 match value.as_str().map(|s| s.to_lowercase()) {
465 Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
466 Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
467 _ => TaskComplexity::Moderate,
468 }
469}
470
471fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
472 text.lines()
473 .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
474 .take(5)
475 .map(|l| PlanStep {
476 description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
477 tools: vec!["read".to_string()],
478 optional: false,
479 })
480 .collect()
481}
482
483fn extract_text(response: &ChatResponse) -> String {
484 response.content
485 .iter()
486 .filter_map(|block| {
487 if let ContentBlock::Text { text } = block {
488 Some(text.clone())
489 } else {
490 None
491 }
492 })
493 .collect::<Vec<_>>()
494 .join("\n")
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_model_config_defaults() {
503 let main = ModelConfig::for_role(ModelRole::Main);
504 assert!(main.name.contains("claude"));
505 assert!(main.think);
506
507 let compress = ModelConfig::for_role(ModelRole::Compress);
508 assert!(compress.name.contains("haiku"));
509 assert!(!compress.think);
510 }
511
512 #[test]
513 fn test_infer_context_size() {
514 assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
515 assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
516 assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
517 }
518
519 #[test]
520 fn test_multi_model_config() {
521 let config = MultiModelConfig::default();
522 assert!(config.main.name.contains("sonnet"));
523 assert!(config.compress.name.contains("haiku"));
524 }
525
526 #[test]
527 fn test_multi_model_config_with_main() {
528 let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
530
531 assert_eq!(config.main.name, "claude-sonnet-4");
533 assert_eq!(config.plan.name, "claude-sonnet-4");
534 assert_eq!(config.compress.name, "claude-sonnet-4");
535 assert_eq!(config.fast.name, "claude-sonnet-4");
536
537 assert!(config.main.think);
539 assert!(config.plan.think);
540 assert!(config.compress.think);
541 assert!(config.fast.think);
542 }
543
544 #[test]
545 fn test_multi_model_config_override() {
546 let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
547
548 config.set(ModelRole::Compress, ModelConfig::new("claude-3-5-haiku".to_string()));
550
551 assert_eq!(config.main.name, "claude-sonnet-4");
552 assert_eq!(config.plan.name, "claude-sonnet-4");
553 assert_eq!(config.compress.name, "claude-3-5-haiku");
554 assert_eq!(config.fast.name, "claude-sonnet-4"); }
556
557 #[test]
558 fn test_task_plan_format() {
559 let plan = TaskPlan {
560 request: "测试任务".to_string(),
561 steps: vec![
562 PlanStep {
563 description: "读取文件".to_string(),
564 tools: vec!["read".to_string()],
565 optional: false,
566 },
567 ],
568 complexity: TaskComplexity::Simple,
569 approach: "直接执行".to_string(),
570 considerations: vec!["注意检查".to_string()],
571 };
572
573 let formatted = plan.format();
574 assert!(formatted.contains("测试任务"));
575 assert!(formatted.contains("简单"));
576 assert!(formatted.contains("读取文件"));
577 }
578
579 #[test]
580 fn test_complexity_display() {
581 assert_eq!(TaskComplexity::Simple.display(), "简单");
582 assert_eq!(TaskComplexity::Moderate.display(), "中等");
583 assert_eq!(TaskComplexity::Complex.display(), "复杂");
584 }
585
586 #[test]
587 fn test_task_plan_to_todo() {
588 let plan = TaskPlan {
589 request: "任务".to_string(),
590 steps: vec![
591 PlanStep { description: "步骤1".to_string(), tools: vec![], optional: false },
592 PlanStep { description: "步骤2".to_string(), tools: vec![], optional: false },
593 ],
594 complexity: TaskComplexity::Simple,
595 approach: "执行".to_string(),
596 considerations: vec![],
597 };
598
599 let todos = plan.to_todo_items();
600 assert_eq!(todos.len(), 2);
601 assert_eq!(todos[0].status, "in_progress");
602 assert_eq!(todos[1].status, "pending");
603 }
604
605 #[test]
606 fn test_parse_plan_response_json() {
607 let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
608 let plan = parse_plan_response("test", json).unwrap();
609
610 assert_eq!(plan.complexity, TaskComplexity::Simple);
611 assert_eq!(plan.steps.len(), 1);
612 assert_eq!(plan.steps[0].description, "read file");
613 }
614}