1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::constants::{COMPRESS_MAX_TOKENS, DEFAULT_MAX_TOKENS, FAST_MAX_TOKENS};
5use crate::providers::{
6 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
7};
8
9pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
12pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
13pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ModelRole {
19 Main,
21 Plan,
23 Compress,
25 Fast,
27}
28
29impl ModelRole {
30 pub fn default_model(&self) -> &'static str {
31 match self {
32 ModelRole::Main => DEFAULT_MAIN_MODEL,
33 ModelRole::Plan => DEFAULT_PLAN_MODEL,
34 ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
35 ModelRole::Fast => DEFAULT_FAST_MODEL,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43 pub name: String,
45 pub max_tokens: u32,
47 pub think: bool,
49 pub context_size: Option<u32>,
51}
52
53impl ModelConfig {
54 pub fn new(name: String) -> Self {
55 Self {
56 name: name.clone(),
57 max_tokens: DEFAULT_MAX_TOKENS,
58 think: true,
59 context_size: infer_context_size(&name),
60 }
61 }
62
63 pub fn for_role(role: ModelRole) -> Self {
65 let name = role.default_model().to_string();
66 match role {
67 ModelRole::Main => Self::new(name),
68 ModelRole::Plan => Self::new(name),
69 ModelRole::Compress => Self {
70 name,
71 max_tokens: COMPRESS_MAX_TOKENS,
72 think: false,
73 context_size: Some(200_000),
74 },
75 ModelRole::Fast => Self {
76 name,
77 max_tokens: FAST_MAX_TOKENS,
78 think: false,
79 context_size: Some(200_000),
80 },
81 }
82 }
83
84 pub fn display_name(&self) -> &str {
85 &self.name
86 }
87}
88
89pub fn context_window_for(model: &str) -> Option<u32> {
114 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
116 && let Ok(n) = raw.trim().parse::<u32>()
117 && n > 0
118 {
119 return Some(n);
120 }
121
122 let m = model.to_ascii_lowercase();
123
124 if m.contains("1m") || m.contains("opus-4-7") || m.contains("opus-4.7") {
129 return Some(1_000_000);
130 }
131 if m.contains("claude-3")
133 || m.contains("claude-4")
134 || m.contains("claude-opus")
135 || m.contains("claude-sonnet")
136 || m.contains("claude-haiku")
137 {
138 return Some(200_000);
139 }
140 if m.contains("claude-2") || m.contains("claude-instant") {
142 return Some(100_000);
143 }
144
145 if m.contains("o1") || m.contains("o3") || m.contains("o4") {
150 return Some(200_000);
151 }
152 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") || m.contains("gpt-4.1") {
154 return Some(128_000);
155 }
156 if m.contains("gpt-4-32k") {
158 return Some(32_768);
159 }
160 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") && !m.contains("4o") {
162 return Some(8_192);
163 }
164 if m.contains("gpt-3.5-turbo-16k") {
166 return Some(16_384);
167 }
168 if m.contains("gpt-3.5") {
169 return Some(4_096);
170 }
171
172 if m.contains("gemini-2") && m.contains("pro") {
177 return Some(2_000_000);
178 }
179 if m.contains("gemini-1.5-pro") || m.contains("gemini-1.5") && m.contains("pro") {
181 return Some(2_000_000);
182 }
183 if m.contains("gemini-1.5") {
184 return Some(1_000_000); }
186 if m.contains("gemini") {
188 return Some(32_000);
189 }
190
191 if m.contains("glm-5") || m.contains("glm-4-long") {
196 return Some(1_000_000);
197 }
198 if m.contains("glm-4") {
200 return Some(128_000);
201 }
202 if m.contains("glm") {
204 return Some(128_000);
205 }
206
207 if m.contains("deepseek-v3") || m.contains("deepseek-r1") || m.contains("deepseek-v3-") {
212 return Some(128_000);
213 }
214 if m.contains("deepseek") {
216 return Some(64_000);
217 }
218
219 if m.contains("qwen-long") || m.contains("qwen2.5-turbo") || m.contains("qwen-turbo") {
224 return Some(1_000_000);
225 }
226 if m.contains("qwen2.5") || m.contains("qwen3") || m.contains("qwen-max") {
228 return Some(128_000);
229 }
230 if m.contains("qwen2") {
232 return Some(32_000);
233 }
234 if m.contains("qwen") {
236 return Some(8_192);
237 }
238
239 if m.contains("kimi") || m.contains("moonshot") {
245 return Some(2_000_000);
246 }
247
248 if m.contains("llama-3") || m.contains("llama3") {
253 if m.contains("70b") || m.contains("405b") {
254 return Some(128_000);
255 }
256 return Some(8_192);
257 }
258
259 if m.contains("mistral-large") || m.contains("mistral") && m.contains("large") {
264 return Some(128_000);
265 }
266 if m.contains("mistral-medium") {
268 return Some(32_000);
269 }
270 if m.contains("mistral") {
272 return Some(32_000);
273 }
274
275 None
279}
280
281fn infer_context_size(model: &str) -> Option<u32> {
283 context_window_for(model)
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct MultiModelConfig {
289 pub main: ModelConfig,
291 pub plan: ModelConfig,
293 pub compress: ModelConfig,
295 pub fast: ModelConfig,
297}
298
299impl Default for MultiModelConfig {
300 fn default() -> Self {
301 Self {
302 main: ModelConfig::for_role(ModelRole::Main),
303 plan: ModelConfig::for_role(ModelRole::Plan),
304 compress: ModelConfig::for_role(ModelRole::Compress),
305 fast: ModelConfig::for_role(ModelRole::Fast),
306 }
307 }
308}
309
310impl MultiModelConfig {
311 pub fn with_main(main_model: String) -> Self {
314 let main_config = ModelConfig::new(main_model);
315 Self {
316 main: main_config.clone(),
317 plan: main_config.clone(),
318 compress: main_config.clone(),
319 fast: main_config,
320 }
321 }
322
323 pub fn unified(model: String) -> Self {
325 let config = ModelConfig::new(model);
326 Self {
327 main: config.clone(),
328 plan: config.clone(),
329 compress: config.clone(),
330 fast: config,
331 }
332 }
333
334 pub fn get(&self, role: ModelRole) -> &ModelConfig {
336 match role {
337 ModelRole::Main => &self.main,
338 ModelRole::Plan => &self.plan,
339 ModelRole::Compress => &self.compress,
340 ModelRole::Fast => &self.fast,
341 }
342 }
343
344 pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
346 match role {
347 ModelRole::Main => self.main = config,
348 ModelRole::Plan => self.plan = config,
349 ModelRole::Compress => self.compress = config,
350 ModelRole::Fast => self.fast = config,
351 }
352 }
353
354 pub fn format_summary(&self) -> String {
356 format!(
357 "main: {}, plan: {}, compress: {}, fast: {}",
358 self.main.display_name(),
359 self.plan.display_name(),
360 self.compress.display_name(),
361 self.fast.display_name()
362 )
363 }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
368#[serde(rename_all = "snake_case")]
369pub enum TaskComplexity {
370 Simple,
371 Moderate,
372 Complex,
373}
374
375impl TaskComplexity {
376 pub fn display(&self) -> &'static str {
377 match self {
378 TaskComplexity::Simple => "简单",
379 TaskComplexity::Moderate => "中等",
380 TaskComplexity::Complex => "复杂",
381 }
382 }
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
387#[serde(rename_all = "snake_case")]
388pub enum StepDifficulty {
389 Easy,
390 Medium,
391 Hard,
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct PlanStep {
397 pub description: String,
399 pub tools: Vec<String>,
401 pub optional: bool,
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct TaskPlan {
408 pub request: String,
410 pub steps: Vec<PlanStep>,
412 pub complexity: TaskComplexity,
414 pub approach: String,
416 pub considerations: Vec<String>,
418}
419
420impl TaskPlan {
421 pub fn format(&self) -> String {
423 let mut output = String::new();
424
425 output.push_str(&format!("任务分析: {}\n", self.request));
426 output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
427 output.push_str(&format!("建议方案: {}\n\n", self.approach));
428
429 output.push_str("执行步骤:\n");
430 for (i, step) in self.steps.iter().enumerate() {
431 let marker = if step.optional { "[可选]" } else { "" };
432 output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
433 if !step.tools.is_empty() {
434 output.push_str(&format!(" 工具: {}\n", step.tools.join(", ")));
435 }
436 }
437
438 if !self.considerations.is_empty() {
439 output.push_str("\n注意事项:\n");
440 for c in &self.considerations {
441 output.push_str(&format!("• {}\n", c));
442 }
443 }
444
445 output
446 }
447
448 pub fn to_todo_items(&self) -> Vec<TodoItem> {
450 self.steps
451 .iter()
452 .enumerate()
453 .map(|(i, step)| TodoItem {
454 content: step.description.clone(),
455 active_form: format!("执行步骤 {}: {}", i + 1, step.description),
456 status: if i == 0 {
457 "in_progress".to_string()
458 } else {
459 "pending".to_string()
460 },
461 })
462 .collect()
463 }
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct TodoItem {
469 pub content: String,
470 pub active_form: String,
471 pub status: String,
472}
473
474pub struct Planner {
476 provider: Box<dyn Provider>,
477 config: ModelConfig,
478}
479
480impl Planner {
481 pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
483 Self { provider, config }
484 }
485
486 pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
488 let prompt = build_plan_prompt(request, available_tools);
489
490 let chat_request = ChatRequest {
491 messages: vec![Message {
492 role: Role::User,
493 content: MessageContent::Text(prompt),
494 }],
495 tools: vec![],
496 system: Some(PLAN_SYSTEM_PROMPT.to_string()),
497 think: false,
498 max_tokens: self.config.max_tokens,
499 server_tools: vec![],
500 enable_caching: false,
501 };
502
503 let response = self.provider.chat(chat_request).await?;
504 let text = extract_text(&response);
505
506 parse_plan_response(request, &text)
507 }
508
509 pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
511 let prompt = format!(
512 "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
513 request
514 );
515
516 let chat_request = ChatRequest {
517 messages: vec![Message {
518 role: Role::User,
519 content: MessageContent::Text(prompt),
520 }],
521 tools: vec![],
522 system: None,
523 think: false,
524 max_tokens: 50,
525 server_tools: vec![],
526 enable_caching: false,
527 };
528
529 let response = self.provider.chat(chat_request).await?;
530 let text = extract_text(&response).to_lowercase();
531
532 if text.contains("简单") || text.contains("simple") {
533 Ok(TaskComplexity::Simple)
534 } else if text.contains("复杂") || text.contains("complex") {
535 Ok(TaskComplexity::Complex)
536 } else {
537 Ok(TaskComplexity::Moderate)
538 }
539 }
540}
541
542const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
544
545输出要求(JSON格式):
546```json
547{
548 "complexity": "simple|moderate|complex",
549 "approach": "建议的方案(一句话)",
550 "steps": [
551 {
552 "description": "步骤描述",
553 "tools": ["需要的工具"],
554 "optional": false
555 }
556 ],
557 "considerations": ["注意事项"]
558}
559```
560
561规划原则:
5621. 简单任务(如读取文件、简单查询)只需1-2步
5632. 中等任务(如修改代码、添加功能)需要3-5步
5643. 复杂任务(如重构、跨模块修改)需要详细规划
5654. 每个步骤要具体、可执行
5665. 标记可选步骤和潜在风险"#;
567
568fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
570 format!(
571 r#"用户请求:
572{}
573
574可用工具:
575{}
576
577请分析任务并生成执行计划(JSON格式)。"#,
578 request,
579 available_tools.join(", ")
580 )
581}
582
583fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
585 if let Some(json_start) = text.find('{')
587 && let Some(json_end) = text.rfind('}')
588 {
589 let json_str = &text[json_start..=json_end];
590 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
591 return Ok(TaskPlan {
592 request: request.to_string(),
593 steps: parse_steps(&parsed["steps"]),
594 complexity: parse_complexity(&parsed["complexity"]),
595 approach: parsed["approach"]
596 .as_str()
597 .unwrap_or("直接执行")
598 .to_string(),
599 considerations: parsed["considerations"]
600 .as_array()
601 .map(|arr| {
602 arr.iter()
603 .filter_map(|v| v.as_str().map(String::from))
604 .collect()
605 })
606 .unwrap_or_default(),
607 });
608 }
609 }
610
611 Ok(TaskPlan {
613 request: request.to_string(),
614 steps: parse_steps_from_text(text),
615 complexity: TaskComplexity::Moderate,
616 approach: "按步骤执行".to_string(),
617 considerations: vec!["请检查执行结果".to_string()],
618 })
619}
620
621fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
622 value
623 .as_array()
624 .map(|arr| {
625 arr.iter()
626 .filter_map(|v| {
627 Some(PlanStep {
628 description: v["description"].as_str()?.to_string(),
629 tools: v["tools"]
630 .as_array()
631 .map(|t| {
632 t.iter()
633 .filter_map(|x| x.as_str().map(String::from))
634 .collect()
635 })
636 .unwrap_or_default(),
637 optional: v["optional"].as_bool().unwrap_or(false),
638 })
639 })
640 .collect()
641 })
642 .unwrap_or_default()
643}
644
645fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
646 match value.as_str().map(|s| s.to_lowercase()) {
647 Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
648 Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
649 _ => TaskComplexity::Moderate,
650 }
651}
652
653fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
654 text.lines()
655 .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
656 .take(5)
657 .map(|l| PlanStep {
658 description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
659 tools: vec!["read".to_string()],
660 optional: false,
661 })
662 .collect()
663}
664
665fn extract_text(response: &ChatResponse) -> String {
666 response
667 .content
668 .iter()
669 .filter_map(|block| {
670 if let ContentBlock::Text { text } = block {
671 Some(text.clone())
672 } else {
673 None
674 }
675 })
676 .collect::<Vec<_>>()
677 .join("\n")
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_model_config_defaults() {
686 let main = ModelConfig::for_role(ModelRole::Main);
687 assert!(main.name.contains("claude"));
688 assert!(main.think);
689
690 let compress = ModelConfig::for_role(ModelRole::Compress);
691 assert!(compress.name.contains("haiku"));
692 assert!(!compress.think);
693 }
694
695 #[test]
696 fn test_infer_context_size() {
697 assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
699 assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
700 assert_eq!(infer_context_size("claude-opus-4"), Some(200_000));
701 assert_eq!(infer_context_size("claude-2"), Some(100_000));
702 assert_eq!(infer_context_size("claude-sonnet-4-1m"), Some(1_000_000));
704 assert_eq!(infer_context_size("claude-opus-4-7"), Some(1_000_000));
705
706 assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
708 assert_eq!(infer_context_size("gpt-4-turbo"), Some(128_000));
709 assert_eq!(infer_context_size("gpt-4"), Some(8_192));
710 assert_eq!(infer_context_size("gpt-4-32k"), Some(32_768));
711 assert_eq!(infer_context_size("gpt-3.5-turbo"), Some(4_096));
712 assert_eq!(infer_context_size("gpt-3.5-turbo-16k"), Some(16_384));
713 assert_eq!(infer_context_size("o1-preview"), Some(200_000));
714 assert_eq!(infer_context_size("o3-mini"), Some(200_000));
715
716 assert_eq!(infer_context_size("gemini-1.5-pro"), Some(2_000_000));
718 assert_eq!(infer_context_size("gemini-1.5-flash"), Some(1_000_000));
719 assert_eq!(infer_context_size("gemini-2.0-pro"), Some(2_000_000));
720 assert_eq!(infer_context_size("gemini-pro"), Some(32_000));
721
722 assert_eq!(infer_context_size("glm-5"), Some(1_000_000));
724 assert_eq!(infer_context_size("glm-4-long"), Some(1_000_000));
725 assert_eq!(infer_context_size("glm-4"), Some(128_000));
726 assert_eq!(infer_context_size("glm"), Some(128_000));
727
728 assert_eq!(infer_context_size("deepseek-v3"), Some(128_000));
730 assert_eq!(infer_context_size("deepseek-r1"), Some(128_000));
731 assert_eq!(infer_context_size("deepseek-chat"), Some(64_000));
732 assert_eq!(infer_context_size("deepseek"), Some(64_000));
733
734 assert_eq!(infer_context_size("qwen2.5-turbo"), Some(1_000_000));
736 assert_eq!(infer_context_size("qwen-long"), Some(1_000_000));
737 assert_eq!(infer_context_size("qwen2.5-72b"), Some(128_000));
738 assert_eq!(infer_context_size("qwen3-32b"), Some(128_000));
739 assert_eq!(infer_context_size("qwen-max"), Some(128_000));
740 assert_eq!(infer_context_size("qwen2-7b"), Some(32_000));
741 assert_eq!(infer_context_size("qwen"), Some(8_192));
742
743 assert_eq!(infer_context_size("kimi"), Some(2_000_000));
745 assert_eq!(infer_context_size("moonshot-v1-8k"), Some(2_000_000));
746
747 assert_eq!(infer_context_size("mistral-large"), Some(128_000));
749 assert_eq!(infer_context_size("mistral-medium"), Some(32_000));
750 assert_eq!(infer_context_size("mistral-7b"), Some(32_000));
751
752 assert_eq!(infer_context_size("llama-3-70b"), Some(128_000));
754 assert_eq!(infer_context_size("llama-3-8b"), Some(8_192));
755 assert_eq!(infer_context_size("llama3-405b"), Some(128_000));
756 }
757
758 #[test]
759 fn test_multi_model_config() {
760 let config = MultiModelConfig::default();
761 assert!(config.main.name.contains("sonnet"));
762 assert!(config.compress.name.contains("haiku"));
763 }
764
765 #[test]
766 fn test_multi_model_config_with_main() {
767 let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
769
770 assert_eq!(config.main.name, "claude-sonnet-4");
772 assert_eq!(config.plan.name, "claude-sonnet-4");
773 assert_eq!(config.compress.name, "claude-sonnet-4");
774 assert_eq!(config.fast.name, "claude-sonnet-4");
775
776 assert!(config.main.think);
778 assert!(config.plan.think);
779 assert!(config.compress.think);
780 assert!(config.fast.think);
781 }
782
783 #[test]
784 fn test_multi_model_config_override() {
785 let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
786
787 config.set(
789 ModelRole::Compress,
790 ModelConfig::new("claude-3-5-haiku".to_string()),
791 );
792
793 assert_eq!(config.main.name, "claude-sonnet-4");
794 assert_eq!(config.plan.name, "claude-sonnet-4");
795 assert_eq!(config.compress.name, "claude-3-5-haiku");
796 assert_eq!(config.fast.name, "claude-sonnet-4"); }
798
799 #[test]
800 fn test_task_plan_format() {
801 let plan = TaskPlan {
802 request: "测试任务".to_string(),
803 steps: vec![PlanStep {
804 description: "读取文件".to_string(),
805 tools: vec!["read".to_string()],
806 optional: false,
807 }],
808 complexity: TaskComplexity::Simple,
809 approach: "直接执行".to_string(),
810 considerations: vec!["注意检查".to_string()],
811 };
812
813 let formatted = plan.format();
814 assert!(formatted.contains("测试任务"));
815 assert!(formatted.contains("简单"));
816 assert!(formatted.contains("读取文件"));
817 }
818
819 #[test]
820 fn test_complexity_display() {
821 assert_eq!(TaskComplexity::Simple.display(), "简单");
822 assert_eq!(TaskComplexity::Moderate.display(), "中等");
823 assert_eq!(TaskComplexity::Complex.display(), "复杂");
824 }
825
826 #[test]
827 fn test_task_plan_to_todo() {
828 let plan = TaskPlan {
829 request: "任务".to_string(),
830 steps: vec![
831 PlanStep {
832 description: "步骤1".to_string(),
833 tools: vec![],
834 optional: false,
835 },
836 PlanStep {
837 description: "步骤2".to_string(),
838 tools: vec![],
839 optional: false,
840 },
841 ],
842 complexity: TaskComplexity::Simple,
843 approach: "执行".to_string(),
844 considerations: vec![],
845 };
846
847 let todos = plan.to_todo_items();
848 assert_eq!(todos.len(), 2);
849 assert_eq!(todos[0].status, "in_progress");
850 assert_eq!(todos[1].status, "pending");
851 }
852
853 #[test]
854 fn test_parse_plan_response_json() {
855 let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
856 let plan = parse_plan_response("test", json).unwrap();
857
858 assert_eq!(plan.complexity, TaskComplexity::Simple);
859 assert_eq!(plan.steps.len(), 1);
860 assert_eq!(plan.steps[0].description, "read file");
861 }
862}