1use rig::completion::{CompletionError, CompletionRequest};
2use rig::extractor::ExtractorBuilder;
3use rig::message::{MessageError, Text};
4use rig::providers::openai;
5use rig::{OneOrMany, completion, message};
6use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 base_url: String,
24 http_client: reqwest::Client,
25}
26
27impl Client {
28 pub fn new(api_key: &str) -> Self {
29 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30 }
31
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 http_client: reqwest::Client::builder()
36 .default_headers({
37 let mut headers = reqwest::header::HeaderMap::new();
38 headers.insert(
39 "Authorization",
40 format!("Bearer {api_key}")
41 .parse()
42 .expect("Bearer token should parse"),
43 );
44 headers
45 })
46 .build()
47 .expect("bigmodel reqwest client should build"),
48 }
49 }
50
51
52
53 fn post(&self, path: &str) -> reqwest::RequestBuilder {
54 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
55 self.http_client.post(url)
56 }
57
58 pub fn completion_model(&self, model: &str) -> CompletionModel {
59 CompletionModel::new(self.clone(), model)
60 }
61
62 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
68 &self,
69 model: &str,
70 ) -> ExtractorBuilder<T, CompletionModel> {
71 ExtractorBuilder::new(self.completion_model(model))
72 }
73}
74
75impl ProviderClient for Client {
76 fn from_env() -> Self
77 where
78 Self: Sized
79 {
80 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
81 Self::new(&api_key)
82 }
83}
84
85impl AsTranscription for Client {}
86
87impl AsEmbeddings for Client {}
88
89
90
91impl CompletionClient for Client {
92 type CompletionModel = CompletionModel;
93
94 fn completion_model(&self, model: &str) -> Self::CompletionModel {
95 CompletionModel::new(self.clone(), model)
96 }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101 message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107 Ok(T),
108 Err(ApiErrorResponse),
109}
110
111pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
115
116#[derive(Debug, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct CompletionResponse {
119 pub choices: Vec<Choice>,
120 pub created: i64,
121 pub id: String,
122 pub model: String,
123 #[serde(rename = "request_id")]
124 pub request_id: String,
125 pub usage: Usage,
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129#[serde(tag = "role", rename_all = "lowercase")]
130pub enum Message {
131 User {
132 content: String,
133 },
134 Assistant {
135 content: Option<String>,
136 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
137 tool_calls: Vec<ToolCall>,
138 },
139 System {
140 content: String,
141 },
142 #[serde(rename = "tool")]
143 ToolResult {
144 tool_call_id: String,
145 content: String,
146 },
147}
148
149impl Message {
150 pub fn system(content: &str) -> Message {
151 Message::System {
152 content: content.to_owned(),
153 }
154 }
155}
156
157#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
158pub struct ToolResultContent {
159 text: String,
160}
161impl TryFrom<message::ToolResultContent> for ToolResultContent {
162 type Error = MessageError;
163 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
164 let message::ToolResultContent::Text(Text { text }) = value else {
165 return Err(MessageError::ConversionError(
166 "Non-text tool results not supported".into(),
167 ));
168 };
169
170 Ok(Self { text })
171 }
172}
173
174impl TryFrom<message::Message> for Message {
175 type Error = MessageError;
176
177 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
178 Ok(match message {
179 message::Message::User { content } => {
180 let mut texts = Vec::new();
181 let mut images = Vec::new();
182
183 for uc in content.into_iter() {
184 match uc {
185 message::UserContent::Text(message::Text { text }) => texts.push(text),
186 message::UserContent::Image(img) => images.push(img.data),
187 message::UserContent::ToolResult(result) => {
188 let content = result
189 .content
190 .into_iter()
191 .map(ToolResultContent::try_from)
192 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
193
194 let content = OneOrMany::many(content).map_err(|x| {
195 MessageError::ConversionError(format!(
196 "Couldn't make a OneOrMany from a list of tool results: {x}"
197 ))
198 })?;
199
200 return Ok(Message::ToolResult {
201 tool_call_id: result.id,
202 content: content.first().text,
203 });
204 }
205 _ => {}
206 }
207 }
208
209 let collapsed_content = texts.join(" ");
210
211 Message::User {
212 content: collapsed_content,
213 }
214 }
215 message::Message::Assistant { content, .. } => {
216 let mut texts = Vec::new();
217 let mut tool_calls = Vec::new();
218
219 for ac in content.into_iter() {
220 match ac {
221 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
222 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
223 }
224 }
225
226 let collapsed_content = texts.join(" ");
227
228 Message::Assistant {
229 content: Some(collapsed_content),
230 tool_calls,
231 }
232 }
233 })
234 }
235}
236
237impl From<message::ToolResult> for Message {
238 fn from(tool_result: message::ToolResult) -> Self {
239 let content = match tool_result.content.first() {
240 message::ToolResultContent::Text(text) => text.text,
241 message::ToolResultContent::Image(_) => String::from("[Image]"),
242 };
243
244 Message::ToolResult {
245 tool_call_id: tool_result.id,
246 content,
247 }
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct ToolCall {
254 pub function: CallFunction,
255 pub id: String,
256 pub index: usize,
257 #[serde(default)]
258 pub r#type: ToolType,
259}
260
261impl From<message::ToolCall> for ToolCall {
262 fn from(tool_call: message::ToolCall) -> Self {
263 Self {
264 id: tool_call.id,
265 index: 0,
267 r#type: ToolType::Function,
268 function: CallFunction {
269 name: tool_call.function.name,
270 arguments: tool_call.function.arguments,
271 },
272 }
273 }
274}
275
276#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
277#[serde(rename_all = "lowercase")]
278pub enum ToolType {
279 #[default]
280 Function,
281}
282
283#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
284pub struct CallFunction {
285 pub name: String,
286 #[serde(with = "json_utils::stringified_json")]
287 pub arguments: serde_json::Value,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
291#[serde(rename_all = "lowercase")]
292pub enum Role {
293 System,
294 User,
295 Assistant,
296}
297
298#[derive(Debug, Serialize, Deserialize)]
299#[serde(rename_all = "camelCase")]
300pub struct Choice {
301 #[serde(rename = "finish_reason")]
302 pub finish_reason: String,
303 pub index: i64,
304 pub message: Message,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(rename_all = "camelCase")]
309pub struct Usage {
310 #[serde(rename = "completion_tokens")]
311 pub completion_tokens: i64,
312 #[serde(rename = "prompt_tokens")]
313 pub prompt_tokens: i64,
314 #[serde(rename = "total_tokens")]
315 pub total_tokens: i64,
316}
317
318impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
319 type Error = CompletionError;
320
321 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
322 let choice = response.choices.first().ok_or_else(|| {
323 CompletionError::ResponseError("Response contained no choices".to_owned())
324 })?;
325
326 match &choice.message {
327 Message::Assistant {
328 tool_calls,
329 content,
330 } => {
331 if !tool_calls.is_empty() {
332 let tool_result = tool_calls
333 .iter()
334 .map(|call| {
335 completion::AssistantContent::tool_call(
336 &call.function.name,
337 &call.function.name,
338 call.function.arguments.clone(),
339 )
340 })
341 .collect::<Vec<_>>();
342
343 let choice = OneOrMany::many(tool_result).map_err(|_| {
344 CompletionError::ResponseError(
345 "Response contained no message or tool call (empty)".to_owned(),
346 )
347 })?;
348 tracing::debug!("response choices: {:?}: ", choice);
355 Ok(completion::CompletionResponse {
356 choice,
357 raw_response: response,
359 })
360 } else {
361 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
362 text: content.clone().unwrap_or_else(|| "".to_owned()),
363 }));
364 Ok(completion::CompletionResponse {
371 choice,
372 raw_response: response,
374 })
375 }
376 }
377 _ => Err(CompletionError::ResponseError(
379 "Chat response does not include an assistant message".into(),
380 )),
381 }
382 }
383}
384
385#[derive(Clone)]
386pub struct CompletionModel {
387 client: Client,
388 pub model: String,
389}
390
391#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
407#[serde(rename_all = "camelCase")]
408pub struct CustomFunctionDefinition {
409 #[serde(rename = "type")]
410 pub type_field: String,
411 pub function: Function,
412}
413
414#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct Function {
417 pub name: String,
418 pub description: String,
419 pub parameters: serde_json::Value,
420}
421
422impl CompletionModel {
423 pub fn new(client: Client, model: &str) -> Self {
424 Self {
425 client,
426 model: model.to_string(),
427 }
428 }
429
430 fn create_completion_request(
431 &self,
432 completion_request: CompletionRequest,
433 ) -> Result<Value, CompletionError> {
434 let mut partial_history = vec![];
436 if let Some(docs) = completion_request.normalized_documents() {
437 partial_history.push(docs);
438 }
439 partial_history.extend(completion_request.chat_history);
440
441 let mut full_history: Vec<Message> = completion_request
443 .preamble
444 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
445
446 full_history.extend(
448 partial_history
449 .into_iter()
450 .map(message::Message::try_into)
451 .collect::<Result<Vec<Message>, _>>()?,
452 );
453
454 let request = if completion_request.tools.is_empty() {
455 json!({
456 "model": self.model,
457 "messages": full_history,
458 "temperature": completion_request.temperature,
459 })
460 } else {
461 let tools = completion_request
463 .tools
464 .into_iter()
465 .map(|item| {
466 let custom_function = Function {
467 name: item.name,
468 description: item.description,
469 parameters: item.parameters,
470 };
471 CustomFunctionDefinition {
472 type_field: "function".to_string(),
473 function: custom_function,
474 }
475 })
476 .collect::<Vec<_>>();
477
478 tracing::debug!("tools: {:?}", tools);
479
480 json!({
481 "model": self.model,
482 "messages": full_history,
483 "temperature": completion_request.temperature,
484 "tools": tools,
485 "tool_choice": "auto",
486 })
487 };
488
489 let request = if let Some(params) = completion_request.additional_params {
490 json_utils::merge(request, params)
491 } else {
492 request
493 };
494
495 Ok(request)
496 }
497}
498
499impl completion::CompletionModel for CompletionModel {
501 type Response = CompletionResponse;
502 type StreamingResponse = openai::StreamingCompletionResponse;
503
504 async fn completion(
505 &self,
506 completion_request: CompletionRequest,
507 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
508 tracing::debug!("create_completion_request========");
509 let request = self.create_completion_request(completion_request)?;
510
511 tracing::debug!(
512 "request: \r\n {}",
513 serde_json::to_string_pretty(&request).unwrap()
514 );
515
516 let response = self
517 .client
518 .post("/chat/completions")
519 .json(&request)
520 .send()
521 .await?;
522
523 if response.status().is_success() {
524 let data: Value = response.json().await.expect("api error");
525 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
526 let data: ApiResponse<CompletionResponse> =
527 serde_json::from_value(data).expect("deserialize completion response");
528 match data {
529 ApiResponse::Ok(response) => {
530 tracing::info!(target: "rig",
531 "bigmodel completion token usage: {:?}",
532 response.usage
533 );
534 response.try_into()
535 }
536 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
537 }
538 } else {
539 Err(CompletionError::ProviderError(response.text().await?))
540 }
541 }
542
543 async fn stream(
544 &self,
545 request: CompletionRequest,
546 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
547 let mut request = self.create_completion_request(request)?;
548
549 request = json_utils::merge(request, json!({"stream": true}));
550
551 let builder = self.client.post("/chat/completions").json(&request);
552
553 send_compatible_streaming_request(builder).await
554 }
555}
556
557