rig_extra/extra_providers/
bigmodel.rs

1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16// ================================================================
17// Main BIGMODEL Client
18// ================================================================
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    base_url: String,
24    http_client: reqwest::Client,
25}
26
27impl Client {
28    pub fn new(api_key: &str) -> Self {
29        Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30    }
31
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            http_client: reqwest::Client::builder()
36                .default_headers({
37                    let mut headers = reqwest::header::HeaderMap::new();
38                    headers.insert(
39                        "Authorization",
40                        format!("Bearer {api_key}")
41                            .parse()
42                            .expect("Bearer token should parse"),
43                    );
44                    headers
45                })
46                .build()
47                .expect("bigmodel reqwest client should build"),
48        }
49    }
50
51
52
53    fn post(&self, path: &str) -> reqwest::RequestBuilder {
54        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55        self.http_client.post(url)
56    }
57
58    pub fn completion_model(&self, model: &str) -> CompletionModel {
59        CompletionModel::new(self.clone(), model)
60    }
61
62    // pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
63    //     AgentBuilder::new(self.completion_model(model))
64    // }
65
66    /// Create an extractor builder with the given completion model.
67    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
68        &self,
69        model: &str,
70    ) -> ExtractorBuilder<T, CompletionModel> {
71        ExtractorBuilder::new(self.completion_model(model))
72    }
73}
74
75impl ProviderClient for Client {
76    fn from_env() -> Self
77    where
78        Self: Sized
79    {
80        let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
81        Self::new(&api_key)
82    }
83}
84
85impl AsTranscription for Client {}
86
87impl AsEmbeddings for Client {}
88
89
90
91impl CompletionClient for Client {
92    type CompletionModel = CompletionModel;
93
94    fn completion_model(&self, model: &str) -> Self::CompletionModel {
95        CompletionModel::new(self.clone(), model)
96    }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101    message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107    Ok(T),
108    Err(ApiErrorResponse),
109}
110
111// ================================================================
112// Bigmodel Completion API
113// ================================================================
114pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
115
116#[derive(Debug, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct CompletionResponse {
119    pub choices: Vec<Choice>,
120    pub created: i64,
121    pub id: String,
122    pub model: String,
123    #[serde(rename = "request_id")]
124    pub request_id: String,
125    pub usage: Usage,
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129#[serde(tag = "role", rename_all = "lowercase")]
130pub enum Message {
131    User {
132        content: String,
133    },
134    Assistant {
135        content: Option<String>,
136        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
137        tool_calls: Vec<ToolCall>,
138    },
139    System {
140        content: String,
141    },
142    #[serde(rename = "tool")]
143    ToolResult {
144        tool_call_id: String,
145        content: String,
146    },
147}
148
149impl Message {
150    pub fn system(content: &str) -> Message {
151        Message::System {
152            content: content.to_owned(),
153        }
154    }
155}
156
157#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
158pub struct ToolResultContent {
159    text: String,
160}
161impl TryFrom<message::ToolResultContent> for ToolResultContent {
162    type Error = MessageError;
163    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
164        let message::ToolResultContent::Text(Text { text }) = value else {
165            return Err(MessageError::ConversionError(
166                "Non-text tool results not supported".into(),
167            ));
168        };
169
170        Ok(Self { text })
171    }
172}
173
174impl TryFrom<message::Message> for Message {
175    type Error = MessageError;
176
177    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
178        Ok(match message {
179            message::Message::User { content } => {
180                let mut texts = Vec::new();
181                let mut images = Vec::new();
182
183                for uc in content.into_iter() {
184                    match uc {
185                        message::UserContent::Text(message::Text { text }) => texts.push(text),
186                        message::UserContent::Image(img) => images.push(img.data),
187                        message::UserContent::ToolResult(result) => {
188                            let content = result
189                                .content
190                                .into_iter()
191                                .map(ToolResultContent::try_from)
192                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
193
194                            let content = OneOrMany::many(content).map_err(|x| {
195                                MessageError::ConversionError(format!(
196                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
197                                ))
198                            })?;
199
200                            return Ok(Message::ToolResult {
201                                tool_call_id: result.id,
202                                content: content.first().text,
203                            });
204                        }
205                        _ => {}
206                    }
207                }
208
209                let collapsed_content = texts.join(" ");
210
211                Message::User {
212                    content: collapsed_content,
213                }
214            }
215            message::Message::Assistant { content, .. } => {
216                let mut texts = Vec::new();
217                let mut tool_calls = Vec::new();
218
219                for ac in content.into_iter() {
220                    match ac {
221                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
222                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
223                    }
224                }
225
226                let collapsed_content = texts.join(" ");
227
228                Message::Assistant {
229                    content: Some(collapsed_content),
230                    tool_calls,
231                }
232            }
233        })
234    }
235}
236
237impl From<message::ToolResult> for Message {
238    fn from(tool_result: message::ToolResult) -> Self {
239        let content = match tool_result.content.first() {
240            message::ToolResultContent::Text(text) => text.text,
241            message::ToolResultContent::Image(_) => String::from("[Image]"),
242        };
243
244        Message::ToolResult {
245            tool_call_id: tool_result.id,
246            content,
247        }
248    }
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct ToolCall {
254    pub function: CallFunction,
255    pub id: String,
256    pub index: usize,
257    #[serde(default)]
258    pub r#type: ToolType,
259}
260
261impl From<message::ToolCall> for ToolCall {
262    fn from(tool_call: message::ToolCall) -> Self {
263        Self {
264            id: tool_call.id,
265            // TODO: update index when we have it
266            index: 0,
267            r#type: ToolType::Function,
268            function: CallFunction {
269                name: tool_call.function.name,
270                arguments: tool_call.function.arguments,
271            },
272        }
273    }
274}
275
276#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
277#[serde(rename_all = "lowercase")]
278pub enum ToolType {
279    #[default]
280    Function,
281}
282
283#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
284pub struct CallFunction {
285    pub name: String,
286    #[serde(with = "json_utils::stringified_json")]
287    pub arguments: serde_json::Value,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
291#[serde(rename_all = "lowercase")]
292pub enum Role {
293    System,
294    User,
295    Assistant,
296}
297
298#[derive(Debug, Serialize, Deserialize)]
299#[serde(rename_all = "camelCase")]
300pub struct Choice {
301    #[serde(rename = "finish_reason")]
302    pub finish_reason: String,
303    pub index: i64,
304    pub message: Message,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(rename_all = "camelCase")]
309pub struct Usage {
310    #[serde(rename = "completion_tokens")]
311    pub completion_tokens: i64,
312    #[serde(rename = "prompt_tokens")]
313    pub prompt_tokens: i64,
314    #[serde(rename = "total_tokens")]
315    pub total_tokens: i64,
316}
317
318impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
319    type Error = CompletionError;
320
321    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
322        let choice = response.choices.first().ok_or_else(|| {
323            CompletionError::ResponseError("Response contained no choices".to_owned())
324        })?;
325
326        match &choice.message {
327            Message::Assistant {
328                tool_calls,
329                content,
330            } => {
331                if !tool_calls.is_empty() {
332                    let tool_result = tool_calls
333                        .iter()
334                        .map(|call| {
335                            completion::AssistantContent::tool_call(
336                                &call.function.name,
337                                &call.function.name,
338                                call.function.arguments.clone(),
339                            )
340                        })
341                        .collect::<Vec<_>>();
342
343                    let choice = OneOrMany::many(tool_result).map_err(|_| {
344                        CompletionError::ResponseError(
345                            "Response contained no message or tool call (empty)".to_owned(),
346                        )
347                    })?;
348                    // let usage = completion::Usage {
349                    //     input_tokens: response.usage.prompt_tokens as u64,
350                    //     output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
351                    //         as u64,
352                    //     total_tokens: response.usage.total_tokens as u64,
353                    // };
354                    tracing::debug!("response choices: {:?}: ", choice);
355                    Ok(completion::CompletionResponse {
356                        choice,
357                        // usage,
358                        raw_response: response,
359                    })
360                } else {
361                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
362                        text: content.clone().unwrap_or_else(|| "".to_owned()),
363                    }));
364                    // let usage = completion::Usage {
365                    //     input_tokens: response.usage.prompt_tokens as u64,
366                    //     output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
367                    //         as u64,
368                    //     total_tokens: response.usage.total_tokens as u64,
369                    // };
370                    Ok(completion::CompletionResponse {
371                        choice,
372                        // usage,
373                        raw_response: response,
374                    })
375                }
376            }
377            // Message::Assistant { tool_calls } => {}
378            _ => Err(CompletionError::ResponseError(
379                "Chat response does not include an assistant message".into(),
380            )),
381        }
382    }
383}
384
385#[derive(Clone)]
386pub struct CompletionModel {
387    client: Client,
388    pub model: String,
389}
390
391// impl completion::CompletionModelDyn for CompletionModel {
392//     async fn completion(&self, request: CompletionRequest) -> Result<completion::CompletionResponse<()>, CompletionError> {
393//
394//     }
395//
396//     async fn stream(&self, request: CompletionRequest) -> Result<StreamingCompletionResponse<()>, CompletionError> {
397//
398//     }
399//
400//     fn completion_request(&self, prompt: completion::Message) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
401//         todo!()
402//     }
403// }
404
405// 函数定义
406#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
407#[serde(rename_all = "camelCase")]
408pub struct CustomFunctionDefinition {
409    #[serde(rename = "type")]
410    pub type_field: String,
411    pub function: Function,
412}
413
414#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct Function {
417    pub name: String,
418    pub description: String,
419    pub parameters: serde_json::Value,
420}
421
422impl CompletionModel {
423    pub fn new(client: Client, model: &str) -> Self {
424        Self {
425            client,
426            model: model.to_string(),
427        }
428    }
429
430    fn create_completion_request(
431        &self,
432        completion_request: CompletionRequest,
433    ) -> Result<Value, CompletionError> {
434        // Build up the order of messages (context, chat_history, prompt)
435        let mut partial_history = vec![];
436        if let Some(docs) = completion_request.normalized_documents() {
437            partial_history.push(docs);
438        }
439        partial_history.extend(completion_request.chat_history);
440
441        // Initialize full history with preamble (or empty if non-existent)
442        let mut full_history: Vec<Message> = completion_request
443            .preamble
444            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
445
446        // Convert and extend the rest of the history
447        full_history.extend(
448            partial_history
449                .into_iter()
450                .map(message::Message::try_into)
451                .collect::<Result<Vec<Message>, _>>()?,
452        );
453
454        let request = if completion_request.tools.is_empty() {
455            json!({
456                "model": self.model,
457                "messages": full_history,
458                "temperature": completion_request.temperature,
459            })
460        } else {
461            // tools
462            let tools = completion_request
463                .tools
464                .into_iter()
465                .map(|item| {
466                    let custom_function = Function {
467                        name: item.name,
468                        description: item.description,
469                        parameters: item.parameters,
470                    };
471                    CustomFunctionDefinition {
472                        type_field: "function".to_string(),
473                        function: custom_function,
474                    }
475                })
476                .collect::<Vec<_>>();
477
478            tracing::debug!("tools: {:?}", tools);
479
480            json!({
481                "model": self.model,
482                "messages": full_history,
483                "temperature": completion_request.temperature,
484                "tools": tools,
485                "tool_choice": "auto",
486            })
487        };
488
489        let request = if let Some(params) = completion_request.additional_params {
490            json_utils::merge(request, params)
491        } else {
492            request
493        };
494
495        Ok(request)
496    }
497}
498
499/// 同步请求
500impl completion::CompletionModel for CompletionModel {
501    type Response = CompletionResponse;
502    type StreamingResponse = openai::StreamingCompletionResponse;
503
504    async fn completion(
505        &self,
506        completion_request: CompletionRequest,
507    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
508        tracing::debug!("create_completion_request========");
509        let request = self.create_completion_request(completion_request)?;
510
511        tracing::debug!(
512            "request: \r\n {}",
513            serde_json::to_string_pretty(&request).unwrap()
514        );
515
516        let response = self
517            .client
518            .post("/chat/completions")
519            .json(&request)
520            .send()
521            .await?;
522
523        if response.status().is_success() {
524            let data: Value = response.json().await.expect("api error");
525            tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
526            let data: ApiResponse<CompletionResponse> =
527                serde_json::from_value(data).expect("deserialize completion response");
528            match data {
529                ApiResponse::Ok(response) => {
530                    tracing::info!(target: "rig",
531                        "bigmodel completion token usage: {:?}",
532                        response.usage
533                    );
534                    response.try_into()
535                }
536                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
537            }
538        } else {
539            Err(CompletionError::ProviderError(response.text().await?))
540        }
541    }
542
543    async fn stream(
544        &self,
545        request: CompletionRequest,
546    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
547        let mut request = self.create_completion_request(request)?;
548
549        request = json_utils::merge(request, json!({"stream": true}));
550
551        let builder = self.client.post("/chat/completions").json(&request);
552
553        send_compatible_streaming_request(builder).await
554    }
555}
556
557