1use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
12use crate::http_client::HttpClientExt;
13use crate::json_utils::merge;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 completion::{self, CompletionError, CompletionRequest},
18 json_utils,
19 providers::openai,
20};
21use crate::{http_client, impl_conversion_traits, message};
22use serde::{Deserialize, Serialize};
23use serde_json::{Value, json};
24use tracing::{Instrument, info_span};
25
26const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
30
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32 api_key: &'a str,
33 base_url: &'a str,
34 http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39 T: Default,
40{
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: MOONSHOT_API_BASE_URL,
45 http_client: Default::default(),
46 }
47 }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51 pub fn base_url(mut self, base_url: &'a str) -> Self {
52 self.base_url = base_url;
53 self
54 }
55
56 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57 ClientBuilder {
58 api_key: self.api_key,
59 base_url: self.base_url,
60 http_client,
61 }
62 }
63
64 pub fn build(self) -> Client<T> {
65 Client {
66 base_url: self.base_url.to_string(),
67 api_key: self.api_key.to_string(),
68 http_client: self.http_client,
69 }
70 }
71}
72
73#[derive(Clone)]
74pub struct Client<T = reqwest::Client> {
75 base_url: String,
76 api_key: String,
77 http_client: T,
78}
79
80impl<T> std::fmt::Debug for Client<T>
81where
82 T: std::fmt::Debug,
83{
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("Client")
86 .field("base_url", &self.base_url)
87 .field("http_client", &self.http_client)
88 .field("api_key", &"<REDACTED>")
89 .finish()
90 }
91}
92
93impl<T> Client<T>
94where
95 T: Default,
96{
97 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
108 ClientBuilder::new(api_key)
109 }
110
111 pub fn new(api_key: &str) -> Self {
116 Self::builder(api_key).build()
117 }
118}
119
120impl<T> Client<T>
121where
122 T: HttpClientExt,
123{
124 fn req(
125 &self,
126 method: http_client::Method,
127 path: &str,
128 ) -> http_client::Result<http_client::Builder> {
129 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
130
131 http_client::with_bearer_auth(
132 http_client::Builder::new().method(method).uri(url),
133 &self.api_key,
134 )
135 }
136
137 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
138 self.req(http_client::Method::GET, path)
139 }
140}
141
142impl Client<reqwest::Client> {
143 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
144 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
145
146 self.http_client.post(url).bearer_auth(&self.api_key)
147 }
148}
149
150impl ProviderClient for Client<reqwest::Client> {
151 fn from_env() -> Self {
154 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
155 Self::new(&api_key)
156 }
157
158 fn from_val(input: crate::client::ProviderValue) -> Self {
159 let crate::client::ProviderValue::Simple(api_key) = input else {
160 panic!("Incorrect provider value type")
161 };
162 Self::new(&api_key)
163 }
164}
165
166impl CompletionClient for Client<reqwest::Client> {
167 type CompletionModel = CompletionModel<reqwest::Client>;
168
169 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
181 CompletionModel::new(self.clone(), model)
182 }
183}
184
185impl VerifyClient for Client<reqwest::Client> {
186 #[cfg_attr(feature = "worker", worker::send)]
187 async fn verify(&self) -> Result<(), VerifyError> {
188 let req = self
189 .get("/models")?
190 .body(http_client::NoBody)
191 .map_err(http_client::Error::from)?;
192
193 let response = HttpClientExt::send(&self.http_client, req).await?;
194
195 match response.status() {
196 reqwest::StatusCode::OK => Ok(()),
197 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
198 reqwest::StatusCode::INTERNAL_SERVER_ERROR
199 | reqwest::StatusCode::SERVICE_UNAVAILABLE
200 | reqwest::StatusCode::BAD_GATEWAY => {
201 let text = http_client::text(response).await?;
202 Err(VerifyError::ProviderError(text))
203 }
204 _ => {
205 Ok(())
207 }
208 }
209 }
210}
211
212impl_conversion_traits!(
213 AsEmbeddings,
214 AsTranscription,
215 AsImageGeneration,
216 AsAudioGeneration for Client<T>
217);
218
219#[derive(Debug, Deserialize)]
220struct ApiErrorResponse {
221 error: MoonshotError,
222}
223
224#[derive(Debug, Deserialize)]
225struct MoonshotError {
226 message: String,
227}
228
229#[derive(Debug, Deserialize)]
230#[serde(untagged)]
231enum ApiResponse<T> {
232 Ok(T),
233 Err(ApiErrorResponse),
234}
235
236pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
240
241#[derive(Clone)]
242pub struct CompletionModel<T = reqwest::Client> {
243 client: Client<T>,
244 pub model: String,
245}
246
247impl<T> CompletionModel<T> {
248 pub fn new(client: Client<T>, model: &str) -> Self {
249 Self {
250 client,
251 model: model.to_string(),
252 }
253 }
254
255 fn create_completion_request(
256 &self,
257 completion_request: CompletionRequest,
258 ) -> Result<Value, CompletionError> {
259 let mut partial_history = vec![];
261 if let Some(docs) = completion_request.normalized_documents() {
262 partial_history.push(docs);
263 }
264 partial_history.extend(completion_request.chat_history);
265
266 let mut full_history: Vec<openai::Message> = completion_request
268 .preamble
269 .map_or_else(Vec::new, |preamble| {
270 vec![openai::Message::system(&preamble)]
271 });
272
273 full_history.extend(
275 partial_history
276 .into_iter()
277 .map(message::Message::try_into)
278 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
279 .into_iter()
280 .flatten()
281 .collect::<Vec<_>>(),
282 );
283
284 let tool_choice = completion_request
285 .tool_choice
286 .map(ToolChoice::try_from)
287 .transpose()?;
288
289 let request = if completion_request.tools.is_empty() {
290 json!({
291 "model": self.model,
292 "messages": full_history,
293 "temperature": completion_request.temperature,
294 })
295 } else {
296 json!({
297 "model": self.model,
298 "messages": full_history,
299 "temperature": completion_request.temperature,
300 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
301 "tool_choice": tool_choice,
302 })
303 };
304
305 let request = if let Some(params) = completion_request.additional_params {
306 json_utils::merge(request, params)
307 } else {
308 request
309 };
310
311 Ok(request)
312 }
313}
314
315impl completion::CompletionModel for CompletionModel<reqwest::Client> {
316 type Response = openai::CompletionResponse;
317 type StreamingResponse = openai::StreamingCompletionResponse;
318
319 #[cfg_attr(feature = "worker", worker::send)]
320 async fn completion(
321 &self,
322 completion_request: CompletionRequest,
323 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
324 let preamble = completion_request.preamble.clone();
325 let request = self.create_completion_request(completion_request)?;
326
327 let span = if tracing::Span::current().is_disabled() {
328 info_span!(
329 target: "rig::completions",
330 "chat",
331 gen_ai.operation.name = "chat",
332 gen_ai.provider.name = "moonshot",
333 gen_ai.request.model = self.model,
334 gen_ai.system_instructions = preamble,
335 gen_ai.response.id = tracing::field::Empty,
336 gen_ai.response.model = tracing::field::Empty,
337 gen_ai.usage.output_tokens = tracing::field::Empty,
338 gen_ai.usage.input_tokens = tracing::field::Empty,
339 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
340 gen_ai.output.messages = tracing::field::Empty,
341 )
342 } else {
343 tracing::Span::current()
344 };
345
346 let async_block = async move {
347 let response = self
348 .client
349 .reqwest_post("/chat/completions")
350 .json(&request)
351 .send()
352 .await
353 .map_err(|e| http_client::Error::Instance(e.into()))?;
354
355 if response.status().is_success() {
356 let t = response
357 .text()
358 .await
359 .map_err(|e| http_client::Error::Instance(e.into()))?;
360 tracing::debug!(target: "rig::completions", "MoonShot completion response: {t}");
361
362 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
363 ApiResponse::Ok(response) => {
364 let span = tracing::Span::current();
365 span.record("gen_ai.response.id", response.id.clone());
366 span.record("gen_ai.response.model_name", response.model.clone());
367 span.record(
368 "gen_ai.output.messages",
369 serde_json::to_string(&response.choices).unwrap(),
370 );
371 if let Some(ref usage) = response.usage {
372 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
373 span.record(
374 "gen_ai.usage.output_tokens",
375 usage.total_tokens - usage.prompt_tokens,
376 );
377 }
378 response.try_into()
379 }
380 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
381 }
382 } else {
383 Err(CompletionError::ProviderError(
384 response
385 .text()
386 .await
387 .map_err(|e| http_client::Error::Instance(e.into()))?,
388 ))
389 }
390 };
391
392 async_block.instrument(span).await
393 }
394
395 #[cfg_attr(feature = "worker", worker::send)]
396 async fn stream(
397 &self,
398 request: CompletionRequest,
399 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
400 let preamble = request.preamble.clone();
401 let mut request = self.create_completion_request(request)?;
402
403 let span = if tracing::Span::current().is_disabled() {
404 info_span!(
405 target: "rig::completions",
406 "chat_streaming",
407 gen_ai.operation.name = "chat_streaming",
408 gen_ai.provider.name = "moonshot",
409 gen_ai.request.model = self.model,
410 gen_ai.system_instructions = preamble,
411 gen_ai.response.id = tracing::field::Empty,
412 gen_ai.response.model = tracing::field::Empty,
413 gen_ai.usage.output_tokens = tracing::field::Empty,
414 gen_ai.usage.input_tokens = tracing::field::Empty,
415 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
416 gen_ai.output.messages = tracing::field::Empty,
417 )
418 } else {
419 tracing::Span::current()
420 };
421
422 request = merge(
423 request,
424 json!({"stream": true, "stream_options": {"include_usage": true}}),
425 );
426
427 let builder = self.client.reqwest_post("/chat/completions").json(&request);
428
429 send_compatible_streaming_request(builder)
430 .instrument(span)
431 .await
432 }
433}
434
435#[derive(Default, Debug, Deserialize, Serialize)]
436pub enum ToolChoice {
437 None,
438 #[default]
439 Auto,
440}
441
442impl TryFrom<message::ToolChoice> for ToolChoice {
443 type Error = CompletionError;
444
445 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
446 let res = match value {
447 message::ToolChoice::None => Self::None,
448 message::ToolChoice::Auto => Self::Auto,
449 choice => {
450 return Err(CompletionError::ProviderError(format!(
451 "Unsupported tool choice type: {choice:?}"
452 )));
453 }
454 };
455
456 Ok(res)
457 }
458}