rig_extra/
simple_rand_builder.rs

1use crate::agent_variant::AgentVariant;
2use crate::extra_providers::bigmodel;
3use crate::rand_agent::RandAgentBuilder;
4use rig::agent::AgentBuilder;
5use rig::client::CompletionClient;
6use rig::providers::*;
7use serde::{Deserialize, Serialize};
8use strum_macros::Display;
9
10#[derive(Debug, Display, Deserialize, Serialize)]
11#[serde(rename_all = "lowercase")]
12pub enum ProviderEnum {
13    Anthropic,
14    Cohere,
15    Gemini,
16    Huggingface,
17    Mistral,
18    OpenAi,
19    ResponsesOpenAi,
20    OpenRouter,
21    Together,
22    XAI,
23    Azure,
24    DeepSeek,
25    Galadriel,
26    Groq,
27    Hyperbolic,
28    Mira,
29    Mooshot,
30    Ollama,
31    Perplexity,
32    // embedding模型
33    // Voyageai,
34    Bigmodel,
35}
36
37#[derive(Debug, Deserialize)]
38pub struct AgentConfig {
39    pub id: i32,
40    pub provider: ProviderEnum,
41    pub model_name: String,
42    pub api_key: String,
43    pub api_base_url: Option<String>,
44    pub system_prompt: Option<String>,
45    pub agent_name: Option<String>,
46}
47
48impl RandAgentBuilder {
49    /// 简单构建器
50    pub fn simple_builder(
51        mut self,
52        agent_configs: Vec<AgentConfig>,
53        global_system_prompt: String,
54    ) -> Self {
55        for agent_conf in agent_configs {
56            let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
57            let system_prompt = agent_conf
58                .system_prompt
59                .unwrap_or(global_system_prompt.clone());
60
61            match agent_conf.provider {
62                ProviderEnum::Anthropic => {
63                    let mut client_builder = anthropic::Client::builder();
64                    if let Some(api_base_url) = &agent_conf.api_base_url {
65                        client_builder = client_builder.base_url(api_base_url);
66                    }
67                    match client_builder.api_key(&agent_conf.api_key).build() {
68                        Ok(client) => {
69                            let agent = client
70                                .agent(&agent_conf.model_name)
71                                .name(agent_name.as_str())
72                                .preamble(&system_prompt)
73                                .build();
74                            self.agents.push((
75                                AgentVariant::Anthropic(agent),
76                                agent_conf.id,
77                                agent_conf.provider.to_string(),
78                                agent_conf.model_name,
79                            ));
80                        }
81                        Err(err) => {
82                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
83                        }
84                    }
85                }
86                ProviderEnum::Cohere => match cohere::Client::new(&agent_conf.api_key) {
87                    Ok(client) => {
88                        let agent = client
89                            .agent(&agent_conf.model_name)
90                            .name(agent_name.as_str())
91                            .preamble(&system_prompt)
92                            .build();
93                        self.agents.push((
94                            AgentVariant::Cohere(agent),
95                            agent_conf.id,
96                            agent_conf.provider.to_string(),
97                            agent_conf.model_name,
98                        ));
99                    }
100                    Err(err) => {
101                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
102                    }
103                },
104                ProviderEnum::Gemini => {
105                    let mut client_builder = gemini::Client::builder();
106                    if let Some(api_base_url) = &agent_conf.api_base_url {
107                        client_builder = client_builder.base_url(api_base_url);
108                    }
109                    match client_builder.api_key(&agent_conf.api_key).build() {
110                        Ok(client) => {
111                            let agent = client
112                                .agent(&agent_conf.model_name)
113                                .name(agent_name.as_str())
114                                .preamble(&system_prompt)
115                                .build();
116                            self.agents.push((
117                                AgentVariant::Gemini(agent),
118                                agent_conf.id,
119                                agent_conf.provider.to_string(),
120                                agent_conf.model_name,
121                            ));
122                        }
123                        Err(err) => {
124                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
125                        }
126                    }
127                }
128                ProviderEnum::Huggingface => match huggingface::Client::new(&agent_conf.api_key) {
129                    Ok(client) => {
130                        let agent = client
131                            .agent(&agent_conf.model_name)
132                            .name(agent_name.as_str())
133                            .preamble(&system_prompt)
134                            .build();
135                        self.agents.push((
136                            AgentVariant::Huggingface(agent),
137                            agent_conf.id,
138                            agent_conf.provider.to_string(),
139                            agent_conf.model_name,
140                        ));
141                    }
142                    Err(err) => {
143                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
144                    }
145                },
146                ProviderEnum::Mistral => match mistral::Client::new(&agent_conf.api_key) {
147                    Ok(client) => {
148                        let agent = client
149                            .agent(&agent_conf.model_name)
150                            .name(agent_name.as_str())
151                            .preamble(&system_prompt)
152                            .build();
153                        self.agents.push((
154                            AgentVariant::Mistral(agent),
155                            agent_conf.id,
156                            agent_conf.provider.to_string(),
157                            agent_conf.model_name,
158                        ));
159                    }
160                    Err(err) => {
161                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
162                    }
163                },
164                ProviderEnum::OpenAi => {
165                    let mut client_builder = openai::Client::builder();
166                    if let Some(api_base_url) = &agent_conf.api_base_url {
167                        client_builder = client_builder.base_url(api_base_url);
168                    }
169                    match client_builder.api_key(&agent_conf.api_key).build() {
170                        Ok(client) => {
171                            let agent = client
172                                .completions_api()
173                                .agent(&agent_conf.model_name)
174                                .name(agent_name.as_str())
175                                .preamble(&system_prompt)
176                                .build();
177
178                            self.agents.push((
179                                AgentVariant::OpenAI(agent),
180                                agent_conf.id,
181                                agent_conf.provider.to_string(),
182                                agent_conf.model_name,
183                            ));
184                        }
185                        Err(err) => {
186                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
187                        }
188                    }
189                }
190                ProviderEnum::ResponsesOpenAi => {
191                    let mut client_builder = openai::Client::builder();
192                    if let Some(api_base_url) = &agent_conf.api_base_url {
193                        client_builder = client_builder.base_url(api_base_url);
194                    }
195                    match client_builder.api_key(&agent_conf.api_key).build() {
196                        Ok(client) => {
197                            let agent = client
198                                .agent(&agent_conf.model_name)
199                                .name(agent_name.as_str())
200                                .preamble(&system_prompt)
201                                .build();
202                            self.agents.push((
203                                AgentVariant::ResponsesOpenAI(agent),
204                                agent_conf.id,
205                                agent_conf.provider.to_string(),
206                                agent_conf.model_name,
207                            ));
208                        }
209                        Err(err) => {
210                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
211                        }
212                    }
213                }
214                ProviderEnum::OpenRouter => {
215                    let mut client_builder = openrouter::Client::builder();
216                    if let Some(api_base_url) = &agent_conf.api_base_url {
217                        client_builder = client_builder.base_url(api_base_url);
218                    }
219                    match client_builder.api_key(&agent_conf.api_key).build() {
220                        Ok(client) => {
221                            let agent = client
222                                .agent(&agent_conf.model_name)
223                                .name(agent_name.as_str())
224                                .preamble(&system_prompt)
225                                .build();
226                            self.agents.push((
227                                AgentVariant::OpenRouter(agent),
228                                agent_conf.id,
229                                agent_conf.provider.to_string(),
230                                agent_conf.model_name,
231                            ));
232                        }
233                        Err(err) => {
234                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
235                        }
236                    }
237                }
238                ProviderEnum::Together => match together::Client::new(&agent_conf.api_key) {
239                    Ok(client) => {
240                        let agent = client
241                            .agent(&agent_conf.model_name)
242                            .name(agent_name.as_str())
243                            .preamble(&system_prompt)
244                            .build();
245                        self.agents.push((
246                            AgentVariant::Together(agent),
247                            agent_conf.id,
248                            agent_conf.provider.to_string(),
249                            agent_conf.model_name,
250                        ));
251                    }
252                    Err(err) => {
253                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
254                    }
255                },
256                ProviderEnum::XAI => match xai::Client::new(&agent_conf.api_key) {
257                    Ok(client) => {
258                        let agent = client
259                            .agent(&agent_conf.model_name)
260                            .name(agent_name.as_str())
261                            .preamble(&system_prompt)
262                            .build();
263                        self.agents.push((
264                            AgentVariant::XAI(agent),
265                            agent_conf.id,
266                            agent_conf.provider.to_string(),
267                            agent_conf.model_name,
268                        ));
269                    }
270                    Err(err) => {
271                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
272                    }
273                },
274                ProviderEnum::Azure => {
275                    tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
276                }
277                ProviderEnum::DeepSeek => match deepseek::Client::new(&agent_conf.api_key) {
278                    Ok(client) => {
279                        let agent = client
280                            .agent(&agent_conf.model_name)
281                            .name(agent_name.as_str())
282                            .preamble(&system_prompt)
283                            .build();
284                        self.agents.push((
285                            AgentVariant::DeepSeek(agent),
286                            agent_conf.id,
287                            agent_conf.provider.to_string(),
288                            agent_conf.model_name,
289                        ));
290                    }
291                    Err(err) => {
292                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
293                    }
294                },
295                ProviderEnum::Galadriel => match galadriel::Client::new(&agent_conf.api_key) {
296                    Ok(client) => {
297                        let agent = client
298                            .agent(&agent_conf.model_name)
299                            .name(agent_name.as_str())
300                            .preamble(&system_prompt)
301                            .build();
302                        self.agents.push((
303                            AgentVariant::Galadriel(agent),
304                            agent_conf.id,
305                            agent_conf.provider.to_string(),
306                            agent_conf.model_name,
307                        ));
308                    }
309                    Err(err) => {
310                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
311                    }
312                },
313                ProviderEnum::Groq => match groq::Client::new(&agent_conf.api_key) {
314                    Ok(client) => {
315                        let agent = client
316                            .agent(&agent_conf.model_name)
317                            .name(agent_name.as_str())
318                            .preamble(&system_prompt)
319                            .build();
320                        self.agents.push((
321                            AgentVariant::Groq(agent),
322                            agent_conf.id,
323                            agent_conf.provider.to_string(),
324                            agent_conf.model_name,
325                        ));
326                    }
327                    Err(err) => {
328                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
329                    }
330                },
331                ProviderEnum::Hyperbolic => match hyperbolic::Client::new(&agent_conf.api_key) {
332                    Ok(client) => {
333                        let agent = client
334                            .agent(&agent_conf.model_name)
335                            .name(agent_name.as_str())
336                            .preamble(&system_prompt)
337                            .build();
338                        self.agents.push((
339                            AgentVariant::Hyperbolic(agent),
340                            agent_conf.id,
341                            agent_conf.provider.to_string(),
342                            agent_conf.model_name,
343                        ));
344                    }
345                    Err(err) => {
346                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
347                    }
348                },
349                ProviderEnum::Mira => match mira::Client::new(&agent_conf.api_key) {
350                    Ok(client) => {
351                        let agent = client
352                            .agent(&agent_conf.model_name)
353                            .name(agent_name.as_str())
354                            .preamble(&system_prompt)
355                            .build();
356                        self.agents.push((
357                            AgentVariant::Mira(agent),
358                            agent_conf.id,
359                            agent_conf.provider.to_string(),
360                            agent_conf.model_name,
361                        ));
362                    }
363                    Err(err) => {
364                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
365                    }
366                },
367                ProviderEnum::Mooshot => match moonshot::Client::new(&agent_conf.api_key) {
368                    Ok(client) => {
369                        let agent = client
370                            .agent(&agent_conf.model_name)
371                            .name(agent_name.as_str())
372                            .preamble(&system_prompt)
373                            .build();
374                        self.agents.push((
375                            AgentVariant::Mooshot(agent),
376                            agent_conf.id,
377                            agent_conf.provider.to_string(),
378                            agent_conf.model_name,
379                        ));
380                    }
381                    Err(err) => {
382                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
383                    }
384                },
385                ProviderEnum::Ollama => {
386                    // Ollama uses Nothing as API key type, which means it doesn't need an API key
387                    // In rig-core 0.25, we need to use builder pattern with no API key
388                    use rig::client::Nothing;
389                    let client_result = if let Some(api_base_url) = &agent_conf.api_base_url {
390                        ollama::Client::builder()
391                            .base_url(api_base_url)
392                            .api_key(Nothing)
393                            .build()
394                    } else {
395                        ollama::Client::builder().api_key(Nothing).build()
396                    };
397                    match client_result {
398                        Ok(client) => {
399                            let agent = client
400                                .agent(&agent_conf.model_name)
401                                .name(agent_name.as_str())
402                                .preamble(&system_prompt)
403                                .build();
404                            self.agents.push((
405                                AgentVariant::Ollama(agent),
406                                agent_conf.id,
407                                agent_conf.provider.to_string(),
408                                agent_conf.model_name,
409                            ));
410                        }
411                        Err(err) => {
412                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
413                        }
414                    }
415                }
416                ProviderEnum::Perplexity => match perplexity::Client::new(&agent_conf.api_key) {
417                    Ok(client) => {
418                        let agent = client
419                            .agent(&agent_conf.model_name)
420                            .name(agent_name.as_str())
421                            .preamble(&system_prompt)
422                            .build();
423                        self.agents.push((
424                            AgentVariant::Perplexity(agent),
425                            agent_conf.id,
426                            agent_conf.provider.to_string(),
427                            agent_conf.model_name,
428                        ));
429                    }
430                    Err(err) => {
431                        tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
432                    }
433                },
434                ProviderEnum::Bigmodel => {
435                    let client = if let Some(api_base_url) = agent_conf.api_base_url {
436                        bigmodel::Client::from_url(&agent_conf.api_key, &api_base_url)
437                    } else {
438                        bigmodel::Client::new(&agent_conf.api_key)
439                    };
440                    let model = client.completion_model(&agent_conf.model_name);
441                    let agent = AgentBuilder::new(model)
442                        .name(agent_name.as_str())
443                        .preamble(&system_prompt)
444                        .build();
445                    self.agents.push((
446                        AgentVariant::Bigmodel(agent),
447                        agent_conf.id,
448                        agent_conf.provider.to_string(),
449                        agent_conf.model_name,
450                    ));
451                }
452            }
453        }
454        self
455    }
456}