Skip to main content

rig_extra/extra_providers/
bigmodel.rs

1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6    BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19// use rig::providers::openai::{Message as OpenAIMessage};
20
21const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34    client: Client<T>,
35    pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40        Self {
41            client,
42            model: model.into(),
43        }
44    }
45
46    fn create_completion_request(
47        &self,
48        completion_request: CompletionRequest,
49    ) -> Result<Value, CompletionError> {
50        // 构建消息顺序(上下文、聊天历史、提示)
51        let mut partial_history = vec![];
52        if let Some(docs) = completion_request.normalized_documents() {
53            partial_history.push(docs);
54        }
55        partial_history.extend(completion_request.chat_history);
56
57        // 使用前言初始化完整历史(如果不存在则为空)
58        let mut full_history: Vec<Message> = completion_request
59            .preamble
60            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62        // 转换并扩展其余历史
63        full_history.extend(
64            partial_history
65                .into_iter()
66                .map(message::Message::try_into)
67                .collect::<Result<Vec<Message>, _>>()?,
68        );
69
70        let request = if completion_request.tools.is_empty() {
71            json!({
72                "model": self.model,
73                "messages": full_history,
74                "temperature": completion_request.temperature,
75            })
76        } else {
77            // tools
78            let tools = completion_request
79                .tools
80                .into_iter()
81                .map(|item| {
82                    let custom_function = Function {
83                        name: item.name,
84                        description: item.description,
85                        parameters: item.parameters,
86                    };
87                    CustomFunctionDefinition {
88                        type_field: "function".to_string(),
89                        function: custom_function,
90                    }
91                })
92                .collect::<Vec<_>>();
93
94            tracing::debug!("tools: {:?}", tools);
95
96            json!({
97                "model": self.model,
98                "messages": full_history,
99                "temperature": completion_request.temperature,
100                "tools": tools,
101                "tool_choice": "auto",
102            })
103        };
104
105        let request = if let Some(params) = completion_request.additional_params {
106            json_utils::merge(request, params)
107        } else {
108            request
109        };
110
111        Ok(request)
112    }
113}
114
115impl Provider for BigmodelExt {
116    const VERIFY_PATH: &'static str = "api/tags";
117    type Builder = BigmodelBuilder;
118}
119
120/// provider有那些功能
121impl<H> Capabilities<H> for BigmodelExt {
122    type Completion = Capable<CompletionModel<H>>;
123    type Embeddings = Nothing;
124    type Transcription = Nothing;
125    type ModelListing = Nothing;
126
127    // #[cfg(feature = "image")]
128    // type ImageGeneration = Nothing;
129    //
130    // #[cfg(feature = "audio")]
131    // type AudioGeneration = Nothing;
132}
133
134impl DebugExt for BigmodelExt {}
135
136impl ProviderBuilder for BigmodelBuilder {
137    type Extension<H>
138        = BigmodelExt
139    where
140        H: HttpClientExt;
141    type ApiKey = BigmodelApiKey;
142    const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
143
144    fn build<H>(
145        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
146    ) -> http_client::Result<Self::Extension<H>>
147    where
148        H: HttpClientExt,
149    {
150        Ok(BigmodelExt)
151    }
152}
153
154pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
155pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
156
157/// Rust 的孤儿规则(Orphan Rule)
158// impl ProviderClient for Client{
159//     type Input = String;
160//
161//     fn from_env() -> Self {
162//         let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_API_KEY not set");
163//         Self::new(&api_key).unwrap()
164//     }
165//
166//     fn from_val(input: Self::Input) -> Self {
167//         Self::new(&input).unwrap()
168//     }
169// }
170
171// ---------- API Error and Response Structures ----------
172#[derive(Debug, Deserialize)]
173struct ApiErrorResponse {
174    message: String,
175}
176
177#[derive(Debug, Deserialize)]
178#[serde(untagged)]
179enum ApiResponse<T> {
180    Ok(T),
181    Err(ApiErrorResponse),
182}
183
184// ================================================================
185// Bigmodel Completion API
186// ================================================================
187
188pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
189
190#[derive(Debug, Deserialize, Serialize)]
191#[serde(rename_all = "camelCase")]
192pub struct CompletionResponse {
193    pub choices: Vec<Choice>,
194    pub created: i64,
195    pub id: String,
196    pub model: String,
197    #[serde(rename = "request_id")]
198    pub request_id: String,
199    pub usage: Option<Usage>,
200}
201
202#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
203#[serde(tag = "role", rename_all = "lowercase")]
204pub enum Message {
205    User {
206        content: String,
207    },
208    Assistant {
209        content: Option<String>,
210        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
211        tool_calls: Vec<ToolCall>,
212    },
213    System {
214        content: String,
215    },
216    #[serde(rename = "tool")]
217    ToolResult {
218        tool_call_id: String,
219        content: String,
220    },
221}
222
223impl Message {
224    pub fn system(content: &str) -> Message {
225        Message::System {
226            content: content.to_owned(),
227        }
228    }
229}
230
231#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
232pub struct ToolResultContent {
233    text: String,
234}
235impl TryFrom<message::ToolResultContent> for ToolResultContent {
236    type Error = MessageError;
237    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
238        let message::ToolResultContent::Text(Text { text }) = value else {
239            return Err(MessageError::ConversionError(
240                "Non-text tool results not supported".into(),
241            ));
242        };
243
244        Ok(Self { text })
245    }
246}
247
248impl TryFrom<message::Message> for Message {
249    type Error = MessageError;
250
251    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
252        Ok(match message {
253            message::Message::User { content } => {
254                let mut texts = Vec::new();
255                let mut images = Vec::new();
256
257                for uc in content.into_iter() {
258                    match uc {
259                        message::UserContent::Text(message::Text { text }) => texts.push(text),
260                        message::UserContent::Image(img) => images.push(img.data),
261                        message::UserContent::ToolResult(result) => {
262                            let content = result
263                                .content
264                                .into_iter()
265                                .map(ToolResultContent::try_from)
266                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
267
268                            let content = OneOrMany::many(content).map_err(|x| {
269                                MessageError::ConversionError(format!(
270                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
271                                ))
272                            })?;
273
274                            return Ok(Message::ToolResult {
275                                tool_call_id: result.id,
276                                content: content.first().text,
277                            });
278                        }
279                        _ => {}
280                    }
281                }
282
283                let collapsed_content = texts.join(" ");
284
285                Message::User {
286                    content: collapsed_content,
287                }
288            }
289            message::Message::Assistant { content, .. } => {
290                let mut texts = Vec::new();
291                let mut tool_calls = Vec::new();
292
293                for ac in content.into_iter() {
294                    match ac {
295                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
296                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
297                        _ => {}
298                    }
299                }
300
301                let collapsed_content = texts.join(" ");
302
303                Message::Assistant {
304                    content: Some(collapsed_content),
305                    tool_calls,
306                }
307            }
308        })
309    }
310}
311
312impl From<message::ToolResult> for Message {
313    fn from(tool_result: message::ToolResult) -> Self {
314        let content = match tool_result.content.first() {
315            message::ToolResultContent::Text(text) => text.text,
316            message::ToolResultContent::Image(_) => String::from("[Image]"),
317        };
318
319        Message::ToolResult {
320            tool_call_id: tool_result.id,
321            content,
322        }
323    }
324}
325
326#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
327#[serde(rename_all = "camelCase")]
328pub struct ToolCall {
329    pub function: CallFunction,
330    pub id: String,
331    pub index: usize,
332    #[serde(default)]
333    pub r#type: ToolType,
334}
335
336impl From<message::ToolCall> for ToolCall {
337    fn from(tool_call: message::ToolCall) -> Self {
338        Self {
339            id: tool_call.id,
340            index: 0,
341            r#type: ToolType::Function,
342            function: CallFunction {
343                name: tool_call.function.name,
344                arguments: tool_call.function.arguments,
345            },
346        }
347    }
348}
349
350#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
351#[serde(rename_all = "lowercase")]
352pub enum ToolType {
353    #[default]
354    Function,
355}
356
357#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
358pub struct CallFunction {
359    pub name: String,
360    #[serde(with = "json_utils::stringified_json")]
361    pub arguments: serde_json::Value,
362}
363
364#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
365#[serde(rename_all = "lowercase")]
366pub enum Role {
367    System,
368    User,
369    Assistant,
370}
371
372#[derive(Debug, Serialize, Deserialize)]
373#[serde(rename_all = "camelCase")]
374pub struct Choice {
375    #[serde(rename = "finish_reason")]
376    pub finish_reason: String,
377    pub index: i64,
378    pub message: Message,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
382#[serde(rename_all = "camelCase")]
383pub struct Usage {
384    #[serde(rename = "completion_tokens")]
385    pub completion_tokens: i64,
386    #[serde(rename = "prompt_tokens")]
387    pub prompt_tokens: i64,
388    #[serde(rename = "total_tokens")]
389    pub total_tokens: i64,
390    #[serde(skip_serializing_if = "Option::is_none")]
391    pub prompt_tokens_details: Option<PromptTokensDetails>,
392}
393
394#[derive(Clone, Debug, Deserialize, Serialize, Default)]
395pub struct PromptTokensDetails {
396    /// Cached tokens from prompt caching
397    #[serde(default)]
398    pub cached_tokens: usize,
399}
400
401impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
402    type Error = CompletionError;
403
404    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
405        let choice = response.choices.first().ok_or_else(|| {
406            CompletionError::ResponseError("Response contained no choices".to_owned())
407        })?;
408
409        match &choice.message {
410            Message::Assistant {
411                tool_calls,
412                content,
413            } => {
414                if !tool_calls.is_empty() {
415                    let tool_result = tool_calls
416                        .iter()
417                        .map(|call| {
418                            completion::AssistantContent::tool_call(
419                                &call.function.name,
420                                &call.function.name,
421                                call.function.arguments.clone(),
422                            )
423                        })
424                        .collect::<Vec<_>>();
425
426                    let choice = OneOrMany::many(tool_result).map_err(|_| {
427                        CompletionError::ResponseError(
428                            "Response contained no message or tool call (empty)".to_owned(),
429                        )
430                    })?;
431                    let usage = response
432                        .usage
433                        .as_ref()
434                        .map(|usage| completion::Usage {
435                            input_tokens: usage.prompt_tokens as u64,
436                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
437                            total_tokens: usage.total_tokens as u64,
438                            cached_input_tokens: usage
439                                .prompt_tokens_details
440                                .as_ref()
441                                .map(|d| d.cached_tokens as u64)
442                                .unwrap_or(0),
443                        })
444                        .unwrap_or_default();
445                    tracing::debug!("response choices: {:?}: ", choice);
446                    Ok(completion::CompletionResponse {
447                        choice,
448                        usage,
449                        raw_response: response,
450                        message_id: None,
451                    })
452                } else {
453                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
454                        text: content.clone().unwrap_or_else(|| "".to_owned()),
455                    }));
456                    let usage = response
457                        .usage
458                        .as_ref()
459                        .map(|usage| completion::Usage {
460                            input_tokens: usage.prompt_tokens as u64,
461                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
462                            total_tokens: usage.total_tokens as u64,
463                            cached_input_tokens: usage
464                                .prompt_tokens_details
465                                .as_ref()
466                                .map(|d| d.cached_tokens as u64)
467                                .unwrap_or(0),
468                        })
469                        .unwrap_or_default();
470                    Ok(completion::CompletionResponse {
471                        choice,
472                        usage,
473                        raw_response: response,
474                        message_id: None,
475                    })
476                }
477            }
478            // Message::Assistant { tool_calls } => {}
479            _ => Err(CompletionError::ResponseError(
480                "Chat response does not include an assistant message".into(),
481            )),
482        }
483    }
484}
485
486// 函数定义
487#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
488#[serde(rename_all = "camelCase")]
489pub struct CustomFunctionDefinition {
490    #[serde(rename = "type")]
491    pub type_field: String,
492    pub function: Function,
493}
494
495#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
496#[serde(rename_all = "camelCase")]
497pub struct Function {
498    pub name: String,
499    pub description: String,
500    pub parameters: serde_json::Value,
501}
502
503/// 同步请求
504impl<T> completion::CompletionModel for CompletionModel<T>
505where
506    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
507{
508    type Response = CompletionResponse;
509    type StreamingResponse = openai::StreamingCompletionResponse;
510    type Client = Client<T>;
511
512    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
513        Self::new(client.clone(), model.into())
514    }
515
516    async fn completion(
517        &self,
518        completion_request: CompletionRequest,
519    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
520        let span = if tracing::Span::current().is_disabled() {
521            info_span!(
522                target: "rig::completions",
523                "chat",
524                gen_ai.operation.name = "chat",
525                gen_ai.provider.name = "groq",
526                gen_ai.request.model = self.model,
527                gen_ai.system_instructions = tracing::field::Empty,
528                gen_ai.response.id = tracing::field::Empty,
529                gen_ai.response.model = tracing::field::Empty,
530                gen_ai.usage.output_tokens = tracing::field::Empty,
531                gen_ai.usage.input_tokens = tracing::field::Empty,
532            )
533        } else {
534            tracing::Span::current()
535        };
536
537        span.record("gen_ai.system_instructions", &completion_request.preamble);
538
539        let request = self.create_completion_request(completion_request)?;
540
541        if tracing::enabled!(tracing::Level::TRACE) {
542            tracing::trace!(target: "rig::completions",
543                "Groq completion request: {}",
544                serde_json::to_string_pretty(&request)?
545            );
546        }
547
548        let body = serde_json::to_vec(&request)?;
549        let req = self
550            .client
551            .post("/chat/completions")?
552            .body(body)
553            .map_err(|e| http_client::Error::Instance(e.into()))?;
554
555        let async_block = async move {
556            let response = self.client.send::<_, Bytes>(req).await?;
557            let status = response.status();
558            let response_body = response.into_body().into_future().await?.to_vec();
559
560            let tt = response_body.clone();
561            let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
562            println!(
563                "response:\r\n {}",
564                serde_json::to_string_pretty(&response).unwrap()
565            );
566
567            if status.is_success() {
568                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
569                    ApiResponse::Ok(response) => {
570                        let span = tracing::Span::current();
571                        span.record("gen_ai.response.id", response.id.clone());
572                        span.record("gen_ai.response.model_name", response.model.clone());
573                        if let Some(ref usage) = response.usage {
574                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
575                            span.record(
576                                "gen_ai.usage.output_tokens",
577                                usage.total_tokens - usage.prompt_tokens,
578                            );
579                        }
580
581                        if tracing::enabled!(tracing::Level::TRACE) {
582                            tracing::trace!(target: "rig::completions",
583                                "Groq completion response: {}",
584                                serde_json::to_string_pretty(&response)?
585                            );
586                        }
587
588                        response.try_into()
589                    }
590                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
591                }
592            } else {
593                Err(CompletionError::ProviderError(
594                    String::from_utf8_lossy(&response_body).to_string(),
595                ))
596            }
597        };
598
599        tracing::Instrument::instrument(async_block, span).await
600    }
601
602    async fn stream(
603        &self,
604        request: CompletionRequest,
605    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
606        let preamble = request.preamble.clone();
607
608        let mut request = self.create_completion_request(request)?;
609
610        request = merge(request, json!({"stream": true}));
611
612        let body = serde_json::to_vec(&request)?;
613
614        let req = self
615            .client
616            .post("/chat/completions")?
617            .body(body)
618            .map_err(|e| http_client::Error::Instance(e.into()))?;
619
620        let span = if tracing::Span::current().is_disabled() {
621            info_span!(
622                target: "rig::completions",
623                "chat_streaming",
624                gen_ai.operation.name = "chat_streaming",
625                gen_ai.provider.name = "galadriel",
626                gen_ai.request.model = self.model,
627                gen_ai.system_instructions = preamble,
628                gen_ai.response.id = tracing::field::Empty,
629                gen_ai.response.model = tracing::field::Empty,
630                gen_ai.usage.output_tokens = tracing::field::Empty,
631                gen_ai.usage.input_tokens = tracing::field::Empty,
632                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
633                gen_ai.output.messages = tracing::field::Empty,
634            )
635        } else {
636            tracing::Span::current()
637        };
638
639        send_compatible_streaming_request(self.client.clone(), req)
640            .instrument(span)
641            .await
642    }
643}