1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6 BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34 client: Client<T>,
35 pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40 Self {
41 client,
42 model: model.into(),
43 }
44 }
45
46 fn create_completion_request(
47 &self,
48 completion_request: CompletionRequest,
49 ) -> Result<Value, CompletionError> {
50 let mut partial_history = vec![];
52 if let Some(docs) = completion_request.normalized_documents() {
53 partial_history.push(docs);
54 }
55 partial_history.extend(completion_request.chat_history);
56
57 let mut full_history: Vec<Message> = completion_request
59 .preamble
60 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62 full_history.extend(
64 partial_history
65 .into_iter()
66 .map(message::Message::try_into)
67 .collect::<Result<Vec<Message>, _>>()?,
68 );
69
70 let request = if completion_request.tools.is_empty() {
71 json!({
72 "model": self.model,
73 "messages": full_history,
74 "temperature": completion_request.temperature,
75 })
76 } else {
77 let tools = completion_request
79 .tools
80 .into_iter()
81 .map(|item| {
82 let custom_function = Function {
83 name: item.name,
84 description: item.description,
85 parameters: item.parameters,
86 };
87 CustomFunctionDefinition {
88 type_field: "function".to_string(),
89 function: custom_function,
90 }
91 })
92 .collect::<Vec<_>>();
93
94 tracing::debug!("tools: {:?}", tools);
95
96 json!({
97 "model": self.model,
98 "messages": full_history,
99 "temperature": completion_request.temperature,
100 "tools": tools,
101 "tool_choice": "auto",
102 })
103 };
104
105 let request = if let Some(params) = completion_request.additional_params {
106 json_utils::merge(request, params)
107 } else {
108 request
109 };
110
111 Ok(request)
112 }
113}
114
115impl Provider for BigmodelExt {
116 const VERIFY_PATH: &'static str = "api/tags";
117 type Builder = BigmodelBuilder;
118
119 fn build<H>(
120 _builder: &client::ClientBuilder<
121 Self::Builder,
122 <Self::Builder as ProviderBuilder>::ApiKey,
123 H,
124 >,
125 ) -> rig::http_client::Result<Self> {
126 Ok(Self)
127 }
128}
129
130impl<H> Capabilities<H> for BigmodelExt {
132 type Completion = Capable<CompletionModel<H>>;
133 type Embeddings = Nothing;
134 type Transcription = Nothing;
135 type ModelListing = Nothing;
136
137 }
143
144impl DebugExt for BigmodelExt {}
145
146impl ProviderBuilder for BigmodelBuilder {
147 type Output = BigmodelExt;
148 type ApiKey = BigmodelApiKey;
149 const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
150}
151
152pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
153pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
154
155#[derive(Debug, Deserialize)]
171struct ApiErrorResponse {
172 message: String,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(untagged)]
177enum ApiResponse<T> {
178 Ok(T),
179 Err(ApiErrorResponse),
180}
181
182pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
187
188#[derive(Debug, Deserialize, Serialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CompletionResponse {
191 pub choices: Vec<Choice>,
192 pub created: i64,
193 pub id: String,
194 pub model: String,
195 #[serde(rename = "request_id")]
196 pub request_id: String,
197 pub usage: Option<Usage>,
198}
199
200#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
201#[serde(tag = "role", rename_all = "lowercase")]
202pub enum Message {
203 User {
204 content: String,
205 },
206 Assistant {
207 content: Option<String>,
208 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
209 tool_calls: Vec<ToolCall>,
210 },
211 System {
212 content: String,
213 },
214 #[serde(rename = "tool")]
215 ToolResult {
216 tool_call_id: String,
217 content: String,
218 },
219}
220
221impl Message {
222 pub fn system(content: &str) -> Message {
223 Message::System {
224 content: content.to_owned(),
225 }
226 }
227}
228
229#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
230pub struct ToolResultContent {
231 text: String,
232}
233impl TryFrom<message::ToolResultContent> for ToolResultContent {
234 type Error = MessageError;
235 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
236 let message::ToolResultContent::Text(Text { text }) = value else {
237 return Err(MessageError::ConversionError(
238 "Non-text tool results not supported".into(),
239 ));
240 };
241
242 Ok(Self { text })
243 }
244}
245
246impl TryFrom<message::Message> for Message {
247 type Error = MessageError;
248
249 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
250 Ok(match message {
251 message::Message::User { content } => {
252 let mut texts = Vec::new();
253 let mut images = Vec::new();
254
255 for uc in content.into_iter() {
256 match uc {
257 message::UserContent::Text(message::Text { text }) => texts.push(text),
258 message::UserContent::Image(img) => images.push(img.data),
259 message::UserContent::ToolResult(result) => {
260 let content = result
261 .content
262 .into_iter()
263 .map(ToolResultContent::try_from)
264 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
265
266 let content = OneOrMany::many(content).map_err(|x| {
267 MessageError::ConversionError(format!(
268 "Couldn't make a OneOrMany from a list of tool results: {x}"
269 ))
270 })?;
271
272 return Ok(Message::ToolResult {
273 tool_call_id: result.id,
274 content: content.first().text,
275 });
276 }
277 _ => {}
278 }
279 }
280
281 let collapsed_content = texts.join(" ");
282
283 Message::User {
284 content: collapsed_content,
285 }
286 }
287 message::Message::Assistant { content, .. } => {
288 let mut texts = Vec::new();
289 let mut tool_calls = Vec::new();
290
291 for ac in content.into_iter() {
292 match ac {
293 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
294 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
295 _ => {}
296 }
297 }
298
299 let collapsed_content = texts.join(" ");
300
301 Message::Assistant {
302 content: Some(collapsed_content),
303 tool_calls,
304 }
305 }
306 })
307 }
308}
309
310impl From<message::ToolResult> for Message {
311 fn from(tool_result: message::ToolResult) -> Self {
312 let content = match tool_result.content.first() {
313 message::ToolResultContent::Text(text) => text.text,
314 message::ToolResultContent::Image(_) => String::from("[Image]"),
315 };
316
317 Message::ToolResult {
318 tool_call_id: tool_result.id,
319 content,
320 }
321 }
322}
323
324#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
325#[serde(rename_all = "camelCase")]
326pub struct ToolCall {
327 pub function: CallFunction,
328 pub id: String,
329 pub index: usize,
330 #[serde(default)]
331 pub r#type: ToolType,
332}
333
334impl From<message::ToolCall> for ToolCall {
335 fn from(tool_call: message::ToolCall) -> Self {
336 Self {
337 id: tool_call.id,
338 index: 0,
339 r#type: ToolType::Function,
340 function: CallFunction {
341 name: tool_call.function.name,
342 arguments: tool_call.function.arguments,
343 },
344 }
345 }
346}
347
348#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
349#[serde(rename_all = "lowercase")]
350pub enum ToolType {
351 #[default]
352 Function,
353}
354
355#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
356pub struct CallFunction {
357 pub name: String,
358 #[serde(with = "json_utils::stringified_json")]
359 pub arguments: serde_json::Value,
360}
361
362#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
363#[serde(rename_all = "lowercase")]
364pub enum Role {
365 System,
366 User,
367 Assistant,
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371#[serde(rename_all = "camelCase")]
372pub struct Choice {
373 #[serde(rename = "finish_reason")]
374 pub finish_reason: String,
375 pub index: i64,
376 pub message: Message,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380#[serde(rename_all = "camelCase")]
381pub struct Usage {
382 #[serde(rename = "completion_tokens")]
383 pub completion_tokens: i64,
384 #[serde(rename = "prompt_tokens")]
385 pub prompt_tokens: i64,
386 #[serde(rename = "total_tokens")]
387 pub total_tokens: i64,
388 #[serde(skip_serializing_if = "Option::is_none")]
389 pub prompt_tokens_details: Option<PromptTokensDetails>,
390}
391
392#[derive(Clone, Debug, Deserialize, Serialize, Default)]
393pub struct PromptTokensDetails {
394 #[serde(default)]
396 pub cached_tokens: usize,
397}
398
399impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
400 type Error = CompletionError;
401
402 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
403 let choice = response.choices.first().ok_or_else(|| {
404 CompletionError::ResponseError("Response contained no choices".to_owned())
405 })?;
406
407 match &choice.message {
408 Message::Assistant {
409 tool_calls,
410 content,
411 } => {
412 if !tool_calls.is_empty() {
413 let tool_result = tool_calls
414 .iter()
415 .map(|call| {
416 completion::AssistantContent::tool_call(
417 &call.function.name,
418 &call.function.name,
419 call.function.arguments.clone(),
420 )
421 })
422 .collect::<Vec<_>>();
423
424 let choice = OneOrMany::many(tool_result).map_err(|_| {
425 CompletionError::ResponseError(
426 "Response contained no message or tool call (empty)".to_owned(),
427 )
428 })?;
429 let usage = response
430 .usage
431 .as_ref()
432 .map(|usage| completion::Usage {
433 input_tokens: usage.prompt_tokens as u64,
434 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
435 total_tokens: usage.total_tokens as u64,
436 cached_input_tokens: usage
437 .prompt_tokens_details
438 .as_ref()
439 .map(|d| d.cached_tokens as u64)
440 .unwrap_or(0),
441 })
442 .unwrap_or_default();
443 tracing::debug!("response choices: {:?}: ", choice);
444 Ok(completion::CompletionResponse {
445 choice,
446 usage,
447 raw_response: response,
448 message_id: None,
449 })
450 } else {
451 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
452 text: content.clone().unwrap_or_else(|| "".to_owned()),
453 }));
454 let usage = response
455 .usage
456 .as_ref()
457 .map(|usage| completion::Usage {
458 input_tokens: usage.prompt_tokens as u64,
459 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
460 total_tokens: usage.total_tokens as u64,
461 cached_input_tokens: usage
462 .prompt_tokens_details
463 .as_ref()
464 .map(|d| d.cached_tokens as u64)
465 .unwrap_or(0),
466 })
467 .unwrap_or_default();
468 Ok(completion::CompletionResponse {
469 choice,
470 usage,
471 raw_response: response,
472 message_id: None,
473 })
474 }
475 }
476 _ => Err(CompletionError::ResponseError(
478 "Chat response does not include an assistant message".into(),
479 )),
480 }
481 }
482}
483
484#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
486#[serde(rename_all = "camelCase")]
487pub struct CustomFunctionDefinition {
488 #[serde(rename = "type")]
489 pub type_field: String,
490 pub function: Function,
491}
492
493#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
494#[serde(rename_all = "camelCase")]
495pub struct Function {
496 pub name: String,
497 pub description: String,
498 pub parameters: serde_json::Value,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
503where
504 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
505{
506 type Response = CompletionResponse;
507 type StreamingResponse = openai::StreamingCompletionResponse;
508 type Client = Client<T>;
509
510 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511 Self::new(client.clone(), model.into())
512 }
513
514 async fn completion(
515 &self,
516 completion_request: CompletionRequest,
517 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
518 let span = if tracing::Span::current().is_disabled() {
519 info_span!(
520 target: "rig::completions",
521 "chat",
522 gen_ai.operation.name = "chat",
523 gen_ai.provider.name = "groq",
524 gen_ai.request.model = self.model,
525 gen_ai.system_instructions = tracing::field::Empty,
526 gen_ai.response.id = tracing::field::Empty,
527 gen_ai.response.model = tracing::field::Empty,
528 gen_ai.usage.output_tokens = tracing::field::Empty,
529 gen_ai.usage.input_tokens = tracing::field::Empty,
530 )
531 } else {
532 tracing::Span::current()
533 };
534
535 span.record("gen_ai.system_instructions", &completion_request.preamble);
536
537 let request = self.create_completion_request(completion_request)?;
538
539 if tracing::enabled!(tracing::Level::TRACE) {
540 tracing::trace!(target: "rig::completions",
541 "Groq completion request: {}",
542 serde_json::to_string_pretty(&request)?
543 );
544 }
545
546 let body = serde_json::to_vec(&request)?;
547 let req = self
548 .client
549 .post("/chat/completions")?
550 .body(body)
551 .map_err(|e| http_client::Error::Instance(e.into()))?;
552
553 let async_block = async move {
554 let response = self.client.send::<_, Bytes>(req).await?;
555 let status = response.status();
556 let response_body = response.into_body().into_future().await?.to_vec();
557
558 let tt = response_body.clone();
559 let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
560 println!(
561 "response:\r\n {}",
562 serde_json::to_string_pretty(&response).unwrap()
563 );
564
565 if status.is_success() {
566 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
567 ApiResponse::Ok(response) => {
568 let span = tracing::Span::current();
569 span.record("gen_ai.response.id", response.id.clone());
570 span.record("gen_ai.response.model_name", response.model.clone());
571 if let Some(ref usage) = response.usage {
572 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
573 span.record(
574 "gen_ai.usage.output_tokens",
575 usage.total_tokens - usage.prompt_tokens,
576 );
577 }
578
579 if tracing::enabled!(tracing::Level::TRACE) {
580 tracing::trace!(target: "rig::completions",
581 "Groq completion response: {}",
582 serde_json::to_string_pretty(&response)?
583 );
584 }
585
586 response.try_into()
587 }
588 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
589 }
590 } else {
591 Err(CompletionError::ProviderError(
592 String::from_utf8_lossy(&response_body).to_string(),
593 ))
594 }
595 };
596
597 tracing::Instrument::instrument(async_block, span).await
598 }
599
600 async fn stream(
601 &self,
602 request: CompletionRequest,
603 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
604 let preamble = request.preamble.clone();
605
606 let mut request = self.create_completion_request(request)?;
607
608 request = merge(request, json!({"stream": true}));
609
610 let body = serde_json::to_vec(&request)?;
611
612 let req = self
613 .client
614 .post("/chat/completions")?
615 .body(body)
616 .map_err(|e| http_client::Error::Instance(e.into()))?;
617
618 let span = if tracing::Span::current().is_disabled() {
619 info_span!(
620 target: "rig::completions",
621 "chat_streaming",
622 gen_ai.operation.name = "chat_streaming",
623 gen_ai.provider.name = "galadriel",
624 gen_ai.request.model = self.model,
625 gen_ai.system_instructions = preamble,
626 gen_ai.response.id = tracing::field::Empty,
627 gen_ai.response.model = tracing::field::Empty,
628 gen_ai.usage.output_tokens = tracing::field::Empty,
629 gen_ai.usage.input_tokens = tracing::field::Empty,
630 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
631 gen_ai.output.messages = tracing::field::Empty,
632 )
633 } else {
634 tracing::Span::current()
635 };
636
637 send_compatible_streaming_request(self.client.clone(), req)
638 .instrument(span)
639 .await
640 }
641}