rig_extra/extra_providers/
bigmodel.rs

1use http::header;
2use reqwest::Method;
3use reqwest::header::HeaderValue;
4use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
5use rig::completion::{CompletionError, CompletionRequest};
6use rig::message::{MessageError, Text};
7use rig::providers::openai;
8use rig::{OneOrMany, client, completion, http_client, message};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::json_utils;
13use crate::json_utils::merge;
14use rig::providers::openai::send_compatible_streaming_request;
15use rig::streaming::StreamingCompletionResponse;
16use tracing::{Instrument, info_span};
17
18// ================================================================
19// BIGMODEL 客户端
20// ================================================================
21const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Clone, Debug)]
24pub struct Client {
25    api_key: String,
26    base_url: String,
27    default_headers: http_client::HeaderMap,
28    http_client: reqwest::Client,
29}
30
31impl Client {
32    pub fn new(api_key: &str) -> Self {
33        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
34    }
35
36    pub fn from_url(api_key: &str, base_url: &str) -> Self {
37        let mut default_headers = reqwest::header::HeaderMap::new();
38        default_headers.insert(
39            reqwest::header::CONTENT_TYPE,
40            "application/json".parse().unwrap(),
41        );
42
43        Self {
44            api_key: api_key.to_string(),
45            base_url: base_url.to_string(),
46            default_headers,
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        "Authorization",
52                        format!("Bearer {api_key}")
53                            .parse()
54                            .expect("Bearer token should parse"),
55                    );
56                    headers
57                })
58                .build()
59                .expect("bigmodel reqwest client should build"),
60        }
61    }
62
63    fn post(&self, path: &str) -> reqwest::RequestBuilder {
64        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65        self.http_client.post(url)
66    }
67
68    pub fn completion_model(&self, model: &str) -> CompletionModel {
69        CompletionModel::new(self.clone(), model)
70    }
71
72    // 为completion模型创建提取构建器
73    // pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
74    //     &self,
75    //     model: &str,
76    // ) -> ExtractorBuilder<T, CompletionModel> {
77    //     ExtractorBuilder::new(self.completion_model(model))
78    // }
79}
80
81impl ProviderClient for Client {
82    fn from_env() -> Self
83    where
84        Self: Sized,
85    {
86        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87        Self::new(&api_key)
88    }
89
90    fn from_val(input: ProviderValue) -> Self
91    where
92        Self: Sized,
93    {
94        let client::ProviderValue::Simple(api_key) = input else {
95            panic!("Incorrect provider value type")
96        };
97        Self::new(&api_key)
98    }
99}
100
101impl AsTranscription for Client {}
102
103impl AsEmbeddings for Client {}
104
105impl CompletionClient for Client {
106    type CompletionModel = CompletionModel;
107
108    fn completion_model(&self, model: &str) -> Self::CompletionModel {
109        CompletionModel::new(self.clone(), model)
110    }
111}
112
113#[derive(Debug, Deserialize)]
114struct ApiErrorResponse {
115    message: String,
116}
117
118#[derive(Debug, Deserialize)]
119#[serde(untagged)]
120enum ApiResponse<T> {
121    Ok(T),
122    Err(ApiErrorResponse),
123}
124
125// ================================================================
126// Bigmodel Completion API
127// ================================================================
128pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
129pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(rename_all = "camelCase")]
133pub struct CompletionResponse {
134    pub choices: Vec<Choice>,
135    pub created: i64,
136    pub id: String,
137    pub model: String,
138    #[serde(rename = "request_id")]
139    pub request_id: String,
140    pub usage: Usage,
141}
142
143#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
144#[serde(tag = "role", rename_all = "lowercase")]
145pub enum Message {
146    User {
147        content: String,
148    },
149    Assistant {
150        content: Option<String>,
151        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
152        tool_calls: Vec<ToolCall>,
153    },
154    System {
155        content: String,
156    },
157    #[serde(rename = "tool")]
158    ToolResult {
159        tool_call_id: String,
160        content: String,
161    },
162}
163
164impl Message {
165    pub fn system(content: &str) -> Message {
166        Message::System {
167            content: content.to_owned(),
168        }
169    }
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173pub struct ToolResultContent {
174    text: String,
175}
176impl TryFrom<message::ToolResultContent> for ToolResultContent {
177    type Error = MessageError;
178    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
179        let message::ToolResultContent::Text(Text { text }) = value else {
180            return Err(MessageError::ConversionError(
181                "Non-text tool results not supported".into(),
182            ));
183        };
184
185        Ok(Self { text })
186    }
187}
188
189impl TryFrom<message::Message> for Message {
190    type Error = MessageError;
191
192    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
193        Ok(match message {
194            message::Message::User { content } => {
195                let mut texts = Vec::new();
196                let mut images = Vec::new();
197
198                for uc in content.into_iter() {
199                    match uc {
200                        message::UserContent::Text(message::Text { text }) => texts.push(text),
201                        message::UserContent::Image(img) => images.push(img.data),
202                        message::UserContent::ToolResult(result) => {
203                            let content = result
204                                .content
205                                .into_iter()
206                                .map(ToolResultContent::try_from)
207                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
208
209                            let content = OneOrMany::many(content).map_err(|x| {
210                                MessageError::ConversionError(format!(
211                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
212                                ))
213                            })?;
214
215                            return Ok(Message::ToolResult {
216                                tool_call_id: result.id,
217                                content: content.first().text,
218                            });
219                        }
220                        _ => {}
221                    }
222                }
223
224                let collapsed_content = texts.join(" ");
225
226                Message::User {
227                    content: collapsed_content,
228                }
229            }
230            message::Message::Assistant { content, .. } => {
231                let mut texts = Vec::new();
232                let mut tool_calls = Vec::new();
233
234                for ac in content.into_iter() {
235                    match ac {
236                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
237                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
238                        _ => {}
239                    }
240                }
241
242                let collapsed_content = texts.join(" ");
243
244                Message::Assistant {
245                    content: Some(collapsed_content),
246                    tool_calls,
247                }
248            }
249        })
250    }
251}
252
253impl From<message::ToolResult> for Message {
254    fn from(tool_result: message::ToolResult) -> Self {
255        let content = match tool_result.content.first() {
256            message::ToolResultContent::Text(text) => text.text,
257            message::ToolResultContent::Image(_) => String::from("[Image]"),
258        };
259
260        Message::ToolResult {
261            tool_call_id: tool_result.id,
262            content,
263        }
264    }
265}
266
267#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
268#[serde(rename_all = "camelCase")]
269pub struct ToolCall {
270    pub function: CallFunction,
271    pub id: String,
272    pub index: usize,
273    #[serde(default)]
274    pub r#type: ToolType,
275}
276
277impl From<message::ToolCall> for ToolCall {
278    fn from(tool_call: message::ToolCall) -> Self {
279        Self {
280            id: tool_call.id,
281            index: 0,
282            r#type: ToolType::Function,
283            function: CallFunction {
284                name: tool_call.function.name,
285                arguments: tool_call.function.arguments,
286            },
287        }
288    }
289}
290
291#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
292#[serde(rename_all = "lowercase")]
293pub enum ToolType {
294    #[default]
295    Function,
296}
297
298#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
299pub struct CallFunction {
300    pub name: String,
301    #[serde(with = "json_utils::stringified_json")]
302    pub arguments: serde_json::Value,
303}
304
305#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
306#[serde(rename_all = "lowercase")]
307pub enum Role {
308    System,
309    User,
310    Assistant,
311}
312
313#[derive(Debug, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315pub struct Choice {
316    #[serde(rename = "finish_reason")]
317    pub finish_reason: String,
318    pub index: i64,
319    pub message: Message,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323#[serde(rename_all = "camelCase")]
324pub struct Usage {
325    #[serde(rename = "completion_tokens")]
326    pub completion_tokens: i64,
327    #[serde(rename = "prompt_tokens")]
328    pub prompt_tokens: i64,
329    #[serde(rename = "total_tokens")]
330    pub total_tokens: i64,
331}
332
333impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
334    type Error = CompletionError;
335
336    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
337        let choice = response.choices.first().ok_or_else(|| {
338            CompletionError::ResponseError("Response contained no choices".to_owned())
339        })?;
340
341        match &choice.message {
342            Message::Assistant {
343                tool_calls,
344                content,
345            } => {
346                if !tool_calls.is_empty() {
347                    let tool_result = tool_calls
348                        .iter()
349                        .map(|call| {
350                            completion::AssistantContent::tool_call(
351                                &call.function.name,
352                                &call.function.name,
353                                call.function.arguments.clone(),
354                            )
355                        })
356                        .collect::<Vec<_>>();
357
358                    let choice = OneOrMany::many(tool_result).map_err(|_| {
359                        CompletionError::ResponseError(
360                            "Response contained no message or tool call (empty)".to_owned(),
361                        )
362                    })?;
363                    let usage = completion::Usage {
364                        input_tokens: response.usage.prompt_tokens as u64,
365                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
366                            as u64,
367                        total_tokens: response.usage.total_tokens as u64,
368                    };
369                    tracing::debug!("response choices: {:?}: ", choice);
370                    Ok(completion::CompletionResponse {
371                        choice,
372                        usage,
373                        raw_response: response,
374                    })
375                } else {
376                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
377                        text: content.clone().unwrap_or_else(|| "".to_owned()),
378                    }));
379                    let usage = completion::Usage {
380                        input_tokens: response.usage.prompt_tokens as u64,
381                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
382                            as u64,
383                        total_tokens: response.usage.total_tokens as u64,
384                    };
385                    Ok(completion::CompletionResponse {
386                        choice,
387                        usage,
388                        raw_response: response,
389                    })
390                }
391            }
392            // Message::Assistant { tool_calls } => {}
393            _ => Err(CompletionError::ResponseError(
394                "Chat response does not include an assistant message".into(),
395            )),
396        }
397    }
398}
399
400#[derive(Clone)]
401pub struct CompletionModel {
402    client: Client,
403    pub model: String,
404}
405
406// 函数定义
407#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
408#[serde(rename_all = "camelCase")]
409pub struct CustomFunctionDefinition {
410    #[serde(rename = "type")]
411    pub type_field: String,
412    pub function: Function,
413}
414
415#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
416#[serde(rename_all = "camelCase")]
417pub struct Function {
418    pub name: String,
419    pub description: String,
420    pub parameters: serde_json::Value,
421}
422
423impl CompletionModel {
424    pub fn new(client: Client, model: &str) -> Self {
425        Self {
426            client,
427            model: model.to_string(),
428        }
429    }
430
431    fn create_completion_request(
432        &self,
433        completion_request: CompletionRequest,
434    ) -> Result<Value, CompletionError> {
435        // 构建消息顺序(上下文、聊天历史、提示)
436        let mut partial_history = vec![];
437        if let Some(docs) = completion_request.normalized_documents() {
438            partial_history.push(docs);
439        }
440        partial_history.extend(completion_request.chat_history);
441
442        // 使用前言初始化完整历史(如果不存在则为空)
443        let mut full_history: Vec<Message> = completion_request
444            .preamble
445            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
446
447        // 转换并扩展其余历史
448        full_history.extend(
449            partial_history
450                .into_iter()
451                .map(message::Message::try_into)
452                .collect::<Result<Vec<Message>, _>>()?,
453        );
454
455        let request = if completion_request.tools.is_empty() {
456            json!({
457                "model": self.model,
458                "messages": full_history,
459                "temperature": completion_request.temperature,
460            })
461        } else {
462            // tools
463            let tools = completion_request
464                .tools
465                .into_iter()
466                .map(|item| {
467                    let custom_function = Function {
468                        name: item.name,
469                        description: item.description,
470                        parameters: item.parameters,
471                    };
472                    CustomFunctionDefinition {
473                        type_field: "function".to_string(),
474                        function: custom_function,
475                    }
476                })
477                .collect::<Vec<_>>();
478
479            tracing::debug!("tools: {:?}", tools);
480
481            json!({
482                "model": self.model,
483                "messages": full_history,
484                "temperature": completion_request.temperature,
485                "tools": tools,
486                "tool_choice": "auto",
487            })
488        };
489
490        let request = if let Some(params) = completion_request.additional_params {
491            json_utils::merge(request, params)
492        } else {
493            request
494        };
495
496        Ok(request)
497    }
498}
499
500/// 同步请求
501impl completion::CompletionModel for CompletionModel {
502    type Response = CompletionResponse;
503    type StreamingResponse = openai::StreamingCompletionResponse;
504
505    async fn completion(
506        &self,
507        completion_request: CompletionRequest,
508    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
509        tracing::debug!("create_completion_request========");
510        let request = self.create_completion_request(completion_request)?;
511
512        tracing::debug!(
513            "request: \r\n {}",
514            serde_json::to_string_pretty(&request).unwrap()
515        );
516
517        let response = self
518            .client
519            .post("/chat/completions")
520            .json(&request)
521            .send()
522            .await
523            .map_err(|e| http_client::Error::Instance(e.into()))?;
524
525        if response.status().is_success() {
526            let data: Value = response.json().await.expect("api error");
527            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
528            let data: ApiResponse<CompletionResponse> =
529                serde_json::from_value(data).expect("deserialize completion response");
530            match data {
531                ApiResponse::Ok(response) => {
532                    tracing::info!(target: "rig",
533                        "bigmodel completion token usage: {:?}",
534                        response.usage
535                    );
536                    response.try_into()
537                }
538                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
539            }
540        } else {
541            Err(CompletionError::ProviderError(
542                response
543                    .text()
544                    .await
545                    .map_err(|e| http_client::Error::Instance(e.into()))?,
546            ))
547        }
548    }
549
550    async fn stream(
551        &self,
552        request: CompletionRequest,
553    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
554        let preamble = request.preamble.clone();
555
556        let mut request = self.create_completion_request(request)?;
557
558        request = merge(request, json!({"stream": true}));
559
560        let body = serde_json::to_vec(&request)?;
561
562        let url = format!(
563            "{}/{}",
564            self.client.base_url,
565            "/chat/completions".trim_start_matches('/')
566        );
567
568        let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
569        for (header, value) in &self.client.default_headers {
570            builder = builder.header(header, value);
571        }
572
573        let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
574            .map_err(http::Error::from)
575            .map_err(rig::http_client::Error::from)?;
576
577        builder = builder.header(header::AUTHORIZATION, auth_header);
578        builder = builder.header("Content-Type", "application/json");
579
580        let req = builder
581            .body(body)
582            .map_err(|e| CompletionError::HttpError(e.into()))?;
583
584        let span = if tracing::Span::current().is_disabled() {
585            info_span!(
586                target: "rig::completions",
587                "chat_streaming",
588                gen_ai.operation.name = "chat_streaming",
589                gen_ai.provider.name = "galadriel",
590                gen_ai.request.model = self.model,
591                gen_ai.system_instructions = preamble,
592                gen_ai.response.id = tracing::field::Empty,
593                gen_ai.response.model = tracing::field::Empty,
594                gen_ai.usage.output_tokens = tracing::field::Empty,
595                gen_ai.usage.input_tokens = tracing::field::Empty,
596                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
597                gen_ai.output.messages = tracing::field::Empty,
598            )
599        } else {
600            tracing::Span::current()
601        };
602
603        send_compatible_streaming_request(self.client.http_client.clone(), req)
604            .instrument(span)
605            .await
606    }
607}