1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 base_url: String,
24 http_client: reqwest::Client,
25}
26
27impl Client {
28 pub fn new(api_key: &str) -> Self {
29 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30 }
31
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 http_client: reqwest::Client::builder()
36 .default_headers({
37 let mut headers = reqwest::header::HeaderMap::new();
38 headers.insert(
39 "Authorization",
40 format!("Bearer {api_key}")
41 .parse()
42 .expect("Bearer token should parse"),
43 );
44 headers
45 })
46 .build()
47 .expect("bigmodel reqwest client should build"),
48 }
49 }
50
51
52
53 fn post(&self, path: &str) -> reqwest::RequestBuilder {
54 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55 self.http_client.post(url)
56 }
57
58 pub fn completion_model(&self, model: &str) -> CompletionModel {
59 CompletionModel::new(self.clone(), model)
60 }
61
62
63
64 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
66 &self,
67 model: &str,
68 ) -> ExtractorBuilder<T, CompletionModel> {
69 ExtractorBuilder::new(self.completion_model(model))
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self
75 where
76 Self: Sized
77 {
78 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
79 Self::new(&api_key)
80 }
81}
82
83impl AsTranscription for Client {}
84
85impl AsEmbeddings for Client {}
86
87
88
89impl CompletionClient for Client {
90 type CompletionModel = CompletionModel;
91
92 fn completion_model(&self, model: &str) -> Self::CompletionModel {
93 CompletionModel::new(self.clone(), model)
94 }
95}
96
97#[derive(Debug, Deserialize)]
98struct ApiErrorResponse {
99 message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104enum ApiResponse<T> {
105 Ok(T),
106 Err(ApiErrorResponse),
107}
108
109pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
113
114#[derive(Debug, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub struct CompletionResponse {
117 pub choices: Vec<Choice>,
118 pub created: i64,
119 pub id: String,
120 pub model: String,
121 #[serde(rename = "request_id")]
122 pub request_id: String,
123 pub usage: Usage,
124}
125
126#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum Message {
129 User {
130 content: String,
131 },
132 Assistant {
133 content: Option<String>,
134 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
135 tool_calls: Vec<ToolCall>,
136 },
137 System {
138 content: String,
139 },
140 #[serde(rename = "tool")]
141 ToolResult {
142 tool_call_id: String,
143 content: String,
144 },
145}
146
147impl Message {
148 pub fn system(content: &str) -> Message {
149 Message::System {
150 content: content.to_owned(),
151 }
152 }
153}
154
155#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
156pub struct ToolResultContent {
157 text: String,
158}
159impl TryFrom<message::ToolResultContent> for ToolResultContent {
160 type Error = MessageError;
161 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
162 let message::ToolResultContent::Text(Text { text }) = value else {
163 return Err(MessageError::ConversionError(
164 "Non-text tool results not supported".into(),
165 ));
166 };
167
168 Ok(Self { text })
169 }
170}
171
172impl TryFrom<message::Message> for Message {
173 type Error = MessageError;
174
175 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
176 Ok(match message {
177 message::Message::User { content } => {
178 let mut texts = Vec::new();
179 let mut images = Vec::new();
180
181 for uc in content.into_iter() {
182 match uc {
183 message::UserContent::Text(message::Text { text }) => texts.push(text),
184 message::UserContent::Image(img) => images.push(img.data),
185 message::UserContent::ToolResult(result) => {
186 let content = result
187 .content
188 .into_iter()
189 .map(ToolResultContent::try_from)
190 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
191
192 let content = OneOrMany::many(content).map_err(|x| {
193 MessageError::ConversionError(format!(
194 "Couldn't make a OneOrMany from a list of tool results: {x}"
195 ))
196 })?;
197
198 return Ok(Message::ToolResult {
199 tool_call_id: result.id,
200 content: content.first().text,
201 });
202 }
203 _ => {}
204 }
205 }
206
207 let collapsed_content = texts.join(" ");
208
209 Message::User {
210 content: collapsed_content,
211 }
212 }
213 message::Message::Assistant { content, .. } => {
214 let mut texts = Vec::new();
215 let mut tool_calls = Vec::new();
216
217 for ac in content.into_iter() {
218 match ac {
219 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
220 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
221 }
222 }
223
224 let collapsed_content = texts.join(" ");
225
226 Message::Assistant {
227 content: Some(collapsed_content),
228 tool_calls,
229 }
230 }
231 })
232 }
233}
234
235impl From<message::ToolResult> for Message {
236 fn from(tool_result: message::ToolResult) -> Self {
237 let content = match tool_result.content.first() {
238 message::ToolResultContent::Text(text) => text.text,
239 message::ToolResultContent::Image(_) => String::from("[Image]"),
240 };
241
242 Message::ToolResult {
243 tool_call_id: tool_result.id,
244 content,
245 }
246 }
247}
248
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250#[serde(rename_all = "camelCase")]
251pub struct ToolCall {
252 pub function: CallFunction,
253 pub id: String,
254 pub index: usize,
255 #[serde(default)]
256 pub r#type: ToolType,
257}
258
259impl From<message::ToolCall> for ToolCall {
260 fn from(tool_call: message::ToolCall) -> Self {
261 Self {
262 id: tool_call.id,
263 index: 0,
264 r#type: ToolType::Function,
265 function: CallFunction {
266 name: tool_call.function.name,
267 arguments: tool_call.function.arguments,
268 },
269 }
270 }
271}
272
273#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
274#[serde(rename_all = "lowercase")]
275pub enum ToolType {
276 #[default]
277 Function,
278}
279
280#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
281pub struct CallFunction {
282 pub name: String,
283 #[serde(with = "json_utils::stringified_json")]
284 pub arguments: serde_json::Value,
285}
286
287#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
288#[serde(rename_all = "lowercase")]
289pub enum Role {
290 System,
291 User,
292 Assistant,
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct Choice {
298 #[serde(rename = "finish_reason")]
299 pub finish_reason: String,
300 pub index: i64,
301 pub message: Message,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct Usage {
307 #[serde(rename = "completion_tokens")]
308 pub completion_tokens: i64,
309 #[serde(rename = "prompt_tokens")]
310 pub prompt_tokens: i64,
311 #[serde(rename = "total_tokens")]
312 pub total_tokens: i64,
313}
314
315impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
316 type Error = CompletionError;
317
318 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
319 let choice = response.choices.first().ok_or_else(|| {
320 CompletionError::ResponseError("Response contained no choices".to_owned())
321 })?;
322
323 match &choice.message {
324 Message::Assistant {
325 tool_calls,
326 content,
327 } => {
328 if !tool_calls.is_empty() {
329 let tool_result = tool_calls
330 .iter()
331 .map(|call| {
332 completion::AssistantContent::tool_call(
333 &call.function.name,
334 &call.function.name,
335 call.function.arguments.clone(),
336 )
337 })
338 .collect::<Vec<_>>();
339
340 let choice = OneOrMany::many(tool_result).map_err(|_| {
341 CompletionError::ResponseError(
342 "Response contained no message or tool call (empty)".to_owned(),
343 )
344 })?;
345 tracing::debug!("response choices: {:?}: ", choice);
352 Ok(completion::CompletionResponse {
353 choice,
354 raw_response: response,
356 })
357 } else {
358 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
359 text: content.clone().unwrap_or_else(|| "".to_owned()),
360 }));
361 Ok(completion::CompletionResponse {
368 choice,
369 raw_response: response,
371 })
372 }
373 }
374 _ => Err(CompletionError::ResponseError(
376 "Chat response does not include an assistant message".into(),
377 )),
378 }
379 }
380}
381
382#[derive(Clone)]
383pub struct CompletionModel {
384 client: Client,
385 pub model: String,
386}
387
388
389
390#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
392#[serde(rename_all = "camelCase")]
393pub struct CustomFunctionDefinition {
394 #[serde(rename = "type")]
395 pub type_field: String,
396 pub function: Function,
397}
398
399#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
400#[serde(rename_all = "camelCase")]
401pub struct Function {
402 pub name: String,
403 pub description: String,
404 pub parameters: serde_json::Value,
405}
406
407impl CompletionModel {
408 pub fn new(client: Client, model: &str) -> Self {
409 Self {
410 client,
411 model: model.to_string(),
412 }
413 }
414
415 fn create_completion_request(
416 &self,
417 completion_request: CompletionRequest,
418 ) -> Result<Value, CompletionError> {
419 let mut partial_history = vec![];
421 if let Some(docs) = completion_request.normalized_documents() {
422 partial_history.push(docs);
423 }
424 partial_history.extend(completion_request.chat_history);
425
426 let mut full_history: Vec<Message> = completion_request
428 .preamble
429 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
430
431 full_history.extend(
433 partial_history
434 .into_iter()
435 .map(message::Message::try_into)
436 .collect::<Result<Vec<Message>, _>>()?,
437 );
438
439 let request = if completion_request.tools.is_empty() {
440 json!({
441 "model": self.model,
442 "messages": full_history,
443 "temperature": completion_request.temperature,
444 })
445 } else {
446 let tools = completion_request
448 .tools
449 .into_iter()
450 .map(|item| {
451 let custom_function = Function {
452 name: item.name,
453 description: item.description,
454 parameters: item.parameters,
455 };
456 CustomFunctionDefinition {
457 type_field: "function".to_string(),
458 function: custom_function,
459 }
460 })
461 .collect::<Vec<_>>();
462
463 tracing::debug!("tools: {:?}", tools);
464
465 json!({
466 "model": self.model,
467 "messages": full_history,
468 "temperature": completion_request.temperature,
469 "tools": tools,
470 "tool_choice": "auto",
471 })
472 };
473
474 let request = if let Some(params) = completion_request.additional_params {
475 json_utils::merge(request, params)
476 } else {
477 request
478 };
479
480 Ok(request)
481 }
482}
483
484impl completion::CompletionModel for CompletionModel {
486 type Response = CompletionResponse;
487 type StreamingResponse = openai::StreamingCompletionResponse;
488
489 async fn completion(
490 &self,
491 completion_request: CompletionRequest,
492 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
493 tracing::debug!("create_completion_request========");
494 let request = self.create_completion_request(completion_request)?;
495
496 tracing::debug!(
497 "request: \r\n {}",
498 serde_json::to_string_pretty(&request).unwrap()
499 );
500
501 let response = self
502 .client
503 .post("/chat/completions")
504 .json(&request)
505 .send()
506 .await?;
507
508 if response.status().is_success() {
509 let data: Value = response.json().await.expect("api error");
510 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
511 let data: ApiResponse<CompletionResponse> =
512 serde_json::from_value(data).expect("deserialize completion response");
513 match data {
514 ApiResponse::Ok(response) => {
515 tracing::info!(target: "rig",
516 "bigmodel completion token usage: {:?}",
517 response.usage
518 );
519 response.try_into()
520 }
521 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
522 }
523 } else {
524 Err(CompletionError::ProviderError(response.text().await?))
525 }
526 }
527
528 async fn stream(
529 &self,
530 request: CompletionRequest,
531 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
532 let mut request = self.create_completion_request(request)?;
533
534 request = json_utils::merge(request, json!({"stream": true}));
535
536 let builder = self.client.post("/chat/completions").json(&request);
537
538 send_compatible_streaming_request(builder).await
539 }
540}
541
542