rig_extra/
simple_rand_builder.rs

1use crate::extra_providers::bigmodel;
2use crate::get_openai_agent::get_openai_agent;
3use crate::rand_agent::RandAgentBuilder;
4use rig::client::completion::CompletionClientDyn;
5use rig::providers::*;
6use serde::{Deserialize, Serialize};
7use strum_macros::Display;
8
9#[derive(Debug, Display, Deserialize, Serialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ProviderEnum {
12    Anthropic,
13    Cohere,
14    Gemini,
15    Huggingface,
16    Mistral,
17    OpenAi,
18    OpenRouter,
19    Together,
20    XAI,
21    Azure,
22    DeepSeek,
23    Galadriel,
24    Groq,
25    Hyperbolic,
26    Mira,
27    Mooshot,
28    Ollama,
29    Perplexity,
30    // embedding模型
31    // Voyageai,
32    Bigmodel,
33}
34
35#[derive(Debug, Deserialize)]
36pub struct AgentConfig {
37    pub id: i32,
38    pub provider: ProviderEnum,
39    pub model_name: String,
40    pub api_key: String,
41    pub api_base_url: Option<String>,
42    pub system_prompt: Option<String>,
43    pub agent_name: Option<String>,
44}
45
46impl RandAgentBuilder {
47    /// 简单构建器
48    pub fn simple_builder(
49        mut self,
50        agent_configs: Vec<AgentConfig>,
51        global_system_prompt: String,
52    ) -> Self {
53        for agent_conf in agent_configs {
54            let agent_name = agent_conf.agent_name.unwrap_or("rand agent".to_string());
55            let system_prompt = agent_conf
56                .system_prompt
57                .unwrap_or(global_system_prompt.clone());
58
59            match agent_conf.provider {
60                ProviderEnum::Anthropic => {
61                    let mut client_builder = anthropic::Client::builder(&agent_conf.api_key);
62                    if let Some(api_base_url) = &agent_conf.api_base_url {
63                        client_builder = client_builder.base_url(api_base_url);
64                    }
65                    match client_builder.build() {
66                        Ok(client) => {
67                            let agent = client
68                                .agent(&agent_conf.model_name)
69                                .name(agent_name.as_str())
70                                .preamble(&system_prompt)
71                                .build();
72                            self.agents.push((
73                                agent,
74                                agent_conf.id,
75                                agent_conf.provider.to_string(),
76                                agent_conf.model_name,
77                            ));
78                        }
79                        Err(err) => {
80                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
81                        }
82                    }
83                }
84                ProviderEnum::Cohere => {
85                    let client = cohere::Client::new(&agent_conf.api_key);
86                    let agent = client
87                        .agent(&agent_conf.model_name)
88                        .name(agent_name.as_str())
89                        .preamble(&system_prompt)
90                        .build();
91                    self.agents.push((
92                        agent,
93                        agent_conf.id,
94                        agent_conf.provider.to_string(),
95                        agent_conf.model_name,
96                    ));
97                }
98                ProviderEnum::Gemini => {
99                    let mut client_builder = gemini::Client::builder(&agent_conf.api_key);
100                    if let Some(api_base_url) = &agent_conf.api_base_url {
101                        client_builder = client_builder.base_url(api_base_url);
102                    }
103                    match client_builder.build() {
104                        Ok(client) => {
105                            let agent = client
106                                .agent(&agent_conf.model_name)
107                                .name(agent_name.as_str())
108                                .preamble(&system_prompt)
109                                .build();
110                            self.agents.push((
111                                agent,
112                                agent_conf.id,
113                                agent_conf.provider.to_string(),
114                                agent_conf.model_name,
115                            ));
116                        }
117                        Err(err) => {
118                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
119                        }
120                    }
121                }
122                ProviderEnum::Huggingface => {
123                    let client = huggingface::Client::new(&agent_conf.api_key);
124                    let agent = client
125                        .agent(&agent_conf.model_name)
126                        .name(agent_name.as_str())
127                        .preamble(&system_prompt)
128                        .build();
129                    self.agents.push((
130                        agent,
131                        agent_conf.id,
132                        agent_conf.provider.to_string(),
133                        agent_conf.model_name,
134                    ));
135                }
136                ProviderEnum::Mistral => {
137                    let client = mistral::Client::new(&agent_conf.api_key);
138                    let agent = client
139                        .agent(&agent_conf.model_name)
140                        .name(agent_name.as_str())
141                        .preamble(&system_prompt)
142                        .build();
143                    self.agents.push((
144                        agent,
145                        agent_conf.id,
146                        agent_conf.provider.to_string(),
147                        agent_conf.model_name,
148                    ));
149                }
150                ProviderEnum::OpenAi => {
151                    let mut client_builder = openai::Client::builder(&agent_conf.api_key);
152                    if let Some(api_base_url) = &agent_conf.api_base_url {
153                        client_builder = client_builder.base_url(api_base_url)
154                    }
155
156                    match client_builder.build() {
157                        Ok(client) => {
158                            let agent = get_openai_agent(
159                                client,
160                                &agent_conf.model_name,
161                                agent_name,
162                                system_prompt,
163                            );
164                            self.agents.push((
165                                agent,
166                                agent_conf.id,
167                                agent_conf.provider.to_string(),
168                                agent_conf.model_name,
169                            ));
170                        }
171                        Err(err) => {
172                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
173                        }
174                    }
175                }
176                ProviderEnum::OpenRouter => {
177                    let mut client_builder = openrouter::Client::builder(&agent_conf.api_key);
178                    if let Some(api_base_url) = &agent_conf.api_base_url {
179                        client_builder = client_builder.base_url(api_base_url)
180                    }
181
182                    match client_builder.build() {
183                        Ok(client) => {
184                            let agent = client
185                                .agent(&agent_conf.model_name)
186                                .name(agent_name.as_str())
187                                .preamble(&system_prompt)
188                                .build();
189                            self.agents.push((
190                                agent,
191                                agent_conf.id,
192                                agent_conf.provider.to_string(),
193                                agent_conf.model_name,
194                            ));
195                        }
196                        Err(err) => {
197                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
198                        }
199                    }
200                }
201                ProviderEnum::Together => {
202                    let client = together::Client::new(&agent_conf.api_key);
203                    let agent = client
204                        .agent(&agent_conf.model_name)
205                        .name(agent_name.as_str())
206                        .preamble(&system_prompt)
207                        .build();
208                    self.agents.push((
209                        agent,
210                        agent_conf.id,
211                        agent_conf.provider.to_string(),
212                        agent_conf.model_name,
213                    ));
214                }
215                ProviderEnum::XAI => {
216                    let client = xai::Client::new(&agent_conf.api_key);
217                    let agent = client
218                        .agent(&agent_conf.model_name)
219                        .name(agent_name.as_str())
220                        .preamble(&system_prompt)
221                        .build();
222                    self.agents.push((
223                        agent,
224                        agent_conf.id,
225                        agent_conf.provider.to_string(),
226                        agent_conf.model_name,
227                    ));
228                }
229                ProviderEnum::Azure => {
230                    tracing::info!("Azure simple_builder暂不支持,参数有点多,可以自行添加........ ")
231                }
232                ProviderEnum::DeepSeek => {
233                    let client = deepseek::Client::new(&agent_conf.api_key);
234                    let agent = client
235                        .agent(&agent_conf.model_name)
236                        .name(agent_name.as_str())
237                        .preamble(&system_prompt)
238                        .build();
239                    self.agents.push((
240                        agent,
241                        agent_conf.id,
242                        agent_conf.provider.to_string(),
243                        agent_conf.model_name,
244                    ));
245                }
246                ProviderEnum::Galadriel => {
247                    let client = galadriel::Client::new(&agent_conf.api_key);
248                    let agent = client
249                        .agent(&agent_conf.model_name)
250                        .name(agent_name.as_str())
251                        .preamble(&system_prompt)
252                        .build();
253                    self.agents.push((
254                        agent,
255                        agent_conf.id,
256                        agent_conf.provider.to_string(),
257                        agent_conf.model_name,
258                    ));
259                }
260                ProviderEnum::Groq => {
261                    let client = groq::Client::new(&agent_conf.api_key);
262                    let agent = client
263                        .agent(&agent_conf.model_name)
264                        .name(agent_name.as_str())
265                        .preamble(&system_prompt)
266                        .build();
267                    self.agents.push((
268                        agent,
269                        agent_conf.id,
270                        agent_conf.provider.to_string(),
271                        agent_conf.model_name,
272                    ));
273                }
274                ProviderEnum::Hyperbolic => {
275                    let client = hyperbolic::Client::new(&agent_conf.api_key);
276                    let agent = client
277                        .agent(&agent_conf.model_name)
278                        .name(agent_name.as_str())
279                        .preamble(&system_prompt)
280                        .build();
281                    self.agents.push((
282                        agent,
283                        agent_conf.id,
284                        agent_conf.provider.to_string(),
285                        agent_conf.model_name,
286                    ));
287                }
288                ProviderEnum::Mira => {
289                    let client = mira::Client::new(&agent_conf.api_key);
290                    let agent = client
291                        .agent(&agent_conf.model_name)
292                        .name(agent_name.as_str())
293                        .preamble(&system_prompt)
294                        .build();
295                    self.agents.push((
296                        agent,
297                        agent_conf.id,
298                        agent_conf.provider.to_string(),
299                        agent_conf.model_name,
300                    ));
301                }
302                ProviderEnum::Mooshot => {
303                    let client = moonshot::Client::new(&agent_conf.api_key);
304                    let agent = client
305                        .agent(&agent_conf.model_name)
306                        .name(agent_name.as_str())
307                        .preamble(&system_prompt)
308                        .build();
309                    self.agents.push((
310                        agent,
311                        agent_conf.id,
312                        agent_conf.provider.to_string(),
313                        agent_conf.model_name,
314                    ));
315                }
316                ProviderEnum::Ollama => {
317                    let mut client_builder = ollama::Client::builder();
318                    if let Some(api_base_url) = &agent_conf.api_base_url {
319                        client_builder = client_builder.base_url(api_base_url);
320                    }
321
322                    match client_builder.build() {
323                        Ok(client) => {
324                            let agent = client
325                                .agent(&agent_conf.model_name)
326                                .name(agent_name.as_str())
327                                .preamble(&system_prompt)
328                                .build();
329                            self.agents.push((
330                                agent,
331                                agent_conf.id,
332                                agent_conf.provider.to_string(),
333                                agent_conf.model_name,
334                            ));
335                        }
336                        Err(err) => {
337                            tracing::error!("添加 {} 错误: {}", agent_conf.provider, err);
338                        }
339                    }
340                }
341                ProviderEnum::Perplexity => {
342                    // let client = perplexity::Client::new(&agent_conf.api_key);
343                    // let agent = client
344                    //     .agent(&agent_conf.model_name)
345                    //     .name(agent_name.as_str())
346                    //     .preamble(&system_prompt)
347                    //     .build();
348                    // self.agents.push((
349                    //     agent,
350                    //     agent_conf.id,
351                    //     agent_conf.provider.to_string(),
352                    //     agent_conf.model_name,
353                    // ));
354                    tracing::info!("Perplexity 暂不支持,没有实现BoxAgent........ ")
355                }
356                ProviderEnum::Bigmodel => {
357                    let client = if let Some(api_base_url) = agent_conf.api_base_url {
358                        bigmodel::Client::from_url(&agent_conf.api_key, &api_base_url)
359                    } else {
360                        bigmodel::Client::new(&agent_conf.api_key)
361                    };
362                    let agent = client
363                        .agent(&agent_conf.model_name)
364                        .name(agent_name.as_str())
365                        .preamble(&system_prompt)
366                        .build();
367                    self.agents.push((
368                        agent,
369                        agent_conf.id,
370                        agent_conf.provider.to_string(),
371                        agent_conf.model_name,
372                    ));
373                }
374            }
375        }
376        self
377    }
378}