1use http::{HeaderValue, Method, header};
2use rig::client::{CompletionClient, ProviderClient};
3use rig::completion::{CompletionError, CompletionRequest};
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, http_client, message};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9
10use crate::json_utils;
11use crate::json_utils::merge;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use tracing::{Instrument, info_span};
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 api_key: String,
24 base_url: String,
25 default_headers: http_client::HeaderMap,
26 http_client: reqwest::Client,
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
31 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
32 }
33
34 pub fn from_url(api_key: &str, base_url: &str) -> Self {
35 let mut default_headers = reqwest::header::HeaderMap::new();
36 default_headers.insert(
37 reqwest::header::CONTENT_TYPE,
38 "application/json".parse().unwrap(),
39 );
40
41 Self {
42 api_key: api_key.to_string(),
43 base_url: base_url.to_string(),
44 default_headers,
45 http_client: reqwest::Client::builder()
46 .default_headers({
47 let mut headers = reqwest::header::HeaderMap::new();
48 headers.insert(
49 "Authorization",
50 format!("Bearer {api_key}")
51 .parse()
52 .expect("Bearer token should parse"),
53 );
54 headers
55 })
56 .build()
57 .expect("bigmodel reqwest client should build"),
58 }
59 }
60
61 fn post(&self, path: &str) -> reqwest::RequestBuilder {
62 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
63 self.http_client.post(url)
64 }
65
66 pub fn completion_model(&self, model: &str) -> CompletionModel {
67 CompletionModel::new(self.clone(), model)
68 }
69
70 }
78
79impl ProviderClient for Client {
80 type Input = String;
81
82 fn from_env() -> Self
83 where
84 Self: Sized,
85 {
86 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87 Self::new(&api_key)
88 }
89
90 fn from_val(input: Self::Input) -> Self {
91 Self::new(&input)
92 }
93}
94impl CompletionClient for Client {
95 type CompletionModel = CompletionModel;
96
97 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
98 CompletionModel::new(self.clone(), &model.into())
99 }
100}
101
102#[derive(Debug, Deserialize)]
103struct ApiErrorResponse {
104 message: String,
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(untagged)]
109enum ApiResponse<T> {
110 Ok(T),
111 Err(ApiErrorResponse),
112}
113
114pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
118
119#[deprecated(note = "GLM-4.5-Flash 将于2026年1月30日下线")]
120pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
121pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
122
123#[derive(Debug, Deserialize, Serialize)]
124#[serde(rename_all = "camelCase")]
125pub struct CompletionResponse {
126 pub choices: Vec<Choice>,
127 pub created: i64,
128 pub id: String,
129 pub model: String,
130 #[serde(rename = "request_id")]
131 pub request_id: String,
132 pub usage: Usage,
133}
134
135#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
136#[serde(tag = "role", rename_all = "lowercase")]
137pub enum Message {
138 User {
139 content: String,
140 },
141 Assistant {
142 content: Option<String>,
143 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
144 tool_calls: Vec<ToolCall>,
145 },
146 System {
147 content: String,
148 },
149 #[serde(rename = "tool")]
150 ToolResult {
151 tool_call_id: String,
152 content: String,
153 },
154}
155
156impl Message {
157 pub fn system(content: &str) -> Message {
158 Message::System {
159 content: content.to_owned(),
160 }
161 }
162}
163
164#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
165pub struct ToolResultContent {
166 text: String,
167}
168impl TryFrom<message::ToolResultContent> for ToolResultContent {
169 type Error = MessageError;
170 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
171 let message::ToolResultContent::Text(Text { text }) = value else {
172 return Err(MessageError::ConversionError(
173 "Non-text tool results not supported".into(),
174 ));
175 };
176
177 Ok(Self { text })
178 }
179}
180
181impl TryFrom<message::Message> for Message {
182 type Error = MessageError;
183
184 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
185 Ok(match message {
186 message::Message::User { content } => {
187 let mut texts = Vec::new();
188 let mut images = Vec::new();
189
190 for uc in content.into_iter() {
191 match uc {
192 message::UserContent::Text(message::Text { text }) => texts.push(text),
193 message::UserContent::Image(img) => images.push(img.data),
194 message::UserContent::ToolResult(result) => {
195 let content = result
196 .content
197 .into_iter()
198 .map(ToolResultContent::try_from)
199 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
200
201 let content = OneOrMany::many(content).map_err(|x| {
202 MessageError::ConversionError(format!(
203 "Couldn't make a OneOrMany from a list of tool results: {x}"
204 ))
205 })?;
206
207 return Ok(Message::ToolResult {
208 tool_call_id: result.id,
209 content: content.first().text,
210 });
211 }
212 _ => {}
213 }
214 }
215
216 let collapsed_content = texts.join(" ");
217
218 Message::User {
219 content: collapsed_content,
220 }
221 }
222 message::Message::Assistant { content, .. } => {
223 let mut texts = Vec::new();
224 let mut tool_calls = Vec::new();
225
226 for ac in content.into_iter() {
227 match ac {
228 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
229 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
230 _ => {}
231 }
232 }
233
234 let collapsed_content = texts.join(" ");
235
236 Message::Assistant {
237 content: Some(collapsed_content),
238 tool_calls,
239 }
240 }
241 })
242 }
243}
244
245impl From<message::ToolResult> for Message {
246 fn from(tool_result: message::ToolResult) -> Self {
247 let content = match tool_result.content.first() {
248 message::ToolResultContent::Text(text) => text.text,
249 message::ToolResultContent::Image(_) => String::from("[Image]"),
250 };
251
252 Message::ToolResult {
253 tool_call_id: tool_result.id,
254 content,
255 }
256 }
257}
258
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct ToolCall {
262 pub function: CallFunction,
263 pub id: String,
264 pub index: usize,
265 #[serde(default)]
266 pub r#type: ToolType,
267}
268
269impl From<message::ToolCall> for ToolCall {
270 fn from(tool_call: message::ToolCall) -> Self {
271 Self {
272 id: tool_call.id,
273 index: 0,
274 r#type: ToolType::Function,
275 function: CallFunction {
276 name: tool_call.function.name,
277 arguments: tool_call.function.arguments,
278 },
279 }
280 }
281}
282
283#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
284#[serde(rename_all = "lowercase")]
285pub enum ToolType {
286 #[default]
287 Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct CallFunction {
292 pub name: String,
293 #[serde(with = "json_utils::stringified_json")]
294 pub arguments: serde_json::Value,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298#[serde(rename_all = "lowercase")]
299pub enum Role {
300 System,
301 User,
302 Assistant,
303}
304
305#[derive(Debug, Serialize, Deserialize)]
306#[serde(rename_all = "camelCase")]
307pub struct Choice {
308 #[serde(rename = "finish_reason")]
309 pub finish_reason: String,
310 pub index: i64,
311 pub message: Message,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct Usage {
317 #[serde(rename = "completion_tokens")]
318 pub completion_tokens: i64,
319 #[serde(rename = "prompt_tokens")]
320 pub prompt_tokens: i64,
321 #[serde(rename = "total_tokens")]
322 pub total_tokens: i64,
323}
324
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326 type Error = CompletionError;
327
328 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
329 let choice = response.choices.first().ok_or_else(|| {
330 CompletionError::ResponseError("Response contained no choices".to_owned())
331 })?;
332
333 match &choice.message {
334 Message::Assistant {
335 tool_calls,
336 content,
337 } => {
338 if !tool_calls.is_empty() {
339 let tool_result = tool_calls
340 .iter()
341 .map(|call| {
342 completion::AssistantContent::tool_call(
343 &call.function.name,
344 &call.function.name,
345 call.function.arguments.clone(),
346 )
347 })
348 .collect::<Vec<_>>();
349
350 let choice = OneOrMany::many(tool_result).map_err(|_| {
351 CompletionError::ResponseError(
352 "Response contained no message or tool call (empty)".to_owned(),
353 )
354 })?;
355 let usage = completion::Usage {
356 input_tokens: response.usage.prompt_tokens as u64,
357 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
358 as u64,
359 total_tokens: response.usage.total_tokens as u64,
360 };
361 tracing::debug!("response choices: {:?}: ", choice);
362 Ok(completion::CompletionResponse {
363 choice,
364 usage,
365 raw_response: response,
366 })
367 } else {
368 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
369 text: content.clone().unwrap_or_else(|| "".to_owned()),
370 }));
371 let usage = completion::Usage {
372 input_tokens: response.usage.prompt_tokens as u64,
373 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
374 as u64,
375 total_tokens: response.usage.total_tokens as u64,
376 };
377 Ok(completion::CompletionResponse {
378 choice,
379 usage,
380 raw_response: response,
381 })
382 }
383 }
384 _ => Err(CompletionError::ResponseError(
386 "Chat response does not include an assistant message".into(),
387 )),
388 }
389 }
390}
391
392#[derive(Clone)]
393pub struct CompletionModel {
394 client: Client,
395 pub model: String,
396}
397
398#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
400#[serde(rename_all = "camelCase")]
401pub struct CustomFunctionDefinition {
402 #[serde(rename = "type")]
403 pub type_field: String,
404 pub function: Function,
405}
406
407#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
408#[serde(rename_all = "camelCase")]
409pub struct Function {
410 pub name: String,
411 pub description: String,
412 pub parameters: serde_json::Value,
413}
414
415impl CompletionModel {
416 pub fn new(client: Client, model: &str) -> Self {
417 Self {
418 client,
419 model: model.to_string(),
420 }
421 }
422
423 fn create_completion_request(
424 &self,
425 completion_request: CompletionRequest,
426 ) -> Result<Value, CompletionError> {
427 let mut partial_history = vec![];
429 if let Some(docs) = completion_request.normalized_documents() {
430 partial_history.push(docs);
431 }
432 partial_history.extend(completion_request.chat_history);
433
434 let mut full_history: Vec<Message> = completion_request
436 .preamble
437 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
438
439 full_history.extend(
441 partial_history
442 .into_iter()
443 .map(message::Message::try_into)
444 .collect::<Result<Vec<Message>, _>>()?,
445 );
446
447 let request = if completion_request.tools.is_empty() {
448 json!({
449 "model": self.model,
450 "messages": full_history,
451 "temperature": completion_request.temperature,
452 })
453 } else {
454 let tools = completion_request
456 .tools
457 .into_iter()
458 .map(|item| {
459 let custom_function = Function {
460 name: item.name,
461 description: item.description,
462 parameters: item.parameters,
463 };
464 CustomFunctionDefinition {
465 type_field: "function".to_string(),
466 function: custom_function,
467 }
468 })
469 .collect::<Vec<_>>();
470
471 tracing::debug!("tools: {:?}", tools);
472
473 json!({
474 "model": self.model,
475 "messages": full_history,
476 "temperature": completion_request.temperature,
477 "tools": tools,
478 "tool_choice": "auto",
479 })
480 };
481
482 let request = if let Some(params) = completion_request.additional_params {
483 json_utils::merge(request, params)
484 } else {
485 request
486 };
487
488 Ok(request)
489 }
490}
491
492impl completion::CompletionModel for CompletionModel {
494 type Response = CompletionResponse;
495 type StreamingResponse = openai::StreamingCompletionResponse;
496 type Client = Client;
497
498 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
499 Self::new(client.clone(), &model.into())
500 }
501
502 async fn completion(
503 &self,
504 completion_request: CompletionRequest,
505 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
506 tracing::debug!("create_completion_request========");
507 let request = self.create_completion_request(completion_request)?;
508
509 tracing::debug!(
510 "request: \r\n {}",
511 serde_json::to_string_pretty(&request).unwrap()
512 );
513
514 let response = self
515 .client
516 .post("/chat/completions")
517 .json(&request)
518 .send()
519 .await
520 .map_err(|e| http_client::Error::Instance(e.into()))?;
521
522 if response.status().is_success() {
523 let data: Value = response.json().await.expect("api error");
524 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
525 let data: ApiResponse<CompletionResponse> =
526 serde_json::from_value(data).expect("deserialize completion response");
527 match data {
528 ApiResponse::Ok(response) => {
529 tracing::info!(target: "rig",
530 "bigmodel completion token usage: {:?}",
531 response.usage
532 );
533 response.try_into()
534 }
535 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
536 }
537 } else {
538 Err(CompletionError::ProviderError(
539 response
540 .text()
541 .await
542 .map_err(|e| http_client::Error::Instance(e.into()))?,
543 ))
544 }
545 }
546
547 async fn stream(
548 &self,
549 request: CompletionRequest,
550 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
551 let preamble = request.preamble.clone();
552
553 let mut request = self.create_completion_request(request)?;
554
555 request = merge(request, json!({"stream": true}));
556
557 let body = serde_json::to_vec(&request)?;
558
559 let url = format!(
560 "{}/{}",
561 self.client.base_url,
562 "/chat/completions".trim_start_matches('/')
563 );
564
565 let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
566 for (header, value) in &self.client.default_headers {
567 builder = builder.header(header, value);
568 }
569
570 let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
571 .map_err(http::Error::from)
572 .map_err(rig::http_client::Error::from)?;
573
574 builder = builder.header(header::AUTHORIZATION, auth_header);
575 builder = builder.header("Content-Type", "application/json");
576
577 let req = builder
578 .body(body)
579 .map_err(|e| CompletionError::HttpError(e.into()))?;
580
581 let span = if tracing::Span::current().is_disabled() {
582 info_span!(
583 target: "rig::completions",
584 "chat_streaming",
585 gen_ai.operation.name = "chat_streaming",
586 gen_ai.provider.name = "galadriel",
587 gen_ai.request.model = self.model,
588 gen_ai.system_instructions = preamble,
589 gen_ai.response.id = tracing::field::Empty,
590 gen_ai.response.model = tracing::field::Empty,
591 gen_ai.usage.output_tokens = tracing::field::Empty,
592 gen_ai.usage.input_tokens = tracing::field::Empty,
593 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
594 gen_ai.output.messages = tracing::field::Empty,
595 )
596 } else {
597 tracing::Span::current()
598 };
599
600 send_compatible_streaming_request(self.client.http_client.clone(), req)
601 .instrument(span)
602 .await
603 }
604}