Skip to main content

rig_extra/extra_providers/
bigmodel.rs

1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6    BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19// use rig::providers::openai::{Message as OpenAIMessage};
20
21const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34    client: Client<T>,
35    pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40        Self {
41            client,
42            model: model.into(),
43        }
44    }
45
46    fn create_completion_request(
47        &self,
48        completion_request: CompletionRequest,
49    ) -> Result<Value, CompletionError> {
50        // 构建消息顺序(上下文、聊天历史、提示)
51        let mut partial_history = vec![];
52        if let Some(docs) = completion_request.normalized_documents() {
53            partial_history.push(docs);
54        }
55        partial_history.extend(completion_request.chat_history);
56
57        // 使用前言初始化完整历史(如果不存在则为空)
58        let mut full_history: Vec<Message> = completion_request
59            .preamble
60            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62        // 转换并扩展其余历史
63        full_history.extend(
64            partial_history
65                .into_iter()
66                .map(message::Message::try_into)
67                .collect::<Result<Vec<Message>, _>>()?,
68        );
69
70        let request = if completion_request.tools.is_empty() {
71            json!({
72                "model": self.model,
73                "messages": full_history,
74                "temperature": completion_request.temperature,
75            })
76        } else {
77            // tools
78            let tools = completion_request
79                .tools
80                .into_iter()
81                .map(|item| {
82                    let custom_function = Function {
83                        name: item.name,
84                        description: item.description,
85                        parameters: item.parameters,
86                    };
87                    CustomFunctionDefinition {
88                        type_field: "function".to_string(),
89                        function: custom_function,
90                    }
91                })
92                .collect::<Vec<_>>();
93
94            tracing::debug!("tools: {:?}", tools);
95
96            json!({
97                "model": self.model,
98                "messages": full_history,
99                "temperature": completion_request.temperature,
100                "tools": tools,
101                "tool_choice": "auto",
102            })
103        };
104
105        let request = if let Some(params) = completion_request.additional_params {
106            json_utils::merge(request, params)
107        } else {
108            request
109        };
110
111        Ok(request)
112    }
113}
114
115impl Provider for BigmodelExt {
116    const VERIFY_PATH: &'static str = "api/tags";
117    type Builder = BigmodelBuilder;
118
119    fn build<H>(
120        _builder: &client::ClientBuilder<
121            Self::Builder,
122            <Self::Builder as ProviderBuilder>::ApiKey,
123            H,
124        >,
125    ) -> rig::http_client::Result<Self> {
126        Ok(Self)
127    }
128}
129
130/// provider有那些功能
131impl<H> Capabilities<H> for BigmodelExt {
132    type Completion = Capable<CompletionModel<H>>;
133    type Embeddings = Nothing;
134    type Transcription = Nothing;
135    type ModelListing = Nothing;
136
137    // #[cfg(feature = "image")]
138    // type ImageGeneration = Nothing;
139    //
140    // #[cfg(feature = "audio")]
141    // type AudioGeneration = Nothing;
142}
143
144impl DebugExt for BigmodelExt {}
145
146impl ProviderBuilder for BigmodelBuilder {
147    type Output = BigmodelExt;
148    type ApiKey = BigmodelApiKey;
149    const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
150}
151
152pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
153pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
154
155/// Rust 的孤儿规则(Orphan Rule)
156// impl ProviderClient for Client{
157//     type Input = String;
158//
159//     fn from_env() -> Self {
160//         let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_API_KEY not set");
161//         Self::new(&api_key).unwrap()
162//     }
163//
164//     fn from_val(input: Self::Input) -> Self {
165//         Self::new(&input).unwrap()
166//     }
167// }
168
169// ---------- API Error and Response Structures ----------
170#[derive(Debug, Deserialize)]
171struct ApiErrorResponse {
172    message: String,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(untagged)]
177enum ApiResponse<T> {
178    Ok(T),
179    Err(ApiErrorResponse),
180}
181
182// ================================================================
183// Bigmodel Completion API
184// ================================================================
185
186pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
187
188#[derive(Debug, Deserialize, Serialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CompletionResponse {
191    pub choices: Vec<Choice>,
192    pub created: i64,
193    pub id: String,
194    pub model: String,
195    #[serde(rename = "request_id")]
196    pub request_id: String,
197    pub usage: Option<Usage>,
198}
199
200#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
201#[serde(tag = "role", rename_all = "lowercase")]
202pub enum Message {
203    User {
204        content: String,
205    },
206    Assistant {
207        content: Option<String>,
208        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
209        tool_calls: Vec<ToolCall>,
210    },
211    System {
212        content: String,
213    },
214    #[serde(rename = "tool")]
215    ToolResult {
216        tool_call_id: String,
217        content: String,
218    },
219}
220
221impl Message {
222    pub fn system(content: &str) -> Message {
223        Message::System {
224            content: content.to_owned(),
225        }
226    }
227}
228
229#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
230pub struct ToolResultContent {
231    text: String,
232}
233impl TryFrom<message::ToolResultContent> for ToolResultContent {
234    type Error = MessageError;
235    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
236        let message::ToolResultContent::Text(Text { text }) = value else {
237            return Err(MessageError::ConversionError(
238                "Non-text tool results not supported".into(),
239            ));
240        };
241
242        Ok(Self { text })
243    }
244}
245
246impl TryFrom<message::Message> for Message {
247    type Error = MessageError;
248
249    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
250        Ok(match message {
251            message::Message::User { content } => {
252                let mut texts = Vec::new();
253                let mut images = Vec::new();
254
255                for uc in content.into_iter() {
256                    match uc {
257                        message::UserContent::Text(message::Text { text }) => texts.push(text),
258                        message::UserContent::Image(img) => images.push(img.data),
259                        message::UserContent::ToolResult(result) => {
260                            let content = result
261                                .content
262                                .into_iter()
263                                .map(ToolResultContent::try_from)
264                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
265
266                            let content = OneOrMany::many(content).map_err(|x| {
267                                MessageError::ConversionError(format!(
268                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
269                                ))
270                            })?;
271
272                            return Ok(Message::ToolResult {
273                                tool_call_id: result.id,
274                                content: content.first().text,
275                            });
276                        }
277                        _ => {}
278                    }
279                }
280
281                let collapsed_content = texts.join(" ");
282
283                Message::User {
284                    content: collapsed_content,
285                }
286            }
287            message::Message::Assistant { content, .. } => {
288                let mut texts = Vec::new();
289                let mut tool_calls = Vec::new();
290
291                for ac in content.into_iter() {
292                    match ac {
293                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
294                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
295                        _ => {}
296                    }
297                }
298
299                let collapsed_content = texts.join(" ");
300
301                Message::Assistant {
302                    content: Some(collapsed_content),
303                    tool_calls,
304                }
305            }
306        })
307    }
308}
309
310impl From<message::ToolResult> for Message {
311    fn from(tool_result: message::ToolResult) -> Self {
312        let content = match tool_result.content.first() {
313            message::ToolResultContent::Text(text) => text.text,
314            message::ToolResultContent::Image(_) => String::from("[Image]"),
315        };
316
317        Message::ToolResult {
318            tool_call_id: tool_result.id,
319            content,
320        }
321    }
322}
323
324#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
325#[serde(rename_all = "camelCase")]
326pub struct ToolCall {
327    pub function: CallFunction,
328    pub id: String,
329    pub index: usize,
330    #[serde(default)]
331    pub r#type: ToolType,
332}
333
334impl From<message::ToolCall> for ToolCall {
335    fn from(tool_call: message::ToolCall) -> Self {
336        Self {
337            id: tool_call.id,
338            index: 0,
339            r#type: ToolType::Function,
340            function: CallFunction {
341                name: tool_call.function.name,
342                arguments: tool_call.function.arguments,
343            },
344        }
345    }
346}
347
348#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
349#[serde(rename_all = "lowercase")]
350pub enum ToolType {
351    #[default]
352    Function,
353}
354
355#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
356pub struct CallFunction {
357    pub name: String,
358    #[serde(with = "json_utils::stringified_json")]
359    pub arguments: serde_json::Value,
360}
361
362#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
363#[serde(rename_all = "lowercase")]
364pub enum Role {
365    System,
366    User,
367    Assistant,
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371#[serde(rename_all = "camelCase")]
372pub struct Choice {
373    #[serde(rename = "finish_reason")]
374    pub finish_reason: String,
375    pub index: i64,
376    pub message: Message,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380#[serde(rename_all = "camelCase")]
381pub struct Usage {
382    #[serde(rename = "completion_tokens")]
383    pub completion_tokens: i64,
384    #[serde(rename = "prompt_tokens")]
385    pub prompt_tokens: i64,
386    #[serde(rename = "total_tokens")]
387    pub total_tokens: i64,
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub prompt_tokens_details: Option<PromptTokensDetails>,
390}
391
392#[derive(Clone, Debug, Deserialize, Serialize, Default)]
393pub struct PromptTokensDetails {
394    /// Cached tokens from prompt caching
395    #[serde(default)]
396    pub cached_tokens: usize,
397}
398
399impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
400    type Error = CompletionError;
401
402    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
403        let choice = response.choices.first().ok_or_else(|| {
404            CompletionError::ResponseError("Response contained no choices".to_owned())
405        })?;
406
407        match &choice.message {
408            Message::Assistant {
409                tool_calls,
410                content,
411            } => {
412                if !tool_calls.is_empty() {
413                    let tool_result = tool_calls
414                        .iter()
415                        .map(|call| {
416                            completion::AssistantContent::tool_call(
417                                &call.function.name,
418                                &call.function.name,
419                                call.function.arguments.clone(),
420                            )
421                        })
422                        .collect::<Vec<_>>();
423
424                    let choice = OneOrMany::many(tool_result).map_err(|_| {
425                        CompletionError::ResponseError(
426                            "Response contained no message or tool call (empty)".to_owned(),
427                        )
428                    })?;
429                    let usage = response
430                        .usage
431                        .as_ref()
432                        .map(|usage| completion::Usage {
433                            input_tokens: usage.prompt_tokens as u64,
434                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
435                            total_tokens: usage.total_tokens as u64,
436                            cached_input_tokens: usage
437                                .prompt_tokens_details
438                                .as_ref()
439                                .map(|d| d.cached_tokens as u64)
440                                .unwrap_or(0),
441                        })
442                        .unwrap_or_default();
443                    tracing::debug!("response choices: {:?}: ", choice);
444                    Ok(completion::CompletionResponse {
445                        choice,
446                        usage,
447                        raw_response: response,
448                        message_id: None,
449                    })
450                } else {
451                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
452                        text: content.clone().unwrap_or_else(|| "".to_owned()),
453                    }));
454                    let usage = response
455                        .usage
456                        .as_ref()
457                        .map(|usage| completion::Usage {
458                            input_tokens: usage.prompt_tokens as u64,
459                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
460                            total_tokens: usage.total_tokens as u64,
461                            cached_input_tokens: usage
462                                .prompt_tokens_details
463                                .as_ref()
464                                .map(|d| d.cached_tokens as u64)
465                                .unwrap_or(0),
466                        })
467                        .unwrap_or_default();
468                    Ok(completion::CompletionResponse {
469                        choice,
470                        usage,
471                        raw_response: response,
472                        message_id: None,
473                    })
474                }
475            }
476            // Message::Assistant { tool_calls } => {}
477            _ => Err(CompletionError::ResponseError(
478                "Chat response does not include an assistant message".into(),
479            )),
480        }
481    }
482}
483
484// 函数定义
485#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
486#[serde(rename_all = "camelCase")]
487pub struct CustomFunctionDefinition {
488    #[serde(rename = "type")]
489    pub type_field: String,
490    pub function: Function,
491}
492
493#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
494#[serde(rename_all = "camelCase")]
495pub struct Function {
496    pub name: String,
497    pub description: String,
498    pub parameters: serde_json::Value,
499}
500
501/// 同步请求
502impl<T> completion::CompletionModel for CompletionModel<T>
503where
504    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
505{
506    type Response = CompletionResponse;
507    type StreamingResponse = openai::StreamingCompletionResponse;
508    type Client = Client<T>;
509
510    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511        Self::new(client.clone(), model.into())
512    }
513
514    async fn completion(
515        &self,
516        completion_request: CompletionRequest,
517    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
518        let span = if tracing::Span::current().is_disabled() {
519            info_span!(
520                target: "rig::completions",
521                "chat",
522                gen_ai.operation.name = "chat",
523                gen_ai.provider.name = "groq",
524                gen_ai.request.model = self.model,
525                gen_ai.system_instructions = tracing::field::Empty,
526                gen_ai.response.id = tracing::field::Empty,
527                gen_ai.response.model = tracing::field::Empty,
528                gen_ai.usage.output_tokens = tracing::field::Empty,
529                gen_ai.usage.input_tokens = tracing::field::Empty,
530            )
531        } else {
532            tracing::Span::current()
533        };
534
535        span.record("gen_ai.system_instructions", &completion_request.preamble);
536
537        let request = self.create_completion_request(completion_request)?;
538
539        if tracing::enabled!(tracing::Level::TRACE) {
540            tracing::trace!(target: "rig::completions",
541                "Groq completion request: {}",
542                serde_json::to_string_pretty(&request)?
543            );
544        }
545
546        let body = serde_json::to_vec(&request)?;
547        let req = self
548            .client
549            .post("/chat/completions")?
550            .body(body)
551            .map_err(|e| http_client::Error::Instance(e.into()))?;
552
553        let async_block = async move {
554            let response = self.client.send::<_, Bytes>(req).await?;
555            let status = response.status();
556            let response_body = response.into_body().into_future().await?.to_vec();
557
558            let tt = response_body.clone();
559            let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
560            println!(
561                "response:\r\n {}",
562                serde_json::to_string_pretty(&response).unwrap()
563            );
564
565            if status.is_success() {
566                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
567                    ApiResponse::Ok(response) => {
568                        let span = tracing::Span::current();
569                        span.record("gen_ai.response.id", response.id.clone());
570                        span.record("gen_ai.response.model_name", response.model.clone());
571                        if let Some(ref usage) = response.usage {
572                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
573                            span.record(
574                                "gen_ai.usage.output_tokens",
575                                usage.total_tokens - usage.prompt_tokens,
576                            );
577                        }
578
579                        if tracing::enabled!(tracing::Level::TRACE) {
580                            tracing::trace!(target: "rig::completions",
581                                "Groq completion response: {}",
582                                serde_json::to_string_pretty(&response)?
583                            );
584                        }
585
586                        response.try_into()
587                    }
588                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
589                }
590            } else {
591                Err(CompletionError::ProviderError(
592                    String::from_utf8_lossy(&response_body).to_string(),
593                ))
594            }
595        };
596
597        tracing::Instrument::instrument(async_block, span).await
598    }
599
600    async fn stream(
601        &self,
602        request: CompletionRequest,
603    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
604        let preamble = request.preamble.clone();
605
606        let mut request = self.create_completion_request(request)?;
607
608        request = merge(request, json!({"stream": true}));
609
610        let body = serde_json::to_vec(&request)?;
611
612        let req = self
613            .client
614            .post("/chat/completions")?
615            .body(body)
616            .map_err(|e| http_client::Error::Instance(e.into()))?;
617
618        let span = if tracing::Span::current().is_disabled() {
619            info_span!(
620                target: "rig::completions",
621                "chat_streaming",
622                gen_ai.operation.name = "chat_streaming",
623                gen_ai.provider.name = "galadriel",
624                gen_ai.request.model = self.model,
625                gen_ai.system_instructions = preamble,
626                gen_ai.response.id = tracing::field::Empty,
627                gen_ai.response.model = tracing::field::Empty,
628                gen_ai.usage.output_tokens = tracing::field::Empty,
629                gen_ai.usage.input_tokens = tracing::field::Empty,
630                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
631                gen_ai.output.messages = tracing::field::Empty,
632            )
633        } else {
634            tracing::Span::current()
635        };
636
637        send_compatible_streaming_request(self.client.clone(), req)
638            .instrument(span)
639            .await
640    }
641}