rig_extra/extra_providers/
bigmodel.rs

1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message, client};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16// ================================================================
17// BIGMODEL 客户端
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    base_url: String,
24    http_client: reqwest::Client,
25}
26
27impl Client {
28    pub fn new(api_key: &str) -> Self {
29        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30    }
31
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            http_client: reqwest::Client::builder()
36                .default_headers({
37                    let mut headers = reqwest::header::HeaderMap::new();
38                    headers.insert(
39                        "Authorization",
40                        format!("Bearer {api_key}")
41                            .parse()
42                            .expect("Bearer token should parse"),
43                    );
44                    headers
45                })
46                .build()
47                .expect("bigmodel reqwest client should build"),
48        }
49    }
50
51
52
53    fn post(&self, path: &str) -> reqwest::RequestBuilder {
54        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55        self.http_client.post(url)
56    }
57
58    pub fn completion_model(&self, model: &str) -> CompletionModel {
59        CompletionModel::new(self.clone(), model)
60    }
61
62
63
64    /// 为completion模型创建提取构建器
65    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66        &self,
67        model: &str,
68    ) -> ExtractorBuilder<T, CompletionModel> {
69        ExtractorBuilder::new(self.completion_model(model))
70    }
71}
72
73impl ProviderClient for Client {
74    fn from_env() -> Self
75    where
76        Self: Sized
77    {
78        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79        Self::new(&api_key)
80    }
81
82    fn from_val(input: ProviderValue) -> Self
83    where
84        Self: Sized
85    {
86        let  client::ProviderValue::Simple(api_key) = input else {
87            panic!("Incorrect provider value type")
88        };
89        Self::new(&api_key)
90    }
91}
92
93impl AsTranscription for Client {}
94
95impl AsEmbeddings for Client {}
96
97
98
99impl CompletionClient for Client {
100    type CompletionModel = CompletionModel;
101
102    fn completion_model(&self, model: &str) -> Self::CompletionModel {
103        CompletionModel::new(self.clone(), model)
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119// ================================================================
120// Bigmodel Completion API
121// ================================================================
122pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
123
124#[derive(Debug, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct CompletionResponse {
127    pub choices: Vec<Choice>,
128    pub created: i64,
129    pub id: String,
130    pub model: String,
131    #[serde(rename = "request_id")]
132    pub request_id: String,
133    pub usage: Usage,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "role", rename_all = "lowercase")]
138pub enum Message {
139    User {
140        content: String,
141    },
142    Assistant {
143        content: Option<String>,
144        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
145        tool_calls: Vec<ToolCall>,
146    },
147    System {
148        content: String,
149    },
150    #[serde(rename = "tool")]
151    ToolResult {
152        tool_call_id: String,
153        content: String,
154    },
155}
156
157impl Message {
158    pub fn system(content: &str) -> Message {
159        Message::System {
160            content: content.to_owned(),
161        }
162    }
163}
164
165#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
166pub struct ToolResultContent {
167    text: String,
168}
169impl TryFrom<message::ToolResultContent> for ToolResultContent {
170    type Error = MessageError;
171    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
172        let message::ToolResultContent::Text(Text { text }) = value else {
173            return Err(MessageError::ConversionError(
174                "Non-text tool results not supported".into(),
175            ));
176        };
177
178        Ok(Self { text })
179    }
180}
181
182impl TryFrom<message::Message> for Message {
183    type Error = MessageError;
184
185    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
186        Ok(match message {
187            message::Message::User { content } => {
188                let mut texts = Vec::new();
189                let mut images = Vec::new();
190
191                for uc in content.into_iter() {
192                    match uc {
193                        message::UserContent::Text(message::Text { text }) => texts.push(text),
194                        message::UserContent::Image(img) => images.push(img.data),
195                        message::UserContent::ToolResult(result) => {
196                            let content = result
197                                .content
198                                .into_iter()
199                                .map(ToolResultContent::try_from)
200                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
201
202                            let content = OneOrMany::many(content).map_err(|x| {
203                                MessageError::ConversionError(format!(
204                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
205                                ))
206                            })?;
207
208                            return Ok(Message::ToolResult {
209                                tool_call_id: result.id,
210                                content: content.first().text,
211                            });
212                        }
213                        _ => {}
214                    }
215                }
216
217                let collapsed_content = texts.join(" ");
218
219                Message::User {
220                    content: collapsed_content,
221                }
222            }
223            message::Message::Assistant { content, .. } => {
224                let mut texts = Vec::new();
225                let mut tool_calls = Vec::new();
226
227                for ac in content.into_iter() {
228                    match ac {
229                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
230                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
231                    }
232                }
233
234                let collapsed_content = texts.join(" ");
235
236                Message::Assistant {
237                    content: Some(collapsed_content),
238                    tool_calls,
239                }
240            }
241        })
242    }
243}
244
245impl From<message::ToolResult> for Message {
246    fn from(tool_result: message::ToolResult) -> Self {
247        let content = match tool_result.content.first() {
248            message::ToolResultContent::Text(text) => text.text,
249            message::ToolResultContent::Image(_) => String::from("[Image]"),
250        };
251
252        Message::ToolResult {
253            tool_call_id: tool_result.id,
254            content,
255        }
256    }
257}
258
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct ToolCall {
262    pub function: CallFunction,
263    pub id: String,
264    pub index: usize,
265    #[serde(default)]
266    pub r#type: ToolType,
267}
268
269impl From<message::ToolCall> for ToolCall {
270    fn from(tool_call: message::ToolCall) -> Self {
271        Self {
272            id: tool_call.id,
273            index: 0,
274            r#type: ToolType::Function,
275            function: CallFunction {
276                name: tool_call.function.name,
277                arguments: tool_call.function.arguments,
278            },
279        }
280    }
281}
282
283#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
284#[serde(rename_all = "lowercase")]
285pub enum ToolType {
286    #[default]
287    Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct CallFunction {
292    pub name: String,
293    #[serde(with = "json_utils::stringified_json")]
294    pub arguments: serde_json::Value,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298#[serde(rename_all = "lowercase")]
299pub enum Role {
300    System,
301    User,
302    Assistant,
303}
304
305#[derive(Debug, Serialize, Deserialize)]
306#[serde(rename_all = "camelCase")]
307pub struct Choice {
308    #[serde(rename = "finish_reason")]
309    pub finish_reason: String,
310    pub index: i64,
311    pub message: Message,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct Usage {
317    #[serde(rename = "completion_tokens")]
318    pub completion_tokens: i64,
319    #[serde(rename = "prompt_tokens")]
320    pub prompt_tokens: i64,
321    #[serde(rename = "total_tokens")]
322    pub total_tokens: i64,
323}
324
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326    type Error = CompletionError;
327
328    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
329        let choice = response.choices.first().ok_or_else(|| {
330            CompletionError::ResponseError("Response contained no choices".to_owned())
331        })?;
332
333        match &choice.message {
334            Message::Assistant {
335                tool_calls,
336                content,
337            } => {
338                if !tool_calls.is_empty() {
339                    let tool_result = tool_calls
340                        .iter()
341                        .map(|call| {
342                            completion::AssistantContent::tool_call(
343                                &call.function.name,
344                                &call.function.name,
345                                call.function.arguments.clone(),
346                            )
347                        })
348                        .collect::<Vec<_>>();
349
350                    let choice = OneOrMany::many(tool_result).map_err(|_| {
351                        CompletionError::ResponseError(
352                            "Response contained no message or tool call (empty)".to_owned(),
353                        )
354                    })?;
355                    let usage = completion::Usage {
356                        input_tokens: response.usage.prompt_tokens as u64,
357                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
358                            as u64,
359                        total_tokens: response.usage.total_tokens as u64,
360                    };
361                    tracing::debug!("response choices: {:?}: ", choice);
362                    Ok(completion::CompletionResponse {
363                        choice,
364                        usage,
365                        raw_response: response,
366                    })
367                } else {
368                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
369                        text: content.clone().unwrap_or_else(|| "".to_owned()),
370                    }));
371                    let usage = completion::Usage {
372                        input_tokens: response.usage.prompt_tokens as u64,
373                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
374                            as u64,
375                        total_tokens: response.usage.total_tokens as u64,
376                    };
377                    Ok(completion::CompletionResponse {
378                        choice,
379                        usage,
380                        raw_response: response,
381                    })
382                }
383            }
384            // Message::Assistant { tool_calls } => {}
385            _ => Err(CompletionError::ResponseError(
386                "Chat response does not include an assistant message".into(),
387            )),
388        }
389    }
390}
391
392#[derive(Clone)]
393pub struct CompletionModel {
394    client: Client,
395    pub model: String,
396}
397
398
399
400// 函数定义
401#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
402#[serde(rename_all = "camelCase")]
403pub struct CustomFunctionDefinition {
404    #[serde(rename = "type")]
405    pub type_field: String,
406    pub function: Function,
407}
408
409#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
410#[serde(rename_all = "camelCase")]
411pub struct Function {
412    pub name: String,
413    pub description: String,
414    pub parameters: serde_json::Value,
415}
416
417impl CompletionModel {
418    pub fn new(client: Client, model: &str) -> Self {
419        Self {
420            client,
421            model: model.to_string(),
422        }
423    }
424
425    fn create_completion_request(
426        &self,
427        completion_request: CompletionRequest,
428    ) -> Result<Value, CompletionError> {
429        // 构建消息顺序(上下文、聊天历史、提示)
430        let mut partial_history = vec![];
431        if let Some(docs) = completion_request.normalized_documents() {
432            partial_history.push(docs);
433        }
434        partial_history.extend(completion_request.chat_history);
435
436        // 使用前言初始化完整历史(如果不存在则为空)
437        let mut full_history: Vec<Message> = completion_request
438            .preamble
439            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
440
441        // 转换并扩展其余历史
442        full_history.extend(
443            partial_history
444                .into_iter()
445                .map(message::Message::try_into)
446                .collect::<Result<Vec<Message>, _>>()?,
447        );
448
449        let request = if completion_request.tools.is_empty() {
450            json!({
451                "model": self.model,
452                "messages": full_history,
453                "temperature": completion_request.temperature,
454            })
455        } else {
456            // tools
457            let tools = completion_request
458                .tools
459                .into_iter()
460                .map(|item| {
461                    let custom_function = Function {
462                        name: item.name,
463                        description: item.description,
464                        parameters: item.parameters,
465                    };
466                    CustomFunctionDefinition {
467                        type_field: "function".to_string(),
468                        function: custom_function,
469                    }
470                })
471                .collect::<Vec<_>>();
472
473            tracing::debug!("tools: {:?}", tools);
474
475            json!({
476                "model": self.model,
477                "messages": full_history,
478                "temperature": completion_request.temperature,
479                "tools": tools,
480                "tool_choice": "auto",
481            })
482        };
483
484        let request = if let Some(params) = completion_request.additional_params {
485            json_utils::merge(request, params)
486        } else {
487            request
488        };
489
490        Ok(request)
491    }
492}
493
494/// 同步请求
495impl completion::CompletionModel for CompletionModel {
496    type Response = CompletionResponse;
497    type StreamingResponse = openai::StreamingCompletionResponse;
498
499    async fn completion(
500        &self,
501        completion_request: CompletionRequest,
502    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
503        tracing::debug!("create_completion_request========");
504        let request = self.create_completion_request(completion_request)?;
505
506        tracing::debug!(
507            "request: \r\n {}",
508            serde_json::to_string_pretty(&request).unwrap()
509        );
510
511        let response = self
512            .client
513            .post("/chat/completions")
514            .json(&request)
515            .send()
516            .await?;
517
518        if response.status().is_success() {
519            let data: Value = response.json().await.expect("api error");
520            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
521            let data: ApiResponse<CompletionResponse> =
522                serde_json::from_value(data).expect("deserialize completion response");
523            match data {
524                ApiResponse::Ok(response) => {
525                    tracing::info!(target: "rig",
526                        "bigmodel completion token usage: {:?}",
527                        response.usage
528                    );
529                    response.try_into()
530                }
531                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
532            }
533        } else {
534            Err(CompletionError::ProviderError(response.text().await?))
535        }
536    }
537
538    async fn stream(
539        &self,
540        request: CompletionRequest,
541    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
542        let mut request = self.create_completion_request(request)?;
543
544        request = json_utils::merge(request, json!({"stream": true}));
545
546        let builder = self.client.post("/chat/completions").json(&request);
547
548        send_compatible_streaming_request(builder).await
549    }
550}
551
552