1use http::header;
2use reqwest::Method;
3use reqwest::header::HeaderValue;
4use rig::client::{CompletionClient, ProviderClient};
5use rig::completion::{CompletionError, CompletionRequest};
6use rig::message::{MessageError, Text};
7use rig::providers::openai;
8use rig::{OneOrMany, completion, http_client, message};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::json_utils;
13use crate::json_utils::merge;
14use rig::providers::openai::send_compatible_streaming_request;
15use rig::streaming::StreamingCompletionResponse;
16use tracing::{Instrument, info_span};
17
18const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Clone, Debug)]
24pub struct Client {
25 api_key: String,
26 base_url: String,
27 default_headers: http_client::HeaderMap,
28 http_client: reqwest::Client,
29}
30
31impl Client {
32 pub fn new(api_key: &str) -> Self {
33 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
34 }
35
36 pub fn from_url(api_key: &str, base_url: &str) -> Self {
37 let mut default_headers = reqwest::header::HeaderMap::new();
38 default_headers.insert(
39 reqwest::header::CONTENT_TYPE,
40 "application/json".parse().unwrap(),
41 );
42
43 Self {
44 api_key: api_key.to_string(),
45 base_url: base_url.to_string(),
46 default_headers,
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 "Authorization",
52 format!("Bearer {api_key}")
53 .parse()
54 .expect("Bearer token should parse"),
55 );
56 headers
57 })
58 .build()
59 .expect("bigmodel reqwest client should build"),
60 }
61 }
62
63 fn post(&self, path: &str) -> reqwest::RequestBuilder {
64 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65 self.http_client.post(url)
66 }
67
68 pub fn completion_model(&self, model: &str) -> CompletionModel {
69 CompletionModel::new(self.clone(), model)
70 }
71
72 }
80
81impl ProviderClient for Client {
82 type Input = String;
83
84 fn from_env() -> Self
85 where
86 Self: Sized,
87 {
88 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
89 Self::new(&api_key)
90 }
91
92 fn from_val(input: Self::Input) -> Self {
93 Self::new(&input)
94 }
95}
96impl CompletionClient for Client {
97 type CompletionModel = CompletionModel;
98
99 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
100 CompletionModel::new(self.clone(), &model.into())
101 }
102}
103
104#[derive(Debug, Deserialize)]
105struct ApiErrorResponse {
106 message: String,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(untagged)]
111enum ApiResponse<T> {
112 Ok(T),
113 Err(ApiErrorResponse),
114}
115
116pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
120pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
121
122#[derive(Debug, Deserialize, Serialize)]
123#[serde(rename_all = "camelCase")]
124pub struct CompletionResponse {
125 pub choices: Vec<Choice>,
126 pub created: i64,
127 pub id: String,
128 pub model: String,
129 #[serde(rename = "request_id")]
130 pub request_id: String,
131 pub usage: Usage,
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135#[serde(tag = "role", rename_all = "lowercase")]
136pub enum Message {
137 User {
138 content: String,
139 },
140 Assistant {
141 content: Option<String>,
142 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
143 tool_calls: Vec<ToolCall>,
144 },
145 System {
146 content: String,
147 },
148 #[serde(rename = "tool")]
149 ToolResult {
150 tool_call_id: String,
151 content: String,
152 },
153}
154
155impl Message {
156 pub fn system(content: &str) -> Message {
157 Message::System {
158 content: content.to_owned(),
159 }
160 }
161}
162
163#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
164pub struct ToolResultContent {
165 text: String,
166}
167impl TryFrom<message::ToolResultContent> for ToolResultContent {
168 type Error = MessageError;
169 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
170 let message::ToolResultContent::Text(Text { text }) = value else {
171 return Err(MessageError::ConversionError(
172 "Non-text tool results not supported".into(),
173 ));
174 };
175
176 Ok(Self { text })
177 }
178}
179
180impl TryFrom<message::Message> for Message {
181 type Error = MessageError;
182
183 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
184 Ok(match message {
185 message::Message::User { content } => {
186 let mut texts = Vec::new();
187 let mut images = Vec::new();
188
189 for uc in content.into_iter() {
190 match uc {
191 message::UserContent::Text(message::Text { text }) => texts.push(text),
192 message::UserContent::Image(img) => images.push(img.data),
193 message::UserContent::ToolResult(result) => {
194 let content = result
195 .content
196 .into_iter()
197 .map(ToolResultContent::try_from)
198 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
199
200 let content = OneOrMany::many(content).map_err(|x| {
201 MessageError::ConversionError(format!(
202 "Couldn't make a OneOrMany from a list of tool results: {x}"
203 ))
204 })?;
205
206 return Ok(Message::ToolResult {
207 tool_call_id: result.id,
208 content: content.first().text,
209 });
210 }
211 _ => {}
212 }
213 }
214
215 let collapsed_content = texts.join(" ");
216
217 Message::User {
218 content: collapsed_content,
219 }
220 }
221 message::Message::Assistant { content, .. } => {
222 let mut texts = Vec::new();
223 let mut tool_calls = Vec::new();
224
225 for ac in content.into_iter() {
226 match ac {
227 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
228 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
229 _ => {}
230 }
231 }
232
233 let collapsed_content = texts.join(" ");
234
235 Message::Assistant {
236 content: Some(collapsed_content),
237 tool_calls,
238 }
239 }
240 })
241 }
242}
243
244impl From<message::ToolResult> for Message {
245 fn from(tool_result: message::ToolResult) -> Self {
246 let content = match tool_result.content.first() {
247 message::ToolResultContent::Text(text) => text.text,
248 message::ToolResultContent::Image(_) => String::from("[Image]"),
249 };
250
251 Message::ToolResult {
252 tool_call_id: tool_result.id,
253 content,
254 }
255 }
256}
257
258#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct ToolCall {
261 pub function: CallFunction,
262 pub id: String,
263 pub index: usize,
264 #[serde(default)]
265 pub r#type: ToolType,
266}
267
268impl From<message::ToolCall> for ToolCall {
269 fn from(tool_call: message::ToolCall) -> Self {
270 Self {
271 id: tool_call.id,
272 index: 0,
273 r#type: ToolType::Function,
274 function: CallFunction {
275 name: tool_call.function.name,
276 arguments: tool_call.function.arguments,
277 },
278 }
279 }
280}
281
282#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolType {
285 #[default]
286 Function,
287}
288
289#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
290pub struct CallFunction {
291 pub name: String,
292 #[serde(with = "json_utils::stringified_json")]
293 pub arguments: serde_json::Value,
294}
295
296#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
297#[serde(rename_all = "lowercase")]
298pub enum Role {
299 System,
300 User,
301 Assistant,
302}
303
304#[derive(Debug, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct Choice {
307 #[serde(rename = "finish_reason")]
308 pub finish_reason: String,
309 pub index: i64,
310 pub message: Message,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315pub struct Usage {
316 #[serde(rename = "completion_tokens")]
317 pub completion_tokens: i64,
318 #[serde(rename = "prompt_tokens")]
319 pub prompt_tokens: i64,
320 #[serde(rename = "total_tokens")]
321 pub total_tokens: i64,
322}
323
324impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
325 type Error = CompletionError;
326
327 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
328 let choice = response.choices.first().ok_or_else(|| {
329 CompletionError::ResponseError("Response contained no choices".to_owned())
330 })?;
331
332 match &choice.message {
333 Message::Assistant {
334 tool_calls,
335 content,
336 } => {
337 if !tool_calls.is_empty() {
338 let tool_result = tool_calls
339 .iter()
340 .map(|call| {
341 completion::AssistantContent::tool_call(
342 &call.function.name,
343 &call.function.name,
344 call.function.arguments.clone(),
345 )
346 })
347 .collect::<Vec<_>>();
348
349 let choice = OneOrMany::many(tool_result).map_err(|_| {
350 CompletionError::ResponseError(
351 "Response contained no message or tool call (empty)".to_owned(),
352 )
353 })?;
354 let usage = completion::Usage {
355 input_tokens: response.usage.prompt_tokens as u64,
356 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
357 as u64,
358 total_tokens: response.usage.total_tokens as u64,
359 };
360 tracing::debug!("response choices: {:?}: ", choice);
361 Ok(completion::CompletionResponse {
362 choice,
363 usage,
364 raw_response: response,
365 })
366 } else {
367 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
368 text: content.clone().unwrap_or_else(|| "".to_owned()),
369 }));
370 let usage = completion::Usage {
371 input_tokens: response.usage.prompt_tokens as u64,
372 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
373 as u64,
374 total_tokens: response.usage.total_tokens as u64,
375 };
376 Ok(completion::CompletionResponse {
377 choice,
378 usage,
379 raw_response: response,
380 })
381 }
382 }
383 _ => Err(CompletionError::ResponseError(
385 "Chat response does not include an assistant message".into(),
386 )),
387 }
388 }
389}
390
391#[derive(Clone)]
392pub struct CompletionModel {
393 client: Client,
394 pub model: String,
395}
396
397#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
399#[serde(rename_all = "camelCase")]
400pub struct CustomFunctionDefinition {
401 #[serde(rename = "type")]
402 pub type_field: String,
403 pub function: Function,
404}
405
406#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
407#[serde(rename_all = "camelCase")]
408pub struct Function {
409 pub name: String,
410 pub description: String,
411 pub parameters: serde_json::Value,
412}
413
414impl CompletionModel {
415 pub fn new(client: Client, model: &str) -> Self {
416 Self {
417 client,
418 model: model.to_string(),
419 }
420 }
421
422 fn create_completion_request(
423 &self,
424 completion_request: CompletionRequest,
425 ) -> Result<Value, CompletionError> {
426 let mut partial_history = vec![];
428 if let Some(docs) = completion_request.normalized_documents() {
429 partial_history.push(docs);
430 }
431 partial_history.extend(completion_request.chat_history);
432
433 let mut full_history: Vec<Message> = completion_request
435 .preamble
436 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
437
438 full_history.extend(
440 partial_history
441 .into_iter()
442 .map(message::Message::try_into)
443 .collect::<Result<Vec<Message>, _>>()?,
444 );
445
446 let request = if completion_request.tools.is_empty() {
447 json!({
448 "model": self.model,
449 "messages": full_history,
450 "temperature": completion_request.temperature,
451 })
452 } else {
453 let tools = completion_request
455 .tools
456 .into_iter()
457 .map(|item| {
458 let custom_function = Function {
459 name: item.name,
460 description: item.description,
461 parameters: item.parameters,
462 };
463 CustomFunctionDefinition {
464 type_field: "function".to_string(),
465 function: custom_function,
466 }
467 })
468 .collect::<Vec<_>>();
469
470 tracing::debug!("tools: {:?}", tools);
471
472 json!({
473 "model": self.model,
474 "messages": full_history,
475 "temperature": completion_request.temperature,
476 "tools": tools,
477 "tool_choice": "auto",
478 })
479 };
480
481 let request = if let Some(params) = completion_request.additional_params {
482 json_utils::merge(request, params)
483 } else {
484 request
485 };
486
487 Ok(request)
488 }
489}
490
491impl completion::CompletionModel for CompletionModel {
493 type Response = CompletionResponse;
494 type StreamingResponse = openai::StreamingCompletionResponse;
495 type Client = Client;
496
497 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
498 Self::new(client.clone(), &model.into())
499 }
500
501 async fn completion(
502 &self,
503 completion_request: CompletionRequest,
504 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
505 tracing::debug!("create_completion_request========");
506 let request = self.create_completion_request(completion_request)?;
507
508 tracing::debug!(
509 "request: \r\n {}",
510 serde_json::to_string_pretty(&request).unwrap()
511 );
512
513 let response = self
514 .client
515 .post("/chat/completions")
516 .json(&request)
517 .send()
518 .await
519 .map_err(|e| http_client::Error::Instance(e.into()))?;
520
521 if response.status().is_success() {
522 let data: Value = response.json().await.expect("api error");
523 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
524 let data: ApiResponse<CompletionResponse> =
525 serde_json::from_value(data).expect("deserialize completion response");
526 match data {
527 ApiResponse::Ok(response) => {
528 tracing::info!(target: "rig",
529 "bigmodel completion token usage: {:?}",
530 response.usage
531 );
532 response.try_into()
533 }
534 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
535 }
536 } else {
537 Err(CompletionError::ProviderError(
538 response
539 .text()
540 .await
541 .map_err(|e| http_client::Error::Instance(e.into()))?,
542 ))
543 }
544 }
545
546 async fn stream(
547 &self,
548 request: CompletionRequest,
549 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
550 let preamble = request.preamble.clone();
551
552 let mut request = self.create_completion_request(request)?;
553
554 request = merge(request, json!({"stream": true}));
555
556 let body = serde_json::to_vec(&request)?;
557
558 let url = format!(
559 "{}/{}",
560 self.client.base_url,
561 "/chat/completions".trim_start_matches('/')
562 );
563
564 let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
565 for (header, value) in &self.client.default_headers {
566 builder = builder.header(header, value);
567 }
568
569 let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
570 .map_err(http::Error::from)
571 .map_err(rig::http_client::Error::from)?;
572
573 builder = builder.header(header::AUTHORIZATION, auth_header);
574 builder = builder.header("Content-Type", "application/json");
575
576 let req = builder
577 .body(body)
578 .map_err(|e| CompletionError::HttpError(e.into()))?;
579
580 let span = if tracing::Span::current().is_disabled() {
581 info_span!(
582 target: "rig::completions",
583 "chat_streaming",
584 gen_ai.operation.name = "chat_streaming",
585 gen_ai.provider.name = "galadriel",
586 gen_ai.request.model = self.model,
587 gen_ai.system_instructions = preamble,
588 gen_ai.response.id = tracing::field::Empty,
589 gen_ai.response.model = tracing::field::Empty,
590 gen_ai.usage.output_tokens = tracing::field::Empty,
591 gen_ai.usage.input_tokens = tracing::field::Empty,
592 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
593 gen_ai.output.messages = tracing::field::Empty,
594 )
595 } else {
596 tracing::Span::current()
597 };
598
599 send_compatible_streaming_request(self.client.http_client.clone(), req)
600 .instrument(span)
601 .await
602 }
603}