Skip to main content

matrixcode_core/
models.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::constants::{DEFAULT_MAX_TOKENS, COMPRESS_MAX_TOKENS, FAST_MAX_TOKENS};
5use crate::providers::{
6    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
7};
8
9/// Default model names for different roles.
10pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
12pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
13pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
14
15/// Model role - what purpose this model serves.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ModelRole {
19    /// Main model for task execution.
20    Main,
21    /// Planning model for task decomposition.
22    Plan,
23    /// Compression model for context summarization.
24    Compress,
25    /// Fast model for quick operations.
26    Fast,
27}
28
29impl ModelRole {
30    pub fn default_model(&self) -> &'static str {
31        match self {
32            ModelRole::Main => DEFAULT_MAIN_MODEL,
33            ModelRole::Plan => DEFAULT_PLAN_MODEL,
34            ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
35            ModelRole::Fast => DEFAULT_FAST_MODEL,
36        }
37    }
38}
39
40/// Configuration for a single model.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelConfig {
43    /// Model name/identifier.
44    pub name: String,
45    /// Maximum output tokens.
46    pub max_tokens: u32,
47    /// Whether to enable thinking for this model.
48    pub think: bool,
49    /// Estimated context window size.
50    pub context_size: Option<u32>,
51}
52
53impl ModelConfig {
54    pub fn new(name: String) -> Self {
55        Self {
56            name: name.clone(),
57            max_tokens: DEFAULT_MAX_TOKENS,
58            think: true,
59            context_size: infer_context_size(&name),
60        }
61    }
62
63    /// Create with specific role defaults.
64    pub fn for_role(role: ModelRole) -> Self {
65        let name = role.default_model().to_string();
66        match role {
67            ModelRole::Main => Self::new(name),
68            ModelRole::Plan => Self::new(name),
69            ModelRole::Compress => Self {
70                name,
71                max_tokens: COMPRESS_MAX_TOKENS,
72                think: false,
73                context_size: Some(200_000),
74            },
75            ModelRole::Fast => Self {
76                name,
77                max_tokens: FAST_MAX_TOKENS,
78                think: false,
79                context_size: Some(200_000),
80            },
81        }
82    }
83
84    pub fn display_name(&self) -> &str {
85        &self.name
86    }
87}
88
89/// Infer context window size from model name.
90/// Honours the `CONTEXT_SIZE` env variable first so users can override.
91pub fn context_window_for(model: &str) -> Option<u32> {
92    // Allow user override via environment variable
93    if let Ok(raw) = std::env::var("CONTEXT_SIZE")
94        && let Ok(n) = raw.trim().parse::<u32>()
95        && n > 0
96    {
97        return Some(n);
98    }
99
100    let m = model.to_ascii_lowercase();
101
102    // Anthropic models - 1M context window
103    if m.contains("1m") || m.contains("opus-4-7") || m.contains("opus-4.7") {
104        return Some(1_000_000);
105    }
106    if m.contains("claude-3")
107        || m.contains("claude-4")
108        || m.contains("claude-opus")
109        || m.contains("claude-sonnet")
110        || m.contains("claude-haiku")
111    {
112        return Some(200_000);
113    }
114    if m.contains("claude-2") || m.contains("claude-instant") {
115        return Some(100_000);
116    }
117
118    // OpenAI models
119    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
120        return Some(128_000);
121    }
122    if m.contains("o1") || m.contains("o3") || m.contains("o4") {
123        return Some(200_000);
124    }
125    if m.contains("gpt-4-32k") {
126        return Some(32_768);
127    }
128    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
129        return Some(8_192);
130    }
131    if m.contains("gpt-3.5-turbo-16k") {
132        return Some(16_384);
133    }
134    if m.contains("gpt-3.5") {
135        return Some(4_096);
136    }
137
138    // DeepSeek models
139    if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
140        return Some(128_000);
141    }
142    if m.contains("deepseek") {
143        return Some(64_000);
144    }
145
146    // Kimi models
147    if m.contains("kimi") {
148        return Some(128_000);
149    }
150
151    // Qwen models
152    if m.contains("qwen") {
153        if m.contains("qwen-max") || m.contains("qwen2.5-72b") || m.contains("qwen2.5") {
154            return Some(128_000);
155        }
156        if m.contains("qwen2") {
157            return Some(32_000);
158        }
159        return Some(8_192);
160    }
161
162    // Llama models
163    if m.contains("llama-3") || m.contains("llama3") {
164        if m.contains("70b") || m.contains("405b") {
165            return Some(128_000);
166        }
167        return Some(8_192);
168    }
169
170    // GLM models (Zhipu AI)
171    if m.contains("glm") {
172        return Some(128_000);
173    }
174
175    None
176}
177
178/// Legacy alias for context_window_for (internal use).
179fn infer_context_size(model: &str) -> Option<u32> {
180    context_window_for(model)
181}
182
183/// Multi-model configuration manager.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct MultiModelConfig {
186    /// Main model for primary tasks.
187    pub main: ModelConfig,
188    /// Planning model for task decomposition.
189    pub plan: ModelConfig,
190    /// Compression model for context summarization.
191    pub compress: ModelConfig,
192    /// Fast model for quick operations.
193    pub fast: ModelConfig,
194}
195
196impl Default for MultiModelConfig {
197    fn default() -> Self {
198        Self {
199            main: ModelConfig::for_role(ModelRole::Main),
200            plan: ModelConfig::for_role(ModelRole::Plan),
201            compress: ModelConfig::for_role(ModelRole::Compress),
202            fast: ModelConfig::for_role(ModelRole::Fast),
203        }
204    }
205}
206
207impl MultiModelConfig {
208    /// Create with a main model, all other roles also use this model by default.
209    /// This ensures that if no specific model is configured for a role, it falls back to the main model.
210    pub fn with_main(main_model: String) -> Self {
211        let main_config = ModelConfig::new(main_model);
212        Self {
213            main: main_config.clone(),
214            plan: main_config.clone(),
215            compress: main_config.clone(),
216            fast: main_config,
217        }
218    }
219
220    /// Create where all roles use the same model.
221    pub fn unified(model: String) -> Self {
222        let config = ModelConfig::new(model);
223        Self {
224            main: config.clone(),
225            plan: config.clone(),
226            compress: config.clone(),
227            fast: config,
228        }
229    }
230
231    /// Get config for a specific role.
232    pub fn get(&self, role: ModelRole) -> &ModelConfig {
233        match role {
234            ModelRole::Main => &self.main,
235            ModelRole::Plan => &self.plan,
236            ModelRole::Compress => &self.compress,
237            ModelRole::Fast => &self.fast,
238        }
239    }
240
241    /// Set model for a specific role.
242    pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
243        match role {
244            ModelRole::Main => self.main = config,
245            ModelRole::Plan => self.plan = config,
246            ModelRole::Compress => self.compress = config,
247            ModelRole::Fast => self.fast = config,
248        }
249    }
250
251    /// Format for display.
252    pub fn format_summary(&self) -> String {
253        format!(
254            "main: {}, plan: {}, compress: {}, fast: {}",
255            self.main.display_name(),
256            self.plan.display_name(),
257            self.compress.display_name(),
258            self.fast.display_name()
259        )
260    }
261}
262
263/// Task complexity level.
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
265#[serde(rename_all = "snake_case")]
266pub enum TaskComplexity {
267    Simple,
268    Moderate,
269    Complex,
270}
271
272impl TaskComplexity {
273    pub fn display(&self) -> &'static str {
274        match self {
275            TaskComplexity::Simple => "简单",
276            TaskComplexity::Moderate => "中等",
277            TaskComplexity::Complex => "复杂",
278        }
279    }
280}
281
282/// Step difficulty level.
283#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
284#[serde(rename_all = "snake_case")]
285pub enum StepDifficulty {
286    Easy,
287    Medium,
288    Hard,
289}
290
291/// A single step in the task plan.
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct PlanStep {
294    /// Step description.
295    pub description: String,
296    /// Tools needed for this step.
297    pub tools: Vec<String>,
298    /// Whether this step is optional.
299    pub optional: bool,
300}
301
302/// Task plan generated by the planning model.
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct TaskPlan {
305    /// Original user request.
306    pub request: String,
307    /// Decomposed steps.
308    pub steps: Vec<PlanStep>,
309    /// Estimated complexity.
310    pub complexity: TaskComplexity,
311    /// Suggested approach summary.
312    pub approach: String,
313    /// Potential risks or considerations.
314    pub considerations: Vec<String>,
315}
316
317impl TaskPlan {
318    /// Format for display.
319    pub fn format(&self) -> String {
320        let mut output = String::new();
321
322        output.push_str(&format!("任务分析: {}\n", self.request));
323        output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
324        output.push_str(&format!("建议方案: {}\n\n", self.approach));
325
326        output.push_str("执行步骤:\n");
327        for (i, step) in self.steps.iter().enumerate() {
328            let marker = if step.optional { "[可选]" } else { "" };
329            output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
330            if !step.tools.is_empty() {
331                output.push_str(&format!("   工具: {}\n", step.tools.join(", ")));
332            }
333        }
334
335        if !self.considerations.is_empty() {
336            output.push_str("\n注意事项:\n");
337            for c in &self.considerations {
338                output.push_str(&format!("• {}\n", c));
339            }
340        }
341
342        output
343    }
344
345    /// Convert to todo items for the agent.
346    pub fn to_todo_items(&self) -> Vec<TodoItem> {
347        self.steps
348            .iter()
349            .enumerate()
350            .map(|(i, step)| TodoItem {
351                content: step.description.clone(),
352                active_form: format!("执行步骤 {}: {}", i + 1, step.description),
353                status: if i == 0 {
354                    "in_progress".to_string()
355                } else {
356                    "pending".to_string()
357                },
358            })
359            .collect()
360    }
361}
362
363/// Todo item for task tracking.
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct TodoItem {
366    pub content: String,
367    pub active_form: String,
368    pub status: String,
369}
370
371/// Planner for generating task plans using the plan model.
372pub struct Planner {
373    provider: Box<dyn Provider>,
374    config: ModelConfig,
375}
376
377impl Planner {
378    /// Create a new planner.
379    pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
380        Self { provider, config }
381    }
382
383    /// Generate a task plan for the given request.
384    pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
385        let prompt = build_plan_prompt(request, available_tools);
386
387        let chat_request = ChatRequest {
388            messages: vec![Message {
389                role: Role::User,
390                content: MessageContent::Text(prompt),
391            }],
392            tools: vec![],
393            system: Some(PLAN_SYSTEM_PROMPT.to_string()),
394            think: false,
395            max_tokens: self.config.max_tokens,
396            server_tools: vec![],
397            enable_caching: false,
398        };
399
400        let response = self.provider.chat(chat_request).await?;
401        let text = extract_text(&response);
402
403        parse_plan_response(request, &text)
404    }
405
406    /// Quick complexity assessment using fast model.
407    pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
408        let prompt = format!(
409            "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
410            request
411        );
412
413        let chat_request = ChatRequest {
414            messages: vec![Message {
415                role: Role::User,
416                content: MessageContent::Text(prompt),
417            }],
418            tools: vec![],
419            system: None,
420            think: false,
421            max_tokens: 50,
422            server_tools: vec![],
423            enable_caching: false,
424        };
425
426        let response = self.provider.chat(chat_request).await?;
427        let text = extract_text(&response).to_lowercase();
428
429        if text.contains("简单") || text.contains("simple") {
430            Ok(TaskComplexity::Simple)
431        } else if text.contains("复杂") || text.contains("complex") {
432            Ok(TaskComplexity::Complex)
433        } else {
434            Ok(TaskComplexity::Moderate)
435        }
436    }
437}
438
439/// System prompt for planning.
440const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
441
442输出要求(JSON格式):
443```json
444{
445  "complexity": "simple|moderate|complex",
446  "approach": "建议的方案(一句话)",
447  "steps": [
448    {
449      "description": "步骤描述",
450      "tools": ["需要的工具"],
451      "optional": false
452    }
453  ],
454  "considerations": ["注意事项"]
455}
456```
457
458规划原则:
4591. 简单任务(如读取文件、简单查询)只需1-2步
4602. 中等任务(如修改代码、添加功能)需要3-5步
4613. 复杂任务(如重构、跨模块修改)需要详细规划
4624. 每个步骤要具体、可执行
4635. 标记可选步骤和潜在风险"#;
464
465/// Build planning prompt.
466fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
467    format!(
468        r#"用户请求:
469{}
470
471可用工具:
472{}
473
474请分析任务并生成执行计划(JSON格式)。"#,
475        request,
476        available_tools.join(", ")
477    )
478}
479
480/// Parse planning response into TaskPlan.
481fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
482    // Try to parse as JSON
483    if let Some(json_start) = text.find('{')
484        && let Some(json_end) = text.rfind('}')
485    {
486        let json_str = &text[json_start..=json_end];
487        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
488            return Ok(TaskPlan {
489                request: request.to_string(),
490                steps: parse_steps(&parsed["steps"]),
491                complexity: parse_complexity(&parsed["complexity"]),
492                approach: parsed["approach"]
493                    .as_str()
494                    .unwrap_or("直接执行")
495                    .to_string(),
496                considerations: parsed["considerations"]
497                    .as_array()
498                    .map(|arr| {
499                        arr.iter()
500                            .filter_map(|v| v.as_str().map(String::from))
501                            .collect()
502                    })
503                    .unwrap_or_default(),
504            });
505        }
506    }
507
508    // Fallback: create simple plan from text
509    Ok(TaskPlan {
510        request: request.to_string(),
511        steps: parse_steps_from_text(text),
512        complexity: TaskComplexity::Moderate,
513        approach: "按步骤执行".to_string(),
514        considerations: vec!["请检查执行结果".to_string()],
515    })
516}
517
518fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
519    value
520        .as_array()
521        .map(|arr| {
522            arr.iter()
523                .filter_map(|v| {
524                    Some(PlanStep {
525                        description: v["description"].as_str()?.to_string(),
526                        tools: v["tools"]
527                            .as_array()
528                            .map(|t| {
529                                t.iter()
530                                    .filter_map(|x| x.as_str().map(String::from))
531                                    .collect()
532                            })
533                            .unwrap_or_default(),
534                        optional: v["optional"].as_bool().unwrap_or(false),
535                    })
536                })
537                .collect()
538        })
539        .unwrap_or_default()
540}
541
542fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
543    match value.as_str().map(|s| s.to_lowercase()) {
544        Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
545        Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
546        _ => TaskComplexity::Moderate,
547    }
548}
549
550fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
551    text.lines()
552        .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
553        .take(5)
554        .map(|l| PlanStep {
555            description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
556            tools: vec!["read".to_string()],
557            optional: false,
558        })
559        .collect()
560}
561
562fn extract_text(response: &ChatResponse) -> String {
563    response
564        .content
565        .iter()
566        .filter_map(|block| {
567            if let ContentBlock::Text { text } = block {
568                Some(text.clone())
569            } else {
570                None
571            }
572        })
573        .collect::<Vec<_>>()
574        .join("\n")
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_model_config_defaults() {
583        let main = ModelConfig::for_role(ModelRole::Main);
584        assert!(main.name.contains("claude"));
585        assert!(main.think);
586
587        let compress = ModelConfig::for_role(ModelRole::Compress);
588        assert!(compress.name.contains("haiku"));
589        assert!(!compress.think);
590    }
591
592    #[test]
593    fn test_infer_context_size() {
594        assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
595        assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
596        assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
597        // 1M context models
598        assert_eq!(infer_context_size("claude-sonnet-4-1m"), Some(1_000_000));
599        assert_eq!(infer_context_size("claude-opus-4-7"), Some(1_000_000));
600    }
601
602    #[test]
603    fn test_multi_model_config() {
604        let config = MultiModelConfig::default();
605        assert!(config.main.name.contains("sonnet"));
606        assert!(config.compress.name.contains("haiku"));
607    }
608
609    #[test]
610    fn test_multi_model_config_with_main() {
611        // with_main should make all roles use the main model
612        let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
613
614        // All roles should use the same model
615        assert_eq!(config.main.name, "claude-sonnet-4");
616        assert_eq!(config.plan.name, "claude-sonnet-4");
617        assert_eq!(config.compress.name, "claude-sonnet-4");
618        assert_eq!(config.fast.name, "claude-sonnet-4");
619
620        // All should have thinking enabled (inherited from main)
621        assert!(config.main.think);
622        assert!(config.plan.think);
623        assert!(config.compress.think);
624        assert!(config.fast.think);
625    }
626
627    #[test]
628    fn test_multi_model_config_override() {
629        let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
630
631        // Override compress model
632        config.set(
633            ModelRole::Compress,
634            ModelConfig::new("claude-3-5-haiku".to_string()),
635        );
636
637        assert_eq!(config.main.name, "claude-sonnet-4");
638        assert_eq!(config.plan.name, "claude-sonnet-4");
639        assert_eq!(config.compress.name, "claude-3-5-haiku");
640        assert_eq!(config.fast.name, "claude-sonnet-4"); // Still uses main
641    }
642
643    #[test]
644    fn test_task_plan_format() {
645        let plan = TaskPlan {
646            request: "测试任务".to_string(),
647            steps: vec![PlanStep {
648                description: "读取文件".to_string(),
649                tools: vec!["read".to_string()],
650                optional: false,
651            }],
652            complexity: TaskComplexity::Simple,
653            approach: "直接执行".to_string(),
654            considerations: vec!["注意检查".to_string()],
655        };
656
657        let formatted = plan.format();
658        assert!(formatted.contains("测试任务"));
659        assert!(formatted.contains("简单"));
660        assert!(formatted.contains("读取文件"));
661    }
662
663    #[test]
664    fn test_complexity_display() {
665        assert_eq!(TaskComplexity::Simple.display(), "简单");
666        assert_eq!(TaskComplexity::Moderate.display(), "中等");
667        assert_eq!(TaskComplexity::Complex.display(), "复杂");
668    }
669
670    #[test]
671    fn test_task_plan_to_todo() {
672        let plan = TaskPlan {
673            request: "任务".to_string(),
674            steps: vec![
675                PlanStep {
676                    description: "步骤1".to_string(),
677                    tools: vec![],
678                    optional: false,
679                },
680                PlanStep {
681                    description: "步骤2".to_string(),
682                    tools: vec![],
683                    optional: false,
684                },
685            ],
686            complexity: TaskComplexity::Simple,
687            approach: "执行".to_string(),
688            considerations: vec![],
689        };
690
691        let todos = plan.to_todo_items();
692        assert_eq!(todos.len(), 2);
693        assert_eq!(todos[0].status, "in_progress");
694        assert_eq!(todos[1].status, "pending");
695    }
696
697    #[test]
698    fn test_parse_plan_response_json() {
699        let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
700        let plan = parse_plan_response("test", json).unwrap();
701
702        assert_eq!(plan.complexity, TaskComplexity::Simple);
703        assert_eq!(plan.steps.len(), 1);
704        assert_eq!(plan.steps[0].description, "read file");
705    }
706}