1use crate::agent_variant::AgentVariant;
2use crate::extra_providers::bigmodel;
3use crate::rand_agent::RandAgentBuilder;
4use rig::client::CompletionClient;
5use rig::providers::*;
6use serde::{Deserialize, Serialize};
7use strum_macros::Display;
8
9#[derive(Debug, Display, Deserialize, Serialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ProviderEnum {
12 Anthropic,
13 Cohere,
14 Gemini,
15 Huggingface,
16 Mistral,
17 OpenAi,
18 ResponsesOpenAi,
19 OpenRouter,
20 Together,
21 XAI,
22 Azure,
23 DeepSeek,
24 Galadriel,
25 Groq,
26 Hyperbolic,
27 Mira,
28 Mooshot,
29 Ollama,
30 Perplexity,
31 Bigmodel,
34}
35
36#[derive(Debug, Deserialize)]
37pub struct AgentConfig {
38 pub id: i32,
39 pub provider: ProviderEnum,
40 pub model_name: String,
41 pub api_key: String,
42 pub api_base_url: Option<String>,
43 pub system_prompt: Option<String>,
44 pub agent_name: Option<String>,
45}
46
47impl RandAgentBuilder {
48 pub fn simple_builder(
50 mut self,
51 agent_configs: Vec<AgentConfig>,
52 global_system_prompt: String,
53 ) -> Self {
54 for agent_conf in agent_configs {
55 let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
56 let system_prompt = agent_conf
57 .system_prompt
58 .unwrap_or(global_system_prompt.clone());
59
60 match agent_conf.provider {
61 ProviderEnum::Anthropic => {
62 let mut client_builder = anthropic::Client::builder();
63 if let Some(api_base_url) = &agent_conf.api_base_url {
64 client_builder = client_builder.base_url(api_base_url);
65 }
66 match client_builder.api_key(&agent_conf.api_key).build() {
67 Ok(client) => {
68 let agent = client
69 .agent(&agent_conf.model_name)
70 .name(agent_name.as_str())
71 .preamble(&system_prompt)
72 .build();
73 self.agents.push((
74 AgentVariant::Anthropic(agent),
75 agent_conf.id,
76 agent_conf.provider.to_string(),
77 agent_conf.model_name,
78 ));
79 }
80 Err(err) => {
81 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
82 }
83 }
84 }
85 ProviderEnum::Cohere => match cohere::Client::new(&agent_conf.api_key) {
86 Ok(client) => {
87 let agent = client
88 .agent(&agent_conf.model_name)
89 .name(agent_name.as_str())
90 .preamble(&system_prompt)
91 .build();
92 self.agents.push((
93 AgentVariant::Cohere(agent),
94 agent_conf.id,
95 agent_conf.provider.to_string(),
96 agent_conf.model_name,
97 ));
98 }
99 Err(err) => {
100 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
101 }
102 },
103 ProviderEnum::Gemini => {
104 let mut client_builder = gemini::Client::builder();
105 if let Some(api_base_url) = &agent_conf.api_base_url {
106 client_builder = client_builder.base_url(api_base_url);
107 }
108 match client_builder.api_key(&agent_conf.api_key).build() {
109 Ok(client) => {
110 let agent = client
111 .agent(&agent_conf.model_name)
112 .name(agent_name.as_str())
113 .preamble(&system_prompt)
114 .build();
115 self.agents.push((
116 AgentVariant::Gemini(agent),
117 agent_conf.id,
118 agent_conf.provider.to_string(),
119 agent_conf.model_name,
120 ));
121 }
122 Err(err) => {
123 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
124 }
125 }
126 }
127 ProviderEnum::Huggingface => match huggingface::Client::new(&agent_conf.api_key) {
128 Ok(client) => {
129 let agent = client
130 .agent(&agent_conf.model_name)
131 .name(agent_name.as_str())
132 .preamble(&system_prompt)
133 .build();
134 self.agents.push((
135 AgentVariant::Huggingface(agent),
136 agent_conf.id,
137 agent_conf.provider.to_string(),
138 agent_conf.model_name,
139 ));
140 }
141 Err(err) => {
142 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
143 }
144 },
145 ProviderEnum::Mistral => match mistral::Client::new(&agent_conf.api_key) {
146 Ok(client) => {
147 let agent = client
148 .agent(&agent_conf.model_name)
149 .name(agent_name.as_str())
150 .preamble(&system_prompt)
151 .build();
152 self.agents.push((
153 AgentVariant::Mistral(agent),
154 agent_conf.id,
155 agent_conf.provider.to_string(),
156 agent_conf.model_name,
157 ));
158 }
159 Err(err) => {
160 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
161 }
162 },
163 ProviderEnum::OpenAi => {
164 let mut client_builder = openai::Client::builder();
165 if let Some(api_base_url) = &agent_conf.api_base_url {
166 client_builder = client_builder.base_url(api_base_url);
167 }
168 match client_builder.api_key(&agent_conf.api_key).build() {
169 Ok(client) => {
170 let agent = client
171 .completions_api()
172 .agent(&agent_conf.model_name)
173 .name(agent_name.as_str())
174 .preamble(&system_prompt)
175 .build();
176
177 self.agents.push((
178 AgentVariant::OpenAI(agent),
179 agent_conf.id,
180 agent_conf.provider.to_string(),
181 agent_conf.model_name,
182 ));
183 }
184 Err(err) => {
185 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
186 }
187 }
188 }
189 ProviderEnum::ResponsesOpenAi => {
190 let mut client_builder = openai::Client::builder();
191 if let Some(api_base_url) = &agent_conf.api_base_url {
192 client_builder = client_builder.base_url(api_base_url);
193 }
194 match client_builder.api_key(&agent_conf.api_key).build() {
195 Ok(client) => {
196 let agent = client
197 .agent(&agent_conf.model_name)
198 .name(agent_name.as_str())
199 .preamble(&system_prompt)
200 .build();
201 self.agents.push((
202 AgentVariant::ResponsesOpenAI(agent),
203 agent_conf.id,
204 agent_conf.provider.to_string(),
205 agent_conf.model_name,
206 ));
207 }
208 Err(err) => {
209 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
210 }
211 }
212 }
213 ProviderEnum::OpenRouter => {
214 let mut client_builder = openrouter::Client::builder();
215 if let Some(api_base_url) = &agent_conf.api_base_url {
216 client_builder = client_builder.base_url(api_base_url);
217 }
218 match client_builder.api_key(&agent_conf.api_key).build() {
219 Ok(client) => {
220 let agent = client
221 .agent(&agent_conf.model_name)
222 .name(agent_name.as_str())
223 .preamble(&system_prompt)
224 .build();
225 self.agents.push((
226 AgentVariant::OpenRouter(agent),
227 agent_conf.id,
228 agent_conf.provider.to_string(),
229 agent_conf.model_name,
230 ));
231 }
232 Err(err) => {
233 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
234 }
235 }
236 }
237 ProviderEnum::Together => match together::Client::new(&agent_conf.api_key) {
238 Ok(client) => {
239 let agent = client
240 .agent(&agent_conf.model_name)
241 .name(agent_name.as_str())
242 .preamble(&system_prompt)
243 .build();
244 self.agents.push((
245 AgentVariant::Together(agent),
246 agent_conf.id,
247 agent_conf.provider.to_string(),
248 agent_conf.model_name,
249 ));
250 }
251 Err(err) => {
252 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
253 }
254 },
255 ProviderEnum::XAI => match xai::Client::new(&agent_conf.api_key) {
256 Ok(client) => {
257 let agent = client
258 .agent(&agent_conf.model_name)
259 .name(agent_name.as_str())
260 .preamble(&system_prompt)
261 .build();
262 self.agents.push((
263 AgentVariant::XAI(agent),
264 agent_conf.id,
265 agent_conf.provider.to_string(),
266 agent_conf.model_name,
267 ));
268 }
269 Err(err) => {
270 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
271 }
272 },
273 ProviderEnum::Azure => {
274 tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
275 }
276 ProviderEnum::DeepSeek => match deepseek::Client::new(&agent_conf.api_key) {
277 Ok(client) => {
278 let agent = client
279 .agent(&agent_conf.model_name)
280 .name(agent_name.as_str())
281 .preamble(&system_prompt)
282 .build();
283 self.agents.push((
284 AgentVariant::DeepSeek(agent),
285 agent_conf.id,
286 agent_conf.provider.to_string(),
287 agent_conf.model_name,
288 ));
289 }
290 Err(err) => {
291 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
292 }
293 },
294 ProviderEnum::Galadriel => match galadriel::Client::new(&agent_conf.api_key) {
295 Ok(client) => {
296 let agent = client
297 .agent(&agent_conf.model_name)
298 .name(agent_name.as_str())
299 .preamble(&system_prompt)
300 .build();
301 self.agents.push((
302 AgentVariant::Galadriel(agent),
303 agent_conf.id,
304 agent_conf.provider.to_string(),
305 agent_conf.model_name,
306 ));
307 }
308 Err(err) => {
309 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
310 }
311 },
312 ProviderEnum::Groq => match groq::Client::new(&agent_conf.api_key) {
313 Ok(client) => {
314 let agent = client
315 .agent(&agent_conf.model_name)
316 .name(agent_name.as_str())
317 .preamble(&system_prompt)
318 .build();
319 self.agents.push((
320 AgentVariant::Groq(agent),
321 agent_conf.id,
322 agent_conf.provider.to_string(),
323 agent_conf.model_name,
324 ));
325 }
326 Err(err) => {
327 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
328 }
329 },
330 ProviderEnum::Hyperbolic => match hyperbolic::Client::new(&agent_conf.api_key) {
331 Ok(client) => {
332 let agent = client
333 .agent(&agent_conf.model_name)
334 .name(agent_name.as_str())
335 .preamble(&system_prompt)
336 .build();
337 self.agents.push((
338 AgentVariant::Hyperbolic(agent),
339 agent_conf.id,
340 agent_conf.provider.to_string(),
341 agent_conf.model_name,
342 ));
343 }
344 Err(err) => {
345 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
346 }
347 },
348 ProviderEnum::Mira => match mira::Client::new(&agent_conf.api_key) {
349 Ok(client) => {
350 let agent = client
351 .agent(&agent_conf.model_name)
352 .name(agent_name.as_str())
353 .preamble(&system_prompt)
354 .build();
355 self.agents.push((
356 AgentVariant::Mira(agent),
357 agent_conf.id,
358 agent_conf.provider.to_string(),
359 agent_conf.model_name,
360 ));
361 }
362 Err(err) => {
363 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
364 }
365 },
366 ProviderEnum::Mooshot => match moonshot::Client::new(&agent_conf.api_key) {
367 Ok(client) => {
368 let agent = client
369 .agent(&agent_conf.model_name)
370 .name(agent_name.as_str())
371 .preamble(&system_prompt)
372 .build();
373 self.agents.push((
374 AgentVariant::Mooshot(agent),
375 agent_conf.id,
376 agent_conf.provider.to_string(),
377 agent_conf.model_name,
378 ));
379 }
380 Err(err) => {
381 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
382 }
383 },
384 ProviderEnum::Ollama => {
385 use rig::client::Nothing;
388 let client_result = if let Some(api_base_url) = &agent_conf.api_base_url {
389 ollama::Client::builder()
390 .base_url(api_base_url)
391 .api_key(Nothing)
392 .build()
393 } else {
394 ollama::Client::builder().api_key(Nothing).build()
395 };
396 match client_result {
397 Ok(client) => {
398 let agent = client
399 .agent(&agent_conf.model_name)
400 .name(agent_name.as_str())
401 .preamble(&system_prompt)
402 .build();
403 self.agents.push((
404 AgentVariant::Ollama(agent),
405 agent_conf.id,
406 agent_conf.provider.to_string(),
407 agent_conf.model_name,
408 ));
409 }
410 Err(err) => {
411 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
412 }
413 }
414 }
415 ProviderEnum::Perplexity => match perplexity::Client::new(&agent_conf.api_key) {
416 Ok(client) => {
417 let agent = client
418 .agent(&agent_conf.model_name)
419 .name(agent_name.as_str())
420 .preamble(&system_prompt)
421 .build();
422 self.agents.push((
423 AgentVariant::Perplexity(agent),
424 agent_conf.id,
425 agent_conf.provider.to_string(),
426 agent_conf.model_name,
427 ));
428 }
429 Err(err) => {
430 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
431 }
432 },
433 ProviderEnum::Bigmodel => match bigmodel::Client::new(&agent_conf.api_key) {
434 Ok(client) => {
435 let agent = client
436 .agent(&agent_conf.model_name)
437 .name(agent_name.as_str())
438 .preamble(&system_prompt)
439 .build();
440 self.agents.push((
441 AgentVariant::Bigmodel(agent),
442 agent_conf.id,
443 agent_conf.provider.to_string(),
444 agent_conf.model_name,
445 ));
446 }
447 Err(err) => {
448 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
449 }
450 },
451 }
452 }
453 self
454 }
455}