Skip to main content

matrixcode_core/
models.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::constants::{COMPRESS_MAX_TOKENS, DEFAULT_MAX_TOKENS, FAST_MAX_TOKENS};
5use crate::providers::{
6    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
7};
8
9/// Default model names for different roles.
10pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
12pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
13pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
14
15/// Model role - what purpose this model serves.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ModelRole {
19    /// Main model for task execution.
20    Main,
21    /// Planning model for task decomposition.
22    Plan,
23    /// Compression model for context summarization.
24    Compress,
25    /// Fast model for quick operations.
26    Fast,
27}
28
29impl ModelRole {
30    pub fn default_model(&self) -> &'static str {
31        match self {
32            ModelRole::Main => DEFAULT_MAIN_MODEL,
33            ModelRole::Plan => DEFAULT_PLAN_MODEL,
34            ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
35            ModelRole::Fast => DEFAULT_FAST_MODEL,
36        }
37    }
38}
39
40/// Configuration for a single model.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43    /// Model name/identifier.
44    pub name: String,
45    /// Maximum output tokens.
46    pub max_tokens: u32,
47    /// Whether to enable thinking for this model.
48    pub think: bool,
49    /// Estimated context window size.
50    pub context_size: Option<u32>,
51}
52
53impl ModelConfig {
54    pub fn new(name: String) -> Self {
55        Self {
56            name: name.clone(),
57            max_tokens: DEFAULT_MAX_TOKENS,
58            think: true,
59            context_size: infer_context_size(&name),
60        }
61    }
62
63    /// Create with specific role defaults.
64    pub fn for_role(role: ModelRole) -> Self {
65        let name = role.default_model().to_string();
66        match role {
67            ModelRole::Main => Self::new(name),
68            ModelRole::Plan => Self::new(name),
69            ModelRole::Compress => Self {
70                name,
71                max_tokens: COMPRESS_MAX_TOKENS,
72                think: false,
73                context_size: Some(200_000),
74            },
75            ModelRole::Fast => Self {
76                name,
77                max_tokens: FAST_MAX_TOKENS,
78                think: false,
79                context_size: Some(200_000),
80            },
81        }
82    }
83
84    pub fn display_name(&self) -> &str {
85        &self.name
86    }
87}
88
89/// Infer context window size from model name.
90/// Honours the `CONTEXT_SIZE` env variable first so users can override.
91///
92/// # Model Context Window Reference (2025)
93///
94/// | Provider | Model | Context Window |
95/// |----------|-------|----------------|
96/// | Anthropic | Claude 3/4 series | 200K |
97/// | Anthropic | Claude 1M variants | 1M |
98/// | OpenAI | GPT-4o, GPT-4-Turbo | 128K |
99/// | OpenAI | o1, o3, o4 series | 200K |
100/// | OpenAI | GPT-4 | 8K/32K |
101/// | OpenAI | GPT-3.5 | 4K/16K |
102/// | 智谱 AI | GLM-5, GLM-4-Long | 1M |
103/// | 智谱 AI | GLM-4 | 128K |
104/// | DeepSeek | V3, R1 | 128K |
105/// | DeepSeek | V2, others | 64K |
106/// | 阿里 Qwen | Qwen2.5-Turbo, Qwen-Long | 1M |
107/// | 阿里 Qwen | Qwen2.5, Qwen3, Qwen-Max | 128K |
108/// | 阿里 Qwen | Qwen2 | 32K |
109/// | 月之暗面 | Kimi | 2M (长文本领先) |
110/// | Google | Gemini 1.5/2.0 Pro | 1M-2M |
111/// | Llama | 70B, 405B | 128K |
112/// | Llama | others | 8K |
113pub fn context_window_for(model: &str) -> Option<u32> {
114    // Allow user override via environment variable
115    if let Ok(raw) = std::env::var("CONTEXT_SIZE")
116        && let Ok(n) = raw.trim().parse::<u32>()
117        && n > 0
118    {
119        return Some(n);
120    }
121
122    let m = model.to_ascii_lowercase();
123
124    // ========================================================================
125    // Anthropic Claude models
126    // ========================================================================
127    // 1M context variants
128    if m.contains("1m") || m.contains("opus-4-7") || m.contains("opus-4.7") {
129        return Some(1_000_000);
130    }
131    // Claude 3/4 series: 200K
132    if m.contains("claude-3")
133        || m.contains("claude-4")
134        || m.contains("claude-opus")
135        || m.contains("claude-sonnet")
136        || m.contains("claude-haiku")
137    {
138        return Some(200_000);
139    }
140    // Claude 2: 100K
141    if m.contains("claude-2") || m.contains("claude-instant") {
142        return Some(100_000);
143    }
144
145    // ========================================================================
146    // OpenAI models
147    // ========================================================================
148    // o1, o3, o4 reasoning models: 200K
149    if m.contains("o1") || m.contains("o3") || m.contains("o4") {
150        return Some(200_000);
151    }
152    // GPT-4o, GPT-4-Turbo: 128K
153    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") || m.contains("gpt-4.1") {
154        return Some(128_000);
155    }
156    // GPT-4-32k: 32K
157    if m.contains("gpt-4-32k") {
158        return Some(32_768);
159    }
160    // GPT-4 base: 8K
161    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") && !m.contains("4o") {
162        return Some(8_192);
163    }
164    // GPT-3.5: 4K/16K
165    if m.contains("gpt-3.5-turbo-16k") {
166        return Some(16_384);
167    }
168    if m.contains("gpt-3.5") {
169        return Some(4_096);
170    }
171
172    // ========================================================================
173    // Google Gemini models
174    // ========================================================================
175    // Gemini 2.0 Pro: 2M
176    if m.contains("gemini-2") && m.contains("pro") {
177        return Some(2_000_000);
178    }
179    // Gemini 1.5 Pro: 2M, Flash: 1M
180    if m.contains("gemini-1.5-pro") || m.contains("gemini-1.5") && m.contains("pro") {
181        return Some(2_000_000);
182    }
183    if m.contains("gemini-1.5") {
184        return Some(1_000_000); // Flash and other 1.5 variants
185    }
186    // Gemini 1.0: 32K
187    if m.contains("gemini") {
188        return Some(32_000);
189    }
190
191    // ========================================================================
192    // 智谱 AI GLM models (中国)
193    // ========================================================================
194    // GLM-5, GLM-4-Long: 1M
195    if m.contains("glm-5") || m.contains("glm-4-long") {
196        return Some(1_000_000);
197    }
198    // GLM-4: 128K
199    if m.contains("glm-4") {
200        return Some(128_000);
201    }
202    // Other GLM models: 128K (保守估计)
203    if m.contains("glm") {
204        return Some(128_000);
205    }
206
207    // ========================================================================
208    // DeepSeek models (中国)
209    // ========================================================================
210    // DeepSeek-V3, DeepSeek-R1: 128K
211    if m.contains("deepseek-v3") || m.contains("deepseek-r1") || m.contains("deepseek-v3-") {
212        return Some(128_000);
213    }
214    // DeepSeek-V2 and others: 64K (API 默认限制)
215    if m.contains("deepseek") {
216        return Some(64_000);
217    }
218
219    // ========================================================================
220    // 阿里通义千问 Qwen models (中国)
221    // ========================================================================
222    // Qwen-Long, Qwen2.5-Turbo: 1M
223    if m.contains("qwen-long") || m.contains("qwen2.5-turbo") || m.contains("qwen-turbo") {
224        return Some(1_000_000);
225    }
226    // Qwen2.5, Qwen3, Qwen-Max: 128K
227    if m.contains("qwen2.5") || m.contains("qwen3") || m.contains("qwen-max") {
228        return Some(128_000);
229    }
230    // Qwen2: 32K
231    if m.contains("qwen2") {
232        return Some(32_000);
233    }
234    // Other Qwen models: 8K (保守估计)
235    if m.contains("qwen") {
236        return Some(8_192);
237    }
238
239    // ========================================================================
240    // 月之暗面 Kimi / Moonshot models (中国) - 长文本领先
241    // ========================================================================
242    // Kimi 支持 200万字 ≈ 2M tokens
243    // API 模型名称: kimi, moonshot-v1-8k, moonshot-v1-32k 等
244    if m.contains("kimi") || m.contains("moonshot") {
245        return Some(2_000_000);
246    }
247
248    // ========================================================================
249    // Meta Llama models
250    // ========================================================================
251    // Llama 3 70B, 405B: 128K
252    if m.contains("llama-3") || m.contains("llama3") {
253        if m.contains("70b") || m.contains("405b") {
254            return Some(128_000);
255        }
256        return Some(8_192);
257    }
258
259    // ========================================================================
260    // Mistral models
261    // ========================================================================
262    // Mistral Large: 128K
263    if m.contains("mistral-large") || m.contains("mistral") && m.contains("large") {
264        return Some(128_000);
265    }
266    // Mistral Medium: 32K
267    if m.contains("mistral-medium") {
268        return Some(32_000);
269    }
270    // Mistral Small, 7B: 32K
271    if m.contains("mistral") {
272        return Some(32_000);
273    }
274
275    // ========================================================================
276    // Unknown model - return None (caller should handle)
277    // ========================================================================
278    None
279}
280
281/// Legacy alias for context_window_for (internal use).
282fn infer_context_size(model: &str) -> Option<u32> {
283    context_window_for(model)
284}
285
286/// Multi-model configuration manager.
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct MultiModelConfig {
289    /// Main model for primary tasks.
290    pub main: ModelConfig,
291    /// Planning model for task decomposition.
292    pub plan: ModelConfig,
293    /// Compression model for context summarization.
294    pub compress: ModelConfig,
295    /// Fast model for quick operations.
296    pub fast: ModelConfig,
297}
298
299impl Default for MultiModelConfig {
300    fn default() -> Self {
301        Self {
302            main: ModelConfig::for_role(ModelRole::Main),
303            plan: ModelConfig::for_role(ModelRole::Plan),
304            compress: ModelConfig::for_role(ModelRole::Compress),
305            fast: ModelConfig::for_role(ModelRole::Fast),
306        }
307    }
308}
309
310impl MultiModelConfig {
311    /// Create with a main model, all other roles also use this model by default.
312    /// This ensures that if no specific model is configured for a role, it falls back to the main model.
313    pub fn with_main(main_model: String) -> Self {
314        let main_config = ModelConfig::new(main_model);
315        Self {
316            main: main_config.clone(),
317            plan: main_config.clone(),
318            compress: main_config.clone(),
319            fast: main_config,
320        }
321    }
322
323    /// Create where all roles use the same model.
324    pub fn unified(model: String) -> Self {
325        let config = ModelConfig::new(model);
326        Self {
327            main: config.clone(),
328            plan: config.clone(),
329            compress: config.clone(),
330            fast: config,
331        }
332    }
333
334    /// Get config for a specific role.
335    pub fn get(&self, role: ModelRole) -> &ModelConfig {
336        match role {
337            ModelRole::Main => &self.main,
338            ModelRole::Plan => &self.plan,
339            ModelRole::Compress => &self.compress,
340            ModelRole::Fast => &self.fast,
341        }
342    }
343
344    /// Set model for a specific role.
345    pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
346        match role {
347            ModelRole::Main => self.main = config,
348            ModelRole::Plan => self.plan = config,
349            ModelRole::Compress => self.compress = config,
350            ModelRole::Fast => self.fast = config,
351        }
352    }
353
354    /// Format for display.
355    pub fn format_summary(&self) -> String {
356        format!(
357            "main: {}, plan: {}, compress: {}, fast: {}",
358            self.main.display_name(),
359            self.plan.display_name(),
360            self.compress.display_name(),
361            self.fast.display_name()
362        )
363    }
364}
365
366/// Task complexity level.
367#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
368#[serde(rename_all = "snake_case")]
369pub enum TaskComplexity {
370    Simple,
371    Moderate,
372    Complex,
373}
374
375impl TaskComplexity {
376    pub fn display(&self) -> &'static str {
377        match self {
378            TaskComplexity::Simple => "简单",
379            TaskComplexity::Moderate => "中等",
380            TaskComplexity::Complex => "复杂",
381        }
382    }
383}
384
385/// Step difficulty level.
386#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
387#[serde(rename_all = "snake_case")]
388pub enum StepDifficulty {
389    Easy,
390    Medium,
391    Hard,
392}
393
394/// A single step in the task plan.
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct PlanStep {
397    /// Step description.
398    pub description: String,
399    /// Tools needed for this step.
400    pub tools: Vec<String>,
401    /// Whether this step is optional.
402    pub optional: bool,
403}
404
405/// Task plan generated by the planning model.
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct TaskPlan {
408    /// Original user request.
409    pub request: String,
410    /// Decomposed steps.
411    pub steps: Vec<PlanStep>,
412    /// Estimated complexity.
413    pub complexity: TaskComplexity,
414    /// Suggested approach summary.
415    pub approach: String,
416    /// Potential risks or considerations.
417    pub considerations: Vec<String>,
418}
419
420impl TaskPlan {
421    /// Format for display.
422    pub fn format(&self) -> String {
423        let mut output = String::new();
424
425        output.push_str(&format!("任务分析: {}\n", self.request));
426        output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
427        output.push_str(&format!("建议方案: {}\n\n", self.approach));
428
429        output.push_str("执行步骤:\n");
430        for (i, step) in self.steps.iter().enumerate() {
431            let marker = if step.optional { "[可选]" } else { "" };
432            output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
433            if !step.tools.is_empty() {
434                output.push_str(&format!("   工具: {}\n", step.tools.join(", ")));
435            }
436        }
437
438        if !self.considerations.is_empty() {
439            output.push_str("\n注意事项:\n");
440            for c in &self.considerations {
441                output.push_str(&format!("• {}\n", c));
442            }
443        }
444
445        output
446    }
447
448    /// Convert to todo items for the agent.
449    pub fn to_todo_items(&self) -> Vec<TodoItem> {
450        self.steps
451            .iter()
452            .enumerate()
453            .map(|(i, step)| TodoItem {
454                content: step.description.clone(),
455                active_form: format!("执行步骤 {}: {}", i + 1, step.description),
456                status: if i == 0 {
457                    "in_progress".to_string()
458                } else {
459                    "pending".to_string()
460                },
461            })
462            .collect()
463    }
464}
465
466/// Todo item for task tracking.
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct TodoItem {
469    pub content: String,
470    pub active_form: String,
471    pub status: String,
472}
473
474/// Planner for generating task plans using the plan model.
475pub struct Planner {
476    provider: Box<dyn Provider>,
477    config: ModelConfig,
478}
479
480impl Planner {
481    /// Create a new planner.
482    pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
483        Self { provider, config }
484    }
485
486    /// Generate a task plan for the given request.
487    pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
488        let prompt = build_plan_prompt(request, available_tools);
489
490        let chat_request = ChatRequest {
491            messages: vec![Message {
492                role: Role::User,
493                content: MessageContent::Text(prompt),
494            }],
495            tools: vec![],
496            system: Some(PLAN_SYSTEM_PROMPT.to_string()),
497            think: false,
498            max_tokens: self.config.max_tokens,
499            server_tools: vec![],
500            enable_caching: false,
501        };
502
503        let response = self.provider.chat(chat_request).await?;
504        let text = extract_text(&response);
505
506        parse_plan_response(request, &text)
507    }
508
509    /// Quick complexity assessment using fast model.
510    pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
511        let prompt = format!(
512            "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
513            request
514        );
515
516        let chat_request = ChatRequest {
517            messages: vec![Message {
518                role: Role::User,
519                content: MessageContent::Text(prompt),
520            }],
521            tools: vec![],
522            system: None,
523            think: false,
524            max_tokens: 50,
525            server_tools: vec![],
526            enable_caching: false,
527        };
528
529        let response = self.provider.chat(chat_request).await?;
530        let text = extract_text(&response).to_lowercase();
531
532        if text.contains("简单") || text.contains("simple") {
533            Ok(TaskComplexity::Simple)
534        } else if text.contains("复杂") || text.contains("complex") {
535            Ok(TaskComplexity::Complex)
536        } else {
537            Ok(TaskComplexity::Moderate)
538        }
539    }
540}
541
542/// System prompt for planning.
543const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
544
545输出要求(JSON格式):
546```json
547{
548  "complexity": "simple|moderate|complex",
549  "approach": "建议的方案(一句话)",
550  "steps": [
551    {
552      "description": "步骤描述",
553      "tools": ["需要的工具"],
554      "optional": false
555    }
556  ],
557  "considerations": ["注意事项"]
558}
559```
560
561规划原则:
5621. 简单任务(如读取文件、简单查询)只需1-2步
5632. 中等任务(如修改代码、添加功能)需要3-5步
5643. 复杂任务(如重构、跨模块修改)需要详细规划
5654. 每个步骤要具体、可执行
5665. 标记可选步骤和潜在风险"#;
567
568/// Build planning prompt.
569fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
570    format!(
571        r#"用户请求:
572{}
573
574可用工具:
575{}
576
577请分析任务并生成执行计划(JSON格式)。"#,
578        request,
579        available_tools.join(", ")
580    )
581}
582
583/// Parse planning response into TaskPlan.
584fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
585    // Try to parse as JSON
586    if let Some(json_start) = text.find('{')
587        && let Some(json_end) = text.rfind('}')
588    {
589        let json_str = &text[json_start..=json_end];
590        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
591            return Ok(TaskPlan {
592                request: request.to_string(),
593                steps: parse_steps(&parsed["steps"]),
594                complexity: parse_complexity(&parsed["complexity"]),
595                approach: parsed["approach"]
596                    .as_str()
597                    .unwrap_or("直接执行")
598                    .to_string(),
599                considerations: parsed["considerations"]
600                    .as_array()
601                    .map(|arr| {
602                        arr.iter()
603                            .filter_map(|v| v.as_str().map(String::from))
604                            .collect()
605                    })
606                    .unwrap_or_default(),
607            });
608        }
609    }
610
611    // Fallback: create simple plan from text
612    Ok(TaskPlan {
613        request: request.to_string(),
614        steps: parse_steps_from_text(text),
615        complexity: TaskComplexity::Moderate,
616        approach: "按步骤执行".to_string(),
617        considerations: vec!["请检查执行结果".to_string()],
618    })
619}
620
621fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
622    value
623        .as_array()
624        .map(|arr| {
625            arr.iter()
626                .filter_map(|v| {
627                    Some(PlanStep {
628                        description: v["description"].as_str()?.to_string(),
629                        tools: v["tools"]
630                            .as_array()
631                            .map(|t| {
632                                t.iter()
633                                    .filter_map(|x| x.as_str().map(String::from))
634                                    .collect()
635                            })
636                            .unwrap_or_default(),
637                        optional: v["optional"].as_bool().unwrap_or(false),
638                    })
639                })
640                .collect()
641        })
642        .unwrap_or_default()
643}
644
645fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
646    match value.as_str().map(|s| s.to_lowercase()) {
647        Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
648        Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
649        _ => TaskComplexity::Moderate,
650    }
651}
652
653fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
654    text.lines()
655        .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
656        .take(5)
657        .map(|l| PlanStep {
658            description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
659            tools: vec!["read".to_string()],
660            optional: false,
661        })
662        .collect()
663}
664
665fn extract_text(response: &ChatResponse) -> String {
666    response
667        .content
668        .iter()
669        .filter_map(|block| {
670            if let ContentBlock::Text { text } = block {
671                Some(text.clone())
672            } else {
673                None
674            }
675        })
676        .collect::<Vec<_>>()
677        .join("\n")
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn test_model_config_defaults() {
686        let main = ModelConfig::for_role(ModelRole::Main);
687        assert!(main.name.contains("claude"));
688        assert!(main.think);
689
690        let compress = ModelConfig::for_role(ModelRole::Compress);
691        assert!(compress.name.contains("haiku"));
692        assert!(!compress.think);
693    }
694
695    #[test]
696    fn test_infer_context_size() {
697        // Anthropic Claude models
698        assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
699        assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
700        assert_eq!(infer_context_size("claude-opus-4"), Some(200_000));
701        assert_eq!(infer_context_size("claude-2"), Some(100_000));
702        // 1M context variants
703        assert_eq!(infer_context_size("claude-sonnet-4-1m"), Some(1_000_000));
704        assert_eq!(infer_context_size("claude-opus-4-7"), Some(1_000_000));
705
706        // OpenAI models
707        assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
708        assert_eq!(infer_context_size("gpt-4-turbo"), Some(128_000));
709        assert_eq!(infer_context_size("gpt-4"), Some(8_192));
710        assert_eq!(infer_context_size("gpt-4-32k"), Some(32_768));
711        assert_eq!(infer_context_size("gpt-3.5-turbo"), Some(4_096));
712        assert_eq!(infer_context_size("gpt-3.5-turbo-16k"), Some(16_384));
713        assert_eq!(infer_context_size("o1-preview"), Some(200_000));
714        assert_eq!(infer_context_size("o3-mini"), Some(200_000));
715
716        // Google Gemini models
717        assert_eq!(infer_context_size("gemini-1.5-pro"), Some(2_000_000));
718        assert_eq!(infer_context_size("gemini-1.5-flash"), Some(1_000_000));
719        assert_eq!(infer_context_size("gemini-2.0-pro"), Some(2_000_000));
720        assert_eq!(infer_context_size("gemini-pro"), Some(32_000));
721
722        // 智谱 AI GLM models
723        assert_eq!(infer_context_size("glm-5"), Some(1_000_000));
724        assert_eq!(infer_context_size("glm-4-long"), Some(1_000_000));
725        assert_eq!(infer_context_size("glm-4"), Some(128_000));
726        assert_eq!(infer_context_size("glm"), Some(128_000));
727
728        // DeepSeek models
729        assert_eq!(infer_context_size("deepseek-v3"), Some(128_000));
730        assert_eq!(infer_context_size("deepseek-r1"), Some(128_000));
731        assert_eq!(infer_context_size("deepseek-chat"), Some(64_000));
732        assert_eq!(infer_context_size("deepseek"), Some(64_000));
733
734        // 阿里 Qwen models
735        assert_eq!(infer_context_size("qwen2.5-turbo"), Some(1_000_000));
736        assert_eq!(infer_context_size("qwen-long"), Some(1_000_000));
737        assert_eq!(infer_context_size("qwen2.5-72b"), Some(128_000));
738        assert_eq!(infer_context_size("qwen3-32b"), Some(128_000));
739        assert_eq!(infer_context_size("qwen-max"), Some(128_000));
740        assert_eq!(infer_context_size("qwen2-7b"), Some(32_000));
741        assert_eq!(infer_context_size("qwen"), Some(8_192));
742
743        // 月之暗面 Kimi models
744        assert_eq!(infer_context_size("kimi"), Some(2_000_000));
745        assert_eq!(infer_context_size("moonshot-v1-8k"), Some(2_000_000));
746
747        // Mistral models
748        assert_eq!(infer_context_size("mistral-large"), Some(128_000));
749        assert_eq!(infer_context_size("mistral-medium"), Some(32_000));
750        assert_eq!(infer_context_size("mistral-7b"), Some(32_000));
751
752        // Llama models
753        assert_eq!(infer_context_size("llama-3-70b"), Some(128_000));
754        assert_eq!(infer_context_size("llama-3-8b"), Some(8_192));
755        assert_eq!(infer_context_size("llama3-405b"), Some(128_000));
756    }
757
758    #[test]
759    fn test_multi_model_config() {
760        let config = MultiModelConfig::default();
761        assert!(config.main.name.contains("sonnet"));
762        assert!(config.compress.name.contains("haiku"));
763    }
764
765    #[test]
766    fn test_multi_model_config_with_main() {
767        // with_main should make all roles use the main model
768        let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
769
770        // All roles should use the same model
771        assert_eq!(config.main.name, "claude-sonnet-4");
772        assert_eq!(config.plan.name, "claude-sonnet-4");
773        assert_eq!(config.compress.name, "claude-sonnet-4");
774        assert_eq!(config.fast.name, "claude-sonnet-4");
775
776        // All should have thinking enabled (inherited from main)
777        assert!(config.main.think);
778        assert!(config.plan.think);
779        assert!(config.compress.think);
780        assert!(config.fast.think);
781    }
782
783    #[test]
784    fn test_multi_model_config_override() {
785        let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
786
787        // Override compress model
788        config.set(
789            ModelRole::Compress,
790            ModelConfig::new("claude-3-5-haiku".to_string()),
791        );
792
793        assert_eq!(config.main.name, "claude-sonnet-4");
794        assert_eq!(config.plan.name, "claude-sonnet-4");
795        assert_eq!(config.compress.name, "claude-3-5-haiku");
796        assert_eq!(config.fast.name, "claude-sonnet-4"); // Still uses main
797    }
798
799    #[test]
800    fn test_task_plan_format() {
801        let plan = TaskPlan {
802            request: "测试任务".to_string(),
803            steps: vec![PlanStep {
804                description: "读取文件".to_string(),
805                tools: vec!["read".to_string()],
806                optional: false,
807            }],
808            complexity: TaskComplexity::Simple,
809            approach: "直接执行".to_string(),
810            considerations: vec!["注意检查".to_string()],
811        };
812
813        let formatted = plan.format();
814        assert!(formatted.contains("测试任务"));
815        assert!(formatted.contains("简单"));
816        assert!(formatted.contains("读取文件"));
817    }
818
819    #[test]
820    fn test_complexity_display() {
821        assert_eq!(TaskComplexity::Simple.display(), "简单");
822        assert_eq!(TaskComplexity::Moderate.display(), "中等");
823        assert_eq!(TaskComplexity::Complex.display(), "复杂");
824    }
825
826    #[test]
827    fn test_task_plan_to_todo() {
828        let plan = TaskPlan {
829            request: "任务".to_string(),
830            steps: vec![
831                PlanStep {
832                    description: "步骤1".to_string(),
833                    tools: vec![],
834                    optional: false,
835                },
836                PlanStep {
837                    description: "步骤2".to_string(),
838                    tools: vec![],
839                    optional: false,
840                },
841            ],
842            complexity: TaskComplexity::Simple,
843            approach: "执行".to_string(),
844            considerations: vec![],
845        };
846
847        let todos = plan.to_todo_items();
848        assert_eq!(todos.len(), 2);
849        assert_eq!(todos[0].status, "in_progress");
850        assert_eq!(todos[1].status, "pending");
851    }
852
853    #[test]
854    fn test_parse_plan_response_json() {
855        let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
856        let plan = parse_plan_response("test", json).unwrap();
857
858        assert_eq!(plan.complexity, TaskComplexity::Simple);
859        assert_eq!(plan.steps.len(), 1);
860        assert_eq!(plan.steps[0].description, "read file");
861    }
862}