rig_extra/extra_providers/
bigmodel.rs

1use rig::agent::AgentBuilder;
2use rig::completion::{CompletionError, CompletionRequest};
3use rig::extractor::ExtractorBuilder;
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, message};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16// ================================================================
17// Main BIGMODEL Client
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    base_url: String,
24    http_client: reqwest::Client,
25}
26
27impl Client {
28    pub fn new(api_key: &str) -> Self {
29        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30    }
31
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            http_client: reqwest::Client::builder()
36                .default_headers({
37                    let mut headers = reqwest::header::HeaderMap::new();
38                    headers.insert(
39                        "Authorization",
40                        format!("Bearer {}", api_key)
41                            .parse()
42                            .expect("Bearer token should parse"),
43                    );
44                    headers
45                })
46                .build()
47                .expect("bigmodel reqwest client should build"),
48        }
49    }
50
51    pub fn from_env() -> Self {
52        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
53        Self::new(&api_key)
54    }
55
56    fn post(&self, path: &str) -> reqwest::RequestBuilder {
57        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58        self.http_client.post(url)
59    }
60
61    pub fn completion_model(&self, model: &str) -> CompletionModel {
62        CompletionModel::new(self.clone(), model)
63    }
64
65    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
66        AgentBuilder::new(self.completion_model(model))
67    }
68
69    /// Create an extractor builder with the given completion model.
70    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
71        &self,
72        model: &str,
73    ) -> ExtractorBuilder<T, CompletionModel> {
74        ExtractorBuilder::new(self.completion_model(model))
75    }
76}
77
78#[derive(Debug, Deserialize)]
79struct ApiErrorResponse {
80    message: String,
81}
82
83#[derive(Debug, Deserialize)]
84#[serde(untagged)]
85enum ApiResponse<T> {
86    Ok(T),
87    Err(ApiErrorResponse),
88}
89
90// ================================================================
91// Bigmodel Completion API
92// ================================================================
93pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
94
95#[derive(Debug, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct CompletionResponse {
98    pub choices: Vec<Choice>,
99    pub created: i64,
100    pub id: String,
101    pub model: String,
102    #[serde(rename = "request_id")]
103    pub request_id: String,
104    pub usage: Usage,
105}
106
107#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
108#[serde(tag = "role", rename_all = "lowercase")]
109pub enum Message {
110    User {
111        content: String,
112    },
113    Assistant {
114        content: Option<String>,
115        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
116        tool_calls: Vec<ToolCall>,
117    },
118    System {
119        content: String,
120    },
121    #[serde(rename = "tool")]
122    ToolResult {
123        tool_call_id: String,
124        content: String,
125    },
126}
127
128impl Message {
129    pub fn system(content: &str) -> Message {
130        Message::System {
131            content: content.to_owned(),
132        }
133    }
134}
135
136#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
137pub struct ToolResultContent {
138    text: String,
139}
140impl TryFrom<message::ToolResultContent> for ToolResultContent {
141    type Error = MessageError;
142    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
143        let message::ToolResultContent::Text(Text { text }) = value else {
144            return Err(MessageError::ConversionError(
145                "Non-text tool results not supported".into(),
146            ));
147        };
148
149        Ok(Self { text })
150    }
151}
152
153impl TryFrom<message::Message> for Message {
154    type Error = MessageError;
155
156    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
157        Ok(match message {
158            message::Message::User { content } => {
159                let mut texts = Vec::new();
160                let mut images = Vec::new();
161
162                for uc in content.into_iter() {
163                    match uc {
164                        message::UserContent::Text(message::Text { text }) => texts.push(text),
165                        message::UserContent::Image(img) => images.push(img.data),
166                        message::UserContent::ToolResult(result) => {
167                            let content = result
168                                .content
169                                .into_iter()
170                                .map(ToolResultContent::try_from)
171                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
172
173                            let content = OneOrMany::many(content).map_err(|x| {
174                                MessageError::ConversionError(format!(
175                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
176                                ))
177                            })?;
178
179                            return Ok(Message::ToolResult {
180                                tool_call_id: result.id,
181                                content: content.first().text,
182                            });
183                        }
184                        _ => {}
185                    }
186                }
187
188                let collapsed_content = texts.join(" ");
189
190                Message::User {
191                    content: collapsed_content,
192                }
193            }
194            message::Message::Assistant { content } => {
195                let mut texts = Vec::new();
196                let mut tool_calls = Vec::new();
197
198                for ac in content.into_iter() {
199                    match ac {
200                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
201                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
202                    }
203                }
204
205                let collapsed_content = texts.join(" ");
206
207                Message::Assistant {
208                    content: Some(collapsed_content),
209                    tool_calls,
210                }
211            }
212        })
213    }
214}
215
216impl From<message::ToolResult> for Message {
217    fn from(tool_result: message::ToolResult) -> Self {
218        let content = match tool_result.content.first() {
219            message::ToolResultContent::Text(text) => text.text,
220            message::ToolResultContent::Image(_) => String::from("[Image]"),
221        };
222
223        Message::ToolResult {
224            tool_call_id: tool_result.id,
225            content,
226        }
227    }
228}
229
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231#[serde(rename_all = "camelCase")]
232pub struct ToolCall {
233    pub function: CallFunction,
234    pub id: String,
235    pub index: usize,
236    #[serde(default)]
237    pub r#type: ToolType,
238}
239
240impl From<message::ToolCall> for ToolCall {
241    fn from(tool_call: message::ToolCall) -> Self {
242        Self {
243            id: tool_call.id,
244            // TODO: update index when we have it
245            index: 0,
246            r#type: ToolType::Function,
247            function: CallFunction {
248                name: tool_call.function.name,
249                arguments: tool_call.function.arguments,
250            },
251        }
252    }
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258    #[default]
259    Function,
260}
261
262#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
263pub struct CallFunction {
264    pub name: String,
265    #[serde(with = "json_utils::stringified_json")]
266    pub arguments: serde_json::Value,
267}
268
269#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
270#[serde(rename_all = "lowercase")]
271pub enum Role {
272    System,
273    User,
274    Assistant,
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278#[serde(rename_all = "camelCase")]
279pub struct Choice {
280    #[serde(rename = "finish_reason")]
281    pub finish_reason: String,
282    pub index: i64,
283    pub message: Message,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
287#[serde(rename_all = "camelCase")]
288pub struct Usage {
289    #[serde(rename = "completion_tokens")]
290    pub completion_tokens: i64,
291    #[serde(rename = "prompt_tokens")]
292    pub prompt_tokens: i64,
293    #[serde(rename = "total_tokens")]
294    pub total_tokens: i64,
295}
296
297impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
298    type Error = CompletionError;
299
300    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
301        let choice = response.choices.first().ok_or_else(|| {
302            CompletionError::ResponseError("Response contained no choices".to_owned())
303        })?;
304
305        match &choice.message {
306            Message::Assistant {
307                tool_calls,
308                content,
309            } => {
310                if !tool_calls.is_empty() {
311                    let tool_result = tool_calls
312                        .iter()
313                        .map(|call| {
314                            completion::AssistantContent::tool_call(
315                                &call.function.name,
316                                &call.function.name,
317                                call.function.arguments.clone(),
318                            )
319                        })
320                        .collect::<Vec<_>>();
321
322                    let choice = OneOrMany::many(tool_result).map_err(|_| {
323                        CompletionError::ResponseError(
324                            "Response contained no message or tool call (empty)".to_owned(),
325                        )
326                    })?;
327                    tracing::debug!("response choices: {:?}: ", choice);
328                    Ok(completion::CompletionResponse {
329                        choice,
330                        raw_response: response,
331                    })
332                } else {
333                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
334                        text: content.clone().unwrap_or_else(|| "".to_owned()),
335                    }));
336                    Ok(completion::CompletionResponse {
337                        choice,
338                        raw_response: response,
339                    })
340                }
341            }
342            // Message::Assistant { tool_calls } => {}
343            _ => Err(CompletionError::ResponseError(
344                "Chat response does not include an assistant message".into(),
345            )),
346        }
347    }
348}
349
350#[derive(Clone)]
351pub struct CompletionModel {
352    client: Client,
353    pub model: String,
354}
355
356// 函数定义
357#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
358#[serde(rename_all = "camelCase")]
359pub struct CustomFunctionDefinition {
360    #[serde(rename = "type")]
361    pub type_field: String,
362    pub function: Function,
363}
364
365#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
366#[serde(rename_all = "camelCase")]
367pub struct Function {
368    pub name: String,
369    pub description: String,
370    pub parameters: serde_json::Value,
371}
372
373impl CompletionModel {
374    pub fn new(client: Client, model: &str) -> Self {
375        Self {
376            client,
377            model: model.to_string(),
378        }
379    }
380
381    fn create_completion_request(
382        &self,
383        completion_request: CompletionRequest,
384    ) -> Result<Value, CompletionError> {
385        // Build up the order of messages (context, chat_history, prompt)
386        let mut partial_history = vec![];
387        if let Some(docs) = completion_request.normalized_documents() {
388            partial_history.push(docs);
389        }
390        partial_history.extend(completion_request.chat_history);
391
392        // Initialize full history with preamble (or empty if non-existent)
393        let mut full_history: Vec<Message> = completion_request
394            .preamble
395            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
396
397        // Convert and extend the rest of the history
398        full_history.extend(
399            partial_history
400                .into_iter()
401                .map(message::Message::try_into)
402                .collect::<Result<Vec<Message>, _>>()?,
403        );
404
405        let request = if completion_request.tools.is_empty() {
406            json!({
407                "model": self.model,
408                "messages": full_history,
409                "temperature": completion_request.temperature,
410            })
411        } else {
412            // tools
413            let tools = completion_request
414                .tools
415                .into_iter()
416                .map(|item| {
417                    let custom_function = Function {
418                        name: item.name,
419                        description: item.description,
420                        parameters: item.parameters,
421                    };
422                    CustomFunctionDefinition {
423                        type_field: "function".to_string(),
424                        function: custom_function,
425                    }
426                })
427                .collect::<Vec<_>>();
428
429            tracing::debug!("tools: {:?}", tools);
430
431            json!({
432                "model": self.model,
433                "messages": full_history,
434                "temperature": completion_request.temperature,
435                "tools": tools,
436                "tool_choice": "auto",
437            })
438        };
439
440        let request = if let Some(params) = completion_request.additional_params {
441            json_utils::merge(request, params)
442        } else {
443            request
444        };
445
446        Ok(request)
447    }
448}
449
450/// 同步请求
451impl completion::CompletionModel for CompletionModel {
452    type Response = CompletionResponse;
453    type StreamingResponse = openai::StreamingCompletionResponse;
454
455    async fn completion(
456        &self,
457        completion_request: CompletionRequest,
458    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
459        tracing::debug!("create_completion_request========");
460        let request = self.create_completion_request(completion_request)?;
461
462        tracing::debug!(
463            "request: \r\n {}",
464            serde_json::to_string_pretty(&request).unwrap()
465        );
466
467        let response = self
468            .client
469            .post("/chat/completions")
470            .json(&request)
471            .send()
472            .await?;
473
474        if response.status().is_success() {
475            let data: Value = response.json().await.expect("api error");
476            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
477            let data: ApiResponse<CompletionResponse> =
478                serde_json::from_value(data).expect("deserialize completion response");
479            match data {
480                ApiResponse::Ok(response) => {
481                    tracing::info!(target: "rig",
482                        "bigmodel completion token usage: {:?}",
483                        response.usage
484                    );
485                    response.try_into()
486                }
487                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
488            }
489        } else {
490            Err(CompletionError::ProviderError(response.text().await?))
491        }
492    }
493
494    async fn stream(
495        &self,
496        request: CompletionRequest,
497    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
498        let mut request = self.create_completion_request(request)?;
499
500        request = json_utils::merge(request, json!({"stream": true}));
501
502        let builder = self.client.post("/chat/completions").json(&request);
503
504        send_compatible_streaming_request(builder).await
505    }
506}