rig_extra/extra_providers/
bigmodel.rs

1use http::{HeaderValue, Method, header};
2use rig::client::{CompletionClient, ProviderClient};
3use rig::completion::{CompletionError, CompletionRequest};
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, http_client, message};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9
10use crate::json_utils;
11use crate::json_utils::merge;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use tracing::{Instrument, info_span};
15
16// ================================================================
17// BIGMODEL 客户端
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    api_key: String,
24    base_url: String,
25    default_headers: http_client::HeaderMap,
26    http_client: reqwest::Client,
27}
28
29impl Client {
30    pub fn new(api_key: &str) -> Self {
31        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
32    }
33
34    pub fn from_url(api_key: &str, base_url: &str) -> Self {
35        let mut default_headers = reqwest::header::HeaderMap::new();
36        default_headers.insert(
37            reqwest::header::CONTENT_TYPE,
38            "application/json".parse().unwrap(),
39        );
40
41        Self {
42            api_key: api_key.to_string(),
43            base_url: base_url.to_string(),
44            default_headers,
45            http_client: reqwest::Client::builder()
46                .default_headers({
47                    let mut headers = reqwest::header::HeaderMap::new();
48                    headers.insert(
49                        "Authorization",
50                        format!("Bearer {api_key}")
51                            .parse()
52                            .expect("Bearer token should parse"),
53                    );
54                    headers
55                })
56                .build()
57                .expect("bigmodel reqwest client should build"),
58        }
59    }
60
61    fn post(&self, path: &str) -> reqwest::RequestBuilder {
62        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
63        self.http_client.post(url)
64    }
65
66    pub fn completion_model(&self, model: &str) -> CompletionModel {
67        CompletionModel::new(self.clone(), model)
68    }
69
70    // 为completion模型创建提取构建器
71    // pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
72    //     &self,
73    //     model: &str,
74    // ) -> ExtractorBuilder<T, CompletionModel> {
75    //     ExtractorBuilder::new(self.completion_model(model))
76    // }
77}
78
79impl ProviderClient for Client {
80    type Input = String;
81
82    fn from_env() -> Self
83    where
84        Self: Sized,
85    {
86        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87        Self::new(&api_key)
88    }
89
90    fn from_val(input: Self::Input) -> Self {
91        Self::new(&input)
92    }
93}
94impl CompletionClient for Client {
95    type CompletionModel = CompletionModel;
96
97    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
98        CompletionModel::new(self.clone(), &model.into())
99    }
100}
101
102#[derive(Debug, Deserialize)]
103struct ApiErrorResponse {
104    message: String,
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(untagged)]
109enum ApiResponse<T> {
110    Ok(T),
111    Err(ApiErrorResponse),
112}
113
114// ================================================================
115// Bigmodel Completion API
116// ================================================================
117pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
118
119#[deprecated(note = "GLM-4.5-Flash 将于2026年1月30日下线")]
120pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
121pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
122
123#[derive(Debug, Deserialize, Serialize)]
124#[serde(rename_all = "camelCase")]
125pub struct CompletionResponse {
126    pub choices: Vec<Choice>,
127    pub created: i64,
128    pub id: String,
129    pub model: String,
130    #[serde(rename = "request_id")]
131    pub request_id: String,
132    pub usage: Usage,
133}
134
135#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
136#[serde(tag = "role", rename_all = "lowercase")]
137pub enum Message {
138    User {
139        content: String,
140    },
141    Assistant {
142        content: Option<String>,
143        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
144        tool_calls: Vec<ToolCall>,
145    },
146    System {
147        content: String,
148    },
149    #[serde(rename = "tool")]
150    ToolResult {
151        tool_call_id: String,
152        content: String,
153    },
154}
155
156impl Message {
157    pub fn system(content: &str) -> Message {
158        Message::System {
159            content: content.to_owned(),
160        }
161    }
162}
163
164#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
165pub struct ToolResultContent {
166    text: String,
167}
168impl TryFrom<message::ToolResultContent> for ToolResultContent {
169    type Error = MessageError;
170    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
171        let message::ToolResultContent::Text(Text { text }) = value else {
172            return Err(MessageError::ConversionError(
173                "Non-text tool results not supported".into(),
174            ));
175        };
176
177        Ok(Self { text })
178    }
179}
180
181impl TryFrom<message::Message> for Message {
182    type Error = MessageError;
183
184    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
185        Ok(match message {
186            message::Message::User { content } => {
187                let mut texts = Vec::new();
188                let mut images = Vec::new();
189
190                for uc in content.into_iter() {
191                    match uc {
192                        message::UserContent::Text(message::Text { text }) => texts.push(text),
193                        message::UserContent::Image(img) => images.push(img.data),
194                        message::UserContent::ToolResult(result) => {
195                            let content = result
196                                .content
197                                .into_iter()
198                                .map(ToolResultContent::try_from)
199                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
200
201                            let content = OneOrMany::many(content).map_err(|x| {
202                                MessageError::ConversionError(format!(
203                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
204                                ))
205                            })?;
206
207                            return Ok(Message::ToolResult {
208                                tool_call_id: result.id,
209                                content: content.first().text,
210                            });
211                        }
212                        _ => {}
213                    }
214                }
215
216                let collapsed_content = texts.join(" ");
217
218                Message::User {
219                    content: collapsed_content,
220                }
221            }
222            message::Message::Assistant { content, .. } => {
223                let mut texts = Vec::new();
224                let mut tool_calls = Vec::new();
225
226                for ac in content.into_iter() {
227                    match ac {
228                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
229                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
230                        _ => {}
231                    }
232                }
233
234                let collapsed_content = texts.join(" ");
235
236                Message::Assistant {
237                    content: Some(collapsed_content),
238                    tool_calls,
239                }
240            }
241        })
242    }
243}
244
245impl From<message::ToolResult> for Message {
246    fn from(tool_result: message::ToolResult) -> Self {
247        let content = match tool_result.content.first() {
248            message::ToolResultContent::Text(text) => text.text,
249            message::ToolResultContent::Image(_) => String::from("[Image]"),
250        };
251
252        Message::ToolResult {
253            tool_call_id: tool_result.id,
254            content,
255        }
256    }
257}
258
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct ToolCall {
262    pub function: CallFunction,
263    pub id: String,
264    pub index: usize,
265    #[serde(default)]
266    pub r#type: ToolType,
267}
268
269impl From<message::ToolCall> for ToolCall {
270    fn from(tool_call: message::ToolCall) -> Self {
271        Self {
272            id: tool_call.id,
273            index: 0,
274            r#type: ToolType::Function,
275            function: CallFunction {
276                name: tool_call.function.name,
277                arguments: tool_call.function.arguments,
278            },
279        }
280    }
281}
282
283#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
284#[serde(rename_all = "lowercase")]
285pub enum ToolType {
286    #[default]
287    Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct CallFunction {
292    pub name: String,
293    #[serde(with = "json_utils::stringified_json")]
294    pub arguments: serde_json::Value,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298#[serde(rename_all = "lowercase")]
299pub enum Role {
300    System,
301    User,
302    Assistant,
303}
304
305#[derive(Debug, Serialize, Deserialize)]
306#[serde(rename_all = "camelCase")]
307pub struct Choice {
308    #[serde(rename = "finish_reason")]
309    pub finish_reason: String,
310    pub index: i64,
311    pub message: Message,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct Usage {
317    #[serde(rename = "completion_tokens")]
318    pub completion_tokens: i64,
319    #[serde(rename = "prompt_tokens")]
320    pub prompt_tokens: i64,
321    #[serde(rename = "total_tokens")]
322    pub total_tokens: i64,
323}
324
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326    type Error = CompletionError;
327
328    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
329        let choice = response.choices.first().ok_or_else(|| {
330            CompletionError::ResponseError("Response contained no choices".to_owned())
331        })?;
332
333        match &choice.message {
334            Message::Assistant {
335                tool_calls,
336                content,
337            } => {
338                if !tool_calls.is_empty() {
339                    let tool_result = tool_calls
340                        .iter()
341                        .map(|call| {
342                            completion::AssistantContent::tool_call(
343                                &call.function.name,
344                                &call.function.name,
345                                call.function.arguments.clone(),
346                            )
347                        })
348                        .collect::<Vec<_>>();
349
350                    let choice = OneOrMany::many(tool_result).map_err(|_| {
351                        CompletionError::ResponseError(
352                            "Response contained no message or tool call (empty)".to_owned(),
353                        )
354                    })?;
355                    let usage = completion::Usage {
356                        input_tokens: response.usage.prompt_tokens as u64,
357                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
358                            as u64,
359                        total_tokens: response.usage.total_tokens as u64,
360                    };
361                    tracing::debug!("response choices: {:?}: ", choice);
362                    Ok(completion::CompletionResponse {
363                        choice,
364                        usage,
365                        raw_response: response,
366                    })
367                } else {
368                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
369                        text: content.clone().unwrap_or_else(|| "".to_owned()),
370                    }));
371                    let usage = completion::Usage {
372                        input_tokens: response.usage.prompt_tokens as u64,
373                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
374                            as u64,
375                        total_tokens: response.usage.total_tokens as u64,
376                    };
377                    Ok(completion::CompletionResponse {
378                        choice,
379                        usage,
380                        raw_response: response,
381                    })
382                }
383            }
384            // Message::Assistant { tool_calls } => {}
385            _ => Err(CompletionError::ResponseError(
386                "Chat response does not include an assistant message".into(),
387            )),
388        }
389    }
390}
391
392#[derive(Clone)]
393pub struct CompletionModel {
394    client: Client,
395    pub model: String,
396}
397
398// 函数定义
399#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
400#[serde(rename_all = "camelCase")]
401pub struct CustomFunctionDefinition {
402    #[serde(rename = "type")]
403    pub type_field: String,
404    pub function: Function,
405}
406
407#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
408#[serde(rename_all = "camelCase")]
409pub struct Function {
410    pub name: String,
411    pub description: String,
412    pub parameters: serde_json::Value,
413}
414
415impl CompletionModel {
416    pub fn new(client: Client, model: &str) -> Self {
417        Self {
418            client,
419            model: model.to_string(),
420        }
421    }
422
423    fn create_completion_request(
424        &self,
425        completion_request: CompletionRequest,
426    ) -> Result<Value, CompletionError> {
427        // 构建消息顺序(上下文、聊天历史、提示)
428        let mut partial_history = vec![];
429        if let Some(docs) = completion_request.normalized_documents() {
430            partial_history.push(docs);
431        }
432        partial_history.extend(completion_request.chat_history);
433
434        // 使用前言初始化完整历史(如果不存在则为空)
435        let mut full_history: Vec<Message> = completion_request
436            .preamble
437            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
438
439        // 转换并扩展其余历史
440        full_history.extend(
441            partial_history
442                .into_iter()
443                .map(message::Message::try_into)
444                .collect::<Result<Vec<Message>, _>>()?,
445        );
446
447        let request = if completion_request.tools.is_empty() {
448            json!({
449                "model": self.model,
450                "messages": full_history,
451                "temperature": completion_request.temperature,
452            })
453        } else {
454            // tools
455            let tools = completion_request
456                .tools
457                .into_iter()
458                .map(|item| {
459                    let custom_function = Function {
460                        name: item.name,
461                        description: item.description,
462                        parameters: item.parameters,
463                    };
464                    CustomFunctionDefinition {
465                        type_field: "function".to_string(),
466                        function: custom_function,
467                    }
468                })
469                .collect::<Vec<_>>();
470
471            tracing::debug!("tools: {:?}", tools);
472
473            json!({
474                "model": self.model,
475                "messages": full_history,
476                "temperature": completion_request.temperature,
477                "tools": tools,
478                "tool_choice": "auto",
479            })
480        };
481
482        let request = if let Some(params) = completion_request.additional_params {
483            json_utils::merge(request, params)
484        } else {
485            request
486        };
487
488        Ok(request)
489    }
490}
491
492/// 同步请求
493impl completion::CompletionModel for CompletionModel {
494    type Response = CompletionResponse;
495    type StreamingResponse = openai::StreamingCompletionResponse;
496    type Client = Client;
497
498    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
499        Self::new(client.clone(), &model.into())
500    }
501
502    async fn completion(
503        &self,
504        completion_request: CompletionRequest,
505    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
506        tracing::debug!("create_completion_request========");
507        let request = self.create_completion_request(completion_request)?;
508
509        tracing::debug!(
510            "request: \r\n {}",
511            serde_json::to_string_pretty(&request).unwrap()
512        );
513
514        let response = self
515            .client
516            .post("/chat/completions")
517            .json(&request)
518            .send()
519            .await
520            .map_err(|e| http_client::Error::Instance(e.into()))?;
521
522        if response.status().is_success() {
523            let data: Value = response.json().await.expect("api error");
524            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
525            let data: ApiResponse<CompletionResponse> =
526                serde_json::from_value(data).expect("deserialize completion response");
527            match data {
528                ApiResponse::Ok(response) => {
529                    tracing::info!(target: "rig",
530                        "bigmodel completion token usage: {:?}",
531                        response.usage
532                    );
533                    response.try_into()
534                }
535                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
536            }
537        } else {
538            Err(CompletionError::ProviderError(
539                response
540                    .text()
541                    .await
542                    .map_err(|e| http_client::Error::Instance(e.into()))?,
543            ))
544        }
545    }
546
547    async fn stream(
548        &self,
549        request: CompletionRequest,
550    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
551        let preamble = request.preamble.clone();
552
553        let mut request = self.create_completion_request(request)?;
554
555        request = merge(request, json!({"stream": true}));
556
557        let body = serde_json::to_vec(&request)?;
558
559        let url = format!(
560            "{}/{}",
561            self.client.base_url,
562            "/chat/completions".trim_start_matches('/')
563        );
564
565        let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
566        for (header, value) in &self.client.default_headers {
567            builder = builder.header(header, value);
568        }
569
570        let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
571            .map_err(http::Error::from)
572            .map_err(rig::http_client::Error::from)?;
573
574        builder = builder.header(header::AUTHORIZATION, auth_header);
575        builder = builder.header("Content-Type", "application/json");
576
577        let req = builder
578            .body(body)
579            .map_err(|e| CompletionError::HttpError(e.into()))?;
580
581        let span = if tracing::Span::current().is_disabled() {
582            info_span!(
583                target: "rig::completions",
584                "chat_streaming",
585                gen_ai.operation.name = "chat_streaming",
586                gen_ai.provider.name = "galadriel",
587                gen_ai.request.model = self.model,
588                gen_ai.system_instructions = preamble,
589                gen_ai.response.id = tracing::field::Empty,
590                gen_ai.response.model = tracing::field::Empty,
591                gen_ai.usage.output_tokens = tracing::field::Empty,
592                gen_ai.usage.input_tokens = tracing::field::Empty,
593                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
594                gen_ai.output.messages = tracing::field::Empty,
595            )
596        } else {
597            tracing::Span::current()
598        };
599
600        send_compatible_streaming_request(self.client.http_client.clone(), req)
601            .instrument(span)
602            .await
603    }
604}