1use crate::json_utils;
2use crate::json_utils::merge;
3use bytes::Bytes;
4use rig::agent::Text;
5use rig::client::{
6 BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
7};
8use rig::completion::{CompletionError, CompletionRequest};
9use rig::http_client::HttpClientExt;
10use rig::message::MessageError;
11use rig::providers::openai;
12use rig::providers::openai::send_compatible_streaming_request;
13use rig::streaming::StreamingCompletionResponse;
14use rig::{OneOrMany, client, completion, http_client, message};
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19const BIGMODEL_API_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/";
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct BigmodelExt;
25
26#[derive(Debug, Default, Clone, Copy)]
27
28pub struct BigmodelBuilder;
29
30type BigmodelApiKey = BearerAuth;
31
32#[derive(Clone, Debug)]
33pub struct CompletionModel<T = reqwest::Client> {
34 client: Client<T>,
35 pub model: String,
36}
37
38impl<T> CompletionModel<T> {
39 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
40 Self {
41 client,
42 model: model.into(),
43 }
44 }
45
46 fn create_completion_request(
47 &self,
48 completion_request: CompletionRequest,
49 ) -> Result<Value, CompletionError> {
50 let mut partial_history = vec![];
52 if let Some(docs) = completion_request.normalized_documents() {
53 partial_history.push(docs);
54 }
55 partial_history.extend(completion_request.chat_history);
56
57 let mut full_history: Vec<Message> = completion_request
59 .preamble
60 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
61
62 full_history.extend(
64 partial_history
65 .into_iter()
66 .map(message::Message::try_into)
67 .collect::<Result<Vec<Message>, _>>()?,
68 );
69
70 let request = if completion_request.tools.is_empty() {
71 json!({
72 "model": self.model,
73 "messages": full_history,
74 "temperature": completion_request.temperature,
75 })
76 } else {
77 let tools = completion_request
79 .tools
80 .into_iter()
81 .map(|item| {
82 let custom_function = Function {
83 name: item.name,
84 description: item.description,
85 parameters: item.parameters,
86 };
87 CustomFunctionDefinition {
88 type_field: "function".to_string(),
89 function: custom_function,
90 }
91 })
92 .collect::<Vec<_>>();
93
94 tracing::debug!("tools: {:?}", tools);
95
96 json!({
97 "model": self.model,
98 "messages": full_history,
99 "temperature": completion_request.temperature,
100 "tools": tools,
101 "tool_choice": "auto",
102 })
103 };
104
105 let request = if let Some(params) = completion_request.additional_params {
106 json_utils::merge(request, params)
107 } else {
108 request
109 };
110
111 Ok(request)
112 }
113}
114
115impl Provider for BigmodelExt {
116 const VERIFY_PATH: &'static str = "api/tags";
117 type Builder = BigmodelBuilder;
118
119 fn build<H>(
120 _builder: &client::ClientBuilder<
121 Self::Builder,
122 <Self::Builder as ProviderBuilder>::ApiKey,
123 H,
124 >,
125 ) -> rig::http_client::Result<Self> {
126 Ok(Self)
127 }
128}
129
130impl<H> Capabilities<H> for BigmodelExt {
132 type Completion = Capable<CompletionModel<H>>;
133 type Embeddings = Nothing;
134 type Transcription = Nothing;
135
136 }
142
143impl DebugExt for BigmodelExt {}
144
145impl ProviderBuilder for BigmodelBuilder {
146 type Output = BigmodelExt;
147 type ApiKey = BigmodelApiKey;
148 const BASE_URL: &'static str = BIGMODEL_API_BASE_URL;
149}
150
151pub type Client<H = reqwest::Client> = client::Client<BigmodelExt, H>;
152pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<BigmodelBuilder, String, H>;
153
154#[derive(Debug, Deserialize)]
170struct ApiErrorResponse {
171 message: String,
172}
173
174#[derive(Debug, Deserialize)]
175#[serde(untagged)]
176enum ApiResponse<T> {
177 Ok(T),
178 Err(ApiErrorResponse),
179}
180
181pub const BIGMODEL_GLM_4_7_FLASH: &str = "glm-4.7-flash";
186
187#[derive(Debug, Deserialize, Serialize)]
188#[serde(rename_all = "camelCase")]
189pub struct CompletionResponse {
190 pub choices: Vec<Choice>,
191 pub created: i64,
192 pub id: String,
193 pub model: String,
194 #[serde(rename = "request_id")]
195 pub request_id: String,
196 pub usage: Option<Usage>,
197}
198
199#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
200#[serde(tag = "role", rename_all = "lowercase")]
201pub enum Message {
202 User {
203 content: String,
204 },
205 Assistant {
206 content: Option<String>,
207 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
208 tool_calls: Vec<ToolCall>,
209 },
210 System {
211 content: String,
212 },
213 #[serde(rename = "tool")]
214 ToolResult {
215 tool_call_id: String,
216 content: String,
217 },
218}
219
220impl Message {
221 pub fn system(content: &str) -> Message {
222 Message::System {
223 content: content.to_owned(),
224 }
225 }
226}
227
228#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
229pub struct ToolResultContent {
230 text: String,
231}
232impl TryFrom<message::ToolResultContent> for ToolResultContent {
233 type Error = MessageError;
234 fn try_from(value: message::ToolResultContent) -> Result<Self, Self::Error> {
235 let message::ToolResultContent::Text(Text { text }) = value else {
236 return Err(MessageError::ConversionError(
237 "Non-text tool results not supported".into(),
238 ));
239 };
240
241 Ok(Self { text })
242 }
243}
244
245impl TryFrom<message::Message> for Message {
246 type Error = MessageError;
247
248 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
249 Ok(match message {
250 message::Message::User { content } => {
251 let mut texts = Vec::new();
252 let mut images = Vec::new();
253
254 for uc in content.into_iter() {
255 match uc {
256 message::UserContent::Text(message::Text { text }) => texts.push(text),
257 message::UserContent::Image(img) => images.push(img.data),
258 message::UserContent::ToolResult(result) => {
259 let content = result
260 .content
261 .into_iter()
262 .map(ToolResultContent::try_from)
263 .collect::<Result<Vec<ToolResultContent>, MessageError>>()?;
264
265 let content = OneOrMany::many(content).map_err(|x| {
266 MessageError::ConversionError(format!(
267 "Couldn't make a OneOrMany from a list of tool results: {x}"
268 ))
269 })?;
270
271 return Ok(Message::ToolResult {
272 tool_call_id: result.id,
273 content: content.first().text,
274 });
275 }
276 _ => {}
277 }
278 }
279
280 let collapsed_content = texts.join(" ");
281
282 Message::User {
283 content: collapsed_content,
284 }
285 }
286 message::Message::Assistant { content, .. } => {
287 let mut texts = Vec::new();
288 let mut tool_calls = Vec::new();
289
290 for ac in content.into_iter() {
291 match ac {
292 message::AssistantContent::Text(message::Text { text }) => texts.push(text),
293 message::AssistantContent::ToolCall(tc) => tool_calls.push(tc.into()),
294 _ => {}
295 }
296 }
297
298 let collapsed_content = texts.join(" ");
299
300 Message::Assistant {
301 content: Some(collapsed_content),
302 tool_calls,
303 }
304 }
305 })
306 }
307}
308
309impl From<message::ToolResult> for Message {
310 fn from(tool_result: message::ToolResult) -> Self {
311 let content = match tool_result.content.first() {
312 message::ToolResultContent::Text(text) => text.text,
313 message::ToolResultContent::Image(_) => String::from("[Image]"),
314 };
315
316 Message::ToolResult {
317 tool_call_id: tool_result.id,
318 content,
319 }
320 }
321}
322
323#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
324#[serde(rename_all = "camelCase")]
325pub struct ToolCall {
326 pub function: CallFunction,
327 pub id: String,
328 pub index: usize,
329 #[serde(default)]
330 pub r#type: ToolType,
331}
332
333impl From<message::ToolCall> for ToolCall {
334 fn from(tool_call: message::ToolCall) -> Self {
335 Self {
336 id: tool_call.id,
337 index: 0,
338 r#type: ToolType::Function,
339 function: CallFunction {
340 name: tool_call.function.name,
341 arguments: tool_call.function.arguments,
342 },
343 }
344 }
345}
346
347#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
348#[serde(rename_all = "lowercase")]
349pub enum ToolType {
350 #[default]
351 Function,
352}
353
354#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
355pub struct CallFunction {
356 pub name: String,
357 #[serde(with = "json_utils::stringified_json")]
358 pub arguments: serde_json::Value,
359}
360
361#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
362#[serde(rename_all = "lowercase")]
363pub enum Role {
364 System,
365 User,
366 Assistant,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370#[serde(rename_all = "camelCase")]
371pub struct Choice {
372 #[serde(rename = "finish_reason")]
373 pub finish_reason: String,
374 pub index: i64,
375 pub message: Message,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
379#[serde(rename_all = "camelCase")]
380pub struct Usage {
381 #[serde(rename = "completion_tokens")]
382 pub completion_tokens: i64,
383 #[serde(rename = "prompt_tokens")]
384 pub prompt_tokens: i64,
385 #[serde(rename = "total_tokens")]
386 pub total_tokens: i64,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub prompt_tokens_details: Option<PromptTokensDetails>,
389}
390
391#[derive(Clone, Debug, Deserialize, Serialize, Default)]
392pub struct PromptTokensDetails {
393 #[serde(default)]
395 pub cached_tokens: usize,
396}
397
398impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
399 type Error = CompletionError;
400
401 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
402 let choice = response.choices.first().ok_or_else(|| {
403 CompletionError::ResponseError("Response contained no choices".to_owned())
404 })?;
405
406 match &choice.message {
407 Message::Assistant {
408 tool_calls,
409 content,
410 } => {
411 if !tool_calls.is_empty() {
412 let tool_result = tool_calls
413 .iter()
414 .map(|call| {
415 completion::AssistantContent::tool_call(
416 &call.function.name,
417 &call.function.name,
418 call.function.arguments.clone(),
419 )
420 })
421 .collect::<Vec<_>>();
422
423 let choice = OneOrMany::many(tool_result).map_err(|_| {
424 CompletionError::ResponseError(
425 "Response contained no message or tool call (empty)".to_owned(),
426 )
427 })?;
428 let usage = response
429 .usage
430 .as_ref()
431 .map(|usage| completion::Usage {
432 input_tokens: usage.prompt_tokens as u64,
433 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
434 total_tokens: usage.total_tokens as u64,
435 cached_input_tokens: usage
436 .prompt_tokens_details
437 .as_ref()
438 .map(|d| d.cached_tokens as u64)
439 .unwrap_or(0),
440 })
441 .unwrap_or_default();
442 tracing::debug!("response choices: {:?}: ", choice);
443 Ok(completion::CompletionResponse {
444 choice,
445 usage,
446 raw_response: response,
447 })
448 } else {
449 let choice = OneOrMany::one(message::AssistantContent::Text(Text {
450 text: content.clone().unwrap_or_else(|| "".to_owned()),
451 }));
452 let usage = response
453 .usage
454 .as_ref()
455 .map(|usage| completion::Usage {
456 input_tokens: usage.prompt_tokens as u64,
457 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
458 total_tokens: usage.total_tokens as u64,
459 cached_input_tokens: usage
460 .prompt_tokens_details
461 .as_ref()
462 .map(|d| d.cached_tokens as u64)
463 .unwrap_or(0),
464 })
465 .unwrap_or_default();
466 Ok(completion::CompletionResponse {
467 choice,
468 usage,
469 raw_response: response,
470 })
471 }
472 }
473 _ => Err(CompletionError::ResponseError(
475 "Chat response does not include an assistant message".into(),
476 )),
477 }
478 }
479}
480
481#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
483#[serde(rename_all = "camelCase")]
484pub struct CustomFunctionDefinition {
485 #[serde(rename = "type")]
486 pub type_field: String,
487 pub function: Function,
488}
489
490#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
491#[serde(rename_all = "camelCase")]
492pub struct Function {
493 pub name: String,
494 pub description: String,
495 pub parameters: serde_json::Value,
496}
497
498impl<T> completion::CompletionModel for CompletionModel<T>
500where
501 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
502{
503 type Response = CompletionResponse;
504 type StreamingResponse = openai::StreamingCompletionResponse;
505 type Client = Client<T>;
506
507 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
508 Self::new(client.clone(), model.into())
509 }
510
511 async fn completion(
512 &self,
513 completion_request: CompletionRequest,
514 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
515 let span = if tracing::Span::current().is_disabled() {
516 info_span!(
517 target: "rig::completions",
518 "chat",
519 gen_ai.operation.name = "chat",
520 gen_ai.provider.name = "groq",
521 gen_ai.request.model = self.model,
522 gen_ai.system_instructions = tracing::field::Empty,
523 gen_ai.response.id = tracing::field::Empty,
524 gen_ai.response.model = tracing::field::Empty,
525 gen_ai.usage.output_tokens = tracing::field::Empty,
526 gen_ai.usage.input_tokens = tracing::field::Empty,
527 )
528 } else {
529 tracing::Span::current()
530 };
531
532 span.record("gen_ai.system_instructions", &completion_request.preamble);
533
534 let request = self.create_completion_request(completion_request)?;
535
536 if tracing::enabled!(tracing::Level::TRACE) {
537 tracing::trace!(target: "rig::completions",
538 "Groq completion request: {}",
539 serde_json::to_string_pretty(&request)?
540 );
541 }
542
543 let body = serde_json::to_vec(&request)?;
544 let req = self
545 .client
546 .post("/chat/completions")?
547 .body(body)
548 .map_err(|e| http_client::Error::Instance(e.into()))?;
549
550 let async_block = async move {
551 let response = self.client.send::<_, Bytes>(req).await?;
552 let status = response.status();
553 let response_body = response.into_body().into_future().await?.to_vec();
554
555 let tt = response_body.clone();
556 let response = serde_json::from_slice::<serde_json::Value>(&tt)?;
557 println!(
558 "response:\r\n {}",
559 serde_json::to_string_pretty(&response).unwrap()
560 );
561
562 if status.is_success() {
563 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
564 ApiResponse::Ok(response) => {
565 let span = tracing::Span::current();
566 span.record("gen_ai.response.id", response.id.clone());
567 span.record("gen_ai.response.model_name", response.model.clone());
568 if let Some(ref usage) = response.usage {
569 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
570 span.record(
571 "gen_ai.usage.output_tokens",
572 usage.total_tokens - usage.prompt_tokens,
573 );
574 }
575
576 if tracing::enabled!(tracing::Level::TRACE) {
577 tracing::trace!(target: "rig::completions",
578 "Groq completion response: {}",
579 serde_json::to_string_pretty(&response)?
580 );
581 }
582
583 response.try_into()
584 }
585 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
586 }
587 } else {
588 Err(CompletionError::ProviderError(
589 String::from_utf8_lossy(&response_body).to_string(),
590 ))
591 }
592 };
593
594 tracing::Instrument::instrument(async_block, span).await
595 }
596
597 async fn stream(
598 &self,
599 request: CompletionRequest,
600 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
601 let preamble = request.preamble.clone();
602
603 let mut request = self.create_completion_request(request)?;
604
605 request = merge(request, json!({"stream": true}));
606
607 let body = serde_json::to_vec(&request)?;
608
609 let req = self
610 .client
611 .post("/chat/completions")?
612 .body(body)
613 .map_err(|e| http_client::Error::Instance(e.into()))?;
614
615 let span = if tracing::Span::current().is_disabled() {
616 info_span!(
617 target: "rig::completions",
618 "chat_streaming",
619 gen_ai.operation.name = "chat_streaming",
620 gen_ai.provider.name = "galadriel",
621 gen_ai.request.model = self.model,
622 gen_ai.system_instructions = preamble,
623 gen_ai.response.id = tracing::field::Empty,
624 gen_ai.response.model = tracing::field::Empty,
625 gen_ai.usage.output_tokens = tracing::field::Empty,
626 gen_ai.usage.input_tokens = tracing::field::Empty,
627 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
628 gen_ai.output.messages = tracing::field::Empty,
629 )
630 } else {
631 tracing::Span::current()
632 };
633
634 send_compatible_streaming_request(self.client.clone(), req)
635 .instrument(span)
636 .await
637 }
638}