1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::constants::{COMPRESS_MAX_TOKENS, DEFAULT_MAX_TOKENS, FAST_MAX_TOKENS};
5use crate::providers::{
6 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
7};
8
9pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
12pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
13pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ModelRole {
19 Main,
21 Plan,
23 Compress,
25 Fast,
27}
28
29impl ModelRole {
30 pub fn default_model(&self) -> &'static str {
31 match self {
32 ModelRole::Main => DEFAULT_MAIN_MODEL,
33 ModelRole::Plan => DEFAULT_PLAN_MODEL,
34 ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
35 ModelRole::Fast => DEFAULT_FAST_MODEL,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43 pub name: String,
45 pub max_tokens: u32,
47 pub think: bool,
49 pub context_size: Option<u32>,
51}
52
53impl ModelConfig {
54 pub fn new(name: String) -> Self {
55 Self {
56 name: name.clone(),
57 max_tokens: DEFAULT_MAX_TOKENS,
58 think: true,
59 context_size: infer_context_size(&name),
60 }
61 }
62
63 pub fn for_role(role: ModelRole) -> Self {
65 let name = role.default_model().to_string();
66 match role {
67 ModelRole::Main => Self::new(name),
68 ModelRole::Plan => Self::new(name),
69 ModelRole::Compress => Self {
70 name,
71 max_tokens: COMPRESS_MAX_TOKENS,
72 think: false,
73 context_size: Some(200_000),
74 },
75 ModelRole::Fast => Self {
76 name,
77 max_tokens: FAST_MAX_TOKENS,
78 think: false,
79 context_size: Some(200_000),
80 },
81 }
82 }
83
84 pub fn display_name(&self) -> &str {
85 &self.name
86 }
87}
88
89pub fn context_window_for(model: &str) -> Option<u32> {
92 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
94 && let Ok(n) = raw.trim().parse::<u32>()
95 && n > 0
96 {
97 return Some(n);
98 }
99
100 let m = model.to_ascii_lowercase();
101
102 if m.contains("1m") || m.contains("opus-4-7") || m.contains("opus-4.7") {
104 return Some(1_000_000);
105 }
106 if m.contains("claude-3")
107 || m.contains("claude-4")
108 || m.contains("claude-opus")
109 || m.contains("claude-sonnet")
110 || m.contains("claude-haiku")
111 {
112 return Some(200_000);
113 }
114 if m.contains("claude-2") || m.contains("claude-instant") {
115 return Some(100_000);
116 }
117
118 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
120 return Some(128_000);
121 }
122 if m.contains("o1") || m.contains("o3") || m.contains("o4") {
123 return Some(200_000);
124 }
125 if m.contains("gpt-4-32k") {
126 return Some(32_768);
127 }
128 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
129 return Some(8_192);
130 }
131 if m.contains("gpt-3.5-turbo-16k") {
132 return Some(16_384);
133 }
134 if m.contains("gpt-3.5") {
135 return Some(4_096);
136 }
137
138 if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
140 return Some(128_000);
141 }
142 if m.contains("deepseek") {
143 return Some(64_000);
144 }
145
146 if m.contains("kimi") {
148 return Some(128_000);
149 }
150
151 if m.contains("qwen") {
153 if m.contains("qwen-max") || m.contains("qwen2.5-72b") || m.contains("qwen2.5") {
154 return Some(128_000);
155 }
156 if m.contains("qwen2") {
157 return Some(32_000);
158 }
159 return Some(8_192);
160 }
161
162 if m.contains("llama-3") || m.contains("llama3") {
164 if m.contains("70b") || m.contains("405b") {
165 return Some(128_000);
166 }
167 return Some(8_192);
168 }
169
170 if m.contains("glm") {
172 return Some(128_000);
173 }
174
175 None
176}
177
178fn infer_context_size(model: &str) -> Option<u32> {
180 context_window_for(model)
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct MultiModelConfig {
186 pub main: ModelConfig,
188 pub plan: ModelConfig,
190 pub compress: ModelConfig,
192 pub fast: ModelConfig,
194}
195
196impl Default for MultiModelConfig {
197 fn default() -> Self {
198 Self {
199 main: ModelConfig::for_role(ModelRole::Main),
200 plan: ModelConfig::for_role(ModelRole::Plan),
201 compress: ModelConfig::for_role(ModelRole::Compress),
202 fast: ModelConfig::for_role(ModelRole::Fast),
203 }
204 }
205}
206
207impl MultiModelConfig {
208 pub fn with_main(main_model: String) -> Self {
211 let main_config = ModelConfig::new(main_model);
212 Self {
213 main: main_config.clone(),
214 plan: main_config.clone(),
215 compress: main_config.clone(),
216 fast: main_config,
217 }
218 }
219
220 pub fn unified(model: String) -> Self {
222 let config = ModelConfig::new(model);
223 Self {
224 main: config.clone(),
225 plan: config.clone(),
226 compress: config.clone(),
227 fast: config,
228 }
229 }
230
231 pub fn get(&self, role: ModelRole) -> &ModelConfig {
233 match role {
234 ModelRole::Main => &self.main,
235 ModelRole::Plan => &self.plan,
236 ModelRole::Compress => &self.compress,
237 ModelRole::Fast => &self.fast,
238 }
239 }
240
241 pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
243 match role {
244 ModelRole::Main => self.main = config,
245 ModelRole::Plan => self.plan = config,
246 ModelRole::Compress => self.compress = config,
247 ModelRole::Fast => self.fast = config,
248 }
249 }
250
251 pub fn format_summary(&self) -> String {
253 format!(
254 "main: {}, plan: {}, compress: {}, fast: {}",
255 self.main.display_name(),
256 self.plan.display_name(),
257 self.compress.display_name(),
258 self.fast.display_name()
259 )
260 }
261}
262
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
265#[serde(rename_all = "snake_case")]
266pub enum TaskComplexity {
267 Simple,
268 Moderate,
269 Complex,
270}
271
272impl TaskComplexity {
273 pub fn display(&self) -> &'static str {
274 match self {
275 TaskComplexity::Simple => "简单",
276 TaskComplexity::Moderate => "中等",
277 TaskComplexity::Complex => "复杂",
278 }
279 }
280}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
284#[serde(rename_all = "snake_case")]
285pub enum StepDifficulty {
286 Easy,
287 Medium,
288 Hard,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct PlanStep {
294 pub description: String,
296 pub tools: Vec<String>,
298 pub optional: bool,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct TaskPlan {
305 pub request: String,
307 pub steps: Vec<PlanStep>,
309 pub complexity: TaskComplexity,
311 pub approach: String,
313 pub considerations: Vec<String>,
315}
316
317impl TaskPlan {
318 pub fn format(&self) -> String {
320 let mut output = String::new();
321
322 output.push_str(&format!("任务分析: {}\n", self.request));
323 output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
324 output.push_str(&format!("建议方案: {}\n\n", self.approach));
325
326 output.push_str("执行步骤:\n");
327 for (i, step) in self.steps.iter().enumerate() {
328 let marker = if step.optional { "[可选]" } else { "" };
329 output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
330 if !step.tools.is_empty() {
331 output.push_str(&format!(" 工具: {}\n", step.tools.join(", ")));
332 }
333 }
334
335 if !self.considerations.is_empty() {
336 output.push_str("\n注意事项:\n");
337 for c in &self.considerations {
338 output.push_str(&format!("• {}\n", c));
339 }
340 }
341
342 output
343 }
344
345 pub fn to_todo_items(&self) -> Vec<TodoItem> {
347 self.steps
348 .iter()
349 .enumerate()
350 .map(|(i, step)| TodoItem {
351 content: step.description.clone(),
352 active_form: format!("执行步骤 {}: {}", i + 1, step.description),
353 status: if i == 0 {
354 "in_progress".to_string()
355 } else {
356 "pending".to_string()
357 },
358 })
359 .collect()
360 }
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct TodoItem {
366 pub content: String,
367 pub active_form: String,
368 pub status: String,
369}
370
371pub struct Planner {
373 provider: Box<dyn Provider>,
374 config: ModelConfig,
375}
376
377impl Planner {
378 pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
380 Self { provider, config }
381 }
382
383 pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
385 let prompt = build_plan_prompt(request, available_tools);
386
387 let chat_request = ChatRequest {
388 messages: vec![Message {
389 role: Role::User,
390 content: MessageContent::Text(prompt),
391 }],
392 tools: vec![],
393 system: Some(PLAN_SYSTEM_PROMPT.to_string()),
394 think: false,
395 max_tokens: self.config.max_tokens,
396 server_tools: vec![],
397 enable_caching: false,
398 };
399
400 let response = self.provider.chat(chat_request).await?;
401 let text = extract_text(&response);
402
403 parse_plan_response(request, &text)
404 }
405
406 pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
408 let prompt = format!(
409 "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
410 request
411 );
412
413 let chat_request = ChatRequest {
414 messages: vec![Message {
415 role: Role::User,
416 content: MessageContent::Text(prompt),
417 }],
418 tools: vec![],
419 system: None,
420 think: false,
421 max_tokens: 50,
422 server_tools: vec![],
423 enable_caching: false,
424 };
425
426 let response = self.provider.chat(chat_request).await?;
427 let text = extract_text(&response).to_lowercase();
428
429 if text.contains("简单") || text.contains("simple") {
430 Ok(TaskComplexity::Simple)
431 } else if text.contains("复杂") || text.contains("complex") {
432 Ok(TaskComplexity::Complex)
433 } else {
434 Ok(TaskComplexity::Moderate)
435 }
436 }
437}
438
439const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
441
442输出要求(JSON格式):
443```json
444{
445 "complexity": "simple|moderate|complex",
446 "approach": "建议的方案(一句话)",
447 "steps": [
448 {
449 "description": "步骤描述",
450 "tools": ["需要的工具"],
451 "optional": false
452 }
453 ],
454 "considerations": ["注意事项"]
455}
456```
457
458规划原则:
4591. 简单任务(如读取文件、简单查询)只需1-2步
4602. 中等任务(如修改代码、添加功能)需要3-5步
4613. 复杂任务(如重构、跨模块修改)需要详细规划
4624. 每个步骤要具体、可执行
4635. 标记可选步骤和潜在风险"#;
464
465fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
467 format!(
468 r#"用户请求:
469{}
470
471可用工具:
472{}
473
474请分析任务并生成执行计划(JSON格式)。"#,
475 request,
476 available_tools.join(", ")
477 )
478}
479
480fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
482 if let Some(json_start) = text.find('{')
484 && let Some(json_end) = text.rfind('}')
485 {
486 let json_str = &text[json_start..=json_end];
487 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
488 return Ok(TaskPlan {
489 request: request.to_string(),
490 steps: parse_steps(&parsed["steps"]),
491 complexity: parse_complexity(&parsed["complexity"]),
492 approach: parsed["approach"]
493 .as_str()
494 .unwrap_or("直接执行")
495 .to_string(),
496 considerations: parsed["considerations"]
497 .as_array()
498 .map(|arr| {
499 arr.iter()
500 .filter_map(|v| v.as_str().map(String::from))
501 .collect()
502 })
503 .unwrap_or_default(),
504 });
505 }
506 }
507
508 Ok(TaskPlan {
510 request: request.to_string(),
511 steps: parse_steps_from_text(text),
512 complexity: TaskComplexity::Moderate,
513 approach: "按步骤执行".to_string(),
514 considerations: vec!["请检查执行结果".to_string()],
515 })
516}
517
518fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
519 value
520 .as_array()
521 .map(|arr| {
522 arr.iter()
523 .filter_map(|v| {
524 Some(PlanStep {
525 description: v["description"].as_str()?.to_string(),
526 tools: v["tools"]
527 .as_array()
528 .map(|t| {
529 t.iter()
530 .filter_map(|x| x.as_str().map(String::from))
531 .collect()
532 })
533 .unwrap_or_default(),
534 optional: v["optional"].as_bool().unwrap_or(false),
535 })
536 })
537 .collect()
538 })
539 .unwrap_or_default()
540}
541
542fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
543 match value.as_str().map(|s| s.to_lowercase()) {
544 Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
545 Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
546 _ => TaskComplexity::Moderate,
547 }
548}
549
550fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
551 text.lines()
552 .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
553 .take(5)
554 .map(|l| PlanStep {
555 description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
556 tools: vec!["read".to_string()],
557 optional: false,
558 })
559 .collect()
560}
561
562fn extract_text(response: &ChatResponse) -> String {
563 response
564 .content
565 .iter()
566 .filter_map(|block| {
567 if let ContentBlock::Text { text } = block {
568 Some(text.clone())
569 } else {
570 None
571 }
572 })
573 .collect::<Vec<_>>()
574 .join("\n")
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_model_config_defaults() {
583 let main = ModelConfig::for_role(ModelRole::Main);
584 assert!(main.name.contains("claude"));
585 assert!(main.think);
586
587 let compress = ModelConfig::for_role(ModelRole::Compress);
588 assert!(compress.name.contains("haiku"));
589 assert!(!compress.think);
590 }
591
592 #[test]
593 fn test_infer_context_size() {
594 assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
595 assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
596 assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
597 assert_eq!(infer_context_size("claude-sonnet-4-1m"), Some(1_000_000));
599 assert_eq!(infer_context_size("claude-opus-4-7"), Some(1_000_000));
600 }
601
602 #[test]
603 fn test_multi_model_config() {
604 let config = MultiModelConfig::default();
605 assert!(config.main.name.contains("sonnet"));
606 assert!(config.compress.name.contains("haiku"));
607 }
608
609 #[test]
610 fn test_multi_model_config_with_main() {
611 let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
613
614 assert_eq!(config.main.name, "claude-sonnet-4");
616 assert_eq!(config.plan.name, "claude-sonnet-4");
617 assert_eq!(config.compress.name, "claude-sonnet-4");
618 assert_eq!(config.fast.name, "claude-sonnet-4");
619
620 assert!(config.main.think);
622 assert!(config.plan.think);
623 assert!(config.compress.think);
624 assert!(config.fast.think);
625 }
626
627 #[test]
628 fn test_multi_model_config_override() {
629 let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
630
631 config.set(
633 ModelRole::Compress,
634 ModelConfig::new("claude-3-5-haiku".to_string()),
635 );
636
637 assert_eq!(config.main.name, "claude-sonnet-4");
638 assert_eq!(config.plan.name, "claude-sonnet-4");
639 assert_eq!(config.compress.name, "claude-3-5-haiku");
640 assert_eq!(config.fast.name, "claude-sonnet-4"); }
642
643 #[test]
644 fn test_task_plan_format() {
645 let plan = TaskPlan {
646 request: "测试任务".to_string(),
647 steps: vec![PlanStep {
648 description: "读取文件".to_string(),
649 tools: vec!["read".to_string()],
650 optional: false,
651 }],
652 complexity: TaskComplexity::Simple,
653 approach: "直接执行".to_string(),
654 considerations: vec!["注意检查".to_string()],
655 };
656
657 let formatted = plan.format();
658 assert!(formatted.contains("测试任务"));
659 assert!(formatted.contains("简单"));
660 assert!(formatted.contains("读取文件"));
661 }
662
663 #[test]
664 fn test_complexity_display() {
665 assert_eq!(TaskComplexity::Simple.display(), "简单");
666 assert_eq!(TaskComplexity::Moderate.display(), "中等");
667 assert_eq!(TaskComplexity::Complex.display(), "复杂");
668 }
669
670 #[test]
671 fn test_task_plan_to_todo() {
672 let plan = TaskPlan {
673 request: "任务".to_string(),
674 steps: vec![
675 PlanStep {
676 description: "步骤1".to_string(),
677 tools: vec![],
678 optional: false,
679 },
680 PlanStep {
681 description: "步骤2".to_string(),
682 tools: vec![],
683 optional: false,
684 },
685 ],
686 complexity: TaskComplexity::Simple,
687 approach: "执行".to_string(),
688 considerations: vec![],
689 };
690
691 let todos = plan.to_todo_items();
692 assert_eq!(todos.len(), 2);
693 assert_eq!(todos[0].status, "in_progress");
694 assert_eq!(todos[1].status, "pending");
695 }
696
697 #[test]
698 fn test_parse_plan_response_json() {
699 let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
700 let plan = parse_plan_response("test", json).unwrap();
701
702 assert_eq!(plan.complexity, TaskComplexity::Simple);
703 assert_eq!(plan.steps.len(), 1);
704 assert_eq!(plan.steps[0].description, "read file");
705 }
706}