Skip to main content

matrixcode_core/
models.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role};
5
6/// Default model names for different roles.
7pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
8pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
9pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
10pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
11
12/// Model role - what purpose this model serves.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ModelRole {
16    /// Main model for task execution.
17    Main,
18    /// Planning model for task decomposition.
19    Plan,
20    /// Compression model for context summarization.
21    Compress,
22    /// Fast model for quick operations.
23    Fast,
24}
25
26impl ModelRole {
27    pub fn default_model(&self) -> &'static str {
28        match self {
29            ModelRole::Main => DEFAULT_MAIN_MODEL,
30            ModelRole::Plan => DEFAULT_PLAN_MODEL,
31            ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
32            ModelRole::Fast => DEFAULT_FAST_MODEL,
33        }
34    }
35}
36
37/// Configuration for a single model.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ModelConfig {
40    /// Model name/identifier.
41    pub name: String,
42    /// Maximum output tokens.
43    pub max_tokens: u32,
44    /// Whether to enable thinking for this model.
45    pub think: bool,
46    /// Estimated context window size.
47    pub context_size: Option<u32>,
48}
49
50impl ModelConfig {
51    pub fn new(name: String) -> Self {
52        Self {
53            name: name.clone(),
54            max_tokens: 16384,
55            think: true,
56            context_size: infer_context_size(&name),
57        }
58    }
59
60    /// Create with specific role defaults.
61    pub fn for_role(role: ModelRole) -> Self {
62        let name = role.default_model().to_string();
63        match role {
64            ModelRole::Main => Self::new(name),
65            ModelRole::Plan => Self::new(name),
66            ModelRole::Compress => Self {
67                name,
68                max_tokens: 1024,
69                think: false,
70                context_size: Some(200_000),
71            },
72            ModelRole::Fast => Self {
73                name,
74                max_tokens: 2048,
75                think: false,
76                context_size: Some(200_000),
77            },
78        }
79    }
80
81    pub fn display_name(&self) -> &str {
82        &self.name
83    }
84}
85
86/// Infer context window size from model name.
87fn infer_context_size(model: &str) -> Option<u32> {
88    let m = model.to_ascii_lowercase();
89    
90    // Anthropic models
91    if m.contains("opus-4-7") || m.contains("opus-4.7") || m.contains("[1m]") {
92        return Some(1_000_000);
93    }
94    if m.contains("claude") {
95        return Some(200_000);
96    }
97    
98    // OpenAI models
99    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
100        return Some(128_000);
101    }
102    if m.contains("o1") {
103        return Some(200_000);
104    }
105    if m.contains("gpt-4-32k") {
106        return Some(32_768);
107    }
108    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
109        return Some(8_192);
110    }
111    if m.contains("gpt-3.5") {
112        return Some(16_384);
113    }
114    
115    // DeepSeek
116    if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
117        return Some(128_000);
118    }
119    if m.contains("deepseek") {
120        return Some(64_000);
121    }
122    
123    None
124}
125
126/// Multi-model configuration manager.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct MultiModelConfig {
129    /// Main model for primary tasks.
130    pub main: ModelConfig,
131    /// Planning model for task decomposition.
132    pub plan: ModelConfig,
133    /// Compression model for context summarization.
134    pub compress: ModelConfig,
135    /// Fast model for quick operations.
136    pub fast: ModelConfig,
137}
138
139impl Default for MultiModelConfig {
140    fn default() -> Self {
141        Self {
142            main: ModelConfig::for_role(ModelRole::Main),
143            plan: ModelConfig::for_role(ModelRole::Plan),
144            compress: ModelConfig::for_role(ModelRole::Compress),
145            fast: ModelConfig::for_role(ModelRole::Fast),
146        }
147    }
148}
149
150impl MultiModelConfig {
151    /// Create with a main model, all other roles also use this model by default.
152    /// This ensures that if no specific model is configured for a role, it falls back to the main model.
153    pub fn with_main(main_model: String) -> Self {
154        let main_config = ModelConfig::new(main_model);
155        Self {
156            main: main_config.clone(),
157            plan: main_config.clone(),
158            compress: main_config.clone(),
159            fast: main_config,
160        }
161    }
162
163    /// Create where all roles use the same model.
164    pub fn unified(model: String) -> Self {
165        let config = ModelConfig::new(model);
166        Self {
167            main: config.clone(),
168            plan: config.clone(),
169            compress: config.clone(),
170            fast: config,
171        }
172    }
173
174    /// Get config for a specific role.
175    pub fn get(&self, role: ModelRole) -> &ModelConfig {
176        match role {
177            ModelRole::Main => &self.main,
178            ModelRole::Plan => &self.plan,
179            ModelRole::Compress => &self.compress,
180            ModelRole::Fast => &self.fast,
181        }
182    }
183
184    /// Set model for a specific role.
185    pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
186        match role {
187            ModelRole::Main => self.main = config,
188            ModelRole::Plan => self.plan = config,
189            ModelRole::Compress => self.compress = config,
190            ModelRole::Fast => self.fast = config,
191        }
192    }
193
194    /// Format for display.
195    pub fn format_summary(&self) -> String {
196        format!(
197            "main: {}, plan: {}, compress: {}, fast: {}",
198            self.main.display_name(),
199            self.plan.display_name(),
200            self.compress.display_name(),
201            self.fast.display_name()
202        )
203    }
204}
205
206/// Task complexity level.
207#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum TaskComplexity {
210    Simple,
211    Moderate,
212    Complex,
213}
214
215impl TaskComplexity {
216    pub fn display(&self) -> &'static str {
217        match self {
218            TaskComplexity::Simple => "简单",
219            TaskComplexity::Moderate => "中等",
220            TaskComplexity::Complex => "复杂",
221        }
222    }
223}
224
225/// Step difficulty level.
226#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
227#[serde(rename_all = "snake_case")]
228pub enum StepDifficulty {
229    Easy,
230    Medium,
231    Hard,
232}
233
234/// A single step in the task plan.
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct PlanStep {
237    /// Step description.
238    pub description: String,
239    /// Tools needed for this step.
240    pub tools: Vec<String>,
241    /// Whether this step is optional.
242    pub optional: bool,
243}
244
245/// Task plan generated by the planning model.
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TaskPlan {
248    /// Original user request.
249    pub request: String,
250    /// Decomposed steps.
251    pub steps: Vec<PlanStep>,
252    /// Estimated complexity.
253    pub complexity: TaskComplexity,
254    /// Suggested approach summary.
255    pub approach: String,
256    /// Potential risks or considerations.
257    pub considerations: Vec<String>,
258}
259
260impl TaskPlan {
261    /// Format for display.
262    pub fn format(&self) -> String {
263        let mut output = String::new();
264        
265        output.push_str(&format!("任务分析: {}\n", self.request));
266        output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
267        output.push_str(&format!("建议方案: {}\n\n", self.approach));
268        
269        output.push_str("执行步骤:\n");
270        for (i, step) in self.steps.iter().enumerate() {
271            let marker = if step.optional { "[可选]" } else { "" };
272            output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
273            if !step.tools.is_empty() {
274                output.push_str(&format!("   工具: {}\n", step.tools.join(", ")));
275            }
276        }
277        
278        if !self.considerations.is_empty() {
279            output.push_str("\n注意事项:\n");
280            for c in &self.considerations {
281                output.push_str(&format!("• {}\n", c));
282            }
283        }
284        
285        output
286    }
287
288    /// Convert to todo items for the agent.
289    pub fn to_todo_items(&self) -> Vec<TodoItem> {
290        self.steps
291            .iter()
292            .enumerate()
293            .map(|(i, step)| TodoItem {
294                content: step.description.clone(),
295                active_form: format!("执行步骤 {}: {}", i + 1, step.description),
296                status: if i == 0 { "in_progress".to_string() } else { "pending".to_string() },
297            })
298            .collect()
299    }
300}
301
302/// Todo item for task tracking.
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct TodoItem {
305    pub content: String,
306    pub active_form: String,
307    pub status: String,
308}
309
310/// Planner for generating task plans using the plan model.
311pub struct Planner {
312    provider: Box<dyn Provider>,
313    config: ModelConfig,
314}
315
316impl Planner {
317    /// Create a new planner.
318    pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
319        Self { provider, config }
320    }
321
322    /// Generate a task plan for the given request.
323    pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
324        let prompt = build_plan_prompt(request, available_tools);
325        
326        let chat_request = ChatRequest {
327            messages: vec![Message {
328                role: Role::User,
329                content: MessageContent::Text(prompt),
330            }],
331            tools: vec![],
332            system: Some(PLAN_SYSTEM_PROMPT.to_string()),
333            think: false,
334            max_tokens: self.config.max_tokens,
335            server_tools: vec![],
336            enable_caching: false,
337        };
338
339        let response = self.provider.chat(chat_request).await?;
340        let text = extract_text(&response);
341        
342        parse_plan_response(request, &text)
343    }
344
345    /// Quick complexity assessment using fast model.
346    pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
347        let prompt = format!(
348            "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
349            request
350        );
351
352        let chat_request = ChatRequest {
353            messages: vec![Message {
354                role: Role::User,
355                content: MessageContent::Text(prompt),
356            }],
357            tools: vec![],
358            system: None,
359            think: false,
360            max_tokens: 50,
361            server_tools: vec![],
362            enable_caching: false,
363        };
364
365        let response = self.provider.chat(chat_request).await?;
366        let text = extract_text(&response).to_lowercase();
367        
368        if text.contains("简单") || text.contains("simple") {
369            Ok(TaskComplexity::Simple)
370        } else if text.contains("复杂") || text.contains("complex") {
371            Ok(TaskComplexity::Complex)
372        } else {
373            Ok(TaskComplexity::Moderate)
374        }
375    }
376}
377
378/// System prompt for planning.
379const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
380
381输出要求(JSON格式):
382```json
383{
384  "complexity": "simple|moderate|complex",
385  "approach": "建议的方案(一句话)",
386  "steps": [
387    {
388      "description": "步骤描述",
389      "tools": ["需要的工具"],
390      "optional": false
391    }
392  ],
393  "considerations": ["注意事项"]
394}
395```
396
397规划原则:
3981. 简单任务(如读取文件、简单查询)只需1-2步
3992. 中等任务(如修改代码、添加功能)需要3-5步
4003. 复杂任务(如重构、跨模块修改)需要详细规划
4014. 每个步骤要具体、可执行
4025. 标记可选步骤和潜在风险"#;
403
404/// Build planning prompt.
405fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
406    format!(
407        r#"用户请求:
408{}
409
410可用工具:
411{}
412
413请分析任务并生成执行计划(JSON格式)。"#,
414        request,
415        available_tools.join(", ")
416    )
417}
418
419/// Parse planning response into TaskPlan.
420fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
421    // Try to parse as JSON
422    if let Some(json_start) = text.find('{')
423        && let Some(json_end) = text.rfind('}') {
424            let json_str = &text[json_start..=json_end];
425            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
426                return Ok(TaskPlan {
427                    request: request.to_string(),
428                    steps: parse_steps(&parsed["steps"]),
429                    complexity: parse_complexity(&parsed["complexity"]),
430                    approach: parsed["approach"].as_str().unwrap_or("直接执行").to_string(),
431                    considerations: parsed["considerations"]
432                        .as_array()
433                        .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
434                        .unwrap_or_default(),
435                });
436            }
437        }
438    
439    // Fallback: create simple plan from text
440    Ok(TaskPlan {
441        request: request.to_string(),
442        steps: parse_steps_from_text(text),
443        complexity: TaskComplexity::Moderate,
444        approach: "按步骤执行".to_string(),
445        considerations: vec!["请检查执行结果".to_string()],
446    })
447}
448
449fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
450    value.as_array()
451        .map(|arr| arr.iter().filter_map(|v| {
452            Some(PlanStep {
453                description: v["description"].as_str()?.to_string(),
454                tools: v["tools"].as_array()
455                    .map(|t| t.iter().filter_map(|x| x.as_str().map(String::from)).collect())
456                    .unwrap_or_default(),
457                optional: v["optional"].as_bool().unwrap_or(false),
458            })
459        }).collect())
460        .unwrap_or_default()
461}
462
463fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
464    match value.as_str().map(|s| s.to_lowercase()) {
465        Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
466        Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
467        _ => TaskComplexity::Moderate,
468    }
469}
470
471fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
472    text.lines()
473        .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
474        .take(5)
475        .map(|l| PlanStep {
476            description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
477            tools: vec!["read".to_string()],
478            optional: false,
479        })
480        .collect()
481}
482
483fn extract_text(response: &ChatResponse) -> String {
484    response.content
485        .iter()
486        .filter_map(|block| {
487            if let ContentBlock::Text { text } = block {
488                Some(text.clone())
489            } else {
490                None
491            }
492        })
493        .collect::<Vec<_>>()
494        .join("\n")
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_model_config_defaults() {
503        let main = ModelConfig::for_role(ModelRole::Main);
504        assert!(main.name.contains("claude"));
505        assert!(main.think);
506
507        let compress = ModelConfig::for_role(ModelRole::Compress);
508        assert!(compress.name.contains("haiku"));
509        assert!(!compress.think);
510    }
511
512    #[test]
513    fn test_infer_context_size() {
514        assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
515        assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
516        assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
517    }
518
519    #[test]
520    fn test_multi_model_config() {
521        let config = MultiModelConfig::default();
522        assert!(config.main.name.contains("sonnet"));
523        assert!(config.compress.name.contains("haiku"));
524    }
525
526    #[test]
527    fn test_multi_model_config_with_main() {
528        // with_main should make all roles use the main model
529        let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
530        
531        // All roles should use the same model
532        assert_eq!(config.main.name, "claude-sonnet-4");
533        assert_eq!(config.plan.name, "claude-sonnet-4");
534        assert_eq!(config.compress.name, "claude-sonnet-4");
535        assert_eq!(config.fast.name, "claude-sonnet-4");
536        
537        // All should have thinking enabled (inherited from main)
538        assert!(config.main.think);
539        assert!(config.plan.think);
540        assert!(config.compress.think);
541        assert!(config.fast.think);
542    }
543
544    #[test]
545    fn test_multi_model_config_override() {
546        let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
547        
548        // Override compress model
549        config.set(ModelRole::Compress, ModelConfig::new("claude-3-5-haiku".to_string()));
550        
551        assert_eq!(config.main.name, "claude-sonnet-4");
552        assert_eq!(config.plan.name, "claude-sonnet-4");
553        assert_eq!(config.compress.name, "claude-3-5-haiku");
554        assert_eq!(config.fast.name, "claude-sonnet-4");  // Still uses main
555    }
556
557    #[test]
558    fn test_task_plan_format() {
559        let plan = TaskPlan {
560            request: "测试任务".to_string(),
561            steps: vec![
562                PlanStep {
563                    description: "读取文件".to_string(),
564                    tools: vec!["read".to_string()],
565                    optional: false,
566                },
567            ],
568            complexity: TaskComplexity::Simple,
569            approach: "直接执行".to_string(),
570            considerations: vec!["注意检查".to_string()],
571        };
572
573        let formatted = plan.format();
574        assert!(formatted.contains("测试任务"));
575        assert!(formatted.contains("简单"));
576        assert!(formatted.contains("读取文件"));
577    }
578
579    #[test]
580    fn test_complexity_display() {
581        assert_eq!(TaskComplexity::Simple.display(), "简单");
582        assert_eq!(TaskComplexity::Moderate.display(), "中等");
583        assert_eq!(TaskComplexity::Complex.display(), "复杂");
584    }
585
586    #[test]
587    fn test_task_plan_to_todo() {
588        let plan = TaskPlan {
589            request: "任务".to_string(),
590            steps: vec![
591                PlanStep { description: "步骤1".to_string(), tools: vec![], optional: false },
592                PlanStep { description: "步骤2".to_string(), tools: vec![], optional: false },
593            ],
594            complexity: TaskComplexity::Simple,
595            approach: "执行".to_string(),
596            considerations: vec![],
597        };
598
599        let todos = plan.to_todo_items();
600        assert_eq!(todos.len(), 2);
601        assert_eq!(todos[0].status, "in_progress");
602        assert_eq!(todos[1].status, "pending");
603    }
604
605    #[test]
606    fn test_parse_plan_response_json() {
607        let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
608        let plan = parse_plan_response("test", json).unwrap();
609        
610        assert_eq!(plan.complexity, TaskComplexity::Simple);
611        assert_eq!(plan.steps.len(), 1);
612        assert_eq!(plan.steps[0].description, "read file");
613    }
614}