1use http::{HeaderValue, Method, header};
2use rig::client::{CompletionClient, ProviderClient};
3use rig::completion::{CompletionError, CompletionRequest};
4use rig::message::{MessageError, Text};
5use rig::providers::openai;
6use rig::{OneOrMany, completion, http_client, message};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9
10use crate::json_utils;
11use crate::json_utils::merge;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use tracing::{Instrument, info_span};
15
16const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 api_key: String,
24 base_url: String,
25 default_headers: http_client::HeaderMap,
26 http_client: reqwest::Client,
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
31 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
32 }
33
34 pub fn from_url(api_key: &str, base_url: &str) -> Self {
35 let mut default_headers = reqwest::header::HeaderMap::new();
36 default_headers.insert(
37 reqwest::header::CONTENT_TYPE,
38 "application/json".parse().unwrap(),
39 );
40
41 Self {
42 api_key: api_key.to_string(),
43 base_url: base_url.to_string(),
44 default_headers,
45 http_client: reqwest::Client::builder()
46 .default_headers({
47 let mut headers = reqwest::header::HeaderMap::new();
48 headers.insert(
49 "Authorization",
50 format!("Bearer {api_key}")
51 .parse()
52 .expect("Bearer token should parse"),
53 );
54 headers
55 })
56 .build()
57 .expect("bigmodel reqwest client should build"),
58 }
59 }
60
61 fn post(&self, path: &str) -> reqwest::RequestBuilder {
62 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
63 self.http_client.post(url)
64 }
65
66 pub fn completion_model(&self, model: &str) -> CompletionModel {
67 CompletionModel::new(self.clone(), model)
68 }
69
70 }
78
79impl ProviderClient for Client {
80 type Input = String;
81
82 fn from_env() -> Self
83 where
84 Self: Sized,
85 {
86 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87 Self::new(&api_key)
88 }
89
90 fn from_val(input: Self::Input) -> Self {
91 Self::new(&input)
92 }
93}
94impl CompletionClient for Client {
95 type CompletionModel = CompletionModel;
96
97 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
98 CompletionModel::new(self.clone(), &model.into())
99 }
100}
101
102#[derive(Debug, Deserialize)]
103struct ApiErrorResponse {
104 message: String,
105}
106
107#[derive(Debug, Deserialize)]
108#[serde(untagged)]
109enum ApiResponse<T> {
110 Ok(T),
111 Err(ApiErrorResponse),
112}
113
114pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
118pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
119
120#[derive(Debug, Deserialize, Serialize)]
121#[serde(rename_all = "camelCase")]
122pub struct CompletionResponse {
123 pub choices: Vec<Choice>,
124 pub created: i64,
125 pub id: String,
126 pub model: String,
127 #[serde(rename = "request_id")]
128 pub request_id: String,
129 pub usage: Usage,
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133#[serde(tag = "role", rename_all = "lowercase")]
134pub enum Message {
135 User {
136 content: String,
137 },
138 Assistant {
139 content: Option<String>,
140 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
141 tool_calls: Vec<ToolCall>,
142 },
143 System {
144 content: String,
145 },
146 #[serde(rename = "tool")]
147 ToolResult {
148 tool_call_id: String,
149 content: String,
150 },
151}
152
153impl Message {
154 pub fn system(content: &str) -> Message {
155 Message::System {
156 content: content.to_owned(),
157 }
158 }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct ToolResultContent {
163 text: String,
164}
165impl TryFrom<message::ToolResultContent> for ToolResultContent {
166 type Error = MessageError;
167 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
168 let message::ToolResultContent::Text(Text { text }) = value else {
169 return Err(MessageError::ConversionError(
170 "Non-text tool results not supported".into(),
171 ));
172 };
173
174 Ok(Self { text })
175 }
176}
177
178impl TryFrom<message::Message> for Message {
179 type Error = MessageError;
180
181 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
182 Ok(match message {
183 message::Message::User { content } => {
184 let mut texts = Vec::new();
185 let mut images = Vec::new();
186
187 for uc in content.into_iter() {
188 match uc {
189 message::UserContent::Text(message::Text { text }) => texts.push(text),
190 message::UserContent::Image(img) => images.push(img.data),
191 message::UserContent::ToolResult(result) => {
192 let content = result
193 .content
194 .into_iter()
195 .map(ToolResultContent::try_from)
196 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
197
198 let content = OneOrMany::many(content).map_err(|x| {
199 MessageError::ConversionError(format!(
200 "Couldn't make a OneOrMany from a list of tool results: {x}"
201 ))
202 })?;
203
204 return Ok(Message::ToolResult {
205 tool_call_id: result.id,
206 content: content.first().text,
207 });
208 }
209 _ => {}
210 }
211 }
212
213 let collapsed_content = texts.join(" ");
214
215 Message::User {
216 content: collapsed_content,
217 }
218 }
219 message::Message::Assistant { content, .. } => {
220 let mut texts = Vec::new();
221 let mut tool_calls = Vec::new();
222
223 for ac in content.into_iter() {
224 match ac {
225 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
226 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
227 _ => {}
228 }
229 }
230
231 let collapsed_content = texts.join(" ");
232
233 Message::Assistant {
234 content: Some(collapsed_content),
235 tool_calls,
236 }
237 }
238 })
239 }
240}
241
242impl From<message::ToolResult> for Message {
243 fn from(tool_result: message::ToolResult) -> Self {
244 let content = match tool_result.content.first() {
245 message::ToolResultContent::Text(text) => text.text,
246 message::ToolResultContent::Image(_) => String::from("[Image]"),
247 };
248
249 Message::ToolResult {
250 tool_call_id: tool_result.id,
251 content,
252 }
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct ToolCall {
259 pub function: CallFunction,
260 pub id: String,
261 pub index: usize,
262 #[serde(default)]
263 pub r#type: ToolType,
264}
265
266impl From<message::ToolCall> for ToolCall {
267 fn from(tool_call: message::ToolCall) -> Self {
268 Self {
269 id: tool_call.id,
270 index: 0,
271 r#type: ToolType::Function,
272 function: CallFunction {
273 name: tool_call.function.name,
274 arguments: tool_call.function.arguments,
275 },
276 }
277 }
278}
279
280#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
281#[serde(rename_all = "lowercase")]
282pub enum ToolType {
283 #[default]
284 Function,
285}
286
287#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
288pub struct CallFunction {
289 pub name: String,
290 #[serde(with = "json_utils::stringified_json")]
291 pub arguments: serde_json::Value,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
295#[serde(rename_all = "lowercase")]
296pub enum Role {
297 System,
298 User,
299 Assistant,
300}
301
302#[derive(Debug, Serialize, Deserialize)]
303#[serde(rename_all = "camelCase")]
304pub struct Choice {
305 #[serde(rename = "finish_reason")]
306 pub finish_reason: String,
307 pub index: i64,
308 pub message: Message,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312#[serde(rename_all = "camelCase")]
313pub struct Usage {
314 #[serde(rename = "completion_tokens")]
315 pub completion_tokens: i64,
316 #[serde(rename = "prompt_tokens")]
317 pub prompt_tokens: i64,
318 #[serde(rename = "total_tokens")]
319 pub total_tokens: i64,
320}
321
322impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
323 type Error = CompletionError;
324
325 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
326 let choice = response.choices.first().ok_or_else(|| {
327 CompletionError::ResponseError("Response contained no choices".to_owned())
328 })?;
329
330 match &choice.message {
331 Message::Assistant {
332 tool_calls,
333 content,
334 } => {
335 if !tool_calls.is_empty() {
336 let tool_result = tool_calls
337 .iter()
338 .map(|call| {
339 completion::AssistantContent::tool_call(
340 &call.function.name,
341 &call.function.name,
342 call.function.arguments.clone(),
343 )
344 })
345 .collect::<Vec<_>>();
346
347 let choice = OneOrMany::many(tool_result).map_err(|_| {
348 CompletionError::ResponseError(
349 "Response contained no message or tool call (empty)".to_owned(),
350 )
351 })?;
352 let usage = completion::Usage {
353 input_tokens: response.usage.prompt_tokens as u64,
354 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
355 as u64,
356 total_tokens: response.usage.total_tokens as u64,
357 };
358 tracing::debug!("response choices: {:?}: ", choice);
359 Ok(completion::CompletionResponse {
360 choice,
361 usage,
362 raw_response: response,
363 })
364 } else {
365 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
366 text: content.clone().unwrap_or_else(|| "".to_owned()),
367 }));
368 let usage = completion::Usage {
369 input_tokens: response.usage.prompt_tokens as u64,
370 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
371 as u64,
372 total_tokens: response.usage.total_tokens as u64,
373 };
374 Ok(completion::CompletionResponse {
375 choice,
376 usage,
377 raw_response: response,
378 })
379 }
380 }
381 _ => Err(CompletionError::ResponseError(
383 "Chat response does not include an assistant message".into(),
384 )),
385 }
386 }
387}
388
389#[derive(Clone)]
390pub struct CompletionModel {
391 client: Client,
392 pub model: String,
393}
394
395#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
397#[serde(rename_all = "camelCase")]
398pub struct CustomFunctionDefinition {
399 #[serde(rename = "type")]
400 pub type_field: String,
401 pub function: Function,
402}
403
404#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
405#[serde(rename_all = "camelCase")]
406pub struct Function {
407 pub name: String,
408 pub description: String,
409 pub parameters: serde_json::Value,
410}
411
412impl CompletionModel {
413 pub fn new(client: Client, model: &str) -> Self {
414 Self {
415 client,
416 model: model.to_string(),
417 }
418 }
419
420 fn create_completion_request(
421 &self,
422 completion_request: CompletionRequest,
423 ) -> Result<Value, CompletionError> {
424 let mut partial_history = vec![];
426 if let Some(docs) = completion_request.normalized_documents() {
427 partial_history.push(docs);
428 }
429 partial_history.extend(completion_request.chat_history);
430
431 let mut full_history: Vec<Message> = completion_request
433 .preamble
434 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
435
436 full_history.extend(
438 partial_history
439 .into_iter()
440 .map(message::Message::try_into)
441 .collect::<Result<Vec<Message>, _>>()?,
442 );
443
444 let request = if completion_request.tools.is_empty() {
445 json!({
446 "model": self.model,
447 "messages": full_history,
448 "temperature": completion_request.temperature,
449 })
450 } else {
451 let tools = completion_request
453 .tools
454 .into_iter()
455 .map(|item| {
456 let custom_function = Function {
457 name: item.name,
458 description: item.description,
459 parameters: item.parameters,
460 };
461 CustomFunctionDefinition {
462 type_field: "function".to_string(),
463 function: custom_function,
464 }
465 })
466 .collect::<Vec<_>>();
467
468 tracing::debug!("tools: {:?}", tools);
469
470 json!({
471 "model": self.model,
472 "messages": full_history,
473 "temperature": completion_request.temperature,
474 "tools": tools,
475 "tool_choice": "auto",
476 })
477 };
478
479 let request = if let Some(params) = completion_request.additional_params {
480 json_utils::merge(request, params)
481 } else {
482 request
483 };
484
485 Ok(request)
486 }
487}
488
489impl completion::CompletionModel for CompletionModel {
491 type Response = CompletionResponse;
492 type StreamingResponse = openai::StreamingCompletionResponse;
493 type Client = Client;
494
495 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
496 Self::new(client.clone(), &model.into())
497 }
498
499 async fn completion(
500 &self,
501 completion_request: CompletionRequest,
502 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
503 tracing::debug!("create_completion_request========");
504 let request = self.create_completion_request(completion_request)?;
505
506 tracing::debug!(
507 "request: \r\n {}",
508 serde_json::to_string_pretty(&request).unwrap()
509 );
510
511 let response = self
512 .client
513 .post("/chat/completions")
514 .json(&request)
515 .send()
516 .await
517 .map_err(|e| http_client::Error::Instance(e.into()))?;
518
519 if response.status().is_success() {
520 let data: Value = response.json().await.expect("api error");
521 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
522 let data: ApiResponse<CompletionResponse> =
523 serde_json::from_value(data).expect("deserialize completion response");
524 match data {
525 ApiResponse::Ok(response) => {
526 tracing::info!(target: "rig",
527 "bigmodel completion token usage: {:?}",
528 response.usage
529 );
530 response.try_into()
531 }
532 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
533 }
534 } else {
535 Err(CompletionError::ProviderError(
536 response
537 .text()
538 .await
539 .map_err(|e| http_client::Error::Instance(e.into()))?,
540 ))
541 }
542 }
543
544 async fn stream(
545 &self,
546 request: CompletionRequest,
547 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
548 let preamble = request.preamble.clone();
549
550 let mut request = self.create_completion_request(request)?;
551
552 request = merge(request, json!({"stream": true}));
553
554 let body = serde_json::to_vec(&request)?;
555
556 let url = format!(
557 "{}/{}",
558 self.client.base_url,
559 "/chat/completions".trim_start_matches('/')
560 );
561
562 let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
563 for (header, value) in &self.client.default_headers {
564 builder = builder.header(header, value);
565 }
566
567 let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
568 .map_err(http::Error::from)
569 .map_err(rig::http_client::Error::from)?;
570
571 builder = builder.header(header::AUTHORIZATION, auth_header);
572 builder = builder.header("Content-Type", "application/json");
573
574 let req = builder
575 .body(body)
576 .map_err(|e| CompletionError::HttpError(e.into()))?;
577
578 let span = if tracing::Span::current().is_disabled() {
579 info_span!(
580 target: "rig::completions",
581 "chat_streaming",
582 gen_ai.operation.name = "chat_streaming",
583 gen_ai.provider.name = "galadriel",
584 gen_ai.request.model = self.model,
585 gen_ai.system_instructions = preamble,
586 gen_ai.response.id = tracing::field::Empty,
587 gen_ai.response.model = tracing::field::Empty,
588 gen_ai.usage.output_tokens = tracing::field::Empty,
589 gen_ai.usage.input_tokens = tracing::field::Empty,
590 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
591 gen_ai.output.messages = tracing::field::Empty,
592 )
593 } else {
594 tracing::Span::current()
595 };
596
597 send_compatible_streaming_request(self.client.http_client.clone(), req)
598 .instrument(span)
599 .await
600 }
601}