rig_extra/extra_providers/
bigmodel.rs

1use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
2use rig::completion::{CompletionError, CompletionRequest};
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, client, completion, http_client, message};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8
9use rig::providers::openai::send_compatible_streaming_request;
10use rig::streaming::StreamingCompletionResponse;
11
12use crate::json_utils;
13
14// ================================================================
15// BIGMODEL 客户端
16// ================================================================
17const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
18
19#[derive(Clone, Debug)]
20pub struct Client {
21    base_url: String,
22    http_client: reqwest::Client,
23}
24
25impl Client {
26    pub fn new(api_key: &str) -> Self {
27        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
28    }
29
30    pub fn from_url(api_key: &str, base_url: &str) -> Self {
31        Self {
32            base_url: base_url.to_string(),
33            http_client: reqwest::Client::builder()
34                .default_headers({
35                    let mut headers = reqwest::header::HeaderMap::new();
36                    headers.insert(
37                        "Authorization",
38                        format!("Bearer {api_key}")
39                            .parse()
40                            .expect("Bearer token should parse"),
41                    );
42                    headers
43                })
44                .build()
45                .expect("bigmodel reqwest client should build"),
46        }
47    }
48
49    fn post(&self, path: &str) -> reqwest::RequestBuilder {
50        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
51        self.http_client.post(url)
52    }
53
54    pub fn completion_model(&self, model: &str) -> CompletionModel {
55        CompletionModel::new(self.clone(), model)
56    }
57
58    // 为completion模型创建提取构建器
59    // pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
60    //     &self,
61    //     model: &str,
62    // ) -> ExtractorBuilder<T, CompletionModel> {
63    //     ExtractorBuilder::new(self.completion_model(model))
64    // }
65}
66
67impl ProviderClient for Client {
68    fn from_env() -> Self
69    where
70        Self: Sized,
71    {
72        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
73        Self::new(&api_key)
74    }
75
76    fn from_val(input: ProviderValue) -> Self
77    where
78        Self: Sized,
79    {
80        let client::ProviderValue::Simple(api_key) = input else {
81            panic!("Incorrect provider value type")
82        };
83        Self::new(&api_key)
84    }
85}
86
87impl AsTranscription for Client {}
88
89impl AsEmbeddings for Client {}
90
91impl CompletionClient for Client {
92    type CompletionModel = CompletionModel;
93
94    fn completion_model(&self, model: &str) -> Self::CompletionModel {
95        CompletionModel::new(self.clone(), model)
96    }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101    message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107    Ok(T),
108    Err(ApiErrorResponse),
109}
110
111// ================================================================
112// Bigmodel Completion API
113// ================================================================
114pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
115pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
116
117#[derive(Debug, Deserialize, Serialize)]
118#[serde(rename_all = "camelCase")]
119pub struct CompletionResponse {
120    pub choices: Vec<Choice>,
121    pub created: i64,
122    pub id: String,
123    pub model: String,
124    #[serde(rename = "request_id")]
125    pub request_id: String,
126    pub usage: Usage,
127}
128
129#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
130#[serde(tag = "role", rename_all = "lowercase")]
131pub enum Message {
132    User {
133        content: String,
134    },
135    Assistant {
136        content: Option<String>,
137        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
138        tool_calls: Vec<ToolCall>,
139    },
140    System {
141        content: String,
142    },
143    #[serde(rename = "tool")]
144    ToolResult {
145        tool_call_id: String,
146        content: String,
147    },
148}
149
150impl Message {
151    pub fn system(content: &str) -> Message {
152        Message::System {
153            content: content.to_owned(),
154        }
155    }
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct ToolResultContent {
160    text: String,
161}
162impl TryFrom<message::ToolResultContent> for ToolResultContent {
163    type Error = MessageError;
164    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
165        let message::ToolResultContent::Text(Text { text }) = value else {
166            return Err(MessageError::ConversionError(
167                "Non-text tool results not supported".into(),
168            ));
169        };
170
171        Ok(Self { text })
172    }
173}
174
175impl TryFrom<message::Message> for Message {
176    type Error = MessageError;
177
178    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
179        Ok(match message {
180            message::Message::User { content } => {
181                let mut texts = Vec::new();
182                let mut images = Vec::new();
183
184                for uc in content.into_iter() {
185                    match uc {
186                        message::UserContent::Text(message::Text { text }) => texts.push(text),
187                        message::UserContent::Image(img) => images.push(img.data),
188                        message::UserContent::ToolResult(result) => {
189                            let content = result
190                                .content
191                                .into_iter()
192                                .map(ToolResultContent::try_from)
193                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
194
195                            let content = OneOrMany::many(content).map_err(|x| {
196                                MessageError::ConversionError(format!(
197                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
198                                ))
199                            })?;
200
201                            return Ok(Message::ToolResult {
202                                tool_call_id: result.id,
203                                content: content.first().text,
204                            });
205                        }
206                        _ => {}
207                    }
208                }
209
210                let collapsed_content = texts.join(" ");
211
212                Message::User {
213                    content: collapsed_content,
214                }
215            }
216            message::Message::Assistant { content, .. } => {
217                let mut texts = Vec::new();
218                let mut tool_calls = Vec::new();
219
220                for ac in content.into_iter() {
221                    match ac {
222                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
223                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
224                        _ => {}
225                    }
226                }
227
228                let collapsed_content = texts.join(" ");
229
230                Message::Assistant {
231                    content: Some(collapsed_content),
232                    tool_calls,
233                }
234            }
235        })
236    }
237}
238
239impl From<message::ToolResult> for Message {
240    fn from(tool_result: message::ToolResult) -> Self {
241        let content = match tool_result.content.first() {
242            message::ToolResultContent::Text(text) => text.text,
243            message::ToolResultContent::Image(_) => String::from("[Image]"),
244        };
245
246        Message::ToolResult {
247            tool_call_id: tool_result.id,
248            content,
249        }
250    }
251}
252
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254#[serde(rename_all = "camelCase")]
255pub struct ToolCall {
256    pub function: CallFunction,
257    pub id: String,
258    pub index: usize,
259    #[serde(default)]
260    pub r#type: ToolType,
261}
262
263impl From<message::ToolCall> for ToolCall {
264    fn from(tool_call: message::ToolCall) -> Self {
265        Self {
266            id: tool_call.id,
267            index: 0,
268            r#type: ToolType::Function,
269            function: CallFunction {
270                name: tool_call.function.name,
271                arguments: tool_call.function.arguments,
272            },
273        }
274    }
275}
276
277#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
278#[serde(rename_all = "lowercase")]
279pub enum ToolType {
280    #[default]
281    Function,
282}
283
284#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
285pub struct CallFunction {
286    pub name: String,
287    #[serde(with = "json_utils::stringified_json")]
288    pub arguments: serde_json::Value,
289}
290
291#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
292#[serde(rename_all = "lowercase")]
293pub enum Role {
294    System,
295    User,
296    Assistant,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300#[serde(rename_all = "camelCase")]
301pub struct Choice {
302    #[serde(rename = "finish_reason")]
303    pub finish_reason: String,
304    pub index: i64,
305    pub message: Message,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
309#[serde(rename_all = "camelCase")]
310pub struct Usage {
311    #[serde(rename = "completion_tokens")]
312    pub completion_tokens: i64,
313    #[serde(rename = "prompt_tokens")]
314    pub prompt_tokens: i64,
315    #[serde(rename = "total_tokens")]
316    pub total_tokens: i64,
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320    type Error = CompletionError;
321
322    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323        let choice = response.choices.first().ok_or_else(|| {
324            CompletionError::ResponseError("Response contained no choices".to_owned())
325        })?;
326
327        match &choice.message {
328            Message::Assistant {
329                tool_calls,
330                content,
331            } => {
332                if !tool_calls.is_empty() {
333                    let tool_result = tool_calls
334                        .iter()
335                        .map(|call| {
336                            completion::AssistantContent::tool_call(
337                                &call.function.name,
338                                &call.function.name,
339                                call.function.arguments.clone(),
340                            )
341                        })
342                        .collect::<Vec<_>>();
343
344                    let choice = OneOrMany::many(tool_result).map_err(|_| {
345                        CompletionError::ResponseError(
346                            "Response contained no message or tool call (empty)".to_owned(),
347                        )
348                    })?;
349                    let usage = completion::Usage {
350                        input_tokens: response.usage.prompt_tokens as u64,
351                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
352                            as u64,
353                        total_tokens: response.usage.total_tokens as u64,
354                    };
355                    tracing::debug!("response choices: {:?}: ", choice);
356                    Ok(completion::CompletionResponse {
357                        choice,
358                        usage,
359                        raw_response: response,
360                    })
361                } else {
362                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
363                        text: content.clone().unwrap_or_else(|| "".to_owned()),
364                    }));
365                    let usage = completion::Usage {
366                        input_tokens: response.usage.prompt_tokens as u64,
367                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
368                            as u64,
369                        total_tokens: response.usage.total_tokens as u64,
370                    };
371                    Ok(completion::CompletionResponse {
372                        choice,
373                        usage,
374                        raw_response: response,
375                    })
376                }
377            }
378            // Message::Assistant { tool_calls } => {}
379            _ => Err(CompletionError::ResponseError(
380                "Chat response does not include an assistant message".into(),
381            )),
382        }
383    }
384}
385
386#[derive(Clone)]
387pub struct CompletionModel {
388    client: Client,
389    pub model: String,
390}
391
392// 函数定义
393#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
394#[serde(rename_all = "camelCase")]
395pub struct CustomFunctionDefinition {
396    #[serde(rename = "type")]
397    pub type_field: String,
398    pub function: Function,
399}
400
401#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
402#[serde(rename_all = "camelCase")]
403pub struct Function {
404    pub name: String,
405    pub description: String,
406    pub parameters: serde_json::Value,
407}
408
409impl CompletionModel {
410    pub fn new(client: Client, model: &str) -> Self {
411        Self {
412            client,
413            model: model.to_string(),
414        }
415    }
416
417    fn create_completion_request(
418        &self,
419        completion_request: CompletionRequest,
420    ) -> Result<Value, CompletionError> {
421        // 构建消息顺序(上下文、聊天历史、提示)
422        let mut partial_history = vec![];
423        if let Some(docs) = completion_request.normalized_documents() {
424            partial_history.push(docs);
425        }
426        partial_history.extend(completion_request.chat_history);
427
428        // 使用前言初始化完整历史(如果不存在则为空)
429        let mut full_history: Vec<Message> = completion_request
430            .preamble
431            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
432
433        // 转换并扩展其余历史
434        full_history.extend(
435            partial_history
436                .into_iter()
437                .map(message::Message::try_into)
438                .collect::<Result<Vec<Message>, _>>()?,
439        );
440
441        let request = if completion_request.tools.is_empty() {
442            json!({
443                "model": self.model,
444                "messages": full_history,
445                "temperature": completion_request.temperature,
446            })
447        } else {
448            // tools
449            let tools = completion_request
450                .tools
451                .into_iter()
452                .map(|item| {
453                    let custom_function = Function {
454                        name: item.name,
455                        description: item.description,
456                        parameters: item.parameters,
457                    };
458                    CustomFunctionDefinition {
459                        type_field: "function".to_string(),
460                        function: custom_function,
461                    }
462                })
463                .collect::<Vec<_>>();
464
465            tracing::debug!("tools: {:?}", tools);
466
467            json!({
468                "model": self.model,
469                "messages": full_history,
470                "temperature": completion_request.temperature,
471                "tools": tools,
472                "tool_choice": "auto",
473            })
474        };
475
476        let request = if let Some(params) = completion_request.additional_params {
477            json_utils::merge(request, params)
478        } else {
479            request
480        };
481
482        Ok(request)
483    }
484}
485
486/// 同步请求
487impl completion::CompletionModel for CompletionModel {
488    type Response = CompletionResponse;
489    type StreamingResponse = openai::StreamingCompletionResponse;
490
491    async fn completion(
492        &self,
493        completion_request: CompletionRequest,
494    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
495        tracing::debug!("create_completion_request========");
496        let request = self.create_completion_request(completion_request)?;
497
498        tracing::debug!(
499            "request: \r\n {}",
500            serde_json::to_string_pretty(&request).unwrap()
501        );
502
503        let response = self
504            .client
505            .post("/chat/completions")
506            .json(&request)
507            .send()
508            .await
509            .map_err(|e| http_client::Error::Instance(e.into()))?;
510
511        if response.status().is_success() {
512            let data: Value = response.json().await.expect("api error");
513            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
514            let data: ApiResponse<CompletionResponse> =
515                serde_json::from_value(data).expect("deserialize completion response");
516            match data {
517                ApiResponse::Ok(response) => {
518                    tracing::info!(target: "rig",
519                        "bigmodel completion token usage: {:?}",
520                        response.usage
521                    );
522                    response.try_into()
523                }
524                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
525            }
526        } else {
527            Err(CompletionError::ProviderError(
528                response
529                    .text()
530                    .await
531                    .map_err(|e| http_client::Error::Instance(e.into()))?,
532            ))
533        }
534    }
535
536    async fn stream(
537        &self,
538        request: CompletionRequest,
539    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
540        let mut request = self.create_completion_request(request)?;
541
542        request = json_utils::merge(request, json!({"stream": true}));
543
544        let builder = self.client.post("/chat/completions").json(&request);
545
546        send_compatible_streaming_request(builder).await
547    }
548}