rig_extra/extra_providers/
bigmodel.rs

1use http::{HeaderValue, Method, header};
2use rig::client::{CompletionClient, ProviderClient};
3use rig::completion::{CompletionError, CompletionRequest};
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, http_client, message};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9
10use crate::json_utils;
11use crate::json_utils::merge;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use tracing::{Instrument, info_span};
15
16// ================================================================
17// BIGMODEL 客户端
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    api_key: String,
24    base_url: String,
25    default_headers: http_client::HeaderMap,
26    http_client: reqwest::Client,
27}
28
29impl Client {
30    pub fn new(api_key: &str) -> Self {
31        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
32    }
33
34    pub fn from_url(api_key: &str, base_url: &str) -> Self {
35        let mut default_headers = reqwest::header::HeaderMap::new();
36        default_headers.insert(
37            reqwest::header::CONTENT_TYPE,
38            "application/json".parse().unwrap(),
39        );
40
41        Self {
42            api_key: api_key.to_string(),
43            base_url: base_url.to_string(),
44            default_headers,
45            http_client: reqwest::Client::builder()
46                .default_headers({
47                    let mut headers = reqwest::header::HeaderMap::new();
48                    headers.insert(
49                        "Authorization",
50                        format!("Bearer {api_key}")
51                            .parse()
52                            .expect("Bearer token should parse"),
53                    );
54                    headers
55                })
56                .build()
57                .expect("bigmodel reqwest client should build"),
58        }
59    }
60
61    fn post(&self, path: &str) -> reqwest::RequestBuilder {
62        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
63        self.http_client.post(url)
64    }
65
66    pub fn completion_model(&self, model: &str) -> CompletionModel {
67        CompletionModel::new(self.clone(), model)
68    }
69
70    // 为completion模型创建提取构建器
71    // pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
72    //     &self,
73    //     model: &str,
74    // ) -> ExtractorBuilder<T, CompletionModel> {
75    //     ExtractorBuilder::new(self.completion_model(model))
76    // }
77}
78
79impl ProviderClient for Client {
80    type Input = String;
81
82    fn from_env() -> Self
83    where
84        Self: Sized,
85    {
86        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87        Self::new(&api_key)
88    }
89
90    fn from_val(input: Self::Input) -> Self {
91        Self::new(&input)
92    }
93}
94impl CompletionClient for Client {
95    type CompletionModel = CompletionModel;
96
97    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
98        CompletionModel::new(self.clone(), &model.into())
99    }
100}
101
102#[derive(Debug, Deserialize)]
103struct ApiErrorResponse {
104    message: String,
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(untagged)]
109enum ApiResponse<T> {
110    Ok(T),
111    Err(ApiErrorResponse),
112}
113
114// ================================================================
115// Bigmodel Completion API
116// ================================================================
117pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
118pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
119
120#[derive(Debug, Deserialize, Serialize)]
121#[serde(rename_all = "camelCase")]
122pub struct CompletionResponse {
123    pub choices: Vec<Choice>,
124    pub created: i64,
125    pub id: String,
126    pub model: String,
127    #[serde(rename = "request_id")]
128    pub request_id: String,
129    pub usage: Usage,
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133#[serde(tag = "role", rename_all = "lowercase")]
134pub enum Message {
135    User {
136        content: String,
137    },
138    Assistant {
139        content: Option<String>,
140        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
141        tool_calls: Vec<ToolCall>,
142    },
143    System {
144        content: String,
145    },
146    #[serde(rename = "tool")]
147    ToolResult {
148        tool_call_id: String,
149        content: String,
150    },
151}
152
153impl Message {
154    pub fn system(content: &str) -> Message {
155        Message::System {
156            content: content.to_owned(),
157        }
158    }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct ToolResultContent {
163    text: String,
164}
165impl TryFrom<message::ToolResultContent> for ToolResultContent {
166    type Error = MessageError;
167    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
168        let message::ToolResultContent::Text(Text { text }) = value else {
169            return Err(MessageError::ConversionError(
170                "Non-text tool results not supported".into(),
171            ));
172        };
173
174        Ok(Self { text })
175    }
176}
177
178impl TryFrom<message::Message> for Message {
179    type Error = MessageError;
180
181    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
182        Ok(match message {
183            message::Message::User { content } => {
184                let mut texts = Vec::new();
185                let mut images = Vec::new();
186
187                for uc in content.into_iter() {
188                    match uc {
189                        message::UserContent::Text(message::Text { text }) => texts.push(text),
190                        message::UserContent::Image(img) => images.push(img.data),
191                        message::UserContent::ToolResult(result) => {
192                            let content = result
193                                .content
194                                .into_iter()
195                                .map(ToolResultContent::try_from)
196                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
197
198                            let content = OneOrMany::many(content).map_err(|x| {
199                                MessageError::ConversionError(format!(
200                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
201                                ))
202                            })?;
203
204                            return Ok(Message::ToolResult {
205                                tool_call_id: result.id,
206                                content: content.first().text,
207                            });
208                        }
209                        _ => {}
210                    }
211                }
212
213                let collapsed_content = texts.join(" ");
214
215                Message::User {
216                    content: collapsed_content,
217                }
218            }
219            message::Message::Assistant { content, .. } => {
220                let mut texts = Vec::new();
221                let mut tool_calls = Vec::new();
222
223                for ac in content.into_iter() {
224                    match ac {
225                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
226                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
227                        _ => {}
228                    }
229                }
230
231                let collapsed_content = texts.join(" ");
232
233                Message::Assistant {
234                    content: Some(collapsed_content),
235                    tool_calls,
236                }
237            }
238        })
239    }
240}
241
242impl From<message::ToolResult> for Message {
243    fn from(tool_result: message::ToolResult) -> Self {
244        let content = match tool_result.content.first() {
245            message::ToolResultContent::Text(text) => text.text,
246            message::ToolResultContent::Image(_) => String::from("[Image]"),
247        };
248
249        Message::ToolResult {
250            tool_call_id: tool_result.id,
251            content,
252        }
253    }
254}
255
256#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct ToolCall {
259    pub function: CallFunction,
260    pub id: String,
261    pub index: usize,
262    #[serde(default)]
263    pub r#type: ToolType,
264}
265
266impl From<message::ToolCall> for ToolCall {
267    fn from(tool_call: message::ToolCall) -> Self {
268        Self {
269            id: tool_call.id,
270            index: 0,
271            r#type: ToolType::Function,
272            function: CallFunction {
273                name: tool_call.function.name,
274                arguments: tool_call.function.arguments,
275            },
276        }
277    }
278}
279
280#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
281#[serde(rename_all = "lowercase")]
282pub enum ToolType {
283    #[default]
284    Function,
285}
286
287#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
288pub struct CallFunction {
289    pub name: String,
290    #[serde(with = "json_utils::stringified_json")]
291    pub arguments: serde_json::Value,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
295#[serde(rename_all = "lowercase")]
296pub enum Role {
297    System,
298    User,
299    Assistant,
300}
301
302#[derive(Debug, Serialize, Deserialize)]
303#[serde(rename_all = "camelCase")]
304pub struct Choice {
305    #[serde(rename = "finish_reason")]
306    pub finish_reason: String,
307    pub index: i64,
308    pub message: Message,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312#[serde(rename_all = "camelCase")]
313pub struct Usage {
314    #[serde(rename = "completion_tokens")]
315    pub completion_tokens: i64,
316    #[serde(rename = "prompt_tokens")]
317    pub prompt_tokens: i64,
318    #[serde(rename = "total_tokens")]
319    pub total_tokens: i64,
320}
321
322impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
323    type Error = CompletionError;
324
325    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
326        let choice = response.choices.first().ok_or_else(|| {
327            CompletionError::ResponseError("Response contained no choices".to_owned())
328        })?;
329
330        match &choice.message {
331            Message::Assistant {
332                tool_calls,
333                content,
334            } => {
335                if !tool_calls.is_empty() {
336                    let tool_result = tool_calls
337                        .iter()
338                        .map(|call| {
339                            completion::AssistantContent::tool_call(
340                                &call.function.name,
341                                &call.function.name,
342                                call.function.arguments.clone(),
343                            )
344                        })
345                        .collect::<Vec<_>>();
346
347                    let choice = OneOrMany::many(tool_result).map_err(|_| {
348                        CompletionError::ResponseError(
349                            "Response contained no message or tool call (empty)".to_owned(),
350                        )
351                    })?;
352                    let usage = completion::Usage {
353                        input_tokens: response.usage.prompt_tokens as u64,
354                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
355                            as u64,
356                        total_tokens: response.usage.total_tokens as u64,
357                    };
358                    tracing::debug!("response choices: {:?}: ", choice);
359                    Ok(completion::CompletionResponse {
360                        choice,
361                        usage,
362                        raw_response: response,
363                    })
364                } else {
365                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
366                        text: content.clone().unwrap_or_else(|| "".to_owned()),
367                    }));
368                    let usage = completion::Usage {
369                        input_tokens: response.usage.prompt_tokens as u64,
370                        output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
371                            as u64,
372                        total_tokens: response.usage.total_tokens as u64,
373                    };
374                    Ok(completion::CompletionResponse {
375                        choice,
376                        usage,
377                        raw_response: response,
378                    })
379                }
380            }
381            // Message::Assistant { tool_calls } => {}
382            _ => Err(CompletionError::ResponseError(
383                "Chat response does not include an assistant message".into(),
384            )),
385        }
386    }
387}
388
389#[derive(Clone)]
390pub struct CompletionModel {
391    client: Client,
392    pub model: String,
393}
394
395// 函数定义
396#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
397#[serde(rename_all = "camelCase")]
398pub struct CustomFunctionDefinition {
399    #[serde(rename = "type")]
400    pub type_field: String,
401    pub function: Function,
402}
403
404#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
405#[serde(rename_all = "camelCase")]
406pub struct Function {
407    pub name: String,
408    pub description: String,
409    pub parameters: serde_json::Value,
410}
411
412impl CompletionModel {
413    pub fn new(client: Client, model: &str) -> Self {
414        Self {
415            client,
416            model: model.to_string(),
417        }
418    }
419
420    fn create_completion_request(
421        &self,
422        completion_request: CompletionRequest,
423    ) -> Result<Value, CompletionError> {
424        // 构建消息顺序(上下文、聊天历史、提示)
425        let mut partial_history = vec![];
426        if let Some(docs) = completion_request.normalized_documents() {
427            partial_history.push(docs);
428        }
429        partial_history.extend(completion_request.chat_history);
430
431        // 使用前言初始化完整历史(如果不存在则为空)
432        let mut full_history: Vec<Message> = completion_request
433            .preamble
434            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
435
436        // 转换并扩展其余历史
437        full_history.extend(
438            partial_history
439                .into_iter()
440                .map(message::Message::try_into)
441                .collect::<Result<Vec<Message>, _>>()?,
442        );
443
444        let request = if completion_request.tools.is_empty() {
445            json!({
446                "model": self.model,
447                "messages": full_history,
448                "temperature": completion_request.temperature,
449            })
450        } else {
451            // tools
452            let tools = completion_request
453                .tools
454                .into_iter()
455                .map(|item| {
456                    let custom_function = Function {
457                        name: item.name,
458                        description: item.description,
459                        parameters: item.parameters,
460                    };
461                    CustomFunctionDefinition {
462                        type_field: "function".to_string(),
463                        function: custom_function,
464                    }
465                })
466                .collect::<Vec<_>>();
467
468            tracing::debug!("tools: {:?}", tools);
469
470            json!({
471                "model": self.model,
472                "messages": full_history,
473                "temperature": completion_request.temperature,
474                "tools": tools,
475                "tool_choice": "auto",
476            })
477        };
478
479        let request = if let Some(params) = completion_request.additional_params {
480            json_utils::merge(request, params)
481        } else {
482            request
483        };
484
485        Ok(request)
486    }
487}
488
489/// 同步请求
490impl completion::CompletionModel for CompletionModel {
491    type Response = CompletionResponse;
492    type StreamingResponse = openai::StreamingCompletionResponse;
493    type Client = Client;
494
495    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
496        Self::new(client.clone(), &model.into())
497    }
498
499    async fn completion(
500        &self,
501        completion_request: CompletionRequest,
502    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
503        tracing::debug!("create_completion_request========");
504        let request = self.create_completion_request(completion_request)?;
505
506        tracing::debug!(
507            "request: \r\n {}",
508            serde_json::to_string_pretty(&request).unwrap()
509        );
510
511        let response = self
512            .client
513            .post("/chat/completions")
514            .json(&request)
515            .send()
516            .await
517            .map_err(|e| http_client::Error::Instance(e.into()))?;
518
519        if response.status().is_success() {
520            let data: Value = response.json().await.expect("api error");
521            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
522            let data: ApiResponse<CompletionResponse> =
523                serde_json::from_value(data).expect("deserialize completion response");
524            match data {
525                ApiResponse::Ok(response) => {
526                    tracing::info!(target: "rig",
527                        "bigmodel completion token usage: {:?}",
528                        response.usage
529                    );
530                    response.try_into()
531                }
532                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
533            }
534        } else {
535            Err(CompletionError::ProviderError(
536                response
537                    .text()
538                    .await
539                    .map_err(|e| http_client::Error::Instance(e.into()))?,
540            ))
541        }
542    }
543
544    async fn stream(
545        &self,
546        request: CompletionRequest,
547    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
548        let preamble = request.preamble.clone();
549
550        let mut request = self.create_completion_request(request)?;
551
552        request = merge(request, json!({"stream": true}));
553
554        let body = serde_json::to_vec(&request)?;
555
556        let url = format!(
557            "{}/{}",
558            self.client.base_url,
559            "/chat/completions".trim_start_matches('/')
560        );
561
562        let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
563        for (header, value) in &self.client.default_headers {
564            builder = builder.header(header, value);
565        }
566
567        let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
568            .map_err(http::Error::from)
569            .map_err(rig::http_client::Error::from)?;
570
571        builder = builder.header(header::AUTHORIZATION, auth_header);
572        builder = builder.header("Content-Type", "application/json");
573
574        let req = builder
575            .body(body)
576            .map_err(|e| CompletionError::HttpError(e.into()))?;
577
578        let span = if tracing::Span::current().is_disabled() {
579            info_span!(
580                target: "rig::completions",
581                "chat_streaming",
582                gen_ai.operation.name = "chat_streaming",
583                gen_ai.provider.name = "galadriel",
584                gen_ai.request.model = self.model,
585                gen_ai.system_instructions = preamble,
586                gen_ai.response.id = tracing::field::Empty,
587                gen_ai.response.model = tracing::field::Empty,
588                gen_ai.usage.output_tokens = tracing::field::Empty,
589                gen_ai.usage.input_tokens = tracing::field::Empty,
590                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
591                gen_ai.output.messages = tracing::field::Empty,
592            )
593        } else {
594            tracing::Span::current()
595        };
596
597        send_compatible_streaming_request(self.client.http_client.clone(), req)
598            .instrument(span)
599            .await
600    }
601}