Skip to main content

rig_extra/
simple_rand_builder.rs

1use crate::agent_variant::AgentVariant;
2use crate::extra_providers::bigmodel;
3use crate::rand_agent::RandAgentBuilder;
4use rig::client::CompletionClient;
5use rig::providers::*;
6use serde::{Deserialize, Serialize};
7use strum_macros::Display;
8
9#[derive(Debug, Display, Deserialize, Serialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ProviderEnum {
12    Anthropic,
13    Cohere,
14    Gemini,
15    Huggingface,
16    Mistral,
17    OpenAi,
18    ResponsesOpenAi,
19    OpenRouter,
20    Together,
21    XAI,
22    Azure,
23    DeepSeek,
24    Galadriel,
25    Groq,
26    Hyperbolic,
27    Mira,
28    Mooshot,
29    Ollama,
30    Perplexity,
31    // embedding模型
32    // Voyageai,
33    Bigmodel,
34}
35
36#[derive(Debug, Deserialize)]
37pub struct AgentConfig {
38    pub id: i32,
39    pub provider: ProviderEnum,
40    pub model_name: String,
41    pub api_key: String,
42    pub api_base_url: Option<String>,
43    pub system_prompt: Option<String>,
44    pub agent_name: Option<String>,
45}
46
47impl RandAgentBuilder {
48    /// 简单构建器
49    pub fn simple_builder(
50        mut self,
51        agent_configs: Vec<AgentConfig>,
52        global_system_prompt: String,
53    ) -> Self {
54        for agent_conf in agent_configs {
55            let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
56            let system_prompt = agent_conf
57                .system_prompt
58                .unwrap_or(global_system_prompt.clone());
59
60            match agent_conf.provider {
61                ProviderEnum::Anthropic => {
62                    let mut client_builder = anthropic::Client::builder();
63                    if let Some(api_base_url) = &agent_conf.api_base_url {
64                        client_builder = client_builder.base_url(api_base_url);
65                    }
66                    match client_builder.api_key(&agent_conf.api_key).build() {
67                        Ok(client) => {
68                            let agent = client
69                                .agent(&agent_conf.model_name)
70                                .name(agent_name.as_str())
71                                .preamble(&system_prompt)
72                                .build();
73                            self.agents.push((
74                                AgentVariant::Anthropic(agent),
75                                agent_conf.id,
76                                agent_conf.provider.to_string(),
77                                agent_conf.model_name,
78                            ));
79                        }
80                        Err(err) => {
81                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
82                        }
83                    }
84                }
85                ProviderEnum::Cohere => match cohere::Client::new(&agent_conf.api_key) {
86                    Ok(client) => {
87                        let agent = client
88                            .agent(&agent_conf.model_name)
89                            .name(agent_name.as_str())
90                            .preamble(&system_prompt)
91                            .build();
92                        self.agents.push((
93                            AgentVariant::Cohere(agent),
94                            agent_conf.id,
95                            agent_conf.provider.to_string(),
96                            agent_conf.model_name,
97                        ));
98                    }
99                    Err(err) => {
100                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
101                    }
102                },
103                ProviderEnum::Gemini => {
104                    let mut client_builder = gemini::Client::builder();
105                    if let Some(api_base_url) = &agent_conf.api_base_url {
106                        client_builder = client_builder.base_url(api_base_url);
107                    }
108                    match client_builder.api_key(&agent_conf.api_key).build() {
109                        Ok(client) => {
110                            let agent = client
111                                .agent(&agent_conf.model_name)
112                                .name(agent_name.as_str())
113                                .preamble(&system_prompt)
114                                .build();
115                            self.agents.push((
116                                AgentVariant::Gemini(agent),
117                                agent_conf.id,
118                                agent_conf.provider.to_string(),
119                                agent_conf.model_name,
120                            ));
121                        }
122                        Err(err) => {
123                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
124                        }
125                    }
126                }
127                ProviderEnum::Huggingface => match huggingface::Client::new(&agent_conf.api_key) {
128                    Ok(client) => {
129                        let agent = client
130                            .agent(&agent_conf.model_name)
131                            .name(agent_name.as_str())
132                            .preamble(&system_prompt)
133                            .build();
134                        self.agents.push((
135                            AgentVariant::Huggingface(agent),
136                            agent_conf.id,
137                            agent_conf.provider.to_string(),
138                            agent_conf.model_name,
139                        ));
140                    }
141                    Err(err) => {
142                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
143                    }
144                },
145                ProviderEnum::Mistral => match mistral::Client::new(&agent_conf.api_key) {
146                    Ok(client) => {
147                        let agent = client
148                            .agent(&agent_conf.model_name)
149                            .name(agent_name.as_str())
150                            .preamble(&system_prompt)
151                            .build();
152                        self.agents.push((
153                            AgentVariant::Mistral(agent),
154                            agent_conf.id,
155                            agent_conf.provider.to_string(),
156                            agent_conf.model_name,
157                        ));
158                    }
159                    Err(err) => {
160                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
161                    }
162                },
163                ProviderEnum::OpenAi => {
164                    let mut client_builder = openai::Client::builder();
165                    if let Some(api_base_url) = &agent_conf.api_base_url {
166                        client_builder = client_builder.base_url(api_base_url);
167                    }
168                    match client_builder.api_key(&agent_conf.api_key).build() {
169                        Ok(client) => {
170                            let agent = client
171                                .completions_api()
172                                .agent(&agent_conf.model_name)
173                                .name(agent_name.as_str())
174                                .preamble(&system_prompt)
175                                .build();
176
177                            self.agents.push((
178                                AgentVariant::OpenAI(agent),
179                                agent_conf.id,
180                                agent_conf.provider.to_string(),
181                                agent_conf.model_name,
182                            ));
183                        }
184                        Err(err) => {
185                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
186                        }
187                    }
188                }
189                ProviderEnum::ResponsesOpenAi => {
190                    let mut client_builder = openai::Client::builder();
191                    if let Some(api_base_url) = &agent_conf.api_base_url {
192                        client_builder = client_builder.base_url(api_base_url);
193                    }
194                    match client_builder.api_key(&agent_conf.api_key).build() {
195                        Ok(client) => {
196                            let agent = client
197                                .agent(&agent_conf.model_name)
198                                .name(agent_name.as_str())
199                                .preamble(&system_prompt)
200                                .build();
201                            self.agents.push((
202                                AgentVariant::ResponsesOpenAI(agent),
203                                agent_conf.id,
204                                agent_conf.provider.to_string(),
205                                agent_conf.model_name,
206                            ));
207                        }
208                        Err(err) => {
209                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
210                        }
211                    }
212                }
213                ProviderEnum::OpenRouter => {
214                    let mut client_builder = openrouter::Client::builder();
215                    if let Some(api_base_url) = &agent_conf.api_base_url {
216                        client_builder = client_builder.base_url(api_base_url);
217                    }
218                    match client_builder.api_key(&agent_conf.api_key).build() {
219                        Ok(client) => {
220                            let agent = client
221                                .agent(&agent_conf.model_name)
222                                .name(agent_name.as_str())
223                                .preamble(&system_prompt)
224                                .build();
225                            self.agents.push((
226                                AgentVariant::OpenRouter(agent),
227                                agent_conf.id,
228                                agent_conf.provider.to_string(),
229                                agent_conf.model_name,
230                            ));
231                        }
232                        Err(err) => {
233                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
234                        }
235                    }
236                }
237                ProviderEnum::Together => match together::Client::new(&agent_conf.api_key) {
238                    Ok(client) => {
239                        let agent = client
240                            .agent(&agent_conf.model_name)
241                            .name(agent_name.as_str())
242                            .preamble(&system_prompt)
243                            .build();
244                        self.agents.push((
245                            AgentVariant::Together(agent),
246                            agent_conf.id,
247                            agent_conf.provider.to_string(),
248                            agent_conf.model_name,
249                        ));
250                    }
251                    Err(err) => {
252                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
253                    }
254                },
255                ProviderEnum::XAI => match xai::Client::new(&agent_conf.api_key) {
256                    Ok(client) => {
257                        let agent = client
258                            .agent(&agent_conf.model_name)
259                            .name(agent_name.as_str())
260                            .preamble(&system_prompt)
261                            .build();
262                        self.agents.push((
263                            AgentVariant::XAI(agent),
264                            agent_conf.id,
265                            agent_conf.provider.to_string(),
266                            agent_conf.model_name,
267                        ));
268                    }
269                    Err(err) => {
270                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
271                    }
272                },
273                ProviderEnum::Azure => {
274                    tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
275                }
276                ProviderEnum::DeepSeek => match deepseek::Client::new(&agent_conf.api_key) {
277                    Ok(client) => {
278                        let agent = client
279                            .agent(&agent_conf.model_name)
280                            .name(agent_name.as_str())
281                            .preamble(&system_prompt)
282                            .build();
283                        self.agents.push((
284                            AgentVariant::DeepSeek(agent),
285                            agent_conf.id,
286                            agent_conf.provider.to_string(),
287                            agent_conf.model_name,
288                        ));
289                    }
290                    Err(err) => {
291                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
292                    }
293                },
294                ProviderEnum::Galadriel => match galadriel::Client::new(&agent_conf.api_key) {
295                    Ok(client) => {
296                        let agent = client
297                            .agent(&agent_conf.model_name)
298                            .name(agent_name.as_str())
299                            .preamble(&system_prompt)
300                            .build();
301                        self.agents.push((
302                            AgentVariant::Galadriel(agent),
303                            agent_conf.id,
304                            agent_conf.provider.to_string(),
305                            agent_conf.model_name,
306                        ));
307                    }
308                    Err(err) => {
309                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
310                    }
311                },
312                ProviderEnum::Groq => match groq::Client::new(&agent_conf.api_key) {
313                    Ok(client) => {
314                        let agent = client
315                            .agent(&agent_conf.model_name)
316                            .name(agent_name.as_str())
317                            .preamble(&system_prompt)
318                            .build();
319                        self.agents.push((
320                            AgentVariant::Groq(agent),
321                            agent_conf.id,
322                            agent_conf.provider.to_string(),
323                            agent_conf.model_name,
324                        ));
325                    }
326                    Err(err) => {
327                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
328                    }
329                },
330                ProviderEnum::Hyperbolic => match hyperbolic::Client::new(&agent_conf.api_key) {
331                    Ok(client) => {
332                        let agent = client
333                            .agent(&agent_conf.model_name)
334                            .name(agent_name.as_str())
335                            .preamble(&system_prompt)
336                            .build();
337                        self.agents.push((
338                            AgentVariant::Hyperbolic(agent),
339                            agent_conf.id,
340                            agent_conf.provider.to_string(),
341                            agent_conf.model_name,
342                        ));
343                    }
344                    Err(err) => {
345                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
346                    }
347                },
348                ProviderEnum::Mira => match mira::Client::new(&agent_conf.api_key) {
349                    Ok(client) => {
350                        let agent = client
351                            .agent(&agent_conf.model_name)
352                            .name(agent_name.as_str())
353                            .preamble(&system_prompt)
354                            .build();
355                        self.agents.push((
356                            AgentVariant::Mira(agent),
357                            agent_conf.id,
358                            agent_conf.provider.to_string(),
359                            agent_conf.model_name,
360                        ));
361                    }
362                    Err(err) => {
363                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
364                    }
365                },
366                ProviderEnum::Mooshot => match moonshot::Client::new(&agent_conf.api_key) {
367                    Ok(client) => {
368                        let agent = client
369                            .agent(&agent_conf.model_name)
370                            .name(agent_name.as_str())
371                            .preamble(&system_prompt)
372                            .build();
373                        self.agents.push((
374                            AgentVariant::Mooshot(agent),
375                            agent_conf.id,
376                            agent_conf.provider.to_string(),
377                            agent_conf.model_name,
378                        ));
379                    }
380                    Err(err) => {
381                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
382                    }
383                },
384                ProviderEnum::Ollama => {
385                    // Ollama uses Nothing as API key type, which means it doesn't need an API key
386                    // In rig-core 0.25, we need to use builder pattern with no API key
387                    use rig::client::Nothing;
388                    let client_result = if let Some(api_base_url) = &agent_conf.api_base_url {
389                        ollama::Client::builder()
390                            .base_url(api_base_url)
391                            .api_key(Nothing)
392                            .build()
393                    } else {
394                        ollama::Client::builder().api_key(Nothing).build()
395                    };
396                    match client_result {
397                        Ok(client) => {
398                            let agent = client
399                                .agent(&agent_conf.model_name)
400                                .name(agent_name.as_str())
401                                .preamble(&system_prompt)
402                                .build();
403                            self.agents.push((
404                                AgentVariant::Ollama(agent),
405                                agent_conf.id,
406                                agent_conf.provider.to_string(),
407                                agent_conf.model_name,
408                            ));
409                        }
410                        Err(err) => {
411                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
412                        }
413                    }
414                }
415                ProviderEnum::Perplexity => match perplexity::Client::new(&agent_conf.api_key) {
416                    Ok(client) => {
417                        let agent = client
418                            .agent(&agent_conf.model_name)
419                            .name(agent_name.as_str())
420                            .preamble(&system_prompt)
421                            .build();
422                        self.agents.push((
423                            AgentVariant::Perplexity(agent),
424                            agent_conf.id,
425                            agent_conf.provider.to_string(),
426                            agent_conf.model_name,
427                        ));
428                    }
429                    Err(err) => {
430                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
431                    }
432                },
433                ProviderEnum::Bigmodel => match bigmodel::Client::new(&agent_conf.api_key) {
434                    Ok(client) => {
435                        let agent = client
436                            .agent(&agent_conf.model_name)
437                            .name(agent_name.as_str())
438                            .preamble(&system_prompt)
439                            .build();
440                        self.agents.push((
441                            AgentVariant::Bigmodel(agent),
442                            agent_conf.id,
443                            agent_conf.provider.to_string(),
444                            agent_conf.model_name,
445                        ));
446                    }
447                    Err(err) => {
448                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
449                    }
450                },
451            }
452        }
453        self
454    }
455}