rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
12use crate::http_client::HttpClientExt;
13use crate::json_utils::merge;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    completion::{self, CompletionError, CompletionRequest},
18    json_utils,
19    providers::openai,
20};
21use crate::{http_client, impl_conversion_traits, message};
22use serde::{Deserialize, Serialize};
23use serde_json::{Value, json};
24use tracing::{Instrument, info_span};
25
26// ================================================================
27// Main Moonshot Client
28// ================================================================
29const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
30
31pub struct ClientBuilder<'a, T = reqwest::Client> {
32    api_key: &'a str,
33    base_url: &'a str,
34    http_client: T,
35}
36
37impl<'a, T> ClientBuilder<'a, T>
38where
39    T: Default,
40{
41    pub fn new(api_key: &'a str) -> Self {
42        Self {
43            api_key,
44            base_url: MOONSHOT_API_BASE_URL,
45            http_client: Default::default(),
46        }
47    }
48}
49
50impl<'a, T> ClientBuilder<'a, T> {
51    pub fn base_url(mut self, base_url: &'a str) -> Self {
52        self.base_url = base_url;
53        self
54    }
55
56    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
57        ClientBuilder {
58            api_key: self.api_key,
59            base_url: self.base_url,
60            http_client,
61        }
62    }
63
64    pub fn build(self) -> Client<T> {
65        Client {
66            base_url: self.base_url.to_string(),
67            api_key: self.api_key.to_string(),
68            http_client: self.http_client,
69        }
70    }
71}
72
73#[derive(Clone)]
74pub struct Client<T = reqwest::Client> {
75    base_url: String,
76    api_key: String,
77    http_client: T,
78}
79
80impl<T> std::fmt::Debug for Client<T>
81where
82    T: std::fmt::Debug,
83{
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("Client")
86            .field("base_url", &self.base_url)
87            .field("http_client", &self.http_client)
88            .field("api_key", &"<REDACTED>")
89            .finish()
90    }
91}
92
93impl<T> Client<T>
94where
95    T: Default,
96{
97    /// Create a new Moonshot client builder.
98    ///
99    /// # Example
100    /// ```
101    /// use rig::providers::moonshot::{ClientBuilder, self};
102    ///
103    /// // Initialize the Moonshot client
104    /// let moonshot = Client::builder("your-moonshot-api-key")
105    ///    .build()
106    /// ```
107    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
108        ClientBuilder::new(api_key)
109    }
110
111    /// Create a new Moonshot client. For more control, use the `builder` method.
112    ///
113    /// # Panics
114    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
115    pub fn new(api_key: &str) -> Self {
116        Self::builder(api_key).build()
117    }
118}
119
120impl<T> Client<T>
121where
122    T: HttpClientExt,
123{
124    fn req(
125        &self,
126        method: http_client::Method,
127        path: &str,
128    ) -> http_client::Result<http_client::Builder> {
129        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
130
131        http_client::with_bearer_auth(
132            http_client::Builder::new().method(method).uri(url),
133            &self.api_key,
134        )
135    }
136
137    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
138        self.req(http_client::Method::GET, path)
139    }
140}
141
142impl Client<reqwest::Client> {
143    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
144        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
145
146        self.http_client.post(url).bearer_auth(&self.api_key)
147    }
148}
149
150impl ProviderClient for Client<reqwest::Client> {
151    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
152    /// Panics if the environment variable is not set.
153    fn from_env() -> Self {
154        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
155        Self::new(&api_key)
156    }
157
158    fn from_val(input: crate::client::ProviderValue) -> Self {
159        let crate::client::ProviderValue::Simple(api_key) = input else {
160            panic!("Incorrect provider value type")
161        };
162        Self::new(&api_key)
163    }
164}
165
166impl CompletionClient for Client<reqwest::Client> {
167    type CompletionModel = CompletionModel<reqwest::Client>;
168
169    /// Create a completion model with the given name.
170    ///
171    /// # Example
172    /// ```
173    /// use rig::providers::moonshot::{Client, self};
174    ///
175    /// // Initialize the Moonshot client
176    /// let moonshot = Client::new("your-moonshot-api-key");
177    ///
178    /// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT);
179    /// ```
180    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
181        CompletionModel::new(self.clone(), model)
182    }
183}
184
185impl VerifyClient for Client<reqwest::Client> {
186    #[cfg_attr(feature = "worker", worker::send)]
187    async fn verify(&self) -> Result<(), VerifyError> {
188        let req = self
189            .get("/models")?
190            .body(http_client::NoBody)
191            .map_err(http_client::Error::from)?;
192
193        let response = HttpClientExt::send(&self.http_client, req).await?;
194
195        match response.status() {
196            reqwest::StatusCode::OK => Ok(()),
197            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
198            reqwest::StatusCode::INTERNAL_SERVER_ERROR
199            | reqwest::StatusCode::SERVICE_UNAVAILABLE
200            | reqwest::StatusCode::BAD_GATEWAY => {
201                let text = http_client::text(response).await?;
202                Err(VerifyError::ProviderError(text))
203            }
204            _ => {
205                //response.error_for_status()?;
206                Ok(())
207            }
208        }
209    }
210}
211
212impl_conversion_traits!(
213    AsEmbeddings,
214    AsTranscription,
215    AsImageGeneration,
216    AsAudioGeneration for Client<T>
217);
218
219#[derive(Debug, Deserialize)]
220struct ApiErrorResponse {
221    error: MoonshotError,
222}
223
224#[derive(Debug, Deserialize)]
225struct MoonshotError {
226    message: String,
227}
228
229#[derive(Debug, Deserialize)]
230#[serde(untagged)]
231enum ApiResponse<T> {
232    Ok(T),
233    Err(ApiErrorResponse),
234}
235
236// ================================================================
237// Moonshot Completion API
238// ================================================================
239pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
240
241#[derive(Clone)]
242pub struct CompletionModel<T = reqwest::Client> {
243    client: Client<T>,
244    pub model: String,
245}
246
247impl<T> CompletionModel<T> {
248    pub fn new(client: Client<T>, model: &str) -> Self {
249        Self {
250            client,
251            model: model.to_string(),
252        }
253    }
254
255    fn create_completion_request(
256        &self,
257        completion_request: CompletionRequest,
258    ) -> Result<Value, CompletionError> {
259        // Build up the order of messages (context, chat_history)
260        let mut partial_history = vec![];
261        if let Some(docs) = completion_request.normalized_documents() {
262            partial_history.push(docs);
263        }
264        partial_history.extend(completion_request.chat_history);
265
266        // Initialize full history with preamble (or empty if non-existent)
267        let mut full_history: Vec<openai::Message> = completion_request
268            .preamble
269            .map_or_else(Vec::new, |preamble| {
270                vec![openai::Message::system(&preamble)]
271            });
272
273        // Convert and extend the rest of the history
274        full_history.extend(
275            partial_history
276                .into_iter()
277                .map(message::Message::try_into)
278                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
279                .into_iter()
280                .flatten()
281                .collect::<Vec<_>>(),
282        );
283
284        let tool_choice = completion_request
285            .tool_choice
286            .map(ToolChoice::try_from)
287            .transpose()?;
288
289        let request = if completion_request.tools.is_empty() {
290            json!({
291                "model": self.model,
292                "messages": full_history,
293                "temperature": completion_request.temperature,
294            })
295        } else {
296            json!({
297                "model": self.model,
298                "messages": full_history,
299                "temperature": completion_request.temperature,
300                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
301                "tool_choice": tool_choice,
302            })
303        };
304
305        let request = if let Some(params) = completion_request.additional_params {
306            json_utils::merge(request, params)
307        } else {
308            request
309        };
310
311        Ok(request)
312    }
313}
314
315impl completion::CompletionModel for CompletionModel<reqwest::Client> {
316    type Response = openai::CompletionResponse;
317    type StreamingResponse = openai::StreamingCompletionResponse;
318
319    #[cfg_attr(feature = "worker", worker::send)]
320    async fn completion(
321        &self,
322        completion_request: CompletionRequest,
323    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
324        let preamble = completion_request.preamble.clone();
325        let request = self.create_completion_request(completion_request)?;
326
327        let span = if tracing::Span::current().is_disabled() {
328            info_span!(
329                target: "rig::completions",
330                "chat",
331                gen_ai.operation.name = "chat",
332                gen_ai.provider.name = "moonshot",
333                gen_ai.request.model = self.model,
334                gen_ai.system_instructions = preamble,
335                gen_ai.response.id = tracing::field::Empty,
336                gen_ai.response.model = tracing::field::Empty,
337                gen_ai.usage.output_tokens = tracing::field::Empty,
338                gen_ai.usage.input_tokens = tracing::field::Empty,
339                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
340                gen_ai.output.messages = tracing::field::Empty,
341            )
342        } else {
343            tracing::Span::current()
344        };
345
346        let async_block = async move {
347            let response = self
348                .client
349                .reqwest_post("/chat/completions")
350                .json(&request)
351                .send()
352                .await
353                .map_err(|e| http_client::Error::Instance(e.into()))?;
354
355            if response.status().is_success() {
356                let t = response
357                    .text()
358                    .await
359                    .map_err(|e| http_client::Error::Instance(e.into()))?;
360                tracing::debug!(target: "rig::completions", "MoonShot completion response: {t}");
361
362                match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
363                    ApiResponse::Ok(response) => {
364                        let span = tracing::Span::current();
365                        span.record("gen_ai.response.id", response.id.clone());
366                        span.record("gen_ai.response.model_name", response.model.clone());
367                        span.record(
368                            "gen_ai.output.messages",
369                            serde_json::to_string(&response.choices).unwrap(),
370                        );
371                        if let Some(ref usage) = response.usage {
372                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
373                            span.record(
374                                "gen_ai.usage.output_tokens",
375                                usage.total_tokens - usage.prompt_tokens,
376                            );
377                        }
378                        response.try_into()
379                    }
380                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
381                }
382            } else {
383                Err(CompletionError::ProviderError(
384                    response
385                        .text()
386                        .await
387                        .map_err(|e| http_client::Error::Instance(e.into()))?,
388                ))
389            }
390        };
391
392        async_block.instrument(span).await
393    }
394
395    #[cfg_attr(feature = "worker", worker::send)]
396    async fn stream(
397        &self,
398        request: CompletionRequest,
399    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
400        let preamble = request.preamble.clone();
401        let mut request = self.create_completion_request(request)?;
402
403        let span = if tracing::Span::current().is_disabled() {
404            info_span!(
405                target: "rig::completions",
406                "chat_streaming",
407                gen_ai.operation.name = "chat_streaming",
408                gen_ai.provider.name = "moonshot",
409                gen_ai.request.model = self.model,
410                gen_ai.system_instructions = preamble,
411                gen_ai.response.id = tracing::field::Empty,
412                gen_ai.response.model = tracing::field::Empty,
413                gen_ai.usage.output_tokens = tracing::field::Empty,
414                gen_ai.usage.input_tokens = tracing::field::Empty,
415                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
416                gen_ai.output.messages = tracing::field::Empty,
417            )
418        } else {
419            tracing::Span::current()
420        };
421
422        request = merge(
423            request,
424            json!({"stream": true, "stream_options": {"include_usage": true}}),
425        );
426
427        let builder = self.client.reqwest_post("/chat/completions").json(&request);
428
429        send_compatible_streaming_request(builder)
430            .instrument(span)
431            .await
432    }
433}
434
435#[derive(Default, Debug, Deserialize, Serialize)]
436pub enum ToolChoice {
437    None,
438    #[default]
439    Auto,
440}
441
442impl TryFrom<message::ToolChoice> for ToolChoice {
443    type Error = CompletionError;
444
445    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
446        let res = match value {
447            message::ToolChoice::None => Self::None,
448            message::ToolChoice::Auto => Self::Auto,
449            choice => {
450                return Err(CompletionError::ProviderError(format!(
451                    "Unsupported tool choice type: {choice:?}"
452                )));
453            }
454        };
455
456        Ok(res)
457    }
458}