1pub mod cli_provider;
7pub mod ollama;
8pub mod openrouter;
9pub mod provider;
10pub mod router;
11pub mod schema;
12
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16use crate::types::StepType;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
27#[serde(rename_all = "snake_case")]
28pub enum ModelRole {
29 Thinking,
31 Coding,
33 Task,
35 Embedding,
37 Auditor,
39}
40
41impl ModelRole {
42 pub fn for_step_type(step_type: &StepType) -> Self {
44 match step_type {
45 StepType::Analyze | StepType::Plan | StepType::Debug => ModelRole::Thinking,
46 StepType::Code | StepType::Refactor | StepType::Fix => ModelRole::Coding,
47 StepType::Search | StepType::Classify | StepType::Summarize => ModelRole::Task,
48 StepType::SecurityCheck => ModelRole::Auditor,
49 StepType::Execute | StepType::Other => ModelRole::Task,
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
56#[serde(rename_all = "snake_case")]
57pub enum Provider {
58 Anthropic,
59 OpenAI,
60 Ollama,
61 OpenRouter,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
66pub struct ModelConfig {
67 pub provider: Provider,
69 pub model: String,
71 pub api_key_env: Option<String>,
73 pub max_cost_per_call: f64,
75 pub temperature: f64,
77 pub max_tokens: u32,
79}
80
81impl Default for ModelConfig {
82 fn default() -> Self {
83 Self {
84 provider: Provider::Ollama,
85 model: "llama3.2:3b".into(),
86 temperature: 0.3,
87 max_tokens: 4096,
88 max_cost_per_call: 1.0,
89 api_key_env: None,
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_model_role_for_step_type() {
100 assert_eq!(ModelRole::for_step_type(&StepType::Analyze), ModelRole::Thinking);
101 assert_eq!(ModelRole::for_step_type(&StepType::Plan), ModelRole::Thinking);
102 assert_eq!(ModelRole::for_step_type(&StepType::Code), ModelRole::Coding);
103 assert_eq!(ModelRole::for_step_type(&StepType::Fix), ModelRole::Coding);
104 assert_eq!(ModelRole::for_step_type(&StepType::Search), ModelRole::Task);
105 assert_eq!(ModelRole::for_step_type(&StepType::SecurityCheck), ModelRole::Auditor);
106 assert_eq!(ModelRole::for_step_type(&StepType::Execute), ModelRole::Task);
107 }
108
109 #[test]
110 fn test_model_role_serialization() {
111 let role = ModelRole::Thinking;
112 let json = serde_json::to_string(&role).unwrap();
113 assert_eq!(json, "\"thinking\"");
114
115 let deserialized: ModelRole = serde_json::from_str(&json).unwrap();
116 assert_eq!(deserialized, ModelRole::Thinking);
117 }
118
119 #[test]
120 fn test_provider_serialization() {
121 let provider = Provider::Anthropic;
122 let json = serde_json::to_string(&provider).unwrap();
123 assert_eq!(json, "\"anthropic\"");
124 }
125
126 #[test]
127 fn test_model_config_default() {
128 let config = ModelConfig::default();
129 assert_eq!(config.provider, Provider::Ollama);
130 assert_eq!(config.model, "llama3.2:3b");
131 }
132}