rig_volcengine/
completion.rs

1use rig::completion::{self, CompletionError, CompletionRequest};
2use rig::http_client;
3use rig::message;
4use rig::providers::openai;
5use rig::providers::openai::completion::Usage;
6use rig::streaming::StreamingCompletionResponse;
7
8use serde_json::{Value, json};
9use tracing::{Instrument, info_span};
10
11use super::client::Client;
12use super::types::{ApiResponse, ToolChoice};
13
14/// Local deep-merge helper to avoid private rig::json_utils.
15/// - Merge objects recursively, right overrides left; otherwise returns right.
16fn merge(left: Value, right: Value) -> Value {
17    match (left, right) {
18        (Value::Object(mut a), Value::Object(b)) => {
19            for (k, v) in b {
20                let merged = match a.remove(&k) {
21                    Some(existing) => merge(existing, v),
22                    None => v,
23                };
24                a.insert(k, merged);
25            }
26            Value::Object(a)
27        }
28        (_, r) => r,
29    }
30}
31
32/// Chat completion model: CompletionModel<T>
33#[derive(Clone)]
34pub struct CompletionModel<T = reqwest::Client> {
35    pub(crate) client: Client<T>,
36    pub model: String,
37}
38
39impl<T> CompletionModel<T> {
40    pub fn new(client: Client<T>, model: &str) -> Self {
41        Self {
42            client,
43            model: model.to_string(),
44        }
45    }
46
47    pub(crate) fn create_completion_request(
48        &self,
49        completion_request: CompletionRequest,
50    ) -> Result<Value, CompletionError> {
51        // Build messages (include context documents if any)
52        let mut partial_history = vec![];
53        if let Some(docs) = completion_request.normalized_documents() {
54            partial_history.push(docs);
55        }
56        partial_history.extend(completion_request.chat_history);
57
58        // Preamble (system) goes first
59        let mut full_history: Vec<openai::Message> = completion_request
60            .preamble
61            .map_or_else(Vec::new, |preamble| {
62                vec![openai::Message::system(&preamble)]
63            });
64
65        // Convert user/assistant messages
66        full_history.extend(
67            partial_history
68                .into_iter()
69                .map(message::Message::try_into)
70                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
71                .into_iter()
72                .flatten()
73                .collect::<Vec<_>>(),
74        );
75
76        let tool_choice = completion_request
77            .tool_choice
78            .map(ToolChoice::try_from)
79            .transpose()?;
80
81        // OpenAI-compatible payload
82        let request = if completion_request.tools.is_empty() {
83            json!({
84                "model": self.model,
85                "messages": full_history,
86                "temperature": completion_request.temperature,
87                "max_tokens": completion_request.max_tokens,
88            })
89        } else {
90            json!({
91                "model": self.model,
92                "messages": full_history,
93                "temperature": completion_request.temperature,
94                "max_tokens": completion_request.max_tokens,
95                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
96                "tool_choice": tool_choice,
97            })
98        };
99
100        Ok(if let Some(params) = completion_request.additional_params {
101            merge(request, params)
102        } else {
103            request
104        })
105    }
106}
107
108impl TryFrom<message::ToolChoice> for ToolChoice {
109    type Error = CompletionError;
110
111    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
112        let res = match value {
113            message::ToolChoice::None => Self::None,
114            message::ToolChoice::Auto => Self::Auto,
115            message::ToolChoice::Required => Self::Required,
116            choice => {
117                return Err(CompletionError::ProviderError(format!(
118                    "Unsupported tool choice type: {choice:?}"
119                )));
120            }
121        };
122
123        Ok(res)
124    }
125}
126
127impl completion::CompletionModel for CompletionModel<reqwest::Client> {
128    type Response = openai::CompletionResponse;
129    type StreamingResponse = openai::StreamingCompletionResponse;
130
131    async fn completion(
132        &self,
133        completion_request: CompletionRequest,
134    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
135        let preamble = completion_request.preamble.clone();
136        let request = self.create_completion_request(completion_request)?;
137
138        let span = if tracing::Span::current().is_disabled() {
139            info_span!(
140                target: "rig::completions",
141                "chat",
142                gen_ai.operation.name = "chat",
143                gen_ai.provider.name = "volcengine",
144                gen_ai.request.model = self.model,
145                gen_ai.system_instructions = preamble,
146                gen_ai.response.id = tracing::field::Empty,
147                gen_ai.response.model = tracing::field::Empty,
148                gen_ai.usage.output_tokens = tracing::field::Empty,
149                gen_ai.usage.input_tokens = tracing::field::Empty,
150                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap_or(&json!([]))).unwrap(),
151                gen_ai.output.messages = tracing::field::Empty,
152            )
153        } else {
154            tracing::Span::current()
155        };
156
157        async move {
158            let response = self
159                .client
160                .reqwest_post("/chat/completions")
161                .json(&request)
162                .send()
163                .await
164                .map_err(|e| http_client::Error::Instance(e.into()))?;
165
166            if response.status().is_success() {
167                let t = response
168                    .text()
169                    .await
170                    .map_err(|e| http_client::Error::Instance(e.into()))?;
171                tracing::debug!(target: "rig::completions", "Volcengine completion response: {t}");
172
173                match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
174                    ApiResponse::Ok(response) => {
175                        let span = tracing::Span::current();
176                        span.record("gen_ai.response.id", response.id.clone());
177                        span.record("gen_ai.response.model_name", response.model.clone());
178                        span.record(
179                            "gen_ai.output.messages",
180                            serde_json::to_string(&response.choices).unwrap(),
181                        );
182                        if let Some(Usage {
183                            prompt_tokens,
184                            total_tokens,
185                            ..
186                        }) = response.usage
187                        {
188                            span.record("gen_ai.usage.input_tokens", prompt_tokens);
189                            span.record(
190                                "gen_ai.usage.output_tokens",
191                                total_tokens.saturating_sub(prompt_tokens),
192                            );
193                        }
194                        response.try_into()
195                    }
196                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
197                }
198            } else {
199                Err(CompletionError::ProviderError(
200                    response
201                        .text()
202                        .await
203                        .map_err(|e| http_client::Error::Instance(e.into()))?,
204                ))
205            }
206        }
207        .instrument(span)
208        .await
209    }
210
211    async fn stream(
212        &self,
213        request: CompletionRequest,
214    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
215        super::streaming::stream_completion(self, request).await
216    }
217}