1use http::header;
2use reqwest::Method;
3use reqwest::header::HeaderValue;
4use rig::client::{AsEmbeddings, AsTranscription, CompletionClient, ProviderClient, ProviderValue};
5use rig::completion::{CompletionError, CompletionRequest};
6use rig::message::{MessageError, Text};
7use rig::providers::openai;
8use rig::{OneOrMany, client, completion, http_client, message};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::json_utils;
13use crate::json_utils::merge;
14use rig::providers::openai::send_compatible_streaming_request;
15use rig::streaming::StreamingCompletionResponse;
16use tracing::{Instrument, info_span};
17
18const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Clone, Debug)]
24pub struct Client {
25 api_key: String,
26 base_url: String,
27 default_headers: http_client::HeaderMap,
28 http_client: reqwest::Client,
29}
30
31impl Client {
32 pub fn new(api_key: &str) -> Self {
33 Self::from_url(api_key, BIGMODEL_API_BASE_URL)
34 }
35
36 pub fn from_url(api_key: &str, base_url: &str) -> Self {
37 let mut default_headers = reqwest::header::HeaderMap::new();
38 default_headers.insert(
39 reqwest::header::CONTENT_TYPE,
40 "application/json".parse().unwrap(),
41 );
42
43 Self {
44 api_key: api_key.to_string(),
45 base_url: base_url.to_string(),
46 default_headers,
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 "Authorization",
52 format!("Bearer {api_key}")
53 .parse()
54 .expect("Bearer token should parse"),
55 );
56 headers
57 })
58 .build()
59 .expect("bigmodel reqwest client should build"),
60 }
61 }
62
63 fn post(&self, path: &str) -> reqwest::RequestBuilder {
64 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65 self.http_client.post(url)
66 }
67
68 pub fn completion_model(&self, model: &str) -> CompletionModel {
69 CompletionModel::new(self.clone(), model)
70 }
71
72 }
80
81impl ProviderClient for Client {
82 fn from_env() -> Self
83 where
84 Self: Sized,
85 {
86 let api_key = std::env::var("BIGMODEL_API_KEY").expect("BIGMODEL_KEY not set");
87 Self::new(&api_key)
88 }
89
90 fn from_val(input: ProviderValue) -> Self
91 where
92 Self: Sized,
93 {
94 let client::ProviderValue::Simple(api_key) = input else {
95 panic!("Incorrect provider value type")
96 };
97 Self::new(&api_key)
98 }
99}
100
101impl AsTranscription for Client {}
102
103impl AsEmbeddings for Client {}
104
105impl CompletionClient for Client {
106 type CompletionModel = CompletionModel;
107
108 fn completion_model(&self, model: &str) -> Self::CompletionModel {
109 CompletionModel::new(self.clone(), model)
110 }
111}
112
113#[derive(Debug, Deserialize)]
114struct ApiErrorResponse {
115 message: String,
116}
117
118#[derive(Debug, Deserialize)]
119#[serde(untagged)]
120enum ApiResponse<T> {
121 Ok(T),
122 Err(ApiErrorResponse),
123}
124
125pub const BIGMODEL_GLM_4_FLASH: &str = "glm-4-flash";
129pub const BIGMODEL_GLM_4_5_FLASH: &str = "glm-4.5-flash";
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(rename_all = "camelCase")]
133pub struct CompletionResponse {
134 pub choices: Vec<Choice>,
135 pub created: i64,
136 pub id: String,
137 pub model: String,
138 #[serde(rename = "request_id")]
139 pub request_id: String,
140 pub usage: Usage,
141}
142
143#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
144#[serde(tag = "role", rename_all = "lowercase")]
145pub enum Message {
146 User {
147 content: String,
148 },
149 Assistant {
150 content: Option<String>,
151 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
152 tool_calls: Vec<ToolCall>,
153 },
154 System {
155 content: String,
156 },
157 #[serde(rename = "tool")]
158 ToolResult {
159 tool_call_id: String,
160 content: String,
161 },
162}
163
164impl Message {
165 pub fn system(content: &str) -> Message {
166 Message::System {
167 content: content.to_owned(),
168 }
169 }
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173pub struct ToolResultContent {
174 text: String,
175}
176impl TryFrom<message::ToolResultContent> for ToolResultContent {
177 type Error = MessageError;
178 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
179 let message::ToolResultContent::Text(Text { text }) = value else {
180 return Err(MessageError::ConversionError(
181 "Non-text tool results not supported".into(),
182 ));
183 };
184
185 Ok(Self { text })
186 }
187}
188
189impl TryFrom<message::Message> for Message {
190 type Error = MessageError;
191
192 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
193 Ok(match message {
194 message::Message::User { content } => {
195 let mut texts = Vec::new();
196 let mut images = Vec::new();
197
198 for uc in content.into_iter() {
199 match uc {
200 message::UserContent::Text(message::Text { text }) => texts.push(text),
201 message::UserContent::Image(img) => images.push(img.data),
202 message::UserContent::ToolResult(result) => {
203 let content = result
204 .content
205 .into_iter()
206 .map(ToolResultContent::try_from)
207 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
208
209 let content = OneOrMany::many(content).map_err(|x| {
210 MessageError::ConversionError(format!(
211 "Couldn't make a OneOrMany from a list of tool results: {x}"
212 ))
213 })?;
214
215 return Ok(Message::ToolResult {
216 tool_call_id: result.id,
217 content: content.first().text,
218 });
219 }
220 _ => {}
221 }
222 }
223
224 let collapsed_content = texts.join(" ");
225
226 Message::User {
227 content: collapsed_content,
228 }
229 }
230 message::Message::Assistant { content, .. } => {
231 let mut texts = Vec::new();
232 let mut tool_calls = Vec::new();
233
234 for ac in content.into_iter() {
235 match ac {
236 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
237 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
238 _ => {}
239 }
240 }
241
242 let collapsed_content = texts.join(" ");
243
244 Message::Assistant {
245 content: Some(collapsed_content),
246 tool_calls,
247 }
248 }
249 })
250 }
251}
252
253impl From<message::ToolResult> for Message {
254 fn from(tool_result: message::ToolResult) -> Self {
255 let content = match tool_result.content.first() {
256 message::ToolResultContent::Text(text) => text.text,
257 message::ToolResultContent::Image(_) => String::from("[Image]"),
258 };
259
260 Message::ToolResult {
261 tool_call_id: tool_result.id,
262 content,
263 }
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
268#[serde(rename_all = "camelCase")]
269pub struct ToolCall {
270 pub function: CallFunction,
271 pub id: String,
272 pub index: usize,
273 #[serde(default)]
274 pub r#type: ToolType,
275}
276
277impl From<message::ToolCall> for ToolCall {
278 fn from(tool_call: message::ToolCall) -> Self {
279 Self {
280 id: tool_call.id,
281 index: 0,
282 r#type: ToolType::Function,
283 function: CallFunction {
284 name: tool_call.function.name,
285 arguments: tool_call.function.arguments,
286 },
287 }
288 }
289}
290
291#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
292#[serde(rename_all = "lowercase")]
293pub enum ToolType {
294 #[default]
295 Function,
296}
297
298#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
299pub struct CallFunction {
300 pub name: String,
301 #[serde(with = "json_utils::stringified_json")]
302 pub arguments: serde_json::Value,
303}
304
305#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
306#[serde(rename_all = "lowercase")]
307pub enum Role {
308 System,
309 User,
310 Assistant,
311}
312
313#[derive(Debug, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315pub struct Choice {
316 #[serde(rename = "finish_reason")]
317 pub finish_reason: String,
318 pub index: i64,
319 pub message: Message,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323#[serde(rename_all = "camelCase")]
324pub struct Usage {
325 #[serde(rename = "completion_tokens")]
326 pub completion_tokens: i64,
327 #[serde(rename = "prompt_tokens")]
328 pub prompt_tokens: i64,
329 #[serde(rename = "total_tokens")]
330 pub total_tokens: i64,
331}
332
333impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
334 type Error = CompletionError;
335
336 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
337 let choice = response.choices.first().ok_or_else(|| {
338 CompletionError::ResponseError("Response contained no choices".to_owned())
339 })?;
340
341 match &choice.message {
342 Message::Assistant {
343 tool_calls,
344 content,
345 } => {
346 if !tool_calls.is_empty() {
347 let tool_result = tool_calls
348 .iter()
349 .map(|call| {
350 completion::AssistantContent::tool_call(
351 &call.function.name,
352 &call.function.name,
353 call.function.arguments.clone(),
354 )
355 })
356 .collect::<Vec<_>>();
357
358 let choice = OneOrMany::many(tool_result).map_err(|_| {
359 CompletionError::ResponseError(
360 "Response contained no message or tool call (empty)".to_owned(),
361 )
362 })?;
363 let usage = completion::Usage {
364 input_tokens: response.usage.prompt_tokens as u64,
365 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
366 as u64,
367 total_tokens: response.usage.total_tokens as u64,
368 };
369 tracing::debug!("response choices: {:?}: ", choice);
370 Ok(completion::CompletionResponse {
371 choice,
372 usage,
373 raw_response: response,
374 })
375 } else {
376 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
377 text: content.clone().unwrap_or_else(|| "".to_owned()),
378 }));
379 let usage = completion::Usage {
380 input_tokens: response.usage.prompt_tokens as u64,
381 output_tokens: (response.usage.total_tokens - response.usage.prompt_tokens)
382 as u64,
383 total_tokens: response.usage.total_tokens as u64,
384 };
385 Ok(completion::CompletionResponse {
386 choice,
387 usage,
388 raw_response: response,
389 })
390 }
391 }
392 _ => Err(CompletionError::ResponseError(
394 "Chat response does not include an assistant message".into(),
395 )),
396 }
397 }
398}
399
400#[derive(Clone)]
401pub struct CompletionModel {
402 client: Client,
403 pub model: String,
404}
405
406#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
408#[serde(rename_all = "camelCase")]
409pub struct CustomFunctionDefinition {
410 #[serde(rename = "type")]
411 pub type_field: String,
412 pub function: Function,
413}
414
415#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
416#[serde(rename_all = "camelCase")]
417pub struct Function {
418 pub name: String,
419 pub description: String,
420 pub parameters: serde_json::Value,
421}
422
423impl CompletionModel {
424 pub fn new(client: Client, model: &str) -> Self {
425 Self {
426 client,
427 model: model.to_string(),
428 }
429 }
430
431 fn create_completion_request(
432 &self,
433 completion_request: CompletionRequest,
434 ) -> Result<Value, CompletionError> {
435 let mut partial_history = vec![];
437 if let Some(docs) = completion_request.normalized_documents() {
438 partial_history.push(docs);
439 }
440 partial_history.extend(completion_request.chat_history);
441
442 let mut full_history: Vec<Message> = completion_request
444 .preamble
445 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
446
447 full_history.extend(
449 partial_history
450 .into_iter()
451 .map(message::Message::try_into)
452 .collect::<Result<Vec<Message>, _>>()?,
453 );
454
455 let request = if completion_request.tools.is_empty() {
456 json!({
457 "model": self.model,
458 "messages": full_history,
459 "temperature": completion_request.temperature,
460 })
461 } else {
462 let tools = completion_request
464 .tools
465 .into_iter()
466 .map(|item| {
467 let custom_function = Function {
468 name: item.name,
469 description: item.description,
470 parameters: item.parameters,
471 };
472 CustomFunctionDefinition {
473 type_field: "function".to_string(),
474 function: custom_function,
475 }
476 })
477 .collect::<Vec<_>>();
478
479 tracing::debug!("tools: {:?}", tools);
480
481 json!({
482 "model": self.model,
483 "messages": full_history,
484 "temperature": completion_request.temperature,
485 "tools": tools,
486 "tool_choice": "auto",
487 })
488 };
489
490 let request = if let Some(params) = completion_request.additional_params {
491 json_utils::merge(request, params)
492 } else {
493 request
494 };
495
496 Ok(request)
497 }
498}
499
500impl completion::CompletionModel for CompletionModel {
502 type Response = CompletionResponse;
503 type StreamingResponse = openai::StreamingCompletionResponse;
504
505 async fn completion(
506 &self,
507 completion_request: CompletionRequest,
508 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
509 tracing::debug!("create_completion_request========");
510 let request = self.create_completion_request(completion_request)?;
511
512 tracing::debug!(
513 "request: \r\n {}",
514 serde_json::to_string_pretty(&request).unwrap()
515 );
516
517 let response = self
518 .client
519 .post("/chat/completions")
520 .json(&request)
521 .send()
522 .await
523 .map_err(|e| http_client::Error::Instance(e.into()))?;
524
525 if response.status().is_success() {
526 let data: Value = response.json().await.expect("api error");
527 tracing::debug!("response: {}", serde_json::to_string_pretty(&data).unwrap());
528 let data: ApiResponse<CompletionResponse> =
529 serde_json::from_value(data).expect("deserialize completion response");
530 match data {
531 ApiResponse::Ok(response) => {
532 tracing::info!(target: "rig",
533 "bigmodel completion token usage: {:?}",
534 response.usage
535 );
536 response.try_into()
537 }
538 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
539 }
540 } else {
541 Err(CompletionError::ProviderError(
542 response
543 .text()
544 .await
545 .map_err(|e| http_client::Error::Instance(e.into()))?,
546 ))
547 }
548 }
549
550 async fn stream(
551 &self,
552 request: CompletionRequest,
553 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
554 let preamble = request.preamble.clone();
555
556 let mut request = self.create_completion_request(request)?;
557
558 request = merge(request, json!({"stream": true}));
559
560 let body = serde_json::to_vec(&request)?;
561
562 let url = format!(
563 "{}/{}",
564 self.client.base_url,
565 "/chat/completions".trim_start_matches('/')
566 );
567
568 let mut builder = http_client::Builder::new().uri(url).method(Method::POST);
569 for (header, value) in &self.client.default_headers {
570 builder = builder.header(header, value);
571 }
572
573 let auth_header = HeaderValue::from_str(&format!("Bearer {}", &self.client.api_key))
574 .map_err(http::Error::from)
575 .map_err(rig::http_client::Error::from)?;
576
577 builder = builder.header(header::AUTHORIZATION, auth_header);
578 builder = builder.header("Content-Type", "application/json");
579
580 let req = builder
581 .body(body)
582 .map_err(|e| CompletionError::HttpError(e.into()))?;
583
584 let span = if tracing::Span::current().is_disabled() {
585 info_span!(
586 target: "rig::completions",
587 "chat_streaming",
588 gen_ai.operation.name = "chat_streaming",
589 gen_ai.provider.name = "galadriel",
590 gen_ai.request.model = self.model,
591 gen_ai.system_instructions = preamble,
592 gen_ai.response.id = tracing::field::Empty,
593 gen_ai.response.model = tracing::field::Empty,
594 gen_ai.usage.output_tokens = tracing::field::Empty,
595 gen_ai.usage.input_tokens = tracing::field::Empty,
596 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
597 gen_ai.output.messages = tracing::field::Empty,
598 )
599 } else {
600 tracing::Span::current()
601 };
602
603 send_compatible_streaming_request(self.client.http_client.clone(), req)
604 .instrument(span)
605 .await
606 }
607}