rig_extra/extra_providers/
bigmodel.rs

1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16// ================================================================
17// BIGMODEL 客户端
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    base_url: String,
24    http_client: reqwest::Client,
25}
26
27impl Client {
28    pub fn new(api_key: &str) -> Self {
29        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30    }
31
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            http_client: reqwest::Client::builder()
36                .default_headers({
37                    let mut headers = reqwest::header::HeaderMap::new();
38                    headers.insert(
39                        "Authorization",
40                        format!("Bearer {api_key}")
41                            .parse()
42                            .expect("Bearer token should parse"),
43                    );
44                    headers
45                })
46                .build()
47                .expect("bigmodel reqwest client should build"),
48        }
49    }
50
51
52
53    fn post(&self, path: &str) -> reqwest::RequestBuilder {
54        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55        self.http_client.post(url)
56    }
57
58    pub fn completion_model(&self, model: &str) -> CompletionModel {
59        CompletionModel::new(self.clone(), model)
60    }
61
62
63
64    /// 为completion模型创建提取构建器
65    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66        &self,
67        model: &str,
68    ) -> ExtractorBuilder<T, CompletionModel> {
69        ExtractorBuilder::new(self.completion_model(model))
70    }
71}
72
73impl ProviderClient for Client {
74    fn from_env() -> Self
75    where
76        Self: Sized
77    {
78        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79        Self::new(&api_key)
80    }
81}
82
83impl AsTranscription for Client {}
84
85impl AsEmbeddings for Client {}
86
87
88
89impl CompletionClient for Client {
90    type CompletionModel = CompletionModel;
91
92    fn completion_model(&self, model: &str) -> Self::CompletionModel {
93        CompletionModel::new(self.clone(), model)
94    }
95}
96
97#[derive(Debug, Deserialize)]
98struct ApiErrorResponse {
99    message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104enum ApiResponse<T> {
105    Ok(T),
106    Err(ApiErrorResponse),
107}
108
109// ================================================================
110// Bigmodel Completion API
111// ================================================================
112pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
113
114#[derive(Debug, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub struct CompletionResponse {
117    pub choices: Vec<Choice>,
118    pub created: i64,
119    pub id: String,
120    pub model: String,
121    #[serde(rename = "request_id")]
122    pub request_id: String,
123    pub usage: Usage,
124}
125
126#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum Message {
129    User {
130        content: String,
131    },
132    Assistant {
133        content: Option<String>,
134        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
135        tool_calls: Vec<ToolCall>,
136    },
137    System {
138        content: String,
139    },
140    #[serde(rename = "tool")]
141    ToolResult {
142        tool_call_id: String,
143        content: String,
144    },
145}
146
147impl Message {
148    pub fn system(content: &str) -> Message {
149        Message::System {
150            content: content.to_owned(),
151        }
152    }
153}
154
155#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
156pub struct ToolResultContent {
157    text: String,
158}
159impl TryFrom<message::ToolResultContent> for ToolResultContent {
160    type Error = MessageError;
161    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
162        let message::ToolResultContent::Text(Text { text }) = value else {
163            return Err(MessageError::ConversionError(
164                "Non-text tool results not supported".into(),
165            ));
166        };
167
168        Ok(Self { text })
169    }
170}
171
172impl TryFrom<message::Message> for Message {
173    type Error = MessageError;
174
175    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
176        Ok(match message {
177            message::Message::User { content } => {
178                let mut texts = Vec::new();
179                let mut images = Vec::new();
180
181                for uc in content.into_iter() {
182                    match uc {
183                        message::UserContent::Text(message::Text { text }) => texts.push(text),
184                        message::UserContent::Image(img) => images.push(img.data),
185                        message::UserContent::ToolResult(result) => {
186                            let content = result
187                                .content
188                                .into_iter()
189                                .map(ToolResultContent::try_from)
190                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
191
192                            let content = OneOrMany::many(content).map_err(|x| {
193                                MessageError::ConversionError(format!(
194                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
195                                ))
196                            })?;
197
198                            return Ok(Message::ToolResult {
199                                tool_call_id: result.id,
200                                content: content.first().text,
201                            });
202                        }
203                        _ => {}
204                    }
205                }
206
207                let collapsed_content = texts.join(" ");
208
209                Message::User {
210                    content: collapsed_content,
211                }
212            }
213            message::Message::Assistant { content, .. } => {
214                let mut texts = Vec::new();
215                let mut tool_calls = Vec::new();
216
217                for ac in content.into_iter() {
218                    match ac {
219                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
220                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
221                    }
222                }
223
224                let collapsed_content = texts.join(" ");
225
226                Message::Assistant {
227                    content: Some(collapsed_content),
228                    tool_calls,
229                }
230            }
231        })
232    }
233}
234
235impl From<message::ToolResult> for Message {
236    fn from(tool_result: message::ToolResult) -> Self {
237        let content = match tool_result.content.first() {
238            message::ToolResultContent::Text(text) => text.text,
239            message::ToolResultContent::Image(_) => String::from("[Image]"),
240        };
241
242        Message::ToolResult {
243            tool_call_id: tool_result.id,
244            content,
245        }
246    }
247}
248
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250#[serde(rename_all = "camelCase")]
251pub struct ToolCall {
252    pub function: CallFunction,
253    pub id: String,
254    pub index: usize,
255    #[serde(default)]
256    pub r#type: ToolType,
257}
258
259impl From<message::ToolCall> for ToolCall {
260    fn from(tool_call: message::ToolCall) -> Self {
261        Self {
262            id: tool_call.id,
263            index: 0,
264            r#type: ToolType::Function,
265            function: CallFunction {
266                name: tool_call.function.name,
267                arguments: tool_call.function.arguments,
268            },
269        }
270    }
271}
272
273#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
274#[serde(rename_all = "lowercase")]
275pub enum ToolType {
276    #[default]
277    Function,
278}
279
280#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
281pub struct CallFunction {
282    pub name: String,
283    #[serde(with = "json_utils::stringified_json")]
284    pub arguments: serde_json::Value,
285}
286
287#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
288#[serde(rename_all = "lowercase")]
289pub enum Role {
290    System,
291    User,
292    Assistant,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct Choice {
298    #[serde(rename = "finish_reason")]
299    pub finish_reason: String,
300    pub index: i64,
301    pub message: Message,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct Usage {
307    #[serde(rename = "completion_tokens")]
308    pub completion_tokens: i64,
309    #[serde(rename = "prompt_tokens")]
310    pub prompt_tokens: i64,
311    #[serde(rename = "total_tokens")]
312    pub total_tokens: i64,
313}
314
315impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
316    type Error = CompletionError;
317
318    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
319        let choice = response.choices.first().ok_or_else(|| {
320            CompletionError::ResponseError("Response contained no choices".to_owned())
321        })?;
322
323        match &choice.message {
324            Message::Assistant {
325                tool_calls,
326                content,
327            } => {
328                if !tool_calls.is_empty() {
329                    let tool_result = tool_calls
330                        .iter()
331                        .map(|call| {
332                            completion::AssistantContent::tool_call(
333                                &call.function.name,
334                                &call.function.name,
335                                call.function.arguments.clone(),
336                            )
337                        })
338                        .collect::<Vec<_>>();
339
340                    let choice = OneOrMany::many(tool_result).map_err(|_| {
341                        CompletionError::ResponseError(
342                            "Response contained no message or tool call (empty)".to_owned(),
343                        )
344                    })?;
345                    // let usage = completion::Usage {
346                    //     input_tokens: response.usage.prompt_tokens as u64,
347                    //     output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
348                    //         as u64,
349                    //     total_tokens: response.usage.total_tokens as u64,
350                    // };
351                    tracing::debug!("response choices: {:?}: ", choice);
352                    Ok(completion::CompletionResponse {
353                        choice,
354                        // usage,
355                        raw_response: response,
356                    })
357                } else {
358                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
359                        text: content.clone().unwrap_or_else(|| "".to_owned()),
360                    }));
361                    // let usage = completion::Usage {
362                    //     input_tokens: response.usage.prompt_tokens as u64,
363                    //     output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
364                    //         as u64,
365                    //     total_tokens: response.usage.total_tokens as u64,
366                    // };
367                    Ok(completion::CompletionResponse {
368                        choice,
369                        // usage,
370                        raw_response: response,
371                    })
372                }
373            }
374            // Message::Assistant { tool_calls } => {}
375            _ => Err(CompletionError::ResponseError(
376                "Chat response does not include an assistant message".into(),
377            )),
378        }
379    }
380}
381
382#[derive(Clone)]
383pub struct CompletionModel {
384    client: Client,
385    pub model: String,
386}
387
388
389
390// 函数定义
391#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
392#[serde(rename_all = "camelCase")]
393pub struct CustomFunctionDefinition {
394    #[serde(rename = "type")]
395    pub type_field: String,
396    pub function: Function,
397}
398
399#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
400#[serde(rename_all = "camelCase")]
401pub struct Function {
402    pub name: String,
403    pub description: String,
404    pub parameters: serde_json::Value,
405}
406
407impl CompletionModel {
408    pub fn new(client: Client, model: &str) -> Self {
409        Self {
410            client,
411            model: model.to_string(),
412        }
413    }
414
415    fn create_completion_request(
416        &self,
417        completion_request: CompletionRequest,
418    ) -> Result<Value, CompletionError> {
419        // 构建消息顺序(上下文、聊天历史、提示)
420        let mut partial_history = vec![];
421        if let Some(docs) = completion_request.normalized_documents() {
422            partial_history.push(docs);
423        }
424        partial_history.extend(completion_request.chat_history);
425
426        // 使用前言初始化完整历史(如果不存在则为空)
427        let mut full_history: Vec<Message> = completion_request
428            .preamble
429            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
430
431        // 转换并扩展其余历史
432        full_history.extend(
433            partial_history
434                .into_iter()
435                .map(message::Message::try_into)
436                .collect::<Result<Vec<Message>, _>>()?,
437        );
438
439        let request = if completion_request.tools.is_empty() {
440            json!({
441                "model": self.model,
442                "messages": full_history,
443                "temperature": completion_request.temperature,
444            })
445        } else {
446            // tools
447            let tools = completion_request
448                .tools
449                .into_iter()
450                .map(|item| {
451                    let custom_function = Function {
452                        name: item.name,
453                        description: item.description,
454                        parameters: item.parameters,
455                    };
456                    CustomFunctionDefinition {
457                        type_field: "function".to_string(),
458                        function: custom_function,
459                    }
460                })
461                .collect::<Vec<_>>();
462
463            tracing::debug!("tools: {:?}", tools);
464
465            json!({
466                "model": self.model,
467                "messages": full_history,
468                "temperature": completion_request.temperature,
469                "tools": tools,
470                "tool_choice": "auto",
471            })
472        };
473
474        let request = if let Some(params) = completion_request.additional_params {
475            json_utils::merge(request, params)
476        } else {
477            request
478        };
479
480        Ok(request)
481    }
482}
483
484/// 同步请求
485impl completion::CompletionModel for CompletionModel {
486    type Response = CompletionResponse;
487    type StreamingResponse = openai::StreamingCompletionResponse;
488
489    async fn completion(
490        &self,
491        completion_request: CompletionRequest,
492    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
493        tracing::debug!("create_completion_request========");
494        let request = self.create_completion_request(completion_request)?;
495
496        tracing::debug!(
497            "request: \r\n {}",
498            serde_json::to_string_pretty(&request).unwrap()
499        );
500
501        let response = self
502            .client
503            .post("/chat/completions")
504            .json(&request)
505            .send()
506            .await?;
507
508        if response.status().is_success() {
509            let data: Value = response.json().await.expect("api error");
510            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
511            let data: ApiResponse<CompletionResponse> =
512                serde_json::from_value(data).expect("deserialize completion response");
513            match data {
514                ApiResponse::Ok(response) => {
515                    tracing::info!(target: "rig",
516                        "bigmodel completion token usage: {:?}",
517                        response.usage
518                    );
519                    response.try_into()
520                }
521                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
522            }
523        } else {
524            Err(CompletionError::ProviderError(response.text().await?))
525        }
526    }
527
528    async fn stream(
529        &self,
530        request: CompletionRequest,
531    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
532        let mut request = self.create_completion_request(request)?;
533
534        request = json_utils::merge(request, json!({"stream": true}));
535
536        let builder = self.client.post("/chat/completions").json(&request);
537
538        send_compatible_streaming_request(builder).await
539    }
540}
541
542