1use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
12use crate::http_client::HttpClientExt;
13use crate::json_utils::merge;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 completion::{self, CompletionError, CompletionRequest},
18 json_utils,
19 providers::openai,
20};
21use crate::{http_client, impl_conversion_traits, message};
22use http::Method;
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use tracing::{Instrument, info_span};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32pub struct ClientBuilder<'a, T = reqwest::Client> {
33 api_key: &'a str,
34 base_url: &'a str,
35 http_client: T,
36}
37
38impl<'a, T> ClientBuilder<'a, T>
39where
40 T: Default,
41{
42 pub fn new(api_key: &'a str) -> Self {
43 Self {
44 api_key,
45 base_url: MOONSHOT_API_BASE_URL,
46 http_client: Default::default(),
47 }
48 }
49}
50
51impl<'a, T> ClientBuilder<'a, T> {
52 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
53 Self {
54 api_key,
55 base_url: MOONSHOT_API_BASE_URL,
56 http_client,
57 }
58 }
59
60 pub fn base_url(mut self, base_url: &'a str) -> Self {
61 self.base_url = base_url;
62 self
63 }
64
65 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
66 ClientBuilder {
67 api_key: self.api_key,
68 base_url: self.base_url,
69 http_client,
70 }
71 }
72
73 pub fn build(self) -> Client<T> {
74 Client {
75 base_url: self.base_url.to_string(),
76 api_key: self.api_key.to_string(),
77 http_client: self.http_client,
78 }
79 }
80}
81
82#[derive(Clone)]
83pub struct Client<T = reqwest::Client> {
84 base_url: String,
85 api_key: String,
86 http_client: T,
87}
88
89impl<T> std::fmt::Debug for Client<T>
90where
91 T: std::fmt::Debug,
92{
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("Client")
95 .field("base_url", &self.base_url)
96 .field("http_client", &self.http_client)
97 .field("api_key", &"<REDACTED>")
98 .finish()
99 }
100}
101
102impl<T> Client<T>
103where
104 T: HttpClientExt,
105{
106 fn req(
107 &self,
108 method: http_client::Method,
109 path: &str,
110 ) -> http_client::Result<http_client::Builder> {
111 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113 http_client::with_bearer_auth(
114 http_client::Builder::new().method(method).uri(url),
115 &self.api_key,
116 )
117 }
118
119 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
120 self.req(http_client::Method::GET, path)
121 }
122}
123
124impl Client<reqwest::Client> {
125 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
126 ClientBuilder::new(api_key)
127 }
128
129 pub fn new(api_key: &str) -> Self {
130 Self::builder(api_key).build()
131 }
132
133 pub fn from_env() -> Self {
134 <Self as ProviderClient>::from_env()
135 }
136}
137
138impl<T> ProviderClient for Client<T>
139where
140 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
141{
142 fn from_env() -> Self {
145 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
146 ClientBuilder::<T>::new(&api_key).build()
147 }
148
149 fn from_val(input: crate::client::ProviderValue) -> Self {
150 let crate::client::ProviderValue::Simple(api_key) = input else {
151 panic!("Incorrect provider value type")
152 };
153 ClientBuilder::<T>::new(&api_key).build()
154 }
155}
156
157impl<T> CompletionClient for Client<T>
158where
159 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
160{
161 type CompletionModel = CompletionModel<T>;
162
163 fn completion_model(&self, model: &str) -> Self::CompletionModel {
175 CompletionModel::new(self.clone(), model)
176 }
177}
178
179impl<T> VerifyClient for Client<T>
180where
181 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
182{
183 #[cfg_attr(feature = "worker", worker::send)]
184 async fn verify(&self) -> Result<(), VerifyError> {
185 let req = self
186 .get("/models")?
187 .body(http_client::NoBody)
188 .map_err(http_client::Error::from)?;
189
190 let response = HttpClientExt::send(&self.http_client, req).await?;
191
192 match response.status() {
193 reqwest::StatusCode::OK => Ok(()),
194 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
195 reqwest::StatusCode::INTERNAL_SERVER_ERROR
196 | reqwest::StatusCode::SERVICE_UNAVAILABLE
197 | reqwest::StatusCode::BAD_GATEWAY => {
198 let text = http_client::text(response).await?;
199 Err(VerifyError::ProviderError(text))
200 }
201 _ => Ok(()),
202 }
203 }
204}
205
206impl_conversion_traits!(
207 AsEmbeddings,
208 AsTranscription,
209 AsImageGeneration,
210 AsAudioGeneration for Client<T>
211);
212
213#[derive(Debug, Deserialize)]
214struct ApiErrorResponse {
215 error: MoonshotError,
216}
217
218#[derive(Debug, Deserialize)]
219struct MoonshotError {
220 message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225enum ApiResponse<T> {
226 Ok(T),
227 Err(ApiErrorResponse),
228}
229
230pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
234
235#[derive(Clone)]
236pub struct CompletionModel<T = reqwest::Client> {
237 client: Client<T>,
238 pub model: String,
239}
240
241impl<T> CompletionModel<T> {
242 pub fn new(client: Client<T>, model: &str) -> Self {
243 Self {
244 client,
245 model: model.to_string(),
246 }
247 }
248
249 fn create_completion_request(
250 &self,
251 completion_request: CompletionRequest,
252 ) -> Result<Value, CompletionError> {
253 let mut partial_history = vec![];
255 if let Some(docs) = completion_request.normalized_documents() {
256 partial_history.push(docs);
257 }
258 partial_history.extend(completion_request.chat_history);
259
260 let mut full_history: Vec<openai::Message> = completion_request
262 .preamble
263 .map_or_else(Vec::new, |preamble| {
264 vec![openai::Message::system(&preamble)]
265 });
266
267 full_history.extend(
269 partial_history
270 .into_iter()
271 .map(message::Message::try_into)
272 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
273 .into_iter()
274 .flatten()
275 .collect::<Vec<_>>(),
276 );
277
278 let tool_choice = completion_request
279 .tool_choice
280 .map(ToolChoice::try_from)
281 .transpose()?;
282
283 let request = if completion_request.tools.is_empty() {
284 json!({
285 "model": self.model,
286 "messages": full_history,
287 "temperature": completion_request.temperature,
288 "max_tokens": completion_request.max_tokens,
289 })
290 } else {
291 json!({
292 "model": self.model,
293 "messages": full_history,
294 "temperature": completion_request.temperature,
295 "max_tokens": completion_request.max_tokens,
296 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
297 "tool_choice": tool_choice,
298 })
299 };
300
301 let request = if let Some(params) = completion_request.additional_params {
302 json_utils::merge(request, params)
303 } else {
304 request
305 };
306
307 Ok(request)
308 }
309}
310
311impl<T> completion::CompletionModel for CompletionModel<T>
312where
313 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
314{
315 type Response = openai::CompletionResponse;
316 type StreamingResponse = openai::StreamingCompletionResponse;
317
318 #[cfg_attr(feature = "worker", worker::send)]
319 async fn completion(
320 &self,
321 completion_request: CompletionRequest,
322 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
323 let preamble = completion_request.preamble.clone();
324 let request = self.create_completion_request(completion_request)?;
325
326 println!(
327 "Moonshot API input: {request}",
328 request = serde_json::to_string_pretty(&request).unwrap()
329 );
330
331 let span = if tracing::Span::current().is_disabled() {
332 info_span!(
333 target: "rig::completions",
334 "chat",
335 gen_ai.operation.name = "chat",
336 gen_ai.provider.name = "moonshot",
337 gen_ai.request.model = self.model,
338 gen_ai.system_instructions = preamble,
339 gen_ai.response.id = tracing::field::Empty,
340 gen_ai.response.model = tracing::field::Empty,
341 gen_ai.usage.output_tokens = tracing::field::Empty,
342 gen_ai.usage.input_tokens = tracing::field::Empty,
343 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
344 gen_ai.output.messages = tracing::field::Empty,
345 )
346 } else {
347 tracing::Span::current()
348 };
349
350 let body = serde_json::to_vec(&request)?;
351 let req = self
352 .client
353 .req(Method::POST, "/chat/completions")?
354 .header("Content-Type", "application/json")
355 .body(body)
356 .map_err(http_client::Error::from)?;
357
358 let async_block = async move {
359 let response = self.client.http_client.send::<_, bytes::Bytes>(req).await?;
360
361 let status = response.status();
362 let response_body = response.into_body().into_future().await?.to_vec();
363
364 if status.is_success() {
365 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
366 &response_body,
367 )? {
368 ApiResponse::Ok(response) => {
369 tracing::debug!(target: "rig::completions", "MoonShot completion response: {t}", t = serde_json::to_string_pretty(&response)?);
370 let span = tracing::Span::current();
371 span.record("gen_ai.response.id", response.id.clone());
372 span.record("gen_ai.response.model_name", response.model.clone());
373 span.record(
374 "gen_ai.output.messages",
375 serde_json::to_string(&response.choices).unwrap(),
376 );
377 if let Some(ref usage) = response.usage {
378 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
379 span.record(
380 "gen_ai.usage.output_tokens",
381 usage.total_tokens - usage.prompt_tokens,
382 );
383 }
384 response.try_into()
385 }
386 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
387 }
388 } else {
389 Err(CompletionError::ProviderError(
390 String::from_utf8_lossy(&response_body).to_string(),
391 ))
392 }
393 };
394
395 async_block.instrument(span).await
396 }
397
398 #[cfg_attr(feature = "worker", worker::send)]
399 async fn stream(
400 &self,
401 request: CompletionRequest,
402 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
403 let preamble = request.preamble.clone();
404 let mut request = self.create_completion_request(request)?;
405
406 let span = if tracing::Span::current().is_disabled() {
407 info_span!(
408 target: "rig::completions",
409 "chat_streaming",
410 gen_ai.operation.name = "chat_streaming",
411 gen_ai.provider.name = "moonshot",
412 gen_ai.request.model = self.model,
413 gen_ai.system_instructions = preamble,
414 gen_ai.response.id = tracing::field::Empty,
415 gen_ai.response.model = tracing::field::Empty,
416 gen_ai.usage.output_tokens = tracing::field::Empty,
417 gen_ai.usage.input_tokens = tracing::field::Empty,
418 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
419 gen_ai.output.messages = tracing::field::Empty,
420 )
421 } else {
422 tracing::Span::current()
423 };
424
425 request = merge(
426 request,
427 json!({"stream": true, "stream_options": {"include_usage": true}}),
428 );
429
430 let body = serde_json::to_vec(&request)?;
431 let req = self
432 .client
433 .req(Method::POST, "/chat/completions")?
434 .header("Content-Type", "application/json")
435 .body(body)
436 .map_err(http_client::Error::from)?;
437
438 send_compatible_streaming_request(self.client.http_client.clone(), req)
439 .instrument(span)
440 .await
441 }
442}
443
444#[derive(Default, Debug, Deserialize, Serialize)]
445pub enum ToolChoice {
446 None,
447 #[default]
448 Auto,
449}
450
451impl TryFrom<message::ToolChoice> for ToolChoice {
452 type Error = CompletionError;
453
454 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
455 let res = match value {
456 message::ToolChoice::None => Self::None,
457 message::ToolChoice::Auto => Self::Auto,
458 choice => {
459 return Err(CompletionError::ProviderError(format!(
460 "Unsupported tool choice type: {choice:?}"
461 )));
462 }
463 };
464
465 Ok(res)
466 }
467}