1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6 BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34 client: Client<T>,
35 pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40 Self {
41 client,
42 model: model.into(),
43 }
44 }
45
46 fn create_completion_request(
47 &self,
48 completion_request: CompletionRequest,
49 ) -> Result<Value, CompletionError> {
50 let mut partial_history = vec![];
52 if let Some(docs) = completion_request.normalized_documents() {
53 partial_history.push(docs);
54 }
55 partial_history.extend(completion_request.chat_history);
56
57 let mut full_history: Vec<Message> = completion_request
59 .preamble
60 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62 full_history.extend(
64 partial_history
65 .into_iter()
66 .map(message::Message::try_into)
67 .collect::<Result<Vec<Message>, _>>()?,
68 );
69
70 let request = if completion_request.tools.is_empty() {
71 json!({
72 "model": self.model,
73 "messages": full_history,
74 "temperature": completion_request.temperature,
75 })
76 } else {
77 let tools = completion_request
79 .tools
80 .into_iter()
81 .map(|item| {
82 let custom_function = Function {
83 name: item.name,
84 description: item.description,
85 parameters: item.parameters,
86 };
87 CustomFunctionDefinition {
88 type_field: "function".to_string(),
89 function: custom_function,
90 }
91 })
92 .collect::<Vec<_>>();
93
94 tracing::debug!("tools: {:?}", tools);
95
96 json!({
97 "model": self.model,
98 "messages": full_history,
99 "temperature": completion_request.temperature,
100 "tools": tools,
101 "tool_choice": "auto",
102 })
103 };
104
105 let request = if let Some(params) = completion_request.additional_params {
106 json_utils::merge(request, params)
107 } else {
108 request
109 };
110
111 Ok(request)
112 }
113}
114
115impl Provider for BigmodelExt {
116 const VERIFY_PATH: &'static str = "api/tags";
117 type Builder = BigmodelBuilder;
118}
119
120impl<H> Capabilities<H> for BigmodelExt {
122 type Completion = Capable<CompletionModel<H>>;
123 type Embeddings = Nothing;
124 type Transcription = Nothing;
125 type ModelListing = Nothing;
126
127 }
133
134impl DebugExt for BigmodelExt {}
135
136impl ProviderBuilder for BigmodelBuilder {
137 type Extension<H>
138 = BigmodelExt
139 where
140 H: HttpClientExt;
141 type ApiKey = BigmodelApiKey;
142 const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
143
144 fn build<H>(
145 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
146 ) -> http_client::Result<Self::Extension<H>>
147 where
148 H: HttpClientExt,
149 {
150 Ok(BigmodelExt)
151 }
152}
153
154pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
155pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
156
157#[derive(Debug, Deserialize)]
173struct ApiErrorResponse {
174 message: String,
175}
176
177#[derive(Debug, Deserialize)]
178#[serde(untagged)]
179enum ApiResponse<T> {
180 Ok(T),
181 Err(ApiErrorResponse),
182}
183
184pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
189
190#[derive(Debug, Deserialize, Serialize)]
191#[serde(rename_all = "camelCase")]
192pub struct CompletionResponse {
193 pub choices: Vec<Choice>,
194 pub created: i64,
195 pub id: String,
196 pub model: String,
197 #[serde(rename = "request_id")]
198 pub request_id: String,
199 pub usage: Option<Usage>,
200}
201
202#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
203#[serde(tag = "role", rename_all = "lowercase")]
204pub enum Message {
205 User {
206 content: String,
207 },
208 Assistant {
209 content: Option<String>,
210 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
211 tool_calls: Vec<ToolCall>,
212 },
213 System {
214 content: String,
215 },
216 #[serde(rename = "tool")]
217 ToolResult {
218 tool_call_id: String,
219 content: String,
220 },
221}
222
223impl Message {
224 pub fn system(content: &str) -> Message {
225 Message::System {
226 content: content.to_owned(),
227 }
228 }
229}
230
231#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
232pub struct ToolResultContent {
233 text: String,
234}
235impl TryFrom<message::ToolResultContent> for ToolResultContent {
236 type Error = MessageError;
237 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
238 let message::ToolResultContent::Text(Text { text }) = value else {
239 return Err(MessageError::ConversionError(
240 "Non-text tool results not supported".into(),
241 ));
242 };
243
244 Ok(Self { text })
245 }
246}
247
248impl TryFrom<message::Message> for Message {
249 type Error = MessageError;
250
251 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
252 Ok(match message {
253 message::Message::User { content } => {
254 let mut texts = Vec::new();
255 let mut images = Vec::new();
256
257 for uc in content.into_iter() {
258 match uc {
259 message::UserContent::Text(message::Text { text }) => texts.push(text),
260 message::UserContent::Image(img) => images.push(img.data),
261 message::UserContent::ToolResult(result) => {
262 let content = result
263 .content
264 .into_iter()
265 .map(ToolResultContent::try_from)
266 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
267
268 let content = OneOrMany::many(content).map_err(|x| {
269 MessageError::ConversionError(format!(
270 "Couldn't make a OneOrMany from a list of tool results: {x}"
271 ))
272 })?;
273
274 return Ok(Message::ToolResult {
275 tool_call_id: result.id,
276 content: content.first().text,
277 });
278 }
279 _ => {}
280 }
281 }
282
283 let collapsed_content = texts.join(" ");
284
285 Message::User {
286 content: collapsed_content,
287 }
288 }
289 message::Message::Assistant { content, .. } => {
290 let mut texts = Vec::new();
291 let mut tool_calls = Vec::new();
292
293 for ac in content.into_iter() {
294 match ac {
295 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
296 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
297 _ => {}
298 }
299 }
300
301 let collapsed_content = texts.join(" ");
302
303 Message::Assistant {
304 content: Some(collapsed_content),
305 tool_calls,
306 }
307 }
308 })
309 }
310}
311
312impl From<message::ToolResult> for Message {
313 fn from(tool_result: message::ToolResult) -> Self {
314 let content = match tool_result.content.first() {
315 message::ToolResultContent::Text(text) => text.text,
316 message::ToolResultContent::Image(_) => String::from("[Image]"),
317 };
318
319 Message::ToolResult {
320 tool_call_id: tool_result.id,
321 content,
322 }
323 }
324}
325
326#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
327#[serde(rename_all = "camelCase")]
328pub struct ToolCall {
329 pub function: CallFunction,
330 pub id: String,
331 pub index: usize,
332 #[serde(default)]
333 pub r#type: ToolType,
334}
335
336impl From<message::ToolCall> for ToolCall {
337 fn from(tool_call: message::ToolCall) -> Self {
338 Self {
339 id: tool_call.id,
340 index: 0,
341 r#type: ToolType::Function,
342 function: CallFunction {
343 name: tool_call.function.name,
344 arguments: tool_call.function.arguments,
345 },
346 }
347 }
348}
349
350#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
351#[serde(rename_all = "lowercase")]
352pub enum ToolType {
353 #[default]
354 Function,
355}
356
357#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
358pub struct CallFunction {
359 pub name: String,
360 #[serde(with = "json_utils::stringified_json")]
361 pub arguments: serde_json::Value,
362}
363
364#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
365#[serde(rename_all = "lowercase")]
366pub enum Role {
367 System,
368 User,
369 Assistant,
370}
371
372#[derive(Debug, Serialize, Deserialize)]
373#[serde(rename_all = "camelCase")]
374pub struct Choice {
375 #[serde(rename = "finish_reason")]
376 pub finish_reason: String,
377 pub index: i64,
378 pub message: Message,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
382#[serde(rename_all = "camelCase")]
383pub struct Usage {
384 #[serde(rename = "completion_tokens")]
385 pub completion_tokens: i64,
386 #[serde(rename = "prompt_tokens")]
387 pub prompt_tokens: i64,
388 #[serde(rename = "total_tokens")]
389 pub total_tokens: i64,
390 #[serde(skip_serializing_if = "Option::is_none")]
391 pub prompt_tokens_details: Option<PromptTokensDetails>,
392}
393
394#[derive(Clone, Debug, Deserialize, Serialize, Default)]
395pub struct PromptTokensDetails {
396 #[serde(default)]
398 pub cached_tokens: usize,
399}
400
401impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
402 type Error = CompletionError;
403
404 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
405 let choice = response.choices.first().ok_or_else(|| {
406 CompletionError::ResponseError("Response contained no choices".to_owned())
407 })?;
408
409 match &choice.message {
410 Message::Assistant {
411 tool_calls,
412 content,
413 } => {
414 if !tool_calls.is_empty() {
415 let tool_result = tool_calls
416 .iter()
417 .map(|call| {
418 completion::AssistantContent::tool_call(
419 &call.function.name,
420 &call.function.name,
421 call.function.arguments.clone(),
422 )
423 })
424 .collect::<Vec<_>>();
425
426 let choice = OneOrMany::many(tool_result).map_err(|_| {
427 CompletionError::ResponseError(
428 "Response contained no message or tool call (empty)".to_owned(),
429 )
430 })?;
431 let usage = response
432 .usage
433 .as_ref()
434 .map(|usage| completion::Usage {
435 input_tokens: usage.prompt_tokens as u64,
436 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
437 total_tokens: usage.total_tokens as u64,
438 cached_input_tokens: usage
439 .prompt_tokens_details
440 .as_ref()
441 .map(|d| d.cached_tokens as u64)
442 .unwrap_or(0),
443 })
444 .unwrap_or_default();
445 tracing::debug!("response choices: {:?}: ", choice);
446 Ok(completion::CompletionResponse {
447 choice,
448 usage,
449 raw_response: response,
450 message_id: None,
451 })
452 } else {
453 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
454 text: content.clone().unwrap_or_else(|| "".to_owned()),
455 }));
456 let usage = response
457 .usage
458 .as_ref()
459 .map(|usage| completion::Usage {
460 input_tokens: usage.prompt_tokens as u64,
461 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
462 total_tokens: usage.total_tokens as u64,
463 cached_input_tokens: usage
464 .prompt_tokens_details
465 .as_ref()
466 .map(|d| d.cached_tokens as u64)
467 .unwrap_or(0),
468 })
469 .unwrap_or_default();
470 Ok(completion::CompletionResponse {
471 choice,
472 usage,
473 raw_response: response,
474 message_id: None,
475 })
476 }
477 }
478 _ => Err(CompletionError::ResponseError(
480 "Chat response does not include an assistant message".into(),
481 )),
482 }
483 }
484}
485
486#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
488#[serde(rename_all = "camelCase")]
489pub struct CustomFunctionDefinition {
490 #[serde(rename = "type")]
491 pub type_field: String,
492 pub function: Function,
493}
494
495#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
496#[serde(rename_all = "camelCase")]
497pub struct Function {
498 pub name: String,
499 pub description: String,
500 pub parameters: serde_json::Value,
501}
502
503impl<T> completion::CompletionModel for CompletionModel<T>
505where
506 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
507{
508 type Response = CompletionResponse;
509 type StreamingResponse = openai::StreamingCompletionResponse;
510 type Client = Client<T>;
511
512 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
513 Self::new(client.clone(), model.into())
514 }
515
516 async fn completion(
517 &self,
518 completion_request: CompletionRequest,
519 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
520 let span = if tracing::Span::current().is_disabled() {
521 info_span!(
522 target: "rig::completions",
523 "chat",
524 gen_ai.operation.name = "chat",
525 gen_ai.provider.name = "groq",
526 gen_ai.request.model = self.model,
527 gen_ai.system_instructions = tracing::field::Empty,
528 gen_ai.response.id = tracing::field::Empty,
529 gen_ai.response.model = tracing::field::Empty,
530 gen_ai.usage.output_tokens = tracing::field::Empty,
531 gen_ai.usage.input_tokens = tracing::field::Empty,
532 )
533 } else {
534 tracing::Span::current()
535 };
536
537 span.record("gen_ai.system_instructions", &completion_request.preamble);
538
539 let request = self.create_completion_request(completion_request)?;
540
541 if tracing::enabled!(tracing::Level::TRACE) {
542 tracing::trace!(target: "rig::completions",
543 "Groq completion request: {}",
544 serde_json::to_string_pretty(&request)?
545 );
546 }
547
548 let body = serde_json::to_vec(&request)?;
549 let req = self
550 .client
551 .post("/chat/completions")?
552 .body(body)
553 .map_err(|e| http_client::Error::Instance(e.into()))?;
554
555 let async_block = async move {
556 let response = self.client.send::<_, Bytes>(req).await?;
557 let status = response.status();
558 let response_body = response.into_body().into_future().await?.to_vec();
559
560 let tt = response_body.clone();
561 let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
562 println!(
563 "response:\r\n {}",
564 serde_json::to_string_pretty(&response).unwrap()
565 );
566
567 if status.is_success() {
568 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
569 ApiResponse::Ok(response) => {
570 let span = tracing::Span::current();
571 span.record("gen_ai.response.id", response.id.clone());
572 span.record("gen_ai.response.model_name", response.model.clone());
573 if let Some(ref usage) = response.usage {
574 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
575 span.record(
576 "gen_ai.usage.output_tokens",
577 usage.total_tokens - usage.prompt_tokens,
578 );
579 }
580
581 if tracing::enabled!(tracing::Level::TRACE) {
582 tracing::trace!(target: "rig::completions",
583 "Groq completion response: {}",
584 serde_json::to_string_pretty(&response)?
585 );
586 }
587
588 response.try_into()
589 }
590 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
591 }
592 } else {
593 Err(CompletionError::ProviderError(
594 String::from_utf8_lossy(&response_body).to_string(),
595 ))
596 }
597 };
598
599 tracing::Instrument::instrument(async_block, span).await
600 }
601
602 async fn stream(
603 &self,
604 request: CompletionRequest,
605 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
606 let preamble = request.preamble.clone();
607
608 let mut request = self.create_completion_request(request)?;
609
610 request = merge(request, json!({"stream": true}));
611
612 let body = serde_json::to_vec(&request)?;
613
614 let req = self
615 .client
616 .post("/chat/completions")?
617 .body(body)
618 .map_err(|e| http_client::Error::Instance(e.into()))?;
619
620 let span = if tracing::Span::current().is_disabled() {
621 info_span!(
622 target: "rig::completions",
623 "chat_streaming",
624 gen_ai.operation.name = "chat_streaming",
625 gen_ai.provider.name = "galadriel",
626 gen_ai.request.model = self.model,
627 gen_ai.system_instructions = preamble,
628 gen_ai.response.id = tracing::field::Empty,
629 gen_ai.response.model = tracing::field::Empty,
630 gen_ai.usage.output_tokens = tracing::field::Empty,
631 gen_ai.usage.input_tokens = tracing::field::Empty,
632 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
633 gen_ai.output.messages = tracing::field::Empty,
634 )
635 } else {
636 tracing::Span::current()
637 };
638
639 send_compatible_streaming_request(self.client.clone(), req)
640 .instrument(span)
641 .await
642 }
643}