Skip to main content

rig_extra/extra_providers/
bigmodel.rs

1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6    BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19// use rig::providers::openai::{Message as OpenAIMessage};
20
21const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34    client: Client<T>,
35    pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40        Self {
41            client,
42            model: model.into(),
43        }
44    }
45
46    fn create_completion_request(
47        &self,
48        completion_request: CompletionRequest,
49    ) -> Result<Value, CompletionError> {
50        // 构建消息顺序(上下文、聊天历史、提示)
51        let mut partial_history = vec![];
52        if let Some(docs) = completion_request.normalized_documents() {
53            partial_history.push(docs);
54        }
55        partial_history.extend(completion_request.chat_history);
56
57        // 使用前言初始化完整历史(如果不存在则为空)
58        let mut full_history: Vec<Message> = completion_request
59            .preamble
60            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62        // 转换并扩展其余历史
63        full_history.extend(
64            partial_history
65                .into_iter()
66                .map(message::Message::try_into)
67                .collect::<Result<Vec<Message>, _>>()?,
68        );
69
70        let request = if completion_request.tools.is_empty() {
71            json!({
72                "model": self.model,
73                "messages": full_history,
74                "temperature": completion_request.temperature,
75            })
76        } else {
77            // tools
78            let tools = completion_request
79                .tools
80                .into_iter()
81                .map(|item| {
82                    let custom_function = Function {
83                        name: item.name,
84                        description: item.description,
85                        parameters: item.parameters,
86                    };
87                    CustomFunctionDefinition {
88                        type_field: "function".to_string(),
89                        function: custom_function,
90                    }
91                })
92                .collect::<Vec<_>>();
93
94            tracing::debug!("tools: {:?}", tools);
95
96            json!({
97                "model": self.model,
98                "messages": full_history,
99                "temperature": completion_request.temperature,
100                "tools": tools,
101                "tool_choice": "auto",
102            })
103        };
104
105        let request = if let Some(params) = completion_request.additional_params {
106            json_utils::merge(request, params)
107        } else {
108            request
109        };
110
111        Ok(request)
112    }
113}
114
115impl Provider for BigmodelExt {
116    const VERIFY_PATH: &'static str = "api/tags";
117    type Builder = BigmodelBuilder;
118
119    fn build<H>(
120        _builder: &client::ClientBuilder<
121            Self::Builder,
122            <Self::Builder as ProviderBuilder>::ApiKey,
123            H,
124        >,
125    ) -> rig::http_client::Result<Self> {
126        Ok(Self)
127    }
128}
129
130/// provider有那些功能
131impl<H> Capabilities<H> for BigmodelExt {
132    type Completion = Capable<CompletionModel<H>>;
133    type Embeddings = Nothing;
134    type Transcription = Nothing;
135
136    // #[cfg(feature = "image")]
137    // type ImageGeneration = Nothing;
138    //
139    // #[cfg(feature = "audio")]
140    // type AudioGeneration = Nothing;
141}
142
143impl DebugExt for BigmodelExt {}
144
145impl ProviderBuilder for BigmodelBuilder {
146    type Output = BigmodelExt;
147    type ApiKey = BigmodelApiKey;
148    const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
149}
150
151pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
152pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
153
154/// Rust 的孤儿规则(Orphan Rule)
155// impl ProviderClient for Client{
156//     type Input = String;
157//
158//     fn from_env() -> Self {
159//         let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_API_KEY not set");
160//         Self::new(&api_key).unwrap()
161//     }
162//
163//     fn from_val(input: Self::Input) -> Self {
164//         Self::new(&input).unwrap()
165//     }
166// }
167
168// ---------- API Error and Response Structures ----------
169#[derive(Debug, Deserialize)]
170struct ApiErrorResponse {
171    message: String,
172}
173
174#[derive(Debug, Deserialize)]
175#[serde(untagged)]
176enum ApiResponse<T> {
177    Ok(T),
178    Err(ApiErrorResponse),
179}
180
181// ================================================================
182// Bigmodel Completion API
183// ================================================================
184
185pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
186
187#[derive(Debug, Deserialize, Serialize)]
188#[serde(rename_all = "camelCase")]
189pub struct CompletionResponse {
190    pub choices: Vec<Choice>,
191    pub created: i64,
192    pub id: String,
193    pub model: String,
194    #[serde(rename = "request_id")]
195    pub request_id: String,
196    pub usage: Option<Usage>,
197}
198
199#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
200#[serde(tag = "role", rename_all = "lowercase")]
201pub enum Message {
202    User {
203        content: String,
204    },
205    Assistant {
206        content: Option<String>,
207        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
208        tool_calls: Vec<ToolCall>,
209    },
210    System {
211        content: String,
212    },
213    #[serde(rename = "tool")]
214    ToolResult {
215        tool_call_id: String,
216        content: String,
217    },
218}
219
220impl Message {
221    pub fn system(content: &str) -> Message {
222        Message::System {
223            content: content.to_owned(),
224        }
225    }
226}
227
228#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
229pub struct ToolResultContent {
230    text: String,
231}
232impl TryFrom<message::ToolResultContent> for ToolResultContent {
233    type Error = MessageError;
234    fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
235        let message::ToolResultContent::Text(Text { text }) = value else {
236            return Err(MessageError::ConversionError(
237                "Non-text tool results not supported".into(),
238            ));
239        };
240
241        Ok(Self { text })
242    }
243}
244
245impl TryFrom<message::Message> for Message {
246    type Error = MessageError;
247
248    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
249        Ok(match message {
250            message::Message::User { content } => {
251                let mut texts = Vec::new();
252                let mut images = Vec::new();
253
254                for uc in content.into_iter() {
255                    match uc {
256                        message::UserContent::Text(message::Text { text }) => texts.push(text),
257                        message::UserContent::Image(img) => images.push(img.data),
258                        message::UserContent::ToolResult(result) => {
259                            let content = result
260                                .content
261                                .into_iter()
262                                .map(ToolResultContent::try_from)
263                                .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
264
265                            let content = OneOrMany::many(content).map_err(|x| {
266                                MessageError::ConversionError(format!(
267                                    "Couldn't make a OneOrMany from a list of tool results: {x}"
268                                ))
269                            })?;
270
271                            return Ok(Message::ToolResult {
272                                tool_call_id: result.id,
273                                content: content.first().text,
274                            });
275                        }
276                        _ => {}
277                    }
278                }
279
280                let collapsed_content = texts.join(" ");
281
282                Message::User {
283                    content: collapsed_content,
284                }
285            }
286            message::Message::Assistant { content, .. } => {
287                let mut texts = Vec::new();
288                let mut tool_calls = Vec::new();
289
290                for ac in content.into_iter() {
291                    match ac {
292                        message::AssistantContent::Text(message::Text { text }) => texts.push(text),
293                        message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
294                        _ => {}
295                    }
296                }
297
298                let collapsed_content = texts.join(" ");
299
300                Message::Assistant {
301                    content: Some(collapsed_content),
302                    tool_calls,
303                }
304            }
305        })
306    }
307}
308
309impl From<message::ToolResult> for Message {
310    fn from(tool_result: message::ToolResult) -> Self {
311        let content = match tool_result.content.first() {
312            message::ToolResultContent::Text(text) => text.text,
313            message::ToolResultContent::Image(_) => String::from("[Image]"),
314        };
315
316        Message::ToolResult {
317            tool_call_id: tool_result.id,
318            content,
319        }
320    }
321}
322
323#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
324#[serde(rename_all = "camelCase")]
325pub struct ToolCall {
326    pub function: CallFunction,
327    pub id: String,
328    pub index: usize,
329    #[serde(default)]
330    pub r#type: ToolType,
331}
332
333impl From<message::ToolCall> for ToolCall {
334    fn from(tool_call: message::ToolCall) -> Self {
335        Self {
336            id: tool_call.id,
337            index: 0,
338            r#type: ToolType::Function,
339            function: CallFunction {
340                name: tool_call.function.name,
341                arguments: tool_call.function.arguments,
342            },
343        }
344    }
345}
346
347#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
348#[serde(rename_all = "lowercase")]
349pub enum ToolType {
350    #[default]
351    Function,
352}
353
354#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
355pub struct CallFunction {
356    pub name: String,
357    #[serde(with = "json_utils::stringified_json")]
358    pub arguments: serde_json::Value,
359}
360
361#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
362#[serde(rename_all = "lowercase")]
363pub enum Role {
364    System,
365    User,
366    Assistant,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370#[serde(rename_all = "camelCase")]
371pub struct Choice {
372    #[serde(rename = "finish_reason")]
373    pub finish_reason: String,
374    pub index: i64,
375    pub message: Message,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
379#[serde(rename_all = "camelCase")]
380pub struct Usage {
381    #[serde(rename = "completion_tokens")]
382    pub completion_tokens: i64,
383    #[serde(rename = "prompt_tokens")]
384    pub prompt_tokens: i64,
385    #[serde(rename = "total_tokens")]
386    pub total_tokens: i64,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub prompt_tokens_details: Option<PromptTokensDetails>,
389}
390
391#[derive(Clone, Debug, Deserialize, Serialize, Default)]
392pub struct PromptTokensDetails {
393    /// Cached tokens from prompt caching
394    #[serde(default)]
395    pub cached_tokens: usize,
396}
397
398impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
399    type Error = CompletionError;
400
401    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
402        let choice = response.choices.first().ok_or_else(|| {
403            CompletionError::ResponseError("Response contained no choices".to_owned())
404        })?;
405
406        match &choice.message {
407            Message::Assistant {
408                tool_calls,
409                content,
410            } => {
411                if !tool_calls.is_empty() {
412                    let tool_result = tool_calls
413                        .iter()
414                        .map(|call| {
415                            completion::AssistantContent::tool_call(
416                                &call.function.name,
417                                &call.function.name,
418                                call.function.arguments.clone(),
419                            )
420                        })
421                        .collect::<Vec<_>>();
422
423                    let choice = OneOrMany::many(tool_result).map_err(|_| {
424                        CompletionError::ResponseError(
425                            "Response contained no message or tool call (empty)".to_owned(),
426                        )
427                    })?;
428                    let usage = response
429                        .usage
430                        .as_ref()
431                        .map(|usage| completion::Usage {
432                            input_tokens: usage.prompt_tokens as u64,
433                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
434                            total_tokens: usage.total_tokens as u64,
435                            cached_input_tokens: usage
436                                .prompt_tokens_details
437                                .as_ref()
438                                .map(|d| d.cached_tokens as u64)
439                                .unwrap_or(0),
440                        })
441                        .unwrap_or_default();
442                    tracing::debug!("response choices: {:?}: ", choice);
443                    Ok(completion::CompletionResponse {
444                        choice,
445                        usage,
446                        raw_response: response,
447                    })
448                } else {
449                    let choice = OneOrMany::one(message::AssistantContent::Text(Text {
450                        text: content.clone().unwrap_or_else(|| "".to_owned()),
451                    }));
452                    let usage = response
453                        .usage
454                        .as_ref()
455                        .map(|usage| completion::Usage {
456                            input_tokens: usage.prompt_tokens as u64,
457                            output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
458                            total_tokens: usage.total_tokens as u64,
459                            cached_input_tokens: usage
460                                .prompt_tokens_details
461                                .as_ref()
462                                .map(|d| d.cached_tokens as u64)
463                                .unwrap_or(0),
464                        })
465                        .unwrap_or_default();
466                    Ok(completion::CompletionResponse {
467                        choice,
468                        usage,
469                        raw_response: response,
470                    })
471                }
472            }
473            // Message::Assistant { tool_calls } => {}
474            _ => Err(CompletionError::ResponseError(
475                "Chat response does not include an assistant message".into(),
476            )),
477        }
478    }
479}
480
481// 函数定义
482#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
483#[serde(rename_all = "camelCase")]
484pub struct CustomFunctionDefinition {
485    #[serde(rename = "type")]
486    pub type_field: String,
487    pub function: Function,
488}
489
490#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
491#[serde(rename_all = "camelCase")]
492pub struct Function {
493    pub name: String,
494    pub description: String,
495    pub parameters: serde_json::Value,
496}
497
498/// 同步请求
499impl<T> completion::CompletionModel for CompletionModel<T>
500where
501    T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
502{
503    type Response = CompletionResponse;
504    type StreamingResponse = openai::StreamingCompletionResponse;
505    type Client = Client<T>;
506
507    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
508        Self::new(client.clone(), model.into())
509    }
510
511    async fn completion(
512        &self,
513        completion_request: CompletionRequest,
514    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
515        let span = if tracing::Span::current().is_disabled() {
516            info_span!(
517                target: "rig::completions",
518                "chat",
519                gen_ai.operation.name = "chat",
520                gen_ai.provider.name = "groq",
521                gen_ai.request.model = self.model,
522                gen_ai.system_instructions = tracing::field::Empty,
523                gen_ai.response.id = tracing::field::Empty,
524                gen_ai.response.model = tracing::field::Empty,
525                gen_ai.usage.output_tokens = tracing::field::Empty,
526                gen_ai.usage.input_tokens = tracing::field::Empty,
527            )
528        } else {
529            tracing::Span::current()
530        };
531
532        span.record("gen_ai.system_instructions", &completion_request.preamble);
533
534        let request = self.create_completion_request(completion_request)?;
535
536        if tracing::enabled!(tracing::Level::TRACE) {
537            tracing::trace!(target: "rig::completions",
538                "Groq completion request: {}",
539                serde_json::to_string_pretty(&request)?
540            );
541        }
542
543        let body = serde_json::to_vec(&request)?;
544        let req = self
545            .client
546            .post("/chat/completions")?
547            .body(body)
548            .map_err(|e| http_client::Error::Instance(e.into()))?;
549
550        let async_block = async move {
551            let response = self.client.send::<_, Bytes>(req).await?;
552            let status = response.status();
553            let response_body = response.into_body().into_future().await?.to_vec();
554
555            let tt = response_body.clone();
556            let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
557            println!(
558                "response:\r\n {}",
559                serde_json::to_string_pretty(&response).unwrap()
560            );
561
562            if status.is_success() {
563                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
564                    ApiResponse::Ok(response) => {
565                        let span = tracing::Span::current();
566                        span.record("gen_ai.response.id", response.id.clone());
567                        span.record("gen_ai.response.model_name", response.model.clone());
568                        if let Some(ref usage) = response.usage {
569                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
570                            span.record(
571                                "gen_ai.usage.output_tokens",
572                                usage.total_tokens - usage.prompt_tokens,
573                            );
574                        }
575
576                        if tracing::enabled!(tracing::Level::TRACE) {
577                            tracing::trace!(target: "rig::completions",
578                                "Groq completion response: {}",
579                                serde_json::to_string_pretty(&response)?
580                            );
581                        }
582
583                        response.try_into()
584                    }
585                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
586                }
587            } else {
588                Err(CompletionError::ProviderError(
589                    String::from_utf8_lossy(&response_body).to_string(),
590                ))
591            }
592        };
593
594        tracing::Instrument::instrument(async_block, span).await
595    }
596
597    async fn stream(
598        &self,
599        request: CompletionRequest,
600    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
601        let preamble = request.preamble.clone();
602
603        let mut request = self.create_completion_request(request)?;
604
605        request = merge(request, json!({"stream": true}));
606
607        let body = serde_json::to_vec(&request)?;
608
609        let req = self
610            .client
611            .post("/chat/completions")?
612            .body(body)
613            .map_err(|e| http_client::Error::Instance(e.into()))?;
614
615        let span = if tracing::Span::current().is_disabled() {
616            info_span!(
617                target: "rig::completions",
618                "chat_streaming",
619                gen_ai.operation.name = "chat_streaming",
620                gen_ai.provider.name = "galadriel",
621                gen_ai.request.model = self.model,
622                gen_ai.system_instructions = preamble,
623                gen_ai.response.id = tracing::field::Empty,
624                gen_ai.response.model = tracing::field::Empty,
625                gen_ai.usage.output_tokens = tracing::field::Empty,
626                gen_ai.usage.input_tokens = tracing::field::Empty,
627                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
628                gen_ai.output.messages = tracing::field::Empty,
629            )
630        } else {
631            tracing::Span::current()
632        };
633
634        send_compatible_streaming_request(self.client.clone(), req)
635            .instrument(span)
636            .await
637    }
638}