rig_volcengine/
completion.rs1use rig::completion::{self, CompletionError, CompletionRequest};
2use rig::http_client;
3use rig::message;
4use rig::providers::openai;
5use rig::providers::openai::completion::Usage;
6use rig::streaming::StreamingCompletionResponse;
7
8use serde_json::{Value, json};
9use tracing::{Instrument, info_span};
10
11use super::client::Client;
12use super::types::{ApiResponse, ToolChoice};
13
14fn merge(left: Value, right: Value) -> Value {
17 match (left, right) {
18 (Value::Object(mut a), Value::Object(b)) => {
19 for (k, v) in b {
20 let merged = match a.remove(&k) {
21 Some(existing) => merge(existing, v),
22 None => v,
23 };
24 a.insert(k, merged);
25 }
26 Value::Object(a)
27 }
28 (_, r) => r,
29 }
30}
31
32#[derive(Clone)]
34pub struct CompletionModel<T = reqwest::Client> {
35 pub(crate) client: Client<T>,
36 pub model: String,
37}
38
39impl<T> CompletionModel<T> {
40 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41 Self {
42 client,
43 model: model.into(),
44 }
45 }
46
47 pub(crate) fn create_completion_request(
48 &self,
49 completion_request: CompletionRequest,
50 ) -> Result<Value, CompletionError> {
51 let mut partial_history = vec![];
53 if let Some(docs) = completion_request.normalized_documents() {
54 partial_history.push(docs);
55 }
56 partial_history.extend(completion_request.chat_history);
57
58 let mut full_history: Vec<openai::Message> = completion_request
60 .preamble
61 .map_or_else(Vec::new, |preamble| {
62 vec![openai::Message::system(&preamble)]
63 });
64
65 full_history.extend(
67 partial_history
68 .into_iter()
69 .map(message::Message::try_into)
70 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
71 .into_iter()
72 .flatten()
73 .collect::<Vec<_>>(),
74 );
75
76 let tool_choice = completion_request
77 .tool_choice
78 .map(ToolChoice::try_from)
79 .transpose()?;
80
81 let request = if completion_request.tools.is_empty() {
83 json!({
84 "model": self.model,
85 "messages": full_history,
86 "temperature": completion_request.temperature,
87 "max_tokens": completion_request.max_tokens,
88 })
89 } else {
90 json!({
91 "model": self.model,
92 "messages": full_history,
93 "temperature": completion_request.temperature,
94 "max_tokens": completion_request.max_tokens,
95 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
96 "tool_choice": tool_choice,
97 })
98 };
99
100 Ok(if let Some(params) = completion_request.additional_params {
101 merge(request, params)
102 } else {
103 request
104 })
105 }
106}
107
108impl TryFrom<message::ToolChoice> for ToolChoice {
109 type Error = CompletionError;
110
111 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
112 let res = match value {
113 message::ToolChoice::None => Self::None,
114 message::ToolChoice::Auto => Self::Auto,
115 message::ToolChoice::Required => Self::Required,
116 choice => {
117 return Err(CompletionError::ProviderError(format!(
118 "Unsupported tool choice type: {choice:?}"
119 )));
120 }
121 };
122
123 Ok(res)
124 }
125}
126
127impl<T> completion::CompletionModel for CompletionModel<T>
128where
129 T: http_client::HttpClientExt + Clone + Default + Send + 'static,
130{
131 type Response = openai::CompletionResponse;
132 type StreamingResponse = openai::StreamingCompletionResponse;
133 type Client = Client<T>;
134
135 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
136 Self::new(client.clone(), model)
137 }
138
139 async fn completion(
140 &self,
141 completion_request: CompletionRequest,
142 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
143 let preamble = completion_request.preamble.clone();
144 let request = self.create_completion_request(completion_request)?;
145
146 let span = if tracing::Span::current().is_disabled() {
147 info_span!(
148 target: "rig::completions",
149 "chat",
150 gen_ai.operation.name = "chat",
151 gen_ai.provider.name = "volcengine",
152 gen_ai.request.model = self.model,
153 gen_ai.system_instructions = preamble,
154 gen_ai.response.id = tracing::field::Empty,
155 gen_ai.response.model = tracing::field::Empty,
156 gen_ai.usage.output_tokens = tracing::field::Empty,
157 gen_ai.usage.input_tokens = tracing::field::Empty,
158 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap_or(&json!([]))).unwrap(),
159 gen_ai.output.messages = tracing::field::Empty,
160 )
161 } else {
162 tracing::Span::current()
163 };
164
165 async move {
166 let body = serde_json::to_vec(&request)?;
167 let req = self
168 .client
169 .post("/chat/completions")?
170 .header("Content-Type", "application/json")
171 .body(body)
172 .map_err(|e| CompletionError::HttpError(e.into()))?;
173
174 let response = http_client::HttpClientExt::send(&self.client.http_client, req)
175 .await
176 .map_err(CompletionError::HttpError)?;
177
178 if response.status().is_success() {
179 let t = http_client::text(response).await?;
180 tracing::debug!(target: "rig::completions", "Volcengine completion response: {t}");
181
182 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
183 ApiResponse::Ok(response) => {
184 let span = tracing::Span::current();
185 span.record("gen_ai.response.id", response.id.clone());
186 span.record("gen_ai.response.model_name", response.model.clone());
187 span.record(
188 "gen_ai.output.messages",
189 serde_json::to_string(&response.choices).unwrap(),
190 );
191 if let Some(Usage {
192 prompt_tokens,
193 total_tokens,
194 ..
195 }) = response.usage
196 {
197 span.record("gen_ai.usage.input_tokens", prompt_tokens);
198 span.record(
199 "gen_ai.usage.output_tokens",
200 total_tokens.saturating_sub(prompt_tokens),
201 );
202 }
203 response.try_into()
204 }
205 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
206 }
207 } else {
208 let t = http_client::text(response).await?;
209 Err(CompletionError::ProviderError(t))
210 }
211 }
212 .instrument(span)
213 .await
214 }
215
216 async fn stream(
217 &self,
218 request: CompletionRequest,
219 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
220 super::streaming::stream_completion(self, request).await
221 }
222}