mermaid_cli/models/
config.rs1use crate::prompts;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelConfig {
13 pub model: String,
16
17 #[serde(default = "default_temperature")]
19 pub temperature: f32,
20
21 #[serde(default = "default_max_tokens")]
23 pub max_tokens: usize,
24
25 pub top_p: Option<f32>,
27
28 pub frequency_penalty: Option<f32>,
30
31 pub presence_penalty: Option<f32>,
33
34 pub system_prompt: Option<String>,
36
37 #[serde(default)]
40 pub backend_options: HashMap<String, HashMap<String, String>>,
41}
42
43impl Default for ModelConfig {
44 fn default() -> Self {
45 Self {
46 model: "ollama/tinyllama".to_string(),
47 temperature: default_temperature(),
48 max_tokens: default_max_tokens(),
49 top_p: Some(default_top_p()),
50 frequency_penalty: None,
51 presence_penalty: None,
52 system_prompt: Some(prompts::get_system_prompt()),
53 backend_options: HashMap::new(),
54 }
55 }
56}
57
58impl ModelConfig {
59 pub fn get_backend_option(&self, backend: &str, key: &str) -> Option<&String> {
61 self.backend_options.get(backend)?.get(key)
62 }
63
64 pub fn get_backend_option_i32(&self, backend: &str, key: &str) -> Option<i32> {
66 self.get_backend_option(backend, key)?
67 .parse::<i32>()
68 .ok()
69 }
70
71 pub fn get_backend_option_bool(&self, backend: &str, key: &str) -> Option<bool> {
73 self.get_backend_option(backend, key)?
74 .parse::<bool>()
75 .ok()
76 }
77
78 pub fn set_backend_option(&mut self, backend: String, key: String, value: String) {
80 self.backend_options
81 .entry(backend)
82 .or_insert_with(HashMap::new)
83 .insert(key, value);
84 }
85
86 pub fn ollama_options(&self) -> OllamaOptions {
88 OllamaOptions {
89 num_gpu: self.get_backend_option_i32("ollama", "num_gpu"),
90 num_thread: self.get_backend_option_i32("ollama", "num_thread"),
91 num_ctx: self.get_backend_option_i32("ollama", "num_ctx"),
92 numa: self.get_backend_option_bool("ollama", "numa"),
93 cloud_api_key: self.get_backend_option("ollama", "cloud_api_key").cloned(),
94 }
95 }
96
97 pub fn with_plan_mode() -> Self {
99 let mut config = Self::default();
100 config.system_prompt = Some(prompts::get_system_prompt_with_plan_mode());
101 config
102 }
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct OllamaOptions {
108 pub num_gpu: Option<i32>,
109 pub num_thread: Option<i32>,
110 pub num_ctx: Option<i32>,
111 pub numa: Option<bool>,
112 pub cloud_api_key: Option<String>,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct BackendConfig {
118 #[serde(default = "default_ollama_url")]
120 pub ollama_url: String,
121
122 #[serde(default = "default_vllm_url")]
124 pub vllm_url: String,
125
126 #[serde(default = "default_litellm_url")]
128 pub litellm_url: String,
129
130 pub litellm_master_key: Option<String>,
132
133 #[serde(default = "default_timeout")]
135 pub timeout_secs: u64,
136
137 #[serde(default = "default_request_timeout")]
139 pub request_timeout_secs: u64,
140
141 #[serde(default = "default_max_idle")]
143 pub max_idle_per_host: usize,
144
145 #[serde(default = "default_health_check_interval")]
147 pub health_check_interval_secs: u64,
148}
149
150impl Default for BackendConfig {
151 fn default() -> Self {
152 Self {
153 ollama_url: default_ollama_url(),
154 vllm_url: default_vllm_url(),
155 litellm_url: default_litellm_url(),
156 litellm_master_key: None,
157 timeout_secs: default_timeout(),
158 request_timeout_secs: default_request_timeout(),
159 max_idle_per_host: default_max_idle(),
160 health_check_interval_secs: default_health_check_interval(),
161 }
162 }
163}
164
165fn default_temperature() -> f32 {
167 0.7
168}
169
170fn default_max_tokens() -> usize {
171 4096
172}
173
174fn default_top_p() -> f32 {
175 1.0
176}
177
178fn default_ollama_url() -> String {
179 std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string())
180}
181
182fn default_vllm_url() -> String {
183 std::env::var("VLLM_API_BASE").unwrap_or_else(|_| "http://localhost:8000".to_string())
184}
185
186fn default_litellm_url() -> String {
187 std::env::var("LITELLM_PROXY_URL").unwrap_or_else(|_| "http://localhost:4000".to_string())
188}
189
190fn default_timeout() -> u64 {
191 10
192}
193
194fn default_request_timeout() -> u64 {
195 120
196}
197
198fn default_max_idle() -> usize {
199 10
200}
201
202fn default_health_check_interval() -> u64 {
203 30
204}