1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::providers::{
5 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
6};
7
8pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
10pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
12pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum ModelRole {
18 Main,
20 Plan,
22 Compress,
24 Fast,
26}
27
28impl ModelRole {
29 pub fn default_model(&self) -> &'static str {
30 match self {
31 ModelRole::Main => DEFAULT_MAIN_MODEL,
32 ModelRole::Plan => DEFAULT_PLAN_MODEL,
33 ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
34 ModelRole::Fast => DEFAULT_FAST_MODEL,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ModelConfig {
42 pub name: String,
44 pub max_tokens: u32,
46 pub think: bool,
48 pub context_size: Option<u32>,
50}
51
52impl ModelConfig {
53 pub fn new(name: String) -> Self {
54 Self {
55 name: name.clone(),
56 max_tokens: 16384,
57 think: true,
58 context_size: infer_context_size(&name),
59 }
60 }
61
62 pub fn for_role(role: ModelRole) -> Self {
64 let name = role.default_model().to_string();
65 match role {
66 ModelRole::Main => Self::new(name),
67 ModelRole::Plan => Self::new(name),
68 ModelRole::Compress => Self {
69 name,
70 max_tokens: 1024,
71 think: false,
72 context_size: Some(200_000),
73 },
74 ModelRole::Fast => Self {
75 name,
76 max_tokens: 2048,
77 think: false,
78 context_size: Some(200_000),
79 },
80 }
81 }
82
83 pub fn display_name(&self) -> &str {
84 &self.name
85 }
86}
87
88pub fn context_window_for(model: &str) -> Option<u32> {
91 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
93 && let Ok(n) = raw.trim().parse::<u32>()
94 && n > 0
95 {
96 return Some(n);
97 }
98
99 let m = model.to_ascii_lowercase();
100
101 if m.contains("[1m]") || m.contains("opus-4-7") || m.contains("opus-4.7") {
103 return Some(1_000_000);
104 }
105 if m.contains("claude-3")
106 || m.contains("claude-4")
107 || m.contains("claude-opus")
108 || m.contains("claude-sonnet")
109 || m.contains("claude-haiku")
110 {
111 return Some(200_000);
112 }
113 if m.contains("claude-2") || m.contains("claude-instant") {
114 return Some(100_000);
115 }
116
117 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
119 return Some(128_000);
120 }
121 if m.contains("o1") || m.contains("o3") || m.contains("o4") {
122 return Some(200_000);
123 }
124 if m.contains("gpt-4-32k") {
125 return Some(32_768);
126 }
127 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
128 return Some(8_192);
129 }
130 if m.contains("gpt-3.5-turbo-16k") {
131 return Some(16_384);
132 }
133 if m.contains("gpt-3.5") {
134 return Some(4_096);
135 }
136
137 if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
139 return Some(128_000);
140 }
141 if m.contains("deepseek") {
142 return Some(64_000);
143 }
144
145 if m.contains("kimi") {
147 return Some(128_000);
148 }
149
150 if m.contains("qwen") {
152 if m.contains("qwen-max") || m.contains("qwen2.5-72b") || m.contains("qwen2.5") {
153 return Some(128_000);
154 }
155 if m.contains("qwen2") {
156 return Some(32_000);
157 }
158 return Some(8_192);
159 }
160
161 if m.contains("llama-3") || m.contains("llama3") {
163 if m.contains("70b") || m.contains("405b") {
164 return Some(128_000);
165 }
166 return Some(8_192);
167 }
168
169 if m.contains("glm") {
171 return Some(128_000);
172 }
173
174 None
175}
176
177fn infer_context_size(model: &str) -> Option<u32> {
179 context_window_for(model)
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct MultiModelConfig {
185 pub main: ModelConfig,
187 pub plan: ModelConfig,
189 pub compress: ModelConfig,
191 pub fast: ModelConfig,
193}
194
195impl Default for MultiModelConfig {
196 fn default() -> Self {
197 Self {
198 main: ModelConfig::for_role(ModelRole::Main),
199 plan: ModelConfig::for_role(ModelRole::Plan),
200 compress: ModelConfig::for_role(ModelRole::Compress),
201 fast: ModelConfig::for_role(ModelRole::Fast),
202 }
203 }
204}
205
206impl MultiModelConfig {
207 pub fn with_main(main_model: String) -> Self {
210 let main_config = ModelConfig::new(main_model);
211 Self {
212 main: main_config.clone(),
213 plan: main_config.clone(),
214 compress: main_config.clone(),
215 fast: main_config,
216 }
217 }
218
219 pub fn unified(model: String) -> Self {
221 let config = ModelConfig::new(model);
222 Self {
223 main: config.clone(),
224 plan: config.clone(),
225 compress: config.clone(),
226 fast: config,
227 }
228 }
229
230 pub fn get(&self, role: ModelRole) -> &ModelConfig {
232 match role {
233 ModelRole::Main => &self.main,
234 ModelRole::Plan => &self.plan,
235 ModelRole::Compress => &self.compress,
236 ModelRole::Fast => &self.fast,
237 }
238 }
239
240 pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
242 match role {
243 ModelRole::Main => self.main = config,
244 ModelRole::Plan => self.plan = config,
245 ModelRole::Compress => self.compress = config,
246 ModelRole::Fast => self.fast = config,
247 }
248 }
249
250 pub fn format_summary(&self) -> String {
252 format!(
253 "main: {}, plan: {}, compress: {}, fast: {}",
254 self.main.display_name(),
255 self.plan.display_name(),
256 self.compress.display_name(),
257 self.fast.display_name()
258 )
259 }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
264#[serde(rename_all = "snake_case")]
265pub enum TaskComplexity {
266 Simple,
267 Moderate,
268 Complex,
269}
270
271impl TaskComplexity {
272 pub fn display(&self) -> &'static str {
273 match self {
274 TaskComplexity::Simple => "简单",
275 TaskComplexity::Moderate => "中等",
276 TaskComplexity::Complex => "复杂",
277 }
278 }
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
283#[serde(rename_all = "snake_case")]
284pub enum StepDifficulty {
285 Easy,
286 Medium,
287 Hard,
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct PlanStep {
293 pub description: String,
295 pub tools: Vec<String>,
297 pub optional: bool,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct TaskPlan {
304 pub request: String,
306 pub steps: Vec<PlanStep>,
308 pub complexity: TaskComplexity,
310 pub approach: String,
312 pub considerations: Vec<String>,
314}
315
316impl TaskPlan {
317 pub fn format(&self) -> String {
319 let mut output = String::new();
320
321 output.push_str(&format!("任务分析: {}\n", self.request));
322 output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
323 output.push_str(&format!("建议方案: {}\n\n", self.approach));
324
325 output.push_str("执行步骤:\n");
326 for (i, step) in self.steps.iter().enumerate() {
327 let marker = if step.optional { "[可选]" } else { "" };
328 output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
329 if !step.tools.is_empty() {
330 output.push_str(&format!(" 工具: {}\n", step.tools.join(", ")));
331 }
332 }
333
334 if !self.considerations.is_empty() {
335 output.push_str("\n注意事项:\n");
336 for c in &self.considerations {
337 output.push_str(&format!("• {}\n", c));
338 }
339 }
340
341 output
342 }
343
344 pub fn to_todo_items(&self) -> Vec<TodoItem> {
346 self.steps
347 .iter()
348 .enumerate()
349 .map(|(i, step)| TodoItem {
350 content: step.description.clone(),
351 active_form: format!("执行步骤 {}: {}", i + 1, step.description),
352 status: if i == 0 {
353 "in_progress".to_string()
354 } else {
355 "pending".to_string()
356 },
357 })
358 .collect()
359 }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct TodoItem {
365 pub content: String,
366 pub active_form: String,
367 pub status: String,
368}
369
370pub struct Planner {
372 provider: Box<dyn Provider>,
373 config: ModelConfig,
374}
375
376impl Planner {
377 pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
379 Self { provider, config }
380 }
381
382 pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
384 let prompt = build_plan_prompt(request, available_tools);
385
386 let chat_request = ChatRequest {
387 messages: vec![Message {
388 role: Role::User,
389 content: MessageContent::Text(prompt),
390 }],
391 tools: vec![],
392 system: Some(PLAN_SYSTEM_PROMPT.to_string()),
393 think: false,
394 max_tokens: self.config.max_tokens,
395 server_tools: vec![],
396 enable_caching: false,
397 };
398
399 let response = self.provider.chat(chat_request).await?;
400 let text = extract_text(&response);
401
402 parse_plan_response(request, &text)
403 }
404
405 pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
407 let prompt = format!(
408 "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
409 request
410 );
411
412 let chat_request = ChatRequest {
413 messages: vec![Message {
414 role: Role::User,
415 content: MessageContent::Text(prompt),
416 }],
417 tools: vec![],
418 system: None,
419 think: false,
420 max_tokens: 50,
421 server_tools: vec![],
422 enable_caching: false,
423 };
424
425 let response = self.provider.chat(chat_request).await?;
426 let text = extract_text(&response).to_lowercase();
427
428 if text.contains("简单") || text.contains("simple") {
429 Ok(TaskComplexity::Simple)
430 } else if text.contains("复杂") || text.contains("complex") {
431 Ok(TaskComplexity::Complex)
432 } else {
433 Ok(TaskComplexity::Moderate)
434 }
435 }
436}
437
438const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
440
441输出要求(JSON格式):
442```json
443{
444 "complexity": "simple|moderate|complex",
445 "approach": "建议的方案(一句话)",
446 "steps": [
447 {
448 "description": "步骤描述",
449 "tools": ["需要的工具"],
450 "optional": false
451 }
452 ],
453 "considerations": ["注意事项"]
454}
455```
456
457规划原则:
4581. 简单任务(如读取文件、简单查询)只需1-2步
4592. 中等任务(如修改代码、添加功能)需要3-5步
4603. 复杂任务(如重构、跨模块修改)需要详细规划
4614. 每个步骤要具体、可执行
4625. 标记可选步骤和潜在风险"#;
463
464fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
466 format!(
467 r#"用户请求:
468{}
469
470可用工具:
471{}
472
473请分析任务并生成执行计划(JSON格式)。"#,
474 request,
475 available_tools.join(", ")
476 )
477}
478
479fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
481 if let Some(json_start) = text.find('{')
483 && let Some(json_end) = text.rfind('}')
484 {
485 let json_str = &text[json_start..=json_end];
486 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
487 return Ok(TaskPlan {
488 request: request.to_string(),
489 steps: parse_steps(&parsed["steps"]),
490 complexity: parse_complexity(&parsed["complexity"]),
491 approach: parsed["approach"]
492 .as_str()
493 .unwrap_or("直接执行")
494 .to_string(),
495 considerations: parsed["considerations"]
496 .as_array()
497 .map(|arr| {
498 arr.iter()
499 .filter_map(|v| v.as_str().map(String::from))
500 .collect()
501 })
502 .unwrap_or_default(),
503 });
504 }
505 }
506
507 Ok(TaskPlan {
509 request: request.to_string(),
510 steps: parse_steps_from_text(text),
511 complexity: TaskComplexity::Moderate,
512 approach: "按步骤执行".to_string(),
513 considerations: vec!["请检查执行结果".to_string()],
514 })
515}
516
517fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
518 value
519 .as_array()
520 .map(|arr| {
521 arr.iter()
522 .filter_map(|v| {
523 Some(PlanStep {
524 description: v["description"].as_str()?.to_string(),
525 tools: v["tools"]
526 .as_array()
527 .map(|t| {
528 t.iter()
529 .filter_map(|x| x.as_str().map(String::from))
530 .collect()
531 })
532 .unwrap_or_default(),
533 optional: v["optional"].as_bool().unwrap_or(false),
534 })
535 })
536 .collect()
537 })
538 .unwrap_or_default()
539}
540
541fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
542 match value.as_str().map(|s| s.to_lowercase()) {
543 Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
544 Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
545 _ => TaskComplexity::Moderate,
546 }
547}
548
549fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
550 text.lines()
551 .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
552 .take(5)
553 .map(|l| PlanStep {
554 description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
555 tools: vec!["read".to_string()],
556 optional: false,
557 })
558 .collect()
559}
560
561fn extract_text(response: &ChatResponse) -> String {
562 response
563 .content
564 .iter()
565 .filter_map(|block| {
566 if let ContentBlock::Text { text } = block {
567 Some(text.clone())
568 } else {
569 None
570 }
571 })
572 .collect::<Vec<_>>()
573 .join("\n")
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_model_config_defaults() {
582 let main = ModelConfig::for_role(ModelRole::Main);
583 assert!(main.name.contains("claude"));
584 assert!(main.think);
585
586 let compress = ModelConfig::for_role(ModelRole::Compress);
587 assert!(compress.name.contains("haiku"));
588 assert!(!compress.think);
589 }
590
591 #[test]
592 fn test_infer_context_size() {
593 assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
594 assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
595 assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
596 }
597
598 #[test]
599 fn test_multi_model_config() {
600 let config = MultiModelConfig::default();
601 assert!(config.main.name.contains("sonnet"));
602 assert!(config.compress.name.contains("haiku"));
603 }
604
605 #[test]
606 fn test_multi_model_config_with_main() {
607 let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
609
610 assert_eq!(config.main.name, "claude-sonnet-4");
612 assert_eq!(config.plan.name, "claude-sonnet-4");
613 assert_eq!(config.compress.name, "claude-sonnet-4");
614 assert_eq!(config.fast.name, "claude-sonnet-4");
615
616 assert!(config.main.think);
618 assert!(config.plan.think);
619 assert!(config.compress.think);
620 assert!(config.fast.think);
621 }
622
623 #[test]
624 fn test_multi_model_config_override() {
625 let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
626
627 config.set(
629 ModelRole::Compress,
630 ModelConfig::new("claude-3-5-haiku".to_string()),
631 );
632
633 assert_eq!(config.main.name, "claude-sonnet-4");
634 assert_eq!(config.plan.name, "claude-sonnet-4");
635 assert_eq!(config.compress.name, "claude-3-5-haiku");
636 assert_eq!(config.fast.name, "claude-sonnet-4"); }
638
639 #[test]
640 fn test_task_plan_format() {
641 let plan = TaskPlan {
642 request: "测试任务".to_string(),
643 steps: vec![PlanStep {
644 description: "读取文件".to_string(),
645 tools: vec!["read".to_string()],
646 optional: false,
647 }],
648 complexity: TaskComplexity::Simple,
649 approach: "直接执行".to_string(),
650 considerations: vec!["注意检查".to_string()],
651 };
652
653 let formatted = plan.format();
654 assert!(formatted.contains("测试任务"));
655 assert!(formatted.contains("简单"));
656 assert!(formatted.contains("读取文件"));
657 }
658
659 #[test]
660 fn test_complexity_display() {
661 assert_eq!(TaskComplexity::Simple.display(), "简单");
662 assert_eq!(TaskComplexity::Moderate.display(), "中等");
663 assert_eq!(TaskComplexity::Complex.display(), "复杂");
664 }
665
666 #[test]
667 fn test_task_plan_to_todo() {
668 let plan = TaskPlan {
669 request: "任务".to_string(),
670 steps: vec![
671 PlanStep {
672 description: "步骤1".to_string(),
673 tools: vec![],
674 optional: false,
675 },
676 PlanStep {
677 description: "步骤2".to_string(),
678 tools: vec![],
679 optional: false,
680 },
681 ],
682 complexity: TaskComplexity::Simple,
683 approach: "执行".to_string(),
684 considerations: vec![],
685 };
686
687 let todos = plan.to_todo_items();
688 assert_eq!(todos.len(), 2);
689 assert_eq!(todos[0].status, "in_progress");
690 assert_eq!(todos[1].status, "pending");
691 }
692
693 #[test]
694 fn test_parse_plan_response_json() {
695 let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
696 let plan = parse_plan_response("test", json).unwrap();
697
698 assert_eq!(plan.complexity, TaskComplexity::Simple);
699 assert_eq!(plan.steps.len(), 1);
700 assert_eq!(plan.steps[0].description, "read file");
701 }
702}