1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message, client};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 base_url: String,
24 http_client: reqwest::Client,
25}
26
27impl Client {
28 pub fn new(api_key: &str) -> Self {
29 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30 }
31
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 http_client: reqwest::Client::builder()
36 .default_headers({
37 let mut headers = reqwest::header::HeaderMap::new();
38 headers.insert(
39 "Authorization",
40 format!("Bearer {api_key}")
41 .parse()
42 .expect("Bearer token should parse"),
43 );
44 headers
45 })
46 .build()
47 .expect("bigmodel reqwest client should build"),
48 }
49 }
50
51
52
53 fn post(&self, path: &str) -> reqwest::RequestBuilder {
54 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55 self.http_client.post(url)
56 }
57
58 pub fn completion_model(&self, model: &str) -> CompletionModel {
59 CompletionModel::new(self.clone(), model)
60 }
61
62
63
64 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66 &self,
67 model: &str,
68 ) -> ExtractorBuilder<T, CompletionModel> {
69 ExtractorBuilder::new(self.completion_model(model))
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self
75 where
76 Self: Sized
77 {
78 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79 Self::new(&api_key)
80 }
81
82 fn from_val(input: ProviderValue) -> Self
83 where
84 Self: Sized
85 {
86 let client::ProviderValue::Simple(api_key) = input else {
87 panic!("Incorrect provider value type")
88 };
89 Self::new(&api_key)
90 }
91}
92
93impl AsTranscription for Client {}
94
95impl AsEmbeddings for Client {}
96
97
98
99impl CompletionClient for Client {
100 type CompletionModel = CompletionModel;
101
102 fn completion_model(&self, model: &str) -> Self::CompletionModel {
103 CompletionModel::new(self.clone(), model)
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
123
124#[derive(Debug, Deserialize,Serialize)]
125#[serde(rename_all = "camelCase")]
126pub struct CompletionResponse {
127 pub choices: Vec<Choice>,
128 pub created: i64,
129 pub id: String,
130 pub model: String,
131 #[serde(rename = "request_id")]
132 pub request_id: String,
133 pub usage: Usage,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "role", rename_all = "lowercase")]
138pub enum Message {
139 User {
140 content: String,
141 },
142 Assistant {
143 content: Option<String>,
144 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
145 tool_calls: Vec<ToolCall>,
146 },
147 System {
148 content: String,
149 },
150 #[serde(rename = "tool")]
151 ToolResult {
152 tool_call_id: String,
153 content: String,
154 },
155}
156
157impl Message {
158 pub fn system(content: &str) -> Message {
159 Message::System {
160 content: content.to_owned(),
161 }
162 }
163}
164
165#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
166pub struct ToolResultContent {
167 text: String,
168}
169impl TryFrom<message::ToolResultContent> for ToolResultContent {
170 type Error = MessageError;
171 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
172 let message::ToolResultContent::Text(Text { text }) = value else {
173 return Err(MessageError::ConversionError(
174 "Non-text tool results not supported".into(),
175 ));
176 };
177
178 Ok(Self { text })
179 }
180}
181
182impl TryFrom<message::Message> for Message {
183 type Error = MessageError;
184
185 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
186 Ok(match message {
187 message::Message::User { content } => {
188 let mut texts = Vec::new();
189 let mut images = Vec::new();
190
191 for uc in content.into_iter() {
192 match uc {
193 message::UserContent::Text(message::Text { text }) => texts.push(text),
194 message::UserContent::Image(img) => images.push(img.data),
195 message::UserContent::ToolResult(result) => {
196 let content = result
197 .content
198 .into_iter()
199 .map(ToolResultContent::try_from)
200 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
201
202 let content = OneOrMany::many(content).map_err(|x| {
203 MessageError::ConversionError(format!(
204 "Couldn't make a OneOrMany from a list of tool results: {x}"
205 ))
206 })?;
207
208 return Ok(Message::ToolResult {
209 tool_call_id: result.id,
210 content: content.first().text,
211 });
212 }
213 _ => {}
214 }
215 }
216
217 let collapsed_content = texts.join(" ");
218
219 Message::User {
220 content: collapsed_content,
221 }
222 }
223 message::Message::Assistant { content, .. } => {
224 let mut texts = Vec::new();
225 let mut tool_calls = Vec::new();
226
227 for ac in content.into_iter() {
228 match ac {
229 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
230 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
231 _ => {}
232 }
233 }
234
235 let collapsed_content = texts.join(" ");
236
237 Message::Assistant {
238 content: Some(collapsed_content),
239 tool_calls,
240 }
241 }
242 })
243 }
244}
245
246impl From<message::ToolResult> for Message {
247 fn from(tool_result: message::ToolResult) -> Self {
248 let content = match tool_result.content.first() {
249 message::ToolResultContent::Text(text) => text.text,
250 message::ToolResultContent::Image(_) => String::from("[Image]"),
251 };
252
253 Message::ToolResult {
254 tool_call_id: tool_result.id,
255 content,
256 }
257 }
258}
259
260#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
261#[serde(rename_all = "camelCase")]
262pub struct ToolCall {
263 pub function: CallFunction,
264 pub id: String,
265 pub index: usize,
266 #[serde(default)]
267 pub r#type: ToolType,
268}
269
270impl From<message::ToolCall> for ToolCall {
271 fn from(tool_call: message::ToolCall) -> Self {
272 Self {
273 id: tool_call.id,
274 index: 0,
275 r#type: ToolType::Function,
276 function: CallFunction {
277 name: tool_call.function.name,
278 arguments: tool_call.function.arguments,
279 },
280 }
281 }
282}
283
284#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
285#[serde(rename_all = "lowercase")]
286pub enum ToolType {
287 #[default]
288 Function,
289}
290
291#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
292pub struct CallFunction {
293 pub name: String,
294 #[serde(with = "json_utils::stringified_json")]
295 pub arguments: serde_json::Value,
296}
297
298#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
299#[serde(rename_all = "lowercase")]
300pub enum Role {
301 System,
302 User,
303 Assistant,
304}
305
306#[derive(Debug, Serialize, Deserialize)]
307#[serde(rename_all = "camelCase")]
308pub struct Choice {
309 #[serde(rename = "finish_reason")]
310 pub finish_reason: String,
311 pub index: i64,
312 pub message: Message,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
316#[serde(rename_all = "camelCase")]
317pub struct Usage {
318 #[serde(rename = "completion_tokens")]
319 pub completion_tokens: i64,
320 #[serde(rename = "prompt_tokens")]
321 pub prompt_tokens: i64,
322 #[serde(rename = "total_tokens")]
323 pub total_tokens: i64,
324}
325
326impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
327 type Error = CompletionError;
328
329 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
330 let choice = response.choices.first().ok_or_else(|| {
331 CompletionError::ResponseError("Response contained no choices".to_owned())
332 })?;
333
334 match &choice.message {
335 Message::Assistant {
336 tool_calls,
337 content,
338 } => {
339 if !tool_calls.is_empty() {
340 let tool_result = tool_calls
341 .iter()
342 .map(|call| {
343 completion::AssistantContent::tool_call(
344 &call.function.name,
345 &call.function.name,
346 call.function.arguments.clone(),
347 )
348 })
349 .collect::<Vec<_>>();
350
351 let choice = OneOrMany::many(tool_result).map_err(|_| {
352 CompletionError::ResponseError(
353 "Response contained no message or tool call (empty)".to_owned(),
354 )
355 })?;
356 let usage = completion::Usage {
357 input_tokens: response.usage.prompt_tokens as u64,
358 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
359 as u64,
360 total_tokens: response.usage.total_tokens as u64,
361 };
362 tracing::debug!("response choices: {:?}: ", choice);
363 Ok(completion::CompletionResponse {
364 choice,
365 usage,
366 raw_response: response,
367 })
368 } else {
369 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
370 text: content.clone().unwrap_or_else(|| "".to_owned()),
371 }));
372 let usage = completion::Usage {
373 input_tokens: response.usage.prompt_tokens as u64,
374 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
375 as u64,
376 total_tokens: response.usage.total_tokens as u64,
377 };
378 Ok(completion::CompletionResponse {
379 choice,
380 usage,
381 raw_response: response,
382 })
383 }
384 }
385 _ => Err(CompletionError::ResponseError(
387 "Chat response does not include an assistant message".into(),
388 )),
389 }
390 }
391}
392
393#[derive(Clone)]
394pub struct CompletionModel {
395 client: Client,
396 pub model: String,
397}
398
399
400
401#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
403#[serde(rename_all = "camelCase")]
404pub struct CustomFunctionDefinition {
405 #[serde(rename = "type")]
406 pub type_field: String,
407 pub function: Function,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
411#[serde(rename_all = "camelCase")]
412pub struct Function {
413 pub name: String,
414 pub description: String,
415 pub parameters: serde_json::Value,
416}
417
418impl CompletionModel {
419 pub fn new(client: Client, model: &str) -> Self {
420 Self {
421 client,
422 model: model.to_string(),
423 }
424 }
425
426 fn create_completion_request(
427 &self,
428 completion_request: CompletionRequest,
429 ) -> Result<Value, CompletionError> {
430 let mut partial_history = vec![];
432 if let Some(docs) = completion_request.normalized_documents() {
433 partial_history.push(docs);
434 }
435 partial_history.extend(completion_request.chat_history);
436
437 let mut full_history: Vec<Message> = completion_request
439 .preamble
440 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
441
442 full_history.extend(
444 partial_history
445 .into_iter()
446 .map(message::Message::try_into)
447 .collect::<Result<Vec<Message>, _>>()?,
448 );
449
450 let request = if completion_request.tools.is_empty() {
451 json!({
452 "model": self.model,
453 "messages": full_history,
454 "temperature": completion_request.temperature,
455 })
456 } else {
457 let tools = completion_request
459 .tools
460 .into_iter()
461 .map(|item| {
462 let custom_function = Function {
463 name: item.name,
464 description: item.description,
465 parameters: item.parameters,
466 };
467 CustomFunctionDefinition {
468 type_field: "function".to_string(),
469 function: custom_function,
470 }
471 })
472 .collect::<Vec<_>>();
473
474 tracing::debug!("tools: {:?}", tools);
475
476 json!({
477 "model": self.model,
478 "messages": full_history,
479 "temperature": completion_request.temperature,
480 "tools": tools,
481 "tool_choice": "auto",
482 })
483 };
484
485 let request = if let Some(params) = completion_request.additional_params {
486 json_utils::merge(request, params)
487 } else {
488 request
489 };
490
491 Ok(request)
492 }
493}
494
495impl completion::CompletionModel for CompletionModel {
497 type Response = CompletionResponse;
498 type StreamingResponse = openai::StreamingCompletionResponse;
499
500 async fn completion(
501 &self,
502 completion_request: CompletionRequest,
503 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
504 tracing::debug!("create_completion_request========");
505 let request = self.create_completion_request(completion_request)?;
506
507 tracing::debug!(
508 "request: \r\n {}",
509 serde_json::to_string_pretty(&request).unwrap()
510 );
511
512 let response = self
513 .client
514 .post("/chat/completions")
515 .json(&request)
516 .send()
517 .await?;
518
519 if response.status().is_success() {
520 let data: Value = response.json().await.expect("api error");
521 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
522 let data: ApiResponse<CompletionResponse> =
523 serde_json::from_value(data).expect("deserialize completion response");
524 match data {
525 ApiResponse::Ok(response) => {
526 tracing::info!(target: "rig",
527 "bigmodel completion token usage: {:?}",
528 response.usage
529 );
530 response.try_into()
531 }
532 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
533 }
534 } else {
535 Err(CompletionError::ProviderError(response.text().await?))
536 }
537 }
538
539 async fn stream(
540 &self,
541 request: CompletionRequest,
542 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
543 let mut request = self.create_completion_request(request)?;
544
545 request = json_utils::merge(request, json!({"stream": true}));
546
547 let builder = self.client.post("/chat/completions").json(&request);
548
549 send_compatible_streaming_request(builder).await
550 }
551}
552
553