rig_extra/extra_providers/
bigmodel.rs

1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message, client};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16// ================================================================
17// BIGMODEL 客户端
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    base_url: String,
24    http_client: reqwest::Client,
25}
26
27impl Client {
28    pub fn new(api_key: &str) -> Self {
29        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30    }
31
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            http_client: reqwest::Client::builder()
36                .default_headers({
37                    let mut headers = reqwest::header::HeaderMap::new();
38                    headers.insert(
39                        "Authorization",
40                        format!("Bearer {api_key}")
41                            .parse()
42                            .expect("Bearer token should parse"),
43                    );
44                    headers
45                })
46                .build()
47                .expect("bigmodel reqwest client should build"),
48        }
49    }
50
51
52
53    fn post(&self, path: &str) -> reqwest::RequestBuilder {
54        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55        self.http_client.post(url)
56    }
57
58    pub fn completion_model(&self, model: &str) -> CompletionModel {
59        CompletionModel::new(self.clone(), model)
60    }
61
62
63
64    /// 为completion模型创建提取构建器
65    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66        &self,
67        model: &str,
68    ) -> ExtractorBuilder<T, CompletionModel> {
69        ExtractorBuilder::new(self.completion_model(model))
70    }
71}
72
73impl ProviderClient for Client {
74    fn from_env() -> Self
75    where
76        Self: Sized
77    {
78        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79        Self::new(&api_key)
80    }
81
82    fn from_val(input: ProviderValue) -> Self
83    where
84        Self: Sized
85    {
86        let  client::ProviderValue::Simple(api_key) = input else {
87            panic!("Incorrect provider value type")
88        };
89        Self::new(&api_key)
90    }
91}
92
93impl AsTranscription for Client {}
94
95impl AsEmbeddings for Client {}
96
97
98
99impl CompletionClient for Client {
100    type CompletionModel = CompletionModel;
101
102    fn completion_model(&self, model: &str) -> Self::CompletionModel {
103        CompletionModel::new(self.clone(), model)
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119// ================================================================
120// Bigmodel Completion API
121// ================================================================
122pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
123
124#[derive(Debug, Deserialize,Serialize)]
125#[serde(rename_all = "camelCase")]
126pub struct CompletionResponse {
127    pub choices: Vec<Choice>,
128    pub created: i64,
129    pub id: String,
130    pub model: String,
131    #[serde(rename = "request_id")]
132    pub request_id: String,
133    pub usage: Usage,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "role", rename_all = "lowercase")]
138pub enum Message {
139    User {
140        content: String,
141    },
142    Assistant {
143        content: Option<String>,
144        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
145        tool_calls: Vec<ToolCall>,
146    },
147    System {
148        content: String,
149    },
150    #[serde(rename = "tool")]
151    ToolResult {
152        tool_call_id: String,
153        content: String,
154    },
155}
156
157impl Message {
158    pub fn system(content: &str) -> Message {
159        Message::System {
160            content: content.to_owned(),
161        }
162    }
163}
164
165#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
166pub struct ToolResultContent {
167    text: String,
168}
169impl TryFrom<message::ToolResultContent> for ToolResultContent {
170    type Error = MessageError;
171    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
172        let message::ToolResultContent::Text(Text { text }) = value else {
173            return Err(MessageError::ConversionError(
174                "Non-text tool results not supported".into(),
175            ));
176        };
177
178        Ok(Self { text })
179    }
180}
181
182impl TryFrom<message::Message> for Message {
183    type Error = MessageError;
184
185    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
186        Ok(match message {
187            message::Message::User { content } => {
188                let mut texts = Vec::new();
189                let mut images = Vec::new();
190
191                for uc in content.into_iter() {
192                    match uc {
193                        message::UserContent::Text(message::Text { text }) => texts.push(text),
194                        message::UserContent::Image(img) => images.push(img.data),
195                        message::UserContent::ToolResult(result) => {
196                            let content = result
197                                .content
198                                .into_iter()
199                                .map(ToolResultContent::try_from)
200                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
201
202                            let content = OneOrMany::many(content).map_err(|x| {
203                                MessageError::ConversionError(format!(
204                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
205                                ))
206                            })?;
207
208                            return Ok(Message::ToolResult {
209                                tool_call_id: result.id,
210                                content: content.first().text,
211                            });
212                        }
213                        _ => {}
214                    }
215                }
216
217                let collapsed_content = texts.join(" ");
218
219                Message::User {
220                    content: collapsed_content,
221                }
222            }
223            message::Message::Assistant { content, .. } => {
224                let mut texts = Vec::new();
225                let mut tool_calls = Vec::new();
226
227                for ac in content.into_iter() {
228                    match ac {
229                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
230                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
231                        _ => {}
232                    }
233                }
234
235                let collapsed_content = texts.join(" ");
236
237                Message::Assistant {
238                    content: Some(collapsed_content),
239                    tool_calls,
240                }
241            }
242        })
243    }
244}
245
246impl From<message::ToolResult> for Message {
247    fn from(tool_result: message::ToolResult) -> Self {
248        let content = match tool_result.content.first() {
249            message::ToolResultContent::Text(text) => text.text,
250            message::ToolResultContent::Image(_) => String::from("[Image]"),
251        };
252
253        Message::ToolResult {
254            tool_call_id: tool_result.id,
255            content,
256        }
257    }
258}
259
260#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
261#[serde(rename_all = "camelCase")]
262pub struct ToolCall {
263    pub function: CallFunction,
264    pub id: String,
265    pub index: usize,
266    #[serde(default)]
267    pub r#type: ToolType,
268}
269
270impl From<message::ToolCall> for ToolCall {
271    fn from(tool_call: message::ToolCall) -> Self {
272        Self {
273            id: tool_call.id,
274            index: 0,
275            r#type: ToolType::Function,
276            function: CallFunction {
277                name: tool_call.function.name,
278                arguments: tool_call.function.arguments,
279            },
280        }
281    }
282}
283
284#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
285#[serde(rename_all = "lowercase")]
286pub enum ToolType {
287    #[default]
288    Function,
289}
290
291#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
292pub struct CallFunction {
293    pub name: String,
294    #[serde(with = "json_utils::stringified_json")]
295    pub arguments: serde_json::Value,
296}
297
298#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
299#[serde(rename_all = "lowercase")]
300pub enum Role {
301    System,
302    User,
303    Assistant,
304}
305
306#[derive(Debug, Serialize, Deserialize)]
307#[serde(rename_all = "camelCase")]
308pub struct Choice {
309    #[serde(rename = "finish_reason")]
310    pub finish_reason: String,
311    pub index: i64,
312    pub message: Message,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
316#[serde(rename_all = "camelCase")]
317pub struct Usage {
318    #[serde(rename = "completion_tokens")]
319    pub completion_tokens: i64,
320    #[serde(rename = "prompt_tokens")]
321    pub prompt_tokens: i64,
322    #[serde(rename = "total_tokens")]
323    pub total_tokens: i64,
324}
325
326impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
327    type Error = CompletionError;
328
329    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
330        let choice = response.choices.first().ok_or_else(|| {
331            CompletionError::ResponseError("Response contained no choices".to_owned())
332        })?;
333
334        match &choice.message {
335            Message::Assistant {
336                tool_calls,
337                content,
338            } => {
339                if !tool_calls.is_empty() {
340                    let tool_result = tool_calls
341                        .iter()
342                        .map(|call| {
343                            completion::AssistantContent::tool_call(
344                                &call.function.name,
345                                &call.function.name,
346                                call.function.arguments.clone(),
347                            )
348                        })
349                        .collect::<Vec<_>>();
350
351                    let choice = OneOrMany::many(tool_result).map_err(|_| {
352                        CompletionError::ResponseError(
353                            "Response contained no message or tool call (empty)".to_owned(),
354                        )
355                    })?;
356                    let usage = completion::Usage {
357                        input_tokens: response.usage.prompt_tokens as u64,
358                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
359                            as u64,
360                        total_tokens: response.usage.total_tokens as u64,
361                    };
362                    tracing::debug!("response choices: {:?}: ", choice);
363                    Ok(completion::CompletionResponse {
364                        choice,
365                        usage,
366                        raw_response: response,
367                    })
368                } else {
369                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
370                        text: content.clone().unwrap_or_else(|| "".to_owned()),
371                    }));
372                    let usage = completion::Usage {
373                        input_tokens: response.usage.prompt_tokens as u64,
374                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
375                            as u64,
376                        total_tokens: response.usage.total_tokens as u64,
377                    };
378                    Ok(completion::CompletionResponse {
379                        choice,
380                        usage,
381                        raw_response: response,
382                    })
383                }
384            }
385            // Message::Assistant { tool_calls } => {}
386            _ => Err(CompletionError::ResponseError(
387                "Chat response does not include an assistant message".into(),
388            )),
389        }
390    }
391}
392
393#[derive(Clone)]
394pub struct CompletionModel {
395    client: Client,
396    pub model: String,
397}
398
399
400
401// 函数定义
402#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
403#[serde(rename_all = "camelCase")]
404pub struct CustomFunctionDefinition {
405    #[serde(rename = "type")]
406    pub type_field: String,
407    pub function: Function,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
411#[serde(rename_all = "camelCase")]
412pub struct Function {
413    pub name: String,
414    pub description: String,
415    pub parameters: serde_json::Value,
416}
417
418impl CompletionModel {
419    pub fn new(client: Client, model: &str) -> Self {
420        Self {
421            client,
422            model: model.to_string(),
423        }
424    }
425
426    fn create_completion_request(
427        &self,
428        completion_request: CompletionRequest,
429    ) -> Result<Value, CompletionError> {
430        // 构建消息顺序(上下文、聊天历史、提示)
431        let mut partial_history = vec![];
432        if let Some(docs) = completion_request.normalized_documents() {
433            partial_history.push(docs);
434        }
435        partial_history.extend(completion_request.chat_history);
436
437        // 使用前言初始化完整历史(如果不存在则为空)
438        let mut full_history: Vec<Message> = completion_request
439            .preamble
440            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
441
442        // 转换并扩展其余历史
443        full_history.extend(
444            partial_history
445                .into_iter()
446                .map(message::Message::try_into)
447                .collect::<Result<Vec<Message>, _>>()?,
448        );
449
450        let request = if completion_request.tools.is_empty() {
451            json!({
452                "model": self.model,
453                "messages": full_history,
454                "temperature": completion_request.temperature,
455            })
456        } else {
457            // tools
458            let tools = completion_request
459                .tools
460                .into_iter()
461                .map(|item| {
462                    let custom_function = Function {
463                        name: item.name,
464                        description: item.description,
465                        parameters: item.parameters,
466                    };
467                    CustomFunctionDefinition {
468                        type_field: "function".to_string(),
469                        function: custom_function,
470                    }
471                })
472                .collect::<Vec<_>>();
473
474            tracing::debug!("tools: {:?}", tools);
475
476            json!({
477                "model": self.model,
478                "messages": full_history,
479                "temperature": completion_request.temperature,
480                "tools": tools,
481                "tool_choice": "auto",
482            })
483        };
484
485        let request = if let Some(params) = completion_request.additional_params {
486            json_utils::merge(request, params)
487        } else {
488            request
489        };
490
491        Ok(request)
492    }
493}
494
495/// 同步请求
496impl completion::CompletionModel for CompletionModel {
497    type Response = CompletionResponse;
498    type StreamingResponse = openai::StreamingCompletionResponse;
499
500    async fn completion(
501        &self,
502        completion_request: CompletionRequest,
503    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
504        tracing::debug!("create_completion_request========");
505        let request = self.create_completion_request(completion_request)?;
506
507        tracing::debug!(
508            "request: \r\n {}",
509            serde_json::to_string_pretty(&request).unwrap()
510        );
511
512        let response = self
513            .client
514            .post("/chat/completions")
515            .json(&request)
516            .send()
517            .await?;
518
519        if response.status().is_success() {
520            let data: Value = response.json().await.expect("api error");
521            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
522            let data: ApiResponse<CompletionResponse> =
523                serde_json::from_value(data).expect("deserialize completion response");
524            match data {
525                ApiResponse::Ok(response) => {
526                    tracing::info!(target: "rig",
527                        "bigmodel completion token usage: {:?}",
528                        response.usage
529                    );
530                    response.try_into()
531                }
532                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
533            }
534        } else {
535            Err(CompletionError::ProviderError(response.text().await?))
536        }
537    }
538
539    async fn stream(
540        &self,
541        request: CompletionRequest,
542    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
543        let mut request = self.create_completion_request(request)?;
544
545        request = json_utils::merge(request, json!({"stream": true}));
546
547        let builder = self.client.post("/chat/completions").json(&request);
548
549        send_compatible_streaming_request(builder).await
550    }
551}
552
553