1use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
2use rig::completion::{CompletionError, CompletionRequest};
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, client, completion, http_client, message};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8
9use rig::providers::openai::send_compatible_streaming_request;
10use rig::streaming::StreamingCompletionResponse;
11
12use crate::json_utils;
13
14const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
18
19#[derive(Clone, Debug)]
20pub struct Client {
21 base_url: String,
22 http_client: reqwest::Client,
23}
24
25impl Client {
26 pub fn new(api_key: &str) -> Self {
27 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
28 }
29
30 pub fn from_url(api_key: &str, base_url: &str) -> Self {
31 Self {
32 base_url: base_url.to_string(),
33 http_client: reqwest::Client::builder()
34 .default_headers({
35 let mut headers = reqwest::header::HeaderMap::new();
36 headers.insert(
37 "Authorization",
38 format!("Bearer {api_key}")
39 .parse()
40 .expect("Bearer token should parse"),
41 );
42 headers
43 })
44 .build()
45 .expect("bigmodel reqwest client should build"),
46 }
47 }
48
49 fn post(&self, path: &str) -> reqwest::RequestBuilder {
50 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
51 self.http_client.post(url)
52 }
53
54 pub fn completion_model(&self, model: &str) -> CompletionModel {
55 CompletionModel::new(self.clone(), model)
56 }
57
58 }
66
67impl ProviderClient for Client {
68 fn from_env() -> Self
69 where
70 Self: Sized,
71 {
72 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
73 Self::new(&api_key)
74 }
75
76 fn from_val(input: ProviderValue) -> Self
77 where
78 Self: Sized,
79 {
80 let client::ProviderValue::Simple(api_key) = input else {
81 panic!("Incorrect provider value type")
82 };
83 Self::new(&api_key)
84 }
85}
86
87impl AsTranscription for Client {}
88
89impl AsEmbeddings for Client {}
90
91impl CompletionClient for Client {
92 type CompletionModel = CompletionModel;
93
94 fn completion_model(&self, model: &str) -> Self::CompletionModel {
95 CompletionModel::new(self.clone(), model)
96 }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101 message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107 Ok(T),
108 Err(ApiErrorResponse),
109}
110
111pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
115pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
116
117#[derive(Debug, Deserialize, Serialize)]
118#[serde(rename_all = "camelCase")]
119pub struct CompletionResponse {
120 pub choices: Vec<Choice>,
121 pub created: i64,
122 pub id: String,
123 pub model: String,
124 #[serde(rename = "request_id")]
125 pub request_id: String,
126 pub usage: Usage,
127}
128
129#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
130#[serde(tag = "role", rename_all = "lowercase")]
131pub enum Message {
132 User {
133 content: String,
134 },
135 Assistant {
136 content: Option<String>,
137 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
138 tool_calls: Vec<ToolCall>,
139 },
140 System {
141 content: String,
142 },
143 #[serde(rename = "tool")]
144 ToolResult {
145 tool_call_id: String,
146 content: String,
147 },
148}
149
150impl Message {
151 pub fn system(content: &str) -> Message {
152 Message::System {
153 content: content.to_owned(),
154 }
155 }
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct ToolResultContent {
160 text: String,
161}
162impl TryFrom<message::ToolResultContent> for ToolResultContent {
163 type Error = MessageError;
164 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
165 let message::ToolResultContent::Text(Text { text }) = value else {
166 return Err(MessageError::ConversionError(
167 "Non-text tool results not supported".into(),
168 ));
169 };
170
171 Ok(Self { text })
172 }
173}
174
175impl TryFrom<message::Message> for Message {
176 type Error = MessageError;
177
178 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
179 Ok(match message {
180 message::Message::User { content } => {
181 let mut texts = Vec::new();
182 let mut images = Vec::new();
183
184 for uc in content.into_iter() {
185 match uc {
186 message::UserContent::Text(message::Text { text }) => texts.push(text),
187 message::UserContent::Image(img) => images.push(img.data),
188 message::UserContent::ToolResult(result) => {
189 let content = result
190 .content
191 .into_iter()
192 .map(ToolResultContent::try_from)
193 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
194
195 let content = OneOrMany::many(content).map_err(|x| {
196 MessageError::ConversionError(format!(
197 "Couldn't make a OneOrMany from a list of tool results: {x}"
198 ))
199 })?;
200
201 return Ok(Message::ToolResult {
202 tool_call_id: result.id,
203 content: content.first().text,
204 });
205 }
206 _ => {}
207 }
208 }
209
210 let collapsed_content = texts.join(" ");
211
212 Message::User {
213 content: collapsed_content,
214 }
215 }
216 message::Message::Assistant { content, .. } => {
217 let mut texts = Vec::new();
218 let mut tool_calls = Vec::new();
219
220 for ac in content.into_iter() {
221 match ac {
222 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
223 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
224 _ => {}
225 }
226 }
227
228 let collapsed_content = texts.join(" ");
229
230 Message::Assistant {
231 content: Some(collapsed_content),
232 tool_calls,
233 }
234 }
235 })
236 }
237}
238
239impl From<message::ToolResult> for Message {
240 fn from(tool_result: message::ToolResult) -> Self {
241 let content = match tool_result.content.first() {
242 message::ToolResultContent::Text(text) => text.text,
243 message::ToolResultContent::Image(_) => String::from("[Image]"),
244 };
245
246 Message::ToolResult {
247 tool_call_id: tool_result.id,
248 content,
249 }
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254#[serde(rename_all = "camelCase")]
255pub struct ToolCall {
256 pub function: CallFunction,
257 pub id: String,
258 pub index: usize,
259 #[serde(default)]
260 pub r#type: ToolType,
261}
262
263impl From<message::ToolCall> for ToolCall {
264 fn from(tool_call: message::ToolCall) -> Self {
265 Self {
266 id: tool_call.id,
267 index: 0,
268 r#type: ToolType::Function,
269 function: CallFunction {
270 name: tool_call.function.name,
271 arguments: tool_call.function.arguments,
272 },
273 }
274 }
275}
276
277#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
278#[serde(rename_all = "lowercase")]
279pub enum ToolType {
280 #[default]
281 Function,
282}
283
284#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
285pub struct CallFunction {
286 pub name: String,
287 #[serde(with = "json_utils::stringified_json")]
288 pub arguments: serde_json::Value,
289}
290
291#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
292#[serde(rename_all = "lowercase")]
293pub enum Role {
294 System,
295 User,
296 Assistant,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300#[serde(rename_all = "camelCase")]
301pub struct Choice {
302 #[serde(rename = "finish_reason")]
303 pub finish_reason: String,
304 pub index: i64,
305 pub message: Message,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
309#[serde(rename_all = "camelCase")]
310pub struct Usage {
311 #[serde(rename = "completion_tokens")]
312 pub completion_tokens: i64,
313 #[serde(rename = "prompt_tokens")]
314 pub prompt_tokens: i64,
315 #[serde(rename = "total_tokens")]
316 pub total_tokens: i64,
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320 type Error = CompletionError;
321
322 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323 let choice = response.choices.first().ok_or_else(|| {
324 CompletionError::ResponseError("Response contained no choices".to_owned())
325 })?;
326
327 match &choice.message {
328 Message::Assistant {
329 tool_calls,
330 content,
331 } => {
332 if !tool_calls.is_empty() {
333 let tool_result = tool_calls
334 .iter()
335 .map(|call| {
336 completion::AssistantContent::tool_call(
337 &call.function.name,
338 &call.function.name,
339 call.function.arguments.clone(),
340 )
341 })
342 .collect::<Vec<_>>();
343
344 let choice = OneOrMany::many(tool_result).map_err(|_| {
345 CompletionError::ResponseError(
346 "Response contained no message or tool call (empty)".to_owned(),
347 )
348 })?;
349 let usage = completion::Usage {
350 input_tokens: response.usage.prompt_tokens as u64,
351 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
352 as u64,
353 total_tokens: response.usage.total_tokens as u64,
354 };
355 tracing::debug!("response choices: {:?}: ", choice);
356 Ok(completion::CompletionResponse {
357 choice,
358 usage,
359 raw_response: response,
360 })
361 } else {
362 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
363 text: content.clone().unwrap_or_else(|| "".to_owned()),
364 }));
365 let usage = completion::Usage {
366 input_tokens: response.usage.prompt_tokens as u64,
367 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
368 as u64,
369 total_tokens: response.usage.total_tokens as u64,
370 };
371 Ok(completion::CompletionResponse {
372 choice,
373 usage,
374 raw_response: response,
375 })
376 }
377 }
378 _ => Err(CompletionError::ResponseError(
380 "Chat response does not include an assistant message".into(),
381 )),
382 }
383 }
384}
385
386#[derive(Clone)]
387pub struct CompletionModel {
388 client: Client,
389 pub model: String,
390}
391
392#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
394#[serde(rename_all = "camelCase")]
395pub struct CustomFunctionDefinition {
396 #[serde(rename = "type")]
397 pub type_field: String,
398 pub function: Function,
399}
400
401#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
402#[serde(rename_all = "camelCase")]
403pub struct Function {
404 pub name: String,
405 pub description: String,
406 pub parameters: serde_json::Value,
407}
408
409impl CompletionModel {
410 pub fn new(client: Client, model: &str) -> Self {
411 Self {
412 client,
413 model: model.to_string(),
414 }
415 }
416
417 fn create_completion_request(
418 &self,
419 completion_request: CompletionRequest,
420 ) -> Result<Value, CompletionError> {
421 let mut partial_history = vec![];
423 if let Some(docs) = completion_request.normalized_documents() {
424 partial_history.push(docs);
425 }
426 partial_history.extend(completion_request.chat_history);
427
428 let mut full_history: Vec<Message> = completion_request
430 .preamble
431 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
432
433 full_history.extend(
435 partial_history
436 .into_iter()
437 .map(message::Message::try_into)
438 .collect::<Result<Vec<Message>, _>>()?,
439 );
440
441 let request = if completion_request.tools.is_empty() {
442 json!({
443 "model": self.model,
444 "messages": full_history,
445 "temperature": completion_request.temperature,
446 })
447 } else {
448 let tools = completion_request
450 .tools
451 .into_iter()
452 .map(|item| {
453 let custom_function = Function {
454 name: item.name,
455 description: item.description,
456 parameters: item.parameters,
457 };
458 CustomFunctionDefinition {
459 type_field: "function".to_string(),
460 function: custom_function,
461 }
462 })
463 .collect::<Vec<_>>();
464
465 tracing::debug!("tools: {:?}", tools);
466
467 json!({
468 "model": self.model,
469 "messages": full_history,
470 "temperature": completion_request.temperature,
471 "tools": tools,
472 "tool_choice": "auto",
473 })
474 };
475
476 let request = if let Some(params) = completion_request.additional_params {
477 json_utils::merge(request, params)
478 } else {
479 request
480 };
481
482 Ok(request)
483 }
484}
485
486impl completion::CompletionModel for CompletionModel {
488 type Response = CompletionResponse;
489 type StreamingResponse = openai::StreamingCompletionResponse;
490
491 async fn completion(
492 &self,
493 completion_request: CompletionRequest,
494 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
495 tracing::debug!("create_completion_request========");
496 let request = self.create_completion_request(completion_request)?;
497
498 tracing::debug!(
499 "request: \r\n {}",
500 serde_json::to_string_pretty(&request).unwrap()
501 );
502
503 let response = self
504 .client
505 .post("/chat/completions")
506 .json(&request)
507 .send()
508 .await
509 .map_err(|e| http_client::Error::Instance(e.into()))?;
510
511 if response.status().is_success() {
512 let data: Value = response.json().await.expect("api error");
513 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
514 let data: ApiResponse<CompletionResponse> =
515 serde_json::from_value(data).expect("deserialize completion response");
516 match data {
517 ApiResponse::Ok(response) => {
518 tracing::info!(target: "rig",
519 "bigmodel completion token usage: {:?}",
520 response.usage
521 );
522 response.try_into()
523 }
524 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
525 }
526 } else {
527 Err(CompletionError::ProviderError(
528 response
529 .text()
530 .await
531 .map_err(|e| http_client::Error::Instance(e.into()))?,
532 ))
533 }
534 }
535
536 async fn stream(
537 &self,
538 request: CompletionRequest,
539 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
540 let mut request = self.create_completion_request(request)?;
541
542 request = json_utils::merge(request, json!({"stream": true}));
543
544 let builder = self.client.post("/chat/completions").json(&request);
545
546 send_compatible_streaming_request(builder).await
547 }
548}