1use rig::agent::AgentBuilder;
2use rig::completion::{CompletionError, CompletionRequest};
3use rig::extractor::ExtractorBuilder;
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, message};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use rig::providers::openai::send_compatible_streaming_request;
12use rig::streaming::StreamingCompletionResponse;
13
14use crate::json_utils;
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 base_url: String,
24 http_client: reqwest::Client,
25}
26
27impl Client {
28 pub fn new(api_key: &str) -> Self {
29 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
30 }
31
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 http_client: reqwest::Client::builder()
36 .default_headers({
37 let mut headers = reqwest::header::HeaderMap::new();
38 headers.insert(
39 "Authorization",
40 format!("Bearer {}", api_key)
41 .parse()
42 .expect("Bearer token should parse"),
43 );
44 headers
45 })
46 .build()
47 .expect("bigmodel reqwest client should build"),
48 }
49 }
50
51 pub fn from_env() -> Self {
52 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
53 Self::new(&api_key)
54 }
55
56 fn post(&self, path: &str) -> reqwest::RequestBuilder {
57 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58 self.http_client.post(url)
59 }
60
61 pub fn completion_model(&self, model: &str) -> CompletionModel {
62 CompletionModel::new(self.clone(), model)
63 }
64
65 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
66 AgentBuilder::new(self.completion_model(model))
67 }
68
69 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
71 &self,
72 model: &str,
73 ) -> ExtractorBuilder<T, CompletionModel> {
74 ExtractorBuilder::new(self.completion_model(model))
75 }
76}
77
78#[derive(Debug, Deserialize)]
79struct ApiErrorResponse {
80 message: String,
81}
82
83#[derive(Debug, Deserialize)]
84#[serde(untagged)]
85enum ApiResponse<T> {
86 Ok(T),
87 Err(ApiErrorResponse),
88}
89
90pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
94
95#[derive(Debug, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct CompletionResponse {
98 pub choices: Vec<Choice>,
99 pub created: i64,
100 pub id: String,
101 pub model: String,
102 #[serde(rename = "request_id")]
103 pub request_id: String,
104 pub usage: Usage,
105}
106
107#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
108#[serde(tag = "role", rename_all = "lowercase")]
109pub enum Message {
110 User {
111 content: String,
112 },
113 Assistant {
114 content: Option<String>,
115 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
116 tool_calls: Vec<ToolCall>,
117 },
118 System {
119 content: String,
120 },
121 #[serde(rename = "tool")]
122 ToolResult {
123 tool_call_id: String,
124 content: String,
125 },
126}
127
128impl Message {
129 pub fn system(content: &str) -> Message {
130 Message::System {
131 content: content.to_owned(),
132 }
133 }
134}
135
136#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
137pub struct ToolResultContent {
138 text: String,
139}
140impl TryFrom<message::ToolResultContent> for ToolResultContent {
141 type Error = MessageError;
142 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
143 let message::ToolResultContent::Text(Text { text }) = value else {
144 return Err(MessageError::ConversionError(
145 "Non-text tool results not supported".into(),
146 ));
147 };
148
149 Ok(Self { text })
150 }
151}
152
153impl TryFrom<message::Message> for Message {
154 type Error = MessageError;
155
156 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
157 Ok(match message {
158 message::Message::User { content } => {
159 let mut texts = Vec::new();
160 let mut images = Vec::new();
161
162 for uc in content.into_iter() {
163 match uc {
164 message::UserContent::Text(message::Text { text }) => texts.push(text),
165 message::UserContent::Image(img) => images.push(img.data),
166 message::UserContent::ToolResult(result) => {
167 let content = result
168 .content
169 .into_iter()
170 .map(ToolResultContent::try_from)
171 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
172
173 let content = OneOrMany::many(content).map_err(|x| {
174 MessageError::ConversionError(format!(
175 "Couldn't make a OneOrMany from a list of tool results: {x}"
176 ))
177 })?;
178
179 return Ok(Message::ToolResult {
180 tool_call_id: result.id,
181 content: content.first().text,
182 });
183 }
184 _ => {}
185 }
186 }
187
188 let collapsed_content = texts.join(" ");
189
190 Message::User {
191 content: collapsed_content,
192 }
193 }
194 message::Message::Assistant { content } => {
195 let mut texts = Vec::new();
196 let mut tool_calls = Vec::new();
197
198 for ac in content.into_iter() {
199 match ac {
200 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
201 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
202 }
203 }
204
205 let collapsed_content = texts.join(" ");
206
207 Message::Assistant {
208 content: Some(collapsed_content),
209 tool_calls,
210 }
211 }
212 })
213 }
214}
215
216impl From<message::ToolResult> for Message {
217 fn from(tool_result: message::ToolResult) -> Self {
218 let content = match tool_result.content.first() {
219 message::ToolResultContent::Text(text) => text.text,
220 message::ToolResultContent::Image(_) => String::from("[Image]"),
221 };
222
223 Message::ToolResult {
224 tool_call_id: tool_result.id,
225 content,
226 }
227 }
228}
229
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231#[serde(rename_all = "camelCase")]
232pub struct ToolCall {
233 pub function: CallFunction,
234 pub id: String,
235 pub index: usize,
236 #[serde(default)]
237 pub r#type: ToolType,
238}
239
240impl From<message::ToolCall> for ToolCall {
241 fn from(tool_call: message::ToolCall) -> Self {
242 Self {
243 id: tool_call.id,
244 index: 0,
246 r#type: ToolType::Function,
247 function: CallFunction {
248 name: tool_call.function.name,
249 arguments: tool_call.function.arguments,
250 },
251 }
252 }
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258 #[default]
259 Function,
260}
261
262#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
263pub struct CallFunction {
264 pub name: String,
265 #[serde(with = "json_utils::stringified_json")]
266 pub arguments: serde_json::Value,
267}
268
269#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
270#[serde(rename_all = "lowercase")]
271pub enum Role {
272 System,
273 User,
274 Assistant,
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278#[serde(rename_all = "camelCase")]
279pub struct Choice {
280 #[serde(rename = "finish_reason")]
281 pub finish_reason: String,
282 pub index: i64,
283 pub message: Message,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
287#[serde(rename_all = "camelCase")]
288pub struct Usage {
289 #[serde(rename = "completion_tokens")]
290 pub completion_tokens: i64,
291 #[serde(rename = "prompt_tokens")]
292 pub prompt_tokens: i64,
293 #[serde(rename = "total_tokens")]
294 pub total_tokens: i64,
295}
296
297impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
298 type Error = CompletionError;
299
300 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
301 let choice = response.choices.first().ok_or_else(|| {
302 CompletionError::ResponseError("Response contained no choices".to_owned())
303 })?;
304
305 match &choice.message {
306 Message::Assistant {
307 tool_calls,
308 content,
309 } => {
310 if !tool_calls.is_empty() {
311 let tool_result = tool_calls
312 .iter()
313 .map(|call| {
314 completion::AssistantContent::tool_call(
315 &call.function.name,
316 &call.function.name,
317 call.function.arguments.clone(),
318 )
319 })
320 .collect::<Vec<_>>();
321
322 let choice = OneOrMany::many(tool_result).map_err(|_| {
323 CompletionError::ResponseError(
324 "Response contained no message or tool call (empty)".to_owned(),
325 )
326 })?;
327 tracing::debug!("response choices: {:?}: ", choice);
328 Ok(completion::CompletionResponse {
329 choice,
330 raw_response: response,
331 })
332 } else {
333 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
334 text: content.clone().unwrap_or_else(|| "".to_owned()),
335 }));
336 Ok(completion::CompletionResponse {
337 choice,
338 raw_response: response,
339 })
340 }
341 }
342 _ => Err(CompletionError::ResponseError(
344 "Chat response does not include an assistant message".into(),
345 )),
346 }
347 }
348}
349
350#[derive(Clone)]
351pub struct CompletionModel {
352 client: Client,
353 pub model: String,
354}
355
356#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
358#[serde(rename_all = "camelCase")]
359pub struct CustomFunctionDefinition {
360 #[serde(rename = "type")]
361 pub type_field: String,
362 pub function: Function,
363}
364
365#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
366#[serde(rename_all = "camelCase")]
367pub struct Function {
368 pub name: String,
369 pub description: String,
370 pub parameters: serde_json::Value,
371}
372
373impl CompletionModel {
374 pub fn new(client: Client, model: &str) -> Self {
375 Self {
376 client,
377 model: model.to_string(),
378 }
379 }
380
381 fn create_completion_request(
382 &self,
383 completion_request: CompletionRequest,
384 ) -> Result<Value, CompletionError> {
385 let mut partial_history = vec![];
387 if let Some(docs) = completion_request.normalized_documents() {
388 partial_history.push(docs);
389 }
390 partial_history.extend(completion_request.chat_history);
391
392 let mut full_history: Vec<Message> = completion_request
394 .preamble
395 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
396
397 full_history.extend(
399 partial_history
400 .into_iter()
401 .map(message::Message::try_into)
402 .collect::<Result<Vec<Message>, _>>()?,
403 );
404
405 let request = if completion_request.tools.is_empty() {
406 json!({
407 "model": self.model,
408 "messages": full_history,
409 "temperature": completion_request.temperature,
410 })
411 } else {
412 let tools = completion_request
414 .tools
415 .into_iter()
416 .map(|item| {
417 let custom_function = Function {
418 name: item.name,
419 description: item.description,
420 parameters: item.parameters,
421 };
422 CustomFunctionDefinition {
423 type_field: "function".to_string(),
424 function: custom_function,
425 }
426 })
427 .collect::<Vec<_>>();
428
429 tracing::debug!("tools: {:?}", tools);
430
431 json!({
432 "model": self.model,
433 "messages": full_history,
434 "temperature": completion_request.temperature,
435 "tools": tools,
436 "tool_choice": "auto",
437 })
438 };
439
440 let request = if let Some(params) = completion_request.additional_params {
441 json_utils::merge(request, params)
442 } else {
443 request
444 };
445
446 Ok(request)
447 }
448}
449
450impl completion::CompletionModel for CompletionModel {
452 type Response = CompletionResponse;
453 type StreamingResponse = openai::StreamingCompletionResponse;
454
455 async fn completion(
456 &self,
457 completion_request: CompletionRequest,
458 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
459 tracing::debug!("create_completion_request========");
460 let request = self.create_completion_request(completion_request)?;
461
462 tracing::debug!(
463 "request: \r\n {}",
464 serde_json::to_string_pretty(&request).unwrap()
465 );
466
467 let response = self
468 .client
469 .post("/chat/completions")
470 .json(&request)
471 .send()
472 .await?;
473
474 if response.status().is_success() {
475 let data: Value = response.json().await.expect("api error");
476 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
477 let data: ApiResponse<CompletionResponse> =
478 serde_json::from_value(data).expect("deserialize completion response");
479 match data {
480 ApiResponse::Ok(response) => {
481 tracing::info!(target: "rig",
482 "bigmodel completion token usage: {:?}",
483 response.usage
484 );
485 response.try_into()
486 }
487 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
488 }
489 } else {
490 Err(CompletionError::ProviderError(response.text().await?))
491 }
492 }
493
494 async fn stream(
495 &self,
496 request: CompletionRequest,
497 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
498 let mut request = self.create_completion_request(request)?;
499
500 request = json_utils::merge(request, json!({"stream": true}));
501
502 let builder = self.client.post("/chat/completions").json(&request);
503
504 send_compatible_streaming_request(builder).await
505 }
506}