rig_extra/extra_providers/
bigmodel.rs

1use http::header;
2use reqwest::Method;
3use reqwest::header::HeaderValue;
4use rig::client::{CompletionClient, ProviderClient};
5use rig::completion::{CompletionError, CompletionRequest};
6use rig::message::{MessageError, Text};
7use rig::providers::openai;
8use rig::{OneOrMany, completion, http_client, message};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::json_utils;
13use crate::json_utils::merge;
14use rig::providers::openai::send_compatible_streaming_request;
15use rig::streaming::StreamingCompletionResponse;
16use tracing::{Instrument, info_span};
17
18// ================================================================
19// BIGMODEL 客户端
20// ================================================================
21const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Clone, Debug)]
24pub struct Client {
25    api_key: String,
26    base_url: String,
27    default_headers: http_client::HeaderMap,
28    http_client: reqwest::Client,
29}
30
31impl Client {
32    pub fn new(api_key: &str) -> Self {
33        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
34    }
35
36    pub fn from_url(api_key: &str, base_url: &str) -> Self {
37        let mut default_headers = reqwest::header::HeaderMap::new();
38        default_headers.insert(
39            reqwest::header::CONTENT_TYPE,
40            "application/json".parse().unwrap(),
41        );
42
43        Self {
44            api_key: api_key.to_string(),
45            base_url: base_url.to_string(),
46            default_headers,
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        "Authorization",
52                        format!("Bearer {api_key}")
53                            .parse()
54                            .expect("Bearer token should parse"),
55                    );
56                    headers
57                })
58                .build()
59                .expect("bigmodel reqwest client should build"),
60        }
61    }
62
63    fn post(&self, path: &str) -> reqwest::RequestBuilder {
64        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65        self.http_client.post(url)
66    }
67
68    pub fn completion_model(&self, model: &str) -> CompletionModel {
69        CompletionModel::new(self.clone(), model)
70    }
71
72    // 为completion模型创建提取构建器
73    // pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
74    //     &self,
75    //     model: &str,
76    // ) -> ExtractorBuilder<T, CompletionModel> {
77    //     ExtractorBuilder::new(self.completion_model(model))
78    // }
79}
80
81impl ProviderClient for Client {
82    type Input = String;
83
84    fn from_env() -> Self
85    where
86        Self: Sized,
87    {
88        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
89        Self::new(&api_key)
90    }
91
92    fn from_val(input: Self::Input) -> Self {
93        Self::new(&input)
94    }
95}
96impl CompletionClient for Client {
97    type CompletionModel = CompletionModel;
98
99    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
100        CompletionModel::new(self.clone(), &model.into())
101    }
102}
103
104#[derive(Debug, Deserialize)]
105struct ApiErrorResponse {
106    message: String,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(untagged)]
111enum ApiResponse<T> {
112    Ok(T),
113    Err(ApiErrorResponse),
114}
115
116// ================================================================
117// Bigmodel Completion API
118// ================================================================
119pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
120pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
121
122#[derive(Debug, Deserialize, Serialize)]
123#[serde(rename_all = "camelCase")]
124pub struct CompletionResponse {
125    pub choices: Vec<Choice>,
126    pub created: i64,
127    pub id: String,
128    pub model: String,
129    #[serde(rename = "request_id")]
130    pub request_id: String,
131    pub usage: Usage,
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135#[serde(tag = "role", rename_all = "lowercase")]
136pub enum Message {
137    User {
138        content: String,
139    },
140    Assistant {
141        content: Option<String>,
142        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
143        tool_calls: Vec<ToolCall>,
144    },
145    System {
146        content: String,
147    },
148    #[serde(rename = "tool")]
149    ToolResult {
150        tool_call_id: String,
151        content: String,
152    },
153}
154
155impl Message {
156    pub fn system(content: &str) -> Message {
157        Message::System {
158            content: content.to_owned(),
159        }
160    }
161}
162
163#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
164pub struct ToolResultContent {
165    text: String,
166}
167impl TryFrom<message::ToolResultContent> for ToolResultContent {
168    type Error = MessageError;
169    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
170        let message::ToolResultContent::Text(Text { text }) = value else {
171            return Err(MessageError::ConversionError(
172                "Non-text tool results not supported".into(),
173            ));
174        };
175
176        Ok(Self { text })
177    }
178}
179
180impl TryFrom<message::Message> for Message {
181    type Error = MessageError;
182
183    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
184        Ok(match message {
185            message::Message::User { content } => {
186                let mut texts = Vec::new();
187                let mut images = Vec::new();
188
189                for uc in content.into_iter() {
190                    match uc {
191                        message::UserContent::Text(message::Text { text }) => texts.push(text),
192                        message::UserContent::Image(img) => images.push(img.data),
193                        message::UserContent::ToolResult(result) => {
194                            let content = result
195                                .content
196                                .into_iter()
197                                .map(ToolResultContent::try_from)
198                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
199
200                            let content = OneOrMany::many(content).map_err(|x| {
201                                MessageError::ConversionError(format!(
202                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
203                                ))
204                            })?;
205
206                            return Ok(Message::ToolResult {
207                                tool_call_id: result.id,
208                                content: content.first().text,
209                            });
210                        }
211                        _ => {}
212                    }
213                }
214
215                let collapsed_content = texts.join(" ");
216
217                Message::User {
218                    content: collapsed_content,
219                }
220            }
221            message::Message::Assistant { content, .. } => {
222                let mut texts = Vec::new();
223                let mut tool_calls = Vec::new();
224
225                for ac in content.into_iter() {
226                    match ac {
227                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
228                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
229                        _ => {}
230                    }
231                }
232
233                let collapsed_content = texts.join(" ");
234
235                Message::Assistant {
236                    content: Some(collapsed_content),
237                    tool_calls,
238                }
239            }
240        })
241    }
242}
243
244impl From<message::ToolResult> for Message {
245    fn from(tool_result: message::ToolResult) -> Self {
246        let content = match tool_result.content.first() {
247            message::ToolResultContent::Text(text) => text.text,
248            message::ToolResultContent::Image(_) => String::from("[Image]"),
249        };
250
251        Message::ToolResult {
252            tool_call_id: tool_result.id,
253            content,
254        }
255    }
256}
257
258#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct ToolCall {
261    pub function: CallFunction,
262    pub id: String,
263    pub index: usize,
264    #[serde(default)]
265    pub r#type: ToolType,
266}
267
268impl From<message::ToolCall> for ToolCall {
269    fn from(tool_call: message::ToolCall) -> Self {
270        Self {
271            id: tool_call.id,
272            index: 0,
273            r#type: ToolType::Function,
274            function: CallFunction {
275                name: tool_call.function.name,
276                arguments: tool_call.function.arguments,
277            },
278        }
279    }
280}
281
282#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolType {
285    #[default]
286    Function,
287}
288
289#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
290pub struct CallFunction {
291    pub name: String,
292    #[serde(with = "json_utils::stringified_json")]
293    pub arguments: serde_json::Value,
294}
295
296#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
297#[serde(rename_all = "lowercase")]
298pub enum Role {
299    System,
300    User,
301    Assistant,
302}
303
304#[derive(Debug, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct Choice {
307    #[serde(rename = "finish_reason")]
308    pub finish_reason: String,
309    pub index: i64,
310    pub message: Message,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315pub struct Usage {
316    #[serde(rename = "completion_tokens")]
317    pub completion_tokens: i64,
318    #[serde(rename = "prompt_tokens")]
319    pub prompt_tokens: i64,
320    #[serde(rename = "total_tokens")]
321    pub total_tokens: i64,
322}
323
324impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
325    type Error = CompletionError;
326
327    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
328        let choice = response.choices.first().ok_or_else(|| {
329            CompletionError::ResponseError("Response contained no choices".to_owned())
330        })?;
331
332        match &choice.message {
333            Message::Assistant {
334                tool_calls,
335                content,
336            } => {
337                if !tool_calls.is_empty() {
338                    let tool_result = tool_calls
339                        .iter()
340                        .map(|call| {
341                            completion::AssistantContent::tool_call(
342                                &call.function.name,
343                                &call.function.name,
344                                call.function.arguments.clone(),
345                            )
346                        })
347                        .collect::<Vec<_>>();
348
349                    let choice = OneOrMany::many(tool_result).map_err(|_| {
350                        CompletionError::ResponseError(
351                            "Response contained no message or tool call (empty)".to_owned(),
352                        )
353                    })?;
354                    let usage = completion::Usage {
355                        input_tokens: response.usage.prompt_tokens as u64,
356                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
357                            as u64,
358                        total_tokens: response.usage.total_tokens as u64,
359                    };
360                    tracing::debug!("response choices: {:?}: ", choice);
361                    Ok(completion::CompletionResponse {
362                        choice,
363                        usage,
364                        raw_response: response,
365                    })
366                } else {
367                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
368                        text: content.clone().unwrap_or_else(|| "".to_owned()),
369                    }));
370                    let usage = completion::Usage {
371                        input_tokens: response.usage.prompt_tokens as u64,
372                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
373                            as u64,
374                        total_tokens: response.usage.total_tokens as u64,
375                    };
376                    Ok(completion::CompletionResponse {
377                        choice,
378                        usage,
379                        raw_response: response,
380                    })
381                }
382            }
383            // Message::Assistant { tool_calls } => {}
384            _ => Err(CompletionError::ResponseError(
385                "Chat response does not include an assistant message".into(),
386            )),
387        }
388    }
389}
390
391#[derive(Clone)]
392pub struct CompletionModel {
393    client: Client,
394    pub model: String,
395}
396
397// 函数定义
398#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
399#[serde(rename_all = "camelCase")]
400pub struct CustomFunctionDefinition {
401    #[serde(rename = "type")]
402    pub type_field: String,
403    pub function: Function,
404}
405
406#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
407#[serde(rename_all = "camelCase")]
408pub struct Function {
409    pub name: String,
410    pub description: String,
411    pub parameters: serde_json::Value,
412}
413
414impl CompletionModel {
415    pub fn new(client: Client, model: &str) -> Self {
416        Self {
417            client,
418            model: model.to_string(),
419        }
420    }
421
422    fn create_completion_request(
423        &self,
424        completion_request: CompletionRequest,
425    ) -> Result<Value, CompletionError> {
426        // 构建消息顺序(上下文、聊天历史、提示)
427        let mut partial_history = vec![];
428        if let Some(docs) = completion_request.normalized_documents() {
429            partial_history.push(docs);
430        }
431        partial_history.extend(completion_request.chat_history);
432
433        // 使用前言初始化完整历史(如果不存在则为空)
434        let mut full_history: Vec<Message> = completion_request
435            .preamble
436            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
437
438        // 转换并扩展其余历史
439        full_history.extend(
440            partial_history
441                .into_iter()
442                .map(message::Message::try_into)
443                .collect::<Result<Vec<Message>, _>>()?,
444        );
445
446        let request = if completion_request.tools.is_empty() {
447            json!({
448                "model": self.model,
449                "messages": full_history,
450                "temperature": completion_request.temperature,
451            })
452        } else {
453            // tools
454            let tools = completion_request
455                .tools
456                .into_iter()
457                .map(|item| {
458                    let custom_function = Function {
459                        name: item.name,
460                        description: item.description,
461                        parameters: item.parameters,
462                    };
463                    CustomFunctionDefinition {
464                        type_field: "function".to_string(),
465                        function: custom_function,
466                    }
467                })
468                .collect::<Vec<_>>();
469
470            tracing::debug!("tools: {:?}", tools);
471
472            json!({
473                "model": self.model,
474                "messages": full_history,
475                "temperature": completion_request.temperature,
476                "tools": tools,
477                "tool_choice": "auto",
478            })
479        };
480
481        let request = if let Some(params) = completion_request.additional_params {
482            json_utils::merge(request, params)
483        } else {
484            request
485        };
486
487        Ok(request)
488    }
489}
490
491/// 同步请求
492impl completion::CompletionModel for CompletionModel {
493    type Response = CompletionResponse;
494    type StreamingResponse = openai::StreamingCompletionResponse;
495    type Client = Client;
496
497    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
498        Self::new(client.clone(), &model.into())
499    }
500
501    async fn completion(
502        &self,
503        completion_request: CompletionRequest,
504    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
505        tracing::debug!("create_completion_request========");
506        let request = self.create_completion_request(completion_request)?;
507
508        tracing::debug!(
509            "request: \r\n {}",
510            serde_json::to_string_pretty(&request).unwrap()
511        );
512
513        let response = self
514            .client
515            .post("/chat/completions")
516            .json(&request)
517            .send()
518            .await
519            .map_err(|e| http_client::Error::Instance(e.into()))?;
520
521        if response.status().is_success() {
522            let data: Value = response.json().await.expect("api error");
523            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
524            let data: ApiResponse<CompletionResponse> =
525                serde_json::from_value(data).expect("deserialize completion response");
526            match data {
527                ApiResponse::Ok(response) => {
528                    tracing::info!(target: "rig",
529                        "bigmodel completion token usage: {:?}",
530                        response.usage
531                    );
532                    response.try_into()
533                }
534                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
535            }
536        } else {
537            Err(CompletionError::ProviderError(
538                response
539                    .text()
540                    .await
541                    .map_err(|e| http_client::Error::Instance(e.into()))?,
542            ))
543        }
544    }
545
546    async fn stream(
547        &self,
548        request: CompletionRequest,
549    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
550        let preamble = request.preamble.clone();
551
552        let mut request = self.create_completion_request(request)?;
553
554        request = merge(request, json!({"stream": true}));
555
556        let body = serde_json::to_vec(&request)?;
557
558        let url = format!(
559            "{}/{}",
560            self.client.base_url,
561            "/chat/completions".trim_start_matches('/')
562        );
563
564        let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
565        for (header, value) in &self.client.default_headers {
566            builder = builder.header(header, value);
567        }
568
569        let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
570            .map_err(http::Error::from)
571            .map_err(rig::http_client::Error::from)?;
572
573        builder = builder.header(header::AUTHORIZATION, auth_header);
574        builder = builder.header("Content-Type", "application/json");
575
576        let req = builder
577            .body(body)
578            .map_err(|e| CompletionError::HttpError(e.into()))?;
579
580        let span = if tracing::Span::current().is_disabled() {
581            info_span!(
582                target: "rig::completions",
583                "chat_streaming",
584                gen_ai.operation.name = "chat_streaming",
585                gen_ai.provider.name = "galadriel",
586                gen_ai.request.model = self.model,
587                gen_ai.system_instructions = preamble,
588                gen_ai.response.id = tracing::field::Empty,
589                gen_ai.response.model = tracing::field::Empty,
590                gen_ai.usage.output_tokens = tracing::field::Empty,
591                gen_ai.usage.input_tokens = tracing::field::Empty,
592                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
593                gen_ai.output.messages = tracing::field::Empty,
594            )
595        } else {
596            tracing::Span::current()
597        };
598
599        send_compatible_streaming_request(self.client.http_client.clone(), req)
600            .instrument(span)
601            .await
602    }
603}