rig_extra/
simple_rand_builder.rs

1use crate::extra_providers::bigmodel;
2use crate::rand_agent::RandAgentBuilder;
3use rig::client::completion::CompletionClientDyn;
4use rig::providers::*;
5use serde::{Deserialize, Serialize};
6use strum_macros::Display;
7
8#[derive(Debug, Display, Deserialize, Serialize)]
9#[serde(rename_all = "lowercase")]
10pub enum ProviderEnum {
11    Anthropic,
12    Cohere,
13    Gemini,
14    Huggingface,
15    Mistral,
16    OpenAi,
17    OpenRouter,
18    Together,
19    XAI,
20    Azure,
21    DeepSeek,
22    Galadriel,
23    Groq,
24    Hyperbolic,
25    Mira,
26    Mooshot,
27    Ollama,
28    Perplexity,
29    Voyageai,
30    Bigmodel,
31}
32
33#[derive(Debug, Deserialize)]
34pub struct AgentConfig {
35    pub id: i32,
36    pub provider: ProviderEnum,
37    pub model_name: String,
38    pub api_key: String,
39    pub api_base_url: Option<String>,
40    pub system_prompt: Option<String>,
41    pub agent_name: Option<String>,
42}
43
44impl RandAgentBuilder {
45    /// 简单构建器
46    pub fn simple_builder(
47        mut self,
48        agent_configs: Vec<AgentConfig>,
49        global_system_prompt: String,
50    ) -> Self {
51        for agent_conf in agent_configs {
52            let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
53            let system_prompt = agent_conf
54                .system_prompt
55                .unwrap_or(global_system_prompt.clone());
56
57            match agent_conf.provider {
58                ProviderEnum::Anthropic => {
59                    let mut client_builder = anthropic::Client::builder(&agent_conf.api_key);
60                    if let Some(api_base_url) = &agent_conf.api_base_url {
61                        client_builder = client_builder.base_url(api_base_url);
62                    }
63                    match client_builder.build() {
64                        Ok(client) => {
65                            let agent = client
66                                .agent(&agent_conf.model_name)
67                                .name(agent_name.as_str())
68                                .preamble(&system_prompt)
69                                .build();
70                            self.agents.push((
71                                agent,
72                                agent_conf.id,
73                                agent_conf.provider.to_string(),
74                                agent_conf.model_name,
75                            ));
76                        }
77                        Err(err) => {
78                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
79                        }
80                    }
81                }
82                ProviderEnum::Cohere => {
83                    let client = cohere::Client::new(&agent_conf.api_key);
84                    let agent = client
85                        .agent(&agent_conf.model_name)
86                        .name(agent_name.as_str())
87                        .preamble(&system_prompt)
88                        .build();
89                    self.agents.push((
90                        agent,
91                        agent_conf.id,
92                        agent_conf.provider.to_string(),
93                        agent_conf.model_name,
94                    ));
95                }
96                ProviderEnum::Gemini => {
97                    tracing::info!("Gemini 暂不支持,没有实现BoxAgent........ ")
98                }
99                ProviderEnum::Huggingface => {
100                    let client = huggingface::Client::new(&agent_conf.api_key);
101                    let agent = client
102                        .agent(&agent_conf.model_name)
103                        .name(agent_name.as_str())
104                        .preamble(&system_prompt)
105                        .build();
106                    self.agents.push((
107                        agent,
108                        agent_conf.id,
109                        agent_conf.provider.to_string(),
110                        agent_conf.model_name,
111                    ));
112                }
113                ProviderEnum::Mistral => {
114                    let client = mistral::Client::new(&agent_conf.api_key);
115                    let agent = client
116                        .agent(&agent_conf.model_name)
117                        .name(agent_name.as_str())
118                        .preamble(&system_prompt)
119                        .build();
120                    self.agents.push((
121                        agent,
122                        agent_conf.id,
123                        agent_conf.provider.to_string(),
124                        agent_conf.model_name,
125                    ));
126                }
127                ProviderEnum::OpenAi => {
128                    let mut client_builder = openai::Client::builder(&agent_conf.api_key);
129                    if let Some(api_base_url) = &agent_conf.api_base_url {
130                        client_builder = client_builder.base_url(api_base_url)
131                    }
132
133                    match client_builder.build() {
134                        Ok(client) => {
135                            // 不支持 completions_api,至少ollama使用这个会报错
136                            let agent = client
137                                .agent(&agent_conf.model_name)
138                                .name(agent_name.as_str())
139                                .preamble(&system_prompt)
140                                .build();
141                            self.agents.push((
142                                agent,
143                                agent_conf.id,
144                                agent_conf.provider.to_string(),
145                                agent_conf.model_name,
146                            ));
147                        }
148                        Err(err) => {
149                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
150                        }
151                    }
152                }
153                ProviderEnum::OpenRouter => {
154                    let mut client_builder = openrouter::Client::builder(&agent_conf.api_key);
155                    if let Some(api_base_url) = &agent_conf.api_base_url {
156                        client_builder = client_builder.base_url(api_base_url)
157                    }
158
159                    match client_builder.build() {
160                        Ok(client) => {
161                            let agent = client
162                                .agent(&agent_conf.model_name)
163                                .name(agent_name.as_str())
164                                .preamble(&system_prompt)
165                                .build();
166                            self.agents.push((
167                                agent,
168                                agent_conf.id,
169                                agent_conf.provider.to_string(),
170                                agent_conf.model_name,
171                            ));
172                        }
173                        Err(err) => {
174                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
175                        }
176                    }
177                }
178                ProviderEnum::Together => {
179                    let client = together::Client::new(&agent_conf.api_key);
180                    let agent = client
181                        .agent(&agent_conf.model_name)
182                        .name(agent_name.as_str())
183                        .preamble(&system_prompt)
184                        .build();
185                    self.agents.push((
186                        agent,
187                        agent_conf.id,
188                        agent_conf.provider.to_string(),
189                        agent_conf.model_name,
190                    ));
191                }
192                ProviderEnum::XAI => {
193                    let client = xai::Client::new(&agent_conf.api_key);
194                    let agent = client
195                        .agent(&agent_conf.model_name)
196                        .name(agent_name.as_str())
197                        .preamble(&system_prompt)
198                        .build();
199                    self.agents.push((
200                        agent,
201                        agent_conf.id,
202                        agent_conf.provider.to_string(),
203                        agent_conf.model_name,
204                    ));
205                }
206                ProviderEnum::Azure => {
207                    tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
208                }
209                ProviderEnum::DeepSeek => {
210                    let client = deepseek::Client::new(&agent_conf.api_key);
211                    let agent = client
212                        .agent(&agent_conf.model_name)
213                        .name(agent_name.as_str())
214                        .preamble(&system_prompt)
215                        .build();
216                    self.agents.push((
217                        agent,
218                        agent_conf.id,
219                        agent_conf.provider.to_string(),
220                        agent_conf.model_name,
221                    ));
222                }
223                ProviderEnum::Galadriel => {
224                    tracing::info!("Galadriel simple_builder暂不支持,可以自行添加........ ")
225                }
226                ProviderEnum::Groq => {
227                    let client = groq::Client::new(&agent_conf.api_key);
228                    let agent = client
229                        .agent(&agent_conf.model_name)
230                        .name(agent_name.as_str())
231                        .preamble(&system_prompt)
232                        .build();
233                    self.agents.push((
234                        agent,
235                        agent_conf.id,
236                        agent_conf.provider.to_string(),
237                        agent_conf.model_name,
238                    ));
239                }
240                ProviderEnum::Hyperbolic => {
241                    let client = hyperbolic::Client::new(&agent_conf.api_key);
242                    let agent = client
243                        .agent(&agent_conf.model_name)
244                        .name(agent_name.as_str())
245                        .preamble(&system_prompt)
246                        .build();
247                    self.agents.push((
248                        agent,
249                        agent_conf.id,
250                        agent_conf.provider.to_string(),
251                        agent_conf.model_name,
252                    ));
253                }
254                ProviderEnum::Mira => {
255                    let client = mira::Client::new(&agent_conf.api_key);
256                    let agent = client
257                        .agent(&agent_conf.model_name)
258                        .name(agent_name.as_str())
259                        .preamble(&system_prompt)
260                        .build();
261                    self.agents.push((
262                        agent,
263                        agent_conf.id,
264                        agent_conf.provider.to_string(),
265                        agent_conf.model_name,
266                    ));
267                }
268                ProviderEnum::Mooshot => {
269                    let client = moonshot::Client::new(&agent_conf.api_key);
270                    let agent = client
271                        .agent(&agent_conf.model_name)
272                        .name(agent_name.as_str())
273                        .preamble(&system_prompt)
274                        .build();
275                    self.agents.push((
276                        agent,
277                        agent_conf.id,
278                        agent_conf.provider.to_string(),
279                        agent_conf.model_name,
280                    ));
281                }
282                ProviderEnum::Ollama => {
283                    let mut client_builder = ollama::Client::builder();
284                    if let Some(api_base_url) = &agent_conf.api_base_url {
285                        client_builder = client_builder.base_url(api_base_url);
286                    }
287
288                    match client_builder.build() {
289                        Ok(client) => {
290                            let agent = client
291                                .agent(&agent_conf.model_name)
292                                .name(agent_name.as_str())
293                                .preamble(&system_prompt)
294                                .build();
295                            self.agents.push((
296                                agent,
297                                agent_conf.id,
298                                agent_conf.provider.to_string(),
299                                agent_conf.model_name,
300                            ));
301                        }
302                        Err(err) => {
303                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
304                        }
305                    }
306                }
307                ProviderEnum::Perplexity => {
308                    tracing::info!("Perplexity 暂不支持,没有实现BoxAgent........ ")
309                }
310                ProviderEnum::Voyageai => {
311                    tracing::info!("Voyageai 暂不支持,........ ")
312                }
313                ProviderEnum::Bigmodel => {
314                    let client = if let Some(api_base_url) = agent_conf.api_base_url {
315                        bigmodel::Client::from_url(&agent_conf.api_key, &api_base_url)
316                    } else {
317                        bigmodel::Client::new(&agent_conf.api_key)
318                    };
319                    let agent = client
320                        .agent(&agent_conf.model_name)
321                        .name(agent_name.as_str())
322                        .preamble(&system_prompt)
323                        .build();
324                    self.agents.push((
325                        agent,
326                        agent_conf.id,
327                        agent_conf.provider.to_string(),
328                        agent_conf.model_name,
329                    ));
330                }
331            }
332        }
333        self
334    }
335}