1use crate::agent_variant::AgentVariant;
2use crate::extra_providers::bigmodel;
3use crate::rand_agent::RandAgentBuilder;
4use rig::agent::AgentBuilder;
5use rig::client::CompletionClient;
6use rig::providers::*;
7use serde::{Deserialize, Serialize};
8use strum_macros::Display;
9
10#[derive(Debug, Display, Deserialize, Serialize)]
11#[serde(rename_all = "lowercase")]
12pub enum ProviderEnum {
13 Anthropic,
14 Cohere,
15 Gemini,
16 Huggingface,
17 Mistral,
18 OpenAi,
19 ResponsesOpenAi,
20 OpenRouter,
21 Together,
22 XAI,
23 Azure,
24 DeepSeek,
25 Galadriel,
26 Groq,
27 Hyperbolic,
28 Mira,
29 Mooshot,
30 Ollama,
31 Perplexity,
32 Bigmodel,
35}
36
37#[derive(Debug, Deserialize)]
38pub struct AgentConfig {
39 pub id: i32,
40 pub provider: ProviderEnum,
41 pub model_name: String,
42 pub api_key: String,
43 pub api_base_url: Option<String>,
44 pub system_prompt: Option<String>,
45 pub agent_name: Option<String>,
46}
47
48impl RandAgentBuilder {
49 pub fn simple_builder(
51 mut self,
52 agent_configs: Vec<AgentConfig>,
53 global_system_prompt: String,
54 ) -> Self {
55 for agent_conf in agent_configs {
56 let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
57 let system_prompt = agent_conf
58 .system_prompt
59 .unwrap_or(global_system_prompt.clone());
60
61 match agent_conf.provider {
62 ProviderEnum::Anthropic => {
63 let mut client_builder = anthropic::Client::builder();
64 if let Some(api_base_url) = &agent_conf.api_base_url {
65 client_builder = client_builder.base_url(api_base_url);
66 }
67 match client_builder.api_key(&agent_conf.api_key).build() {
68 Ok(client) => {
69 let agent = client
70 .agent(&agent_conf.model_name)
71 .name(agent_name.as_str())
72 .preamble(&system_prompt)
73 .build();
74 self.agents.push((
75 AgentVariant::Anthropic(agent),
76 agent_conf.id,
77 agent_conf.provider.to_string(),
78 agent_conf.model_name,
79 ));
80 }
81 Err(err) => {
82 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
83 }
84 }
85 }
86 ProviderEnum::Cohere => match cohere::Client::new(&agent_conf.api_key) {
87 Ok(client) => {
88 let agent = client
89 .agent(&agent_conf.model_name)
90 .name(agent_name.as_str())
91 .preamble(&system_prompt)
92 .build();
93 self.agents.push((
94 AgentVariant::Cohere(agent),
95 agent_conf.id,
96 agent_conf.provider.to_string(),
97 agent_conf.model_name,
98 ));
99 }
100 Err(err) => {
101 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
102 }
103 },
104 ProviderEnum::Gemini => {
105 let mut client_builder = gemini::Client::builder();
106 if let Some(api_base_url) = &agent_conf.api_base_url {
107 client_builder = client_builder.base_url(api_base_url);
108 }
109 match client_builder.api_key(&agent_conf.api_key).build() {
110 Ok(client) => {
111 let agent = client
112 .agent(&agent_conf.model_name)
113 .name(agent_name.as_str())
114 .preamble(&system_prompt)
115 .build();
116 self.agents.push((
117 AgentVariant::Gemini(agent),
118 agent_conf.id,
119 agent_conf.provider.to_string(),
120 agent_conf.model_name,
121 ));
122 }
123 Err(err) => {
124 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
125 }
126 }
127 }
128 ProviderEnum::Huggingface => match huggingface::Client::new(&agent_conf.api_key) {
129 Ok(client) => {
130 let agent = client
131 .agent(&agent_conf.model_name)
132 .name(agent_name.as_str())
133 .preamble(&system_prompt)
134 .build();
135 self.agents.push((
136 AgentVariant::Huggingface(agent),
137 agent_conf.id,
138 agent_conf.provider.to_string(),
139 agent_conf.model_name,
140 ));
141 }
142 Err(err) => {
143 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
144 }
145 },
146 ProviderEnum::Mistral => match mistral::Client::new(&agent_conf.api_key) {
147 Ok(client) => {
148 let agent = client
149 .agent(&agent_conf.model_name)
150 .name(agent_name.as_str())
151 .preamble(&system_prompt)
152 .build();
153 self.agents.push((
154 AgentVariant::Mistral(agent),
155 agent_conf.id,
156 agent_conf.provider.to_string(),
157 agent_conf.model_name,
158 ));
159 }
160 Err(err) => {
161 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
162 }
163 },
164 ProviderEnum::OpenAi => {
165 let mut client_builder = openai::Client::builder();
166 if let Some(api_base_url) = &agent_conf.api_base_url {
167 client_builder = client_builder.base_url(api_base_url);
168 }
169 match client_builder.api_key(&agent_conf.api_key).build() {
170 Ok(client) => {
171 let agent = client
172 .completions_api()
173 .agent(&agent_conf.model_name)
174 .name(agent_name.as_str())
175 .preamble(&system_prompt)
176 .build();
177
178 self.agents.push((
179 AgentVariant::OpenAI(agent),
180 agent_conf.id,
181 agent_conf.provider.to_string(),
182 agent_conf.model_name,
183 ));
184 }
185 Err(err) => {
186 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
187 }
188 }
189 }
190 ProviderEnum::ResponsesOpenAi => {
191 let mut client_builder = openai::Client::builder();
192 if let Some(api_base_url) = &agent_conf.api_base_url {
193 client_builder = client_builder.base_url(api_base_url);
194 }
195 match client_builder.api_key(&agent_conf.api_key).build() {
196 Ok(client) => {
197 let agent = client
198 .agent(&agent_conf.model_name)
199 .name(agent_name.as_str())
200 .preamble(&system_prompt)
201 .build();
202 self.agents.push((
203 AgentVariant::ResponsesOpenAI(agent),
204 agent_conf.id,
205 agent_conf.provider.to_string(),
206 agent_conf.model_name,
207 ));
208 }
209 Err(err) => {
210 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
211 }
212 }
213 }
214 ProviderEnum::OpenRouter => {
215 let mut client_builder = openrouter::Client::builder();
216 if let Some(api_base_url) = &agent_conf.api_base_url {
217 client_builder = client_builder.base_url(api_base_url);
218 }
219 match client_builder.api_key(&agent_conf.api_key).build() {
220 Ok(client) => {
221 let agent = client
222 .agent(&agent_conf.model_name)
223 .name(agent_name.as_str())
224 .preamble(&system_prompt)
225 .build();
226 self.agents.push((
227 AgentVariant::OpenRouter(agent),
228 agent_conf.id,
229 agent_conf.provider.to_string(),
230 agent_conf.model_name,
231 ));
232 }
233 Err(err) => {
234 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
235 }
236 }
237 }
238 ProviderEnum::Together => match together::Client::new(&agent_conf.api_key) {
239 Ok(client) => {
240 let agent = client
241 .agent(&agent_conf.model_name)
242 .name(agent_name.as_str())
243 .preamble(&system_prompt)
244 .build();
245 self.agents.push((
246 AgentVariant::Together(agent),
247 agent_conf.id,
248 agent_conf.provider.to_string(),
249 agent_conf.model_name,
250 ));
251 }
252 Err(err) => {
253 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
254 }
255 },
256 ProviderEnum::XAI => match xai::Client::new(&agent_conf.api_key) {
257 Ok(client) => {
258 let agent = client
259 .agent(&agent_conf.model_name)
260 .name(agent_name.as_str())
261 .preamble(&system_prompt)
262 .build();
263 self.agents.push((
264 AgentVariant::XAI(agent),
265 agent_conf.id,
266 agent_conf.provider.to_string(),
267 agent_conf.model_name,
268 ));
269 }
270 Err(err) => {
271 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
272 }
273 },
274 ProviderEnum::Azure => {
275 tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
276 }
277 ProviderEnum::DeepSeek => match deepseek::Client::new(&agent_conf.api_key) {
278 Ok(client) => {
279 let agent = client
280 .agent(&agent_conf.model_name)
281 .name(agent_name.as_str())
282 .preamble(&system_prompt)
283 .build();
284 self.agents.push((
285 AgentVariant::DeepSeek(agent),
286 agent_conf.id,
287 agent_conf.provider.to_string(),
288 agent_conf.model_name,
289 ));
290 }
291 Err(err) => {
292 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
293 }
294 },
295 ProviderEnum::Galadriel => match galadriel::Client::new(&agent_conf.api_key) {
296 Ok(client) => {
297 let agent = client
298 .agent(&agent_conf.model_name)
299 .name(agent_name.as_str())
300 .preamble(&system_prompt)
301 .build();
302 self.agents.push((
303 AgentVariant::Galadriel(agent),
304 agent_conf.id,
305 agent_conf.provider.to_string(),
306 agent_conf.model_name,
307 ));
308 }
309 Err(err) => {
310 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
311 }
312 },
313 ProviderEnum::Groq => match groq::Client::new(&agent_conf.api_key) {
314 Ok(client) => {
315 let agent = client
316 .agent(&agent_conf.model_name)
317 .name(agent_name.as_str())
318 .preamble(&system_prompt)
319 .build();
320 self.agents.push((
321 AgentVariant::Groq(agent),
322 agent_conf.id,
323 agent_conf.provider.to_string(),
324 agent_conf.model_name,
325 ));
326 }
327 Err(err) => {
328 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
329 }
330 },
331 ProviderEnum::Hyperbolic => match hyperbolic::Client::new(&agent_conf.api_key) {
332 Ok(client) => {
333 let agent = client
334 .agent(&agent_conf.model_name)
335 .name(agent_name.as_str())
336 .preamble(&system_prompt)
337 .build();
338 self.agents.push((
339 AgentVariant::Hyperbolic(agent),
340 agent_conf.id,
341 agent_conf.provider.to_string(),
342 agent_conf.model_name,
343 ));
344 }
345 Err(err) => {
346 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
347 }
348 },
349 ProviderEnum::Mira => match mira::Client::new(&agent_conf.api_key) {
350 Ok(client) => {
351 let agent = client
352 .agent(&agent_conf.model_name)
353 .name(agent_name.as_str())
354 .preamble(&system_prompt)
355 .build();
356 self.agents.push((
357 AgentVariant::Mira(agent),
358 agent_conf.id,
359 agent_conf.provider.to_string(),
360 agent_conf.model_name,
361 ));
362 }
363 Err(err) => {
364 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
365 }
366 },
367 ProviderEnum::Mooshot => match moonshot::Client::new(&agent_conf.api_key) {
368 Ok(client) => {
369 let agent = client
370 .agent(&agent_conf.model_name)
371 .name(agent_name.as_str())
372 .preamble(&system_prompt)
373 .build();
374 self.agents.push((
375 AgentVariant::Mooshot(agent),
376 agent_conf.id,
377 agent_conf.provider.to_string(),
378 agent_conf.model_name,
379 ));
380 }
381 Err(err) => {
382 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
383 }
384 },
385 ProviderEnum::Ollama => {
386 use rig::client::Nothing;
389 let client_result = if let Some(api_base_url) = &agent_conf.api_base_url {
390 ollama::Client::builder()
391 .base_url(api_base_url)
392 .api_key(Nothing)
393 .build()
394 } else {
395 ollama::Client::builder().api_key(Nothing).build()
396 };
397 match client_result {
398 Ok(client) => {
399 let agent = client
400 .agent(&agent_conf.model_name)
401 .name(agent_name.as_str())
402 .preamble(&system_prompt)
403 .build();
404 self.agents.push((
405 AgentVariant::Ollama(agent),
406 agent_conf.id,
407 agent_conf.provider.to_string(),
408 agent_conf.model_name,
409 ));
410 }
411 Err(err) => {
412 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
413 }
414 }
415 }
416 ProviderEnum::Perplexity => match perplexity::Client::new(&agent_conf.api_key) {
417 Ok(client) => {
418 let agent = client
419 .agent(&agent_conf.model_name)
420 .name(agent_name.as_str())
421 .preamble(&system_prompt)
422 .build();
423 self.agents.push((
424 AgentVariant::Perplexity(agent),
425 agent_conf.id,
426 agent_conf.provider.to_string(),
427 agent_conf.model_name,
428 ));
429 }
430 Err(err) => {
431 tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
432 }
433 },
434 ProviderEnum::Bigmodel => {
435 let client = if let Some(api_base_url) = agent_conf.api_base_url {
436 bigmodel::Client::from_url(&agent_conf.api_key, &api_base_url)
437 } else {
438 bigmodel::Client::new(&agent_conf.api_key)
439 };
440 let model = client.completion_model(&agent_conf.model_name);
441 let agent = AgentBuilder::new(model)
442 .name(agent_name.as_str())
443 .preamble(&system_prompt)
444 .build();
445 self.agents.push((
446 AgentVariant::Bigmodel(agent),
447 agent_conf.id,
448 agent_conf.provider.to_string(),
449 agent_conf.model_name,
450 ));
451 }
452 }
453 }
454 self
455 }
456}