Skip to main content

matrixcode_core/
models.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use crate::providers::{
5    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role,
6};
7
8/// Default model names for different roles.
9pub const DEFAULT_MAIN_MODEL: &str = "claude-sonnet-4-20250514";
10pub const DEFAULT_PLAN_MODEL: &str = "claude-sonnet-4-20250514";
11pub const DEFAULT_COMPRESS_MODEL: &str = "claude-3-5-haiku-20241022";
12pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
13
14/// Model role - what purpose this model serves.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum ModelRole {
18    /// Main model for task execution.
19    Main,
20    /// Planning model for task decomposition.
21    Plan,
22    /// Compression model for context summarization.
23    Compress,
24    /// Fast model for quick operations.
25    Fast,
26}
27
28impl ModelRole {
29    pub fn default_model(&self) -> &'static str {
30        match self {
31            ModelRole::Main => DEFAULT_MAIN_MODEL,
32            ModelRole::Plan => DEFAULT_PLAN_MODEL,
33            ModelRole::Compress => DEFAULT_COMPRESS_MODEL,
34            ModelRole::Fast => DEFAULT_FAST_MODEL,
35        }
36    }
37}
38
39/// Configuration for a single model.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ModelConfig {
42    /// Model name/identifier.
43    pub name: String,
44    /// Maximum output tokens.
45    pub max_tokens: u32,
46    /// Whether to enable thinking for this model.
47    pub think: bool,
48    /// Estimated context window size.
49    pub context_size: Option<u32>,
50}
51
52impl ModelConfig {
53    pub fn new(name: String) -> Self {
54        Self {
55            name: name.clone(),
56            max_tokens: 16384,
57            think: true,
58            context_size: infer_context_size(&name),
59        }
60    }
61
62    /// Create with specific role defaults.
63    pub fn for_role(role: ModelRole) -> Self {
64        let name = role.default_model().to_string();
65        match role {
66            ModelRole::Main => Self::new(name),
67            ModelRole::Plan => Self::new(name),
68            ModelRole::Compress => Self {
69                name,
70                max_tokens: 1024,
71                think: false,
72                context_size: Some(200_000),
73            },
74            ModelRole::Fast => Self {
75                name,
76                max_tokens: 2048,
77                think: false,
78                context_size: Some(200_000),
79            },
80        }
81    }
82
83    pub fn display_name(&self) -> &str {
84        &self.name
85    }
86}
87
88/// Infer context window size from model name.
89/// Honours the `CONTEXT_SIZE` env variable first so users can override.
90pub fn context_window_for(model: &str) -> Option<u32> {
91    // Allow user override via environment variable
92    if let Ok(raw) = std::env::var("CONTEXT_SIZE")
93        && let Ok(n) = raw.trim().parse::<u32>()
94        && n > 0
95    {
96        return Some(n);
97    }
98
99    let m = model.to_ascii_lowercase();
100
101    // Anthropic models
102    if m.contains("[1m]") || m.contains("opus-4-7") || m.contains("opus-4.7") {
103        return Some(1_000_000);
104    }
105    if m.contains("claude-3")
106        || m.contains("claude-4")
107        || m.contains("claude-opus")
108        || m.contains("claude-sonnet")
109        || m.contains("claude-haiku")
110    {
111        return Some(200_000);
112    }
113    if m.contains("claude-2") || m.contains("claude-instant") {
114        return Some(100_000);
115    }
116
117    // OpenAI models
118    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
119        return Some(128_000);
120    }
121    if m.contains("o1") || m.contains("o3") || m.contains("o4") {
122        return Some(200_000);
123    }
124    if m.contains("gpt-4-32k") {
125        return Some(32_768);
126    }
127    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
128        return Some(8_192);
129    }
130    if m.contains("gpt-3.5-turbo-16k") {
131        return Some(16_384);
132    }
133    if m.contains("gpt-3.5") {
134        return Some(4_096);
135    }
136
137    // DeepSeek models
138    if m.contains("deepseek-v3") || m.contains("deepseek-r1") {
139        return Some(128_000);
140    }
141    if m.contains("deepseek") {
142        return Some(64_000);
143    }
144
145    // Kimi models
146    if m.contains("kimi") {
147        return Some(128_000);
148    }
149
150    // Qwen models
151    if m.contains("qwen") {
152        if m.contains("qwen-max") || m.contains("qwen2.5-72b") || m.contains("qwen2.5") {
153            return Some(128_000);
154        }
155        if m.contains("qwen2") {
156            return Some(32_000);
157        }
158        return Some(8_192);
159    }
160
161    // Llama models
162    if m.contains("llama-3") || m.contains("llama3") {
163        if m.contains("70b") || m.contains("405b") {
164            return Some(128_000);
165        }
166        return Some(8_192);
167    }
168
169    // GLM models (Zhipu AI)
170    if m.contains("glm") {
171        return Some(128_000);
172    }
173
174    None
175}
176
177/// Legacy alias for context_window_for (internal use).
178fn infer_context_size(model: &str) -> Option<u32> {
179    context_window_for(model)
180}
181
182/// Multi-model configuration manager.
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct MultiModelConfig {
185    /// Main model for primary tasks.
186    pub main: ModelConfig,
187    /// Planning model for task decomposition.
188    pub plan: ModelConfig,
189    /// Compression model for context summarization.
190    pub compress: ModelConfig,
191    /// Fast model for quick operations.
192    pub fast: ModelConfig,
193}
194
195impl Default for MultiModelConfig {
196    fn default() -> Self {
197        Self {
198            main: ModelConfig::for_role(ModelRole::Main),
199            plan: ModelConfig::for_role(ModelRole::Plan),
200            compress: ModelConfig::for_role(ModelRole::Compress),
201            fast: ModelConfig::for_role(ModelRole::Fast),
202        }
203    }
204}
205
206impl MultiModelConfig {
207    /// Create with a main model, all other roles also use this model by default.
208    /// This ensures that if no specific model is configured for a role, it falls back to the main model.
209    pub fn with_main(main_model: String) -> Self {
210        let main_config = ModelConfig::new(main_model);
211        Self {
212            main: main_config.clone(),
213            plan: main_config.clone(),
214            compress: main_config.clone(),
215            fast: main_config,
216        }
217    }
218
219    /// Create where all roles use the same model.
220    pub fn unified(model: String) -> Self {
221        let config = ModelConfig::new(model);
222        Self {
223            main: config.clone(),
224            plan: config.clone(),
225            compress: config.clone(),
226            fast: config,
227        }
228    }
229
230    /// Get config for a specific role.
231    pub fn get(&self, role: ModelRole) -> &ModelConfig {
232        match role {
233            ModelRole::Main => &self.main,
234            ModelRole::Plan => &self.plan,
235            ModelRole::Compress => &self.compress,
236            ModelRole::Fast => &self.fast,
237        }
238    }
239
240    /// Set model for a specific role.
241    pub fn set(&mut self, role: ModelRole, config: ModelConfig) {
242        match role {
243            ModelRole::Main => self.main = config,
244            ModelRole::Plan => self.plan = config,
245            ModelRole::Compress => self.compress = config,
246            ModelRole::Fast => self.fast = config,
247        }
248    }
249
250    /// Format for display.
251    pub fn format_summary(&self) -> String {
252        format!(
253            "main: {}, plan: {}, compress: {}, fast: {}",
254            self.main.display_name(),
255            self.plan.display_name(),
256            self.compress.display_name(),
257            self.fast.display_name()
258        )
259    }
260}
261
262/// Task complexity level.
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
264#[serde(rename_all = "snake_case")]
265pub enum TaskComplexity {
266    Simple,
267    Moderate,
268    Complex,
269}
270
271impl TaskComplexity {
272    pub fn display(&self) -> &'static str {
273        match self {
274            TaskComplexity::Simple => "简单",
275            TaskComplexity::Moderate => "中等",
276            TaskComplexity::Complex => "复杂",
277        }
278    }
279}
280
281/// Step difficulty level.
282#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
283#[serde(rename_all = "snake_case")]
284pub enum StepDifficulty {
285    Easy,
286    Medium,
287    Hard,
288}
289
290/// A single step in the task plan.
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct PlanStep {
293    /// Step description.
294    pub description: String,
295    /// Tools needed for this step.
296    pub tools: Vec<String>,
297    /// Whether this step is optional.
298    pub optional: bool,
299}
300
301/// Task plan generated by the planning model.
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct TaskPlan {
304    /// Original user request.
305    pub request: String,
306    /// Decomposed steps.
307    pub steps: Vec<PlanStep>,
308    /// Estimated complexity.
309    pub complexity: TaskComplexity,
310    /// Suggested approach summary.
311    pub approach: String,
312    /// Potential risks or considerations.
313    pub considerations: Vec<String>,
314}
315
316impl TaskPlan {
317    /// Format for display.
318    pub fn format(&self) -> String {
319        let mut output = String::new();
320
321        output.push_str(&format!("任务分析: {}\n", self.request));
322        output.push_str(&format!("复杂度: {}\n", self.complexity.display()));
323        output.push_str(&format!("建议方案: {}\n\n", self.approach));
324
325        output.push_str("执行步骤:\n");
326        for (i, step) in self.steps.iter().enumerate() {
327            let marker = if step.optional { "[可选]" } else { "" };
328            output.push_str(&format!("{}. {} {}\n", i + 1, step.description, marker));
329            if !step.tools.is_empty() {
330                output.push_str(&format!("   工具: {}\n", step.tools.join(", ")));
331            }
332        }
333
334        if !self.considerations.is_empty() {
335            output.push_str("\n注意事项:\n");
336            for c in &self.considerations {
337                output.push_str(&format!("• {}\n", c));
338            }
339        }
340
341        output
342    }
343
344    /// Convert to todo items for the agent.
345    pub fn to_todo_items(&self) -> Vec<TodoItem> {
346        self.steps
347            .iter()
348            .enumerate()
349            .map(|(i, step)| TodoItem {
350                content: step.description.clone(),
351                active_form: format!("执行步骤 {}: {}", i + 1, step.description),
352                status: if i == 0 {
353                    "in_progress".to_string()
354                } else {
355                    "pending".to_string()
356                },
357            })
358            .collect()
359    }
360}
361
362/// Todo item for task tracking.
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct TodoItem {
365    pub content: String,
366    pub active_form: String,
367    pub status: String,
368}
369
370/// Planner for generating task plans using the plan model.
371pub struct Planner {
372    provider: Box<dyn Provider>,
373    config: ModelConfig,
374}
375
376impl Planner {
377    /// Create a new planner.
378    pub fn new(provider: Box<dyn Provider>, config: ModelConfig) -> Self {
379        Self { provider, config }
380    }
381
382    /// Generate a task plan for the given request.
383    pub async fn plan(&self, request: &str, available_tools: &[&str]) -> Result<TaskPlan> {
384        let prompt = build_plan_prompt(request, available_tools);
385
386        let chat_request = ChatRequest {
387            messages: vec![Message {
388                role: Role::User,
389                content: MessageContent::Text(prompt),
390            }],
391            tools: vec![],
392            system: Some(PLAN_SYSTEM_PROMPT.to_string()),
393            think: false,
394            max_tokens: self.config.max_tokens,
395            server_tools: vec![],
396            enable_caching: false,
397        };
398
399        let response = self.provider.chat(chat_request).await?;
400        let text = extract_text(&response);
401
402        parse_plan_response(request, &text)
403    }
404
405    /// Quick complexity assessment using fast model.
406    pub async fn assess_complexity(&self, request: &str) -> Result<TaskComplexity> {
407        let prompt = format!(
408            "评估此任务的复杂度(简单/中等/复杂),只需回答一个词:\n{}",
409            request
410        );
411
412        let chat_request = ChatRequest {
413            messages: vec![Message {
414                role: Role::User,
415                content: MessageContent::Text(prompt),
416            }],
417            tools: vec![],
418            system: None,
419            think: false,
420            max_tokens: 50,
421            server_tools: vec![],
422            enable_caching: false,
423        };
424
425        let response = self.provider.chat(chat_request).await?;
426        let text = extract_text(&response).to_lowercase();
427
428        if text.contains("简单") || text.contains("simple") {
429            Ok(TaskComplexity::Simple)
430        } else if text.contains("复杂") || text.contains("complex") {
431            Ok(TaskComplexity::Complex)
432        } else {
433            Ok(TaskComplexity::Moderate)
434        }
435    }
436}
437
438/// System prompt for planning.
439const PLAN_SYSTEM_PROMPT: &str = r#"你是一个任务规划助手。你的职责是分析编程任务,并将其分解为清晰的执行步骤。
440
441输出要求(JSON格式):
442```json
443{
444  "complexity": "simple|moderate|complex",
445  "approach": "建议的方案(一句话)",
446  "steps": [
447    {
448      "description": "步骤描述",
449      "tools": ["需要的工具"],
450      "optional": false
451    }
452  ],
453  "considerations": ["注意事项"]
454}
455```
456
457规划原则:
4581. 简单任务(如读取文件、简单查询)只需1-2步
4592. 中等任务(如修改代码、添加功能)需要3-5步
4603. 复杂任务(如重构、跨模块修改)需要详细规划
4614. 每个步骤要具体、可执行
4625. 标记可选步骤和潜在风险"#;
463
464/// Build planning prompt.
465fn build_plan_prompt(request: &str, available_tools: &[&str]) -> String {
466    format!(
467        r#"用户请求:
468{}
469
470可用工具:
471{}
472
473请分析任务并生成执行计划(JSON格式)。"#,
474        request,
475        available_tools.join(", ")
476    )
477}
478
479/// Parse planning response into TaskPlan.
480fn parse_plan_response(request: &str, text: &str) -> Result<TaskPlan> {
481    // Try to parse as JSON
482    if let Some(json_start) = text.find('{')
483        && let Some(json_end) = text.rfind('}')
484    {
485        let json_str = &text[json_start..=json_end];
486        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
487            return Ok(TaskPlan {
488                request: request.to_string(),
489                steps: parse_steps(&parsed["steps"]),
490                complexity: parse_complexity(&parsed["complexity"]),
491                approach: parsed["approach"]
492                    .as_str()
493                    .unwrap_or("直接执行")
494                    .to_string(),
495                considerations: parsed["considerations"]
496                    .as_array()
497                    .map(|arr| {
498                        arr.iter()
499                            .filter_map(|v| v.as_str().map(String::from))
500                            .collect()
501                    })
502                    .unwrap_or_default(),
503            });
504        }
505    }
506
507    // Fallback: create simple plan from text
508    Ok(TaskPlan {
509        request: request.to_string(),
510        steps: parse_steps_from_text(text),
511        complexity: TaskComplexity::Moderate,
512        approach: "按步骤执行".to_string(),
513        considerations: vec!["请检查执行结果".to_string()],
514    })
515}
516
517fn parse_steps(value: &serde_json::Value) -> Vec<PlanStep> {
518    value
519        .as_array()
520        .map(|arr| {
521            arr.iter()
522                .filter_map(|v| {
523                    Some(PlanStep {
524                        description: v["description"].as_str()?.to_string(),
525                        tools: v["tools"]
526                            .as_array()
527                            .map(|t| {
528                                t.iter()
529                                    .filter_map(|x| x.as_str().map(String::from))
530                                    .collect()
531                            })
532                            .unwrap_or_default(),
533                        optional: v["optional"].as_bool().unwrap_or(false),
534                    })
535                })
536                .collect()
537        })
538        .unwrap_or_default()
539}
540
541fn parse_complexity(value: &serde_json::Value) -> TaskComplexity {
542    match value.as_str().map(|s| s.to_lowercase()) {
543        Some(s) if s.contains("simple") || s.contains("简单") => TaskComplexity::Simple,
544        Some(s) if s.contains("complex") || s.contains("复杂") => TaskComplexity::Complex,
545        _ => TaskComplexity::Moderate,
546    }
547}
548
549fn parse_steps_from_text(text: &str) -> Vec<PlanStep> {
550    text.lines()
551        .filter(|l| l.trim().starts_with(|c: char| c.is_ascii_digit()))
552        .take(5)
553        .map(|l| PlanStep {
554            description: l.split_whitespace().skip(1).collect::<Vec<_>>().join(" "),
555            tools: vec!["read".to_string()],
556            optional: false,
557        })
558        .collect()
559}
560
561fn extract_text(response: &ChatResponse) -> String {
562    response
563        .content
564        .iter()
565        .filter_map(|block| {
566            if let ContentBlock::Text { text } = block {
567                Some(text.clone())
568            } else {
569                None
570            }
571        })
572        .collect::<Vec<_>>()
573        .join("\n")
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_model_config_defaults() {
582        let main = ModelConfig::for_role(ModelRole::Main);
583        assert!(main.name.contains("claude"));
584        assert!(main.think);
585
586        let compress = ModelConfig::for_role(ModelRole::Compress);
587        assert!(compress.name.contains("haiku"));
588        assert!(!compress.think);
589    }
590
591    #[test]
592    fn test_infer_context_size() {
593        assert_eq!(infer_context_size("claude-sonnet-4"), Some(200_000));
594        assert_eq!(infer_context_size("gpt-4o"), Some(128_000));
595        assert_eq!(infer_context_size("claude-3-5-haiku"), Some(200_000));
596    }
597
598    #[test]
599    fn test_multi_model_config() {
600        let config = MultiModelConfig::default();
601        assert!(config.main.name.contains("sonnet"));
602        assert!(config.compress.name.contains("haiku"));
603    }
604
605    #[test]
606    fn test_multi_model_config_with_main() {
607        // with_main should make all roles use the main model
608        let config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
609
610        // All roles should use the same model
611        assert_eq!(config.main.name, "claude-sonnet-4");
612        assert_eq!(config.plan.name, "claude-sonnet-4");
613        assert_eq!(config.compress.name, "claude-sonnet-4");
614        assert_eq!(config.fast.name, "claude-sonnet-4");
615
616        // All should have thinking enabled (inherited from main)
617        assert!(config.main.think);
618        assert!(config.plan.think);
619        assert!(config.compress.think);
620        assert!(config.fast.think);
621    }
622
623    #[test]
624    fn test_multi_model_config_override() {
625        let mut config = MultiModelConfig::with_main("claude-sonnet-4".to_string());
626
627        // Override compress model
628        config.set(
629            ModelRole::Compress,
630            ModelConfig::new("claude-3-5-haiku".to_string()),
631        );
632
633        assert_eq!(config.main.name, "claude-sonnet-4");
634        assert_eq!(config.plan.name, "claude-sonnet-4");
635        assert_eq!(config.compress.name, "claude-3-5-haiku");
636        assert_eq!(config.fast.name, "claude-sonnet-4"); // Still uses main
637    }
638
639    #[test]
640    fn test_task_plan_format() {
641        let plan = TaskPlan {
642            request: "测试任务".to_string(),
643            steps: vec![PlanStep {
644                description: "读取文件".to_string(),
645                tools: vec!["read".to_string()],
646                optional: false,
647            }],
648            complexity: TaskComplexity::Simple,
649            approach: "直接执行".to_string(),
650            considerations: vec!["注意检查".to_string()],
651        };
652
653        let formatted = plan.format();
654        assert!(formatted.contains("测试任务"));
655        assert!(formatted.contains("简单"));
656        assert!(formatted.contains("读取文件"));
657    }
658
659    #[test]
660    fn test_complexity_display() {
661        assert_eq!(TaskComplexity::Simple.display(), "简单");
662        assert_eq!(TaskComplexity::Moderate.display(), "中等");
663        assert_eq!(TaskComplexity::Complex.display(), "复杂");
664    }
665
666    #[test]
667    fn test_task_plan_to_todo() {
668        let plan = TaskPlan {
669            request: "任务".to_string(),
670            steps: vec![
671                PlanStep {
672                    description: "步骤1".to_string(),
673                    tools: vec![],
674                    optional: false,
675                },
676                PlanStep {
677                    description: "步骤2".to_string(),
678                    tools: vec![],
679                    optional: false,
680                },
681            ],
682            complexity: TaskComplexity::Simple,
683            approach: "执行".to_string(),
684            considerations: vec![],
685        };
686
687        let todos = plan.to_todo_items();
688        assert_eq!(todos.len(), 2);
689        assert_eq!(todos[0].status, "in_progress");
690        assert_eq!(todos[1].status, "pending");
691    }
692
693    #[test]
694    fn test_parse_plan_response_json() {
695        let json = r#"{"complexity":"simple","approach":"直接读取","steps":[{"description":"read file","tools":["read"],"optional":false}],"considerations":[]}"#;
696        let plan = parse_plan_response("test", json).unwrap();
697
698        assert_eq!(plan.complexity, TaskComplexity::Simple);
699        assert_eq!(plan.steps.len(), 1);
700        assert_eq!(plan.steps[0].description, "read file");
701    }
702}