rig_bailian/
completion.rs

1use rig::completion::{self, CompletionError, CompletionRequest};
2use rig::http_client;
3use rig::message;
4use rig::providers::openai;
5use rig::providers::openai::completion::Usage;
6use rig::streaming::StreamingCompletionResponse;
7
8use serde_json::{Value, json};
9use tracing::{Instrument, info_span};
10
11use super::client::Client;
12use super::types::{ApiResponse, ToolChoice};
13
14/// Local deep-merge helper to avoid private rig::json_utils.
15/// - Merge objects recursively, right overrides left; otherwise returns right.
16fn merge(left: Value, right: Value) -> Value {
17    match (left, right) {
18        (Value::Object(mut a), Value::Object(b)) => {
19            for (k, v) in b {
20                let merged = match a.remove(&k) {
21                    Some(existing) => merge(existing, v),
22                    None => v,
23                };
24                a.insert(k, merged);
25            }
26            Value::Object(a)
27        }
28        (_, r) => r,
29    }
30}
31
32/// Chat completion model: CompletionModel<T>
33#[derive(Clone)]
34pub struct CompletionModel<T = reqwest::Client> {
35    pub(crate) client: Client<T>,
36    pub model: String,
37}
38
39impl<T> CompletionModel<T> {
40    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41        Self {
42            client,
43            model: model.into(),
44        }
45    }
46
47    pub(crate) fn create_completion_request(
48        &self,
49        completion_request: CompletionRequest,
50    ) -> Result<Value, CompletionError> {
51        // Build messages (include context documents if any)
52        let mut partial_history = vec![];
53        if let Some(docs) = completion_request.normalized_documents() {
54            partial_history.push(docs);
55        }
56        partial_history.extend(completion_request.chat_history);
57
58        // Preamble (system) goes first
59        let mut full_history: Vec<openai::Message> = completion_request
60            .preamble
61            .map_or_else(Vec::new, |preamble| {
62                vec![openai::Message::system(&preamble)]
63            });
64
65        // Convert user/assistant messages
66        full_history.extend(
67            partial_history
68                .into_iter()
69                .map(message::Message::try_into)
70                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
71                .into_iter()
72                .flatten()
73                .collect::<Vec<_>>(),
74        );
75
76        let tool_choice = completion_request
77            .tool_choice
78            .map(ToolChoice::try_from)
79            .transpose()?;
80
81        // OpenAI-compatible payload
82        let request = if completion_request.tools.is_empty() {
83            json!({
84                "model": self.model,
85                "messages": full_history,
86                "temperature": completion_request.temperature,
87                "max_tokens": completion_request.max_tokens,
88            })
89        } else {
90            json!({
91                "model": self.model,
92                "messages": full_history,
93                "temperature": completion_request.temperature,
94                "max_tokens": completion_request.max_tokens,
95                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
96                "tool_choice": tool_choice,
97            })
98        };
99
100        Ok(if let Some(params) = completion_request.additional_params {
101            merge(request, params)
102        } else {
103            request
104        })
105    }
106}
107
108impl TryFrom<message::ToolChoice> for ToolChoice {
109    type Error = CompletionError;
110
111    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
112        let res = match value {
113            message::ToolChoice::None => Self::None,
114            message::ToolChoice::Auto => Self::Auto,
115            message::ToolChoice::Required => Self::Required,
116            choice => {
117                return Err(CompletionError::ProviderError(format!(
118                    "Unsupported tool choice type: {choice:?}"
119                )));
120            }
121        };
122
123        Ok(res)
124    }
125}
126
127impl<T> completion::CompletionModel for CompletionModel<T>
128where
129    T: http_client::HttpClientExt + Clone + Default + Send + 'static,
130{
131    type Response = openai::CompletionResponse;
132    type StreamingResponse = openai::StreamingCompletionResponse;
133    type Client = Client<T>;
134
135    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
136        Self::new(client.clone(), model)
137    }
138
139    async fn completion(
140        &self,
141        completion_request: CompletionRequest,
142    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
143        let preamble = completion_request.preamble.clone();
144        let request = self.create_completion_request(completion_request)?;
145
146        let span = if tracing::Span::current().is_disabled() {
147            info_span!(
148                target: "rig::completions",
149                "chat",
150                gen_ai.operation.name = "chat",
151                gen_ai.provider.name = "bailian",
152                gen_ai.request.model = self.model,
153                gen_ai.system_instructions = preamble,
154                gen_ai.response.id = tracing::field::Empty,
155                gen_ai.response.model = tracing::field::Empty,
156                gen_ai.usage.output_tokens = tracing::field::Empty,
157                gen_ai.usage.input_tokens = tracing::field::Empty,
158                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap_or(&json!([]))).unwrap(),
159                gen_ai.output.messages = tracing::field::Empty,
160            )
161        } else {
162            tracing::Span::current()
163        };
164
165        async move {
166            let body = serde_json::to_vec(&request)?;
167            let req = self
168                .client
169                .post("/chat/completions")?
170                .header("Content-Type", "application/json")
171                .body(body)
172                .map_err(|e| CompletionError::HttpError(e.into()))?;
173
174            let response = http_client::HttpClientExt::send(&self.client.http_client, req)
175                .await
176                .map_err(CompletionError::HttpError)?;
177
178            if response.status().is_success() {
179                let t = http_client::text(response).await?;
180                tracing::debug!(target: "rig::completions", "Bailian completion response: {t}");
181
182                match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
183                    ApiResponse::Ok(response) => {
184                        let span = tracing::Span::current();
185                        span.record("gen_ai.response.id", response.id.clone());
186                        span.record("gen_ai.response.model_name", response.model.clone());
187                        span.record(
188                            "gen_ai.output.messages",
189                            serde_json::to_string(&response.choices).unwrap(),
190                        );
191                        if let Some(Usage {
192                            prompt_tokens,
193                            total_tokens,
194                            ..
195                        }) = response.usage
196                        {
197                            span.record("gen_ai.usage.input_tokens", prompt_tokens);
198                            span.record(
199                                "gen_ai.usage.output_tokens",
200                                total_tokens.saturating_sub(prompt_tokens),
201                            );
202                        }
203                        response.try_into()
204                    }
205                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
206                }
207            } else {
208                let t = http_client::text(response).await?;
209                Err(CompletionError::ProviderError(t))
210            }
211        }
212        .instrument(span)
213        .await
214    }
215
216    async fn stream(
217        &self,
218        request: CompletionRequest,
219    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
220        super::streaming::stream_completion(self, request).await
221    }
222}