1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message, client};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 base_url: String,
24 http_client: reqwest::Client,
25}
26
27impl Client {
28 pub fn new(api_key: &str) -> Self {
29 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30 }
31
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 http_client: reqwest::Client::builder()
36 .default_headers({
37 let mut headers = reqwest::header::HeaderMap::new();
38 headers.insert(
39 "Authorization",
40 format!("Bearer {api_key}")
41 .parse()
42 .expect("Bearer token should parse"),
43 );
44 headers
45 })
46 .build()
47 .expect("bigmodel reqwest client should build"),
48 }
49 }
50
51
52
53 fn post(&self, path: &str) -> reqwest::RequestBuilder {
54 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55 self.http_client.post(url)
56 }
57
58 pub fn completion_model(&self, model: &str) -> CompletionModel {
59 CompletionModel::new(self.clone(), model)
60 }
61
62
63
64 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66 &self,
67 model: &str,
68 ) -> ExtractorBuilder<T, CompletionModel> {
69 ExtractorBuilder::new(self.completion_model(model))
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self
75 where
76 Self: Sized
77 {
78 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79 Self::new(&api_key)
80 }
81
82 fn from_val(input: ProviderValue) -> Self
83 where
84 Self: Sized
85 {
86 let client::ProviderValue::Simple(api_key) = input else {
87 panic!("Incorrect provider value type")
88 };
89 Self::new(&api_key)
90 }
91}
92
93impl AsTranscription for Client {}
94
95impl AsEmbeddings for Client {}
96
97
98
99impl CompletionClient for Client {
100 type CompletionModel = CompletionModel;
101
102 fn completion_model(&self, model: &str) -> Self::CompletionModel {
103 CompletionModel::new(self.clone(), model)
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
123
124#[derive(Debug, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct CompletionResponse {
127 pub choices: Vec<Choice>,
128 pub created: i64,
129 pub id: String,
130 pub model: String,
131 #[serde(rename = "request_id")]
132 pub request_id: String,
133 pub usage: Usage,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "role", rename_all = "lowercase")]
138pub enum Message {
139 User {
140 content: String,
141 },
142 Assistant {
143 content: Option<String>,
144 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
145 tool_calls: Vec<ToolCall>,
146 },
147 System {
148 content: String,
149 },
150 #[serde(rename = "tool")]
151 ToolResult {
152 tool_call_id: String,
153 content: String,
154 },
155}
156
157impl Message {
158 pub fn system(content: &str) -> Message {
159 Message::System {
160 content: content.to_owned(),
161 }
162 }
163}
164
165#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
166pub struct ToolResultContent {
167 text: String,
168}
169impl TryFrom<message::ToolResultContent> for ToolResultContent {
170 type Error = MessageError;
171 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
172 let message::ToolResultContent::Text(Text { text }) = value else {
173 return Err(MessageError::ConversionError(
174 "Non-text tool results not supported".into(),
175 ));
176 };
177
178 Ok(Self { text })
179 }
180}
181
182impl TryFrom<message::Message> for Message {
183 type Error = MessageError;
184
185 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
186 Ok(match message {
187 message::Message::User { content } => {
188 let mut texts = Vec::new();
189 let mut images = Vec::new();
190
191 for uc in content.into_iter() {
192 match uc {
193 message::UserContent::Text(message::Text { text }) => texts.push(text),
194 message::UserContent::Image(img) => images.push(img.data),
195 message::UserContent::ToolResult(result) => {
196 let content = result
197 .content
198 .into_iter()
199 .map(ToolResultContent::try_from)
200 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
201
202 let content = OneOrMany::many(content).map_err(|x| {
203 MessageError::ConversionError(format!(
204 "Couldn't make a OneOrMany from a list of tool results: {x}"
205 ))
206 })?;
207
208 return Ok(Message::ToolResult {
209 tool_call_id: result.id,
210 content: content.first().text,
211 });
212 }
213 _ => {}
214 }
215 }
216
217 let collapsed_content = texts.join(" ");
218
219 Message::User {
220 content: collapsed_content,
221 }
222 }
223 message::Message::Assistant { content, .. } => {
224 let mut texts = Vec::new();
225 let mut tool_calls = Vec::new();
226
227 for ac in content.into_iter() {
228 match ac {
229 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
230 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
231 }
232 }
233
234 let collapsed_content = texts.join(" ");
235
236 Message::Assistant {
237 content: Some(collapsed_content),
238 tool_calls,
239 }
240 }
241 })
242 }
243}
244
245impl From<message::ToolResult> for Message {
246 fn from(tool_result: message::ToolResult) -> Self {
247 let content = match tool_result.content.first() {
248 message::ToolResultContent::Text(text) => text.text,
249 message::ToolResultContent::Image(_) => String::from("[Image]"),
250 };
251
252 Message::ToolResult {
253 tool_call_id: tool_result.id,
254 content,
255 }
256 }
257}
258
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct ToolCall {
262 pub function: CallFunction,
263 pub id: String,
264 pub index: usize,
265 #[serde(default)]
266 pub r#type: ToolType,
267}
268
269impl From<message::ToolCall> for ToolCall {
270 fn from(tool_call: message::ToolCall) -> Self {
271 Self {
272 id: tool_call.id,
273 index: 0,
274 r#type: ToolType::Function,
275 function: CallFunction {
276 name: tool_call.function.name,
277 arguments: tool_call.function.arguments,
278 },
279 }
280 }
281}
282
283#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
284#[serde(rename_all = "lowercase")]
285pub enum ToolType {
286 #[default]
287 Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct CallFunction {
292 pub name: String,
293 #[serde(with = "json_utils::stringified_json")]
294 pub arguments: serde_json::Value,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
298#[serde(rename_all = "lowercase")]
299pub enum Role {
300 System,
301 User,
302 Assistant,
303}
304
305#[derive(Debug, Serialize, Deserialize)]
306#[serde(rename_all = "camelCase")]
307pub struct Choice {
308 #[serde(rename = "finish_reason")]
309 pub finish_reason: String,
310 pub index: i64,
311 pub message: Message,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct Usage {
317 #[serde(rename = "completion_tokens")]
318 pub completion_tokens: i64,
319 #[serde(rename = "prompt_tokens")]
320 pub prompt_tokens: i64,
321 #[serde(rename = "total_tokens")]
322 pub total_tokens: i64,
323}
324
325impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
326 type Error = CompletionError;
327
328 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
329 let choice = response.choices.first().ok_or_else(|| {
330 CompletionError::ResponseError("Response contained no choices".to_owned())
331 })?;
332
333 match &choice.message {
334 Message::Assistant {
335 tool_calls,
336 content,
337 } => {
338 if !tool_calls.is_empty() {
339 let tool_result = tool_calls
340 .iter()
341 .map(|call| {
342 completion::AssistantContent::tool_call(
343 &call.function.name,
344 &call.function.name,
345 call.function.arguments.clone(),
346 )
347 })
348 .collect::<Vec<_>>();
349
350 let choice = OneOrMany::many(tool_result).map_err(|_| {
351 CompletionError::ResponseError(
352 "Response contained no message or tool call (empty)".to_owned(),
353 )
354 })?;
355 let usage = completion::Usage {
356 input_tokens: response.usage.prompt_tokens as u64,
357 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
358 as u64,
359 total_tokens: response.usage.total_tokens as u64,
360 };
361 tracing::debug!("response choices: {:?}: ", choice);
362 Ok(completion::CompletionResponse {
363 choice,
364 usage,
365 raw_response: response,
366 })
367 } else {
368 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
369 text: content.clone().unwrap_or_else(|| "".to_owned()),
370 }));
371 let usage = completion::Usage {
372 input_tokens: response.usage.prompt_tokens as u64,
373 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
374 as u64,
375 total_tokens: response.usage.total_tokens as u64,
376 };
377 Ok(completion::CompletionResponse {
378 choice,
379 usage,
380 raw_response: response,
381 })
382 }
383 }
384 _ => Err(CompletionError::ResponseError(
386 "Chat response does not include an assistant message".into(),
387 )),
388 }
389 }
390}
391
392#[derive(Clone)]
393pub struct CompletionModel {
394 client: Client,
395 pub model: String,
396}
397
398
399
400#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
402#[serde(rename_all = "camelCase")]
403pub struct CustomFunctionDefinition {
404 #[serde(rename = "type")]
405 pub type_field: String,
406 pub function: Function,
407}
408
409#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
410#[serde(rename_all = "camelCase")]
411pub struct Function {
412 pub name: String,
413 pub description: String,
414 pub parameters: serde_json::Value,
415}
416
417impl CompletionModel {
418 pub fn new(client: Client, model: &str) -> Self {
419 Self {
420 client,
421 model: model.to_string(),
422 }
423 }
424
425 fn create_completion_request(
426 &self,
427 completion_request: CompletionRequest,
428 ) -> Result<Value, CompletionError> {
429 let mut partial_history = vec![];
431 if let Some(docs) = completion_request.normalized_documents() {
432 partial_history.push(docs);
433 }
434 partial_history.extend(completion_request.chat_history);
435
436 let mut full_history: Vec<Message> = completion_request
438 .preamble
439 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
440
441 full_history.extend(
443 partial_history
444 .into_iter()
445 .map(message::Message::try_into)
446 .collect::<Result<Vec<Message>, _>>()?,
447 );
448
449 let request = if completion_request.tools.is_empty() {
450 json!({
451 "model": self.model,
452 "messages": full_history,
453 "temperature": completion_request.temperature,
454 })
455 } else {
456 let tools = completion_request
458 .tools
459 .into_iter()
460 .map(|item| {
461 let custom_function = Function {
462 name: item.name,
463 description: item.description,
464 parameters: item.parameters,
465 };
466 CustomFunctionDefinition {
467 type_field: "function".to_string(),
468 function: custom_function,
469 }
470 })
471 .collect::<Vec<_>>();
472
473 tracing::debug!("tools: {:?}", tools);
474
475 json!({
476 "model": self.model,
477 "messages": full_history,
478 "temperature": completion_request.temperature,
479 "tools": tools,
480 "tool_choice": "auto",
481 })
482 };
483
484 let request = if let Some(params) = completion_request.additional_params {
485 json_utils::merge(request, params)
486 } else {
487 request
488 };
489
490 Ok(request)
491 }
492}
493
494impl completion::CompletionModel for CompletionModel {
496 type Response = CompletionResponse;
497 type StreamingResponse = openai::StreamingCompletionResponse;
498
499 async fn completion(
500 &self,
501 completion_request: CompletionRequest,
502 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
503 tracing::debug!("create_completion_request========");
504 let request = self.create_completion_request(completion_request)?;
505
506 tracing::debug!(
507 "request: \r\n {}",
508 serde_json::to_string_pretty(&request).unwrap()
509 );
510
511 let response = self
512 .client
513 .post("/chat/completions")
514 .json(&request)
515 .send()
516 .await?;
517
518 if response.status().is_success() {
519 let data: Value = response.json().await.expect("api error");
520 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
521 let data: ApiResponse<CompletionResponse> =
522 serde_json::from_value(data).expect("deserialize completion response");
523 match data {
524 ApiResponse::Ok(response) => {
525 tracing::info!(target: "rig",
526 "bigmodel completion token usage: {:?}",
527 response.usage
528 );
529 response.try_into()
530 }
531 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
532 }
533 } else {
534 Err(CompletionError::ProviderError(response.text().await?))
535 }
536 }
537
538 async fn stream(
539 &self,
540 request: CompletionRequest,
541 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
542 let mut request = self.create_completion_request(request)?;
543
544 request = json_utils::merge(request, json!({"stream": true}));
545
546 let builder = self.client.post("/chat/completions").json(&request);
547
548 send_compatible_streaming_request(builder).await
549 }
550}
551
552