Skip to main content

rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11use crate::client::{
12    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
13    ProviderClient,
14};
15use crate::http_client::HttpClientExt;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    completion::{self, CompletionError, CompletionRequest},
20    json_utils,
21    providers::openai,
22};
23use crate::{http_client, message};
24use serde::{Deserialize, Serialize};
25use tracing::{Instrument, info_span};
26
27// ================================================================
28// Main Moonshot Client
29// ================================================================
30const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MoonshotExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MoonshotBuilder;
36
37type MoonshotApiKey = BearerAuth;
38
39impl Provider for MoonshotExt {
40    type Builder = MoonshotBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl DebugExt for MoonshotExt {}
56
57impl ProviderBuilder for MoonshotBuilder {
58    type Output = MoonshotExt;
59    type ApiKey = MoonshotApiKey;
60
61    const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
62}
63
64impl<H> Capabilities<H> for MoonshotExt {
65    type Completion = Capable<CompletionModel<H>>;
66    type Embeddings = Nothing;
67    type Transcription = Nothing;
68    type ModelListing = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
76pub type ClientBuilder<H = reqwest::Client> =
77    client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
78
79impl ProviderClient for Client {
80    type Input = String;
81
82    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
83    /// Panics if the environment variable is not set.
84    fn from_env() -> Self {
85        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
86        Self::new(&api_key).unwrap()
87    }
88
89    fn from_val(input: Self::Input) -> Self {
90        Self::new(&input).unwrap()
91    }
92}
93
94#[derive(Debug, Deserialize)]
95struct ApiErrorResponse {
96    error: MoonshotError,
97}
98
99#[derive(Debug, Deserialize)]
100struct MoonshotError {
101    message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107    Ok(T),
108    Err(ApiErrorResponse),
109}
110
111// ================================================================
112// Moonshot Completion API
113// ================================================================
114
115pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
116
117#[derive(Debug, Serialize, Deserialize)]
118pub(super) struct MoonshotCompletionRequest {
119    model: String,
120    pub messages: Vec<openai::Message>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    temperature: Option<f64>,
123    #[serde(skip_serializing_if = "Vec::is_empty")]
124    tools: Vec<openai::ToolDefinition>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    max_tokens: Option<u64>,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
129    #[serde(flatten, skip_serializing_if = "Option::is_none")]
130    pub additional_params: Option<serde_json::Value>,
131}
132
133impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
134    type Error = CompletionError;
135
136    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
137        if req.output_schema.is_some() {
138            tracing::warn!("Structured outputs currently not supported for Moonshot");
139        }
140        let model = req.model.clone().unwrap_or_else(|| model.to_string());
141        // Build up the order of messages (context, chat_history, prompt)
142        let mut partial_history = vec![];
143        if let Some(docs) = req.normalized_documents() {
144            partial_history.push(docs);
145        }
146        partial_history.extend(req.chat_history);
147
148        // Add preamble to chat history (if available)
149        let mut full_history: Vec<openai::Message> = match &req.preamble {
150            Some(preamble) => vec![openai::Message::system(preamble)],
151            None => vec![],
152        };
153
154        // Convert and extend the rest of the history
155        full_history.extend(
156            partial_history
157                .into_iter()
158                .map(message::Message::try_into)
159                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
160                .into_iter()
161                .flatten()
162                .collect::<Vec<_>>(),
163        );
164
165        let tool_choice = req
166            .tool_choice
167            .clone()
168            .map(crate::providers::openai::ToolChoice::try_from)
169            .transpose()?;
170
171        Ok(Self {
172            model: model.to_string(),
173            messages: full_history,
174            temperature: req.temperature,
175            max_tokens: req.max_tokens,
176            tools: req
177                .tools
178                .clone()
179                .into_iter()
180                .map(openai::ToolDefinition::from)
181                .collect::<Vec<_>>(),
182            tool_choice,
183            additional_params: req.additional_params,
184        })
185    }
186}
187
188#[derive(Clone)]
189pub struct CompletionModel<T = reqwest::Client> {
190    client: Client<T>,
191    pub model: String,
192}
193
194impl<T> CompletionModel<T> {
195    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
196        Self {
197            client,
198            model: model.into(),
199        }
200    }
201}
202
203impl<T> completion::CompletionModel for CompletionModel<T>
204where
205    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
206{
207    type Response = openai::CompletionResponse;
208    type StreamingResponse = openai::StreamingCompletionResponse;
209
210    type Client = Client<T>;
211
212    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
213        Self::new(client.clone(), model)
214    }
215
216    async fn completion(
217        &self,
218        completion_request: CompletionRequest,
219    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
220        let span = if tracing::Span::current().is_disabled() {
221            info_span!(
222                target: "rig::completions",
223                "chat",
224                gen_ai.operation.name = "chat",
225                gen_ai.provider.name = "moonshot",
226                gen_ai.request.model = self.model,
227                gen_ai.system_instructions = tracing::field::Empty,
228                gen_ai.response.id = tracing::field::Empty,
229                gen_ai.response.model = tracing::field::Empty,
230                gen_ai.usage.output_tokens = tracing::field::Empty,
231                gen_ai.usage.input_tokens = tracing::field::Empty,
232            )
233        } else {
234            tracing::Span::current()
235        };
236
237        span.record("gen_ai.system_instructions", &completion_request.preamble);
238
239        let request =
240            MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
241
242        if tracing::enabled!(tracing::Level::TRACE) {
243            tracing::trace!(target: "rig::completions",
244                "MoonShot completion request: {}",
245                serde_json::to_string_pretty(&request)?
246            );
247        }
248
249        let body = serde_json::to_vec(&request)?;
250        let req = self
251            .client
252            .post("/chat/completions")?
253            .body(body)
254            .map_err(http_client::Error::from)?;
255
256        let async_block = async move {
257            let response = self.client.send::<_, bytes::Bytes>(req).await?;
258
259            let status = response.status();
260            let response_body = response.into_body().into_future().await?.to_vec();
261
262            if status.is_success() {
263                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
264                    &response_body,
265                )? {
266                    ApiResponse::Ok(response) => {
267                        let span = tracing::Span::current();
268                        span.record("gen_ai.response.id", response.id.clone());
269                        span.record("gen_ai.response.model_name", response.model.clone());
270                        if let Some(ref usage) = response.usage {
271                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
272                            span.record(
273                                "gen_ai.usage.output_tokens",
274                                usage.total_tokens - usage.prompt_tokens,
275                            );
276                        }
277                        if tracing::enabled!(tracing::Level::TRACE) {
278                            tracing::trace!(target: "rig::completions",
279                                "MoonShot completion response: {}",
280                                serde_json::to_string_pretty(&response)?
281                            );
282                        }
283                        response.try_into()
284                    }
285                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
286                }
287            } else {
288                Err(CompletionError::ProviderError(
289                    String::from_utf8_lossy(&response_body).to_string(),
290                ))
291            }
292        };
293
294        async_block.instrument(span).await
295    }
296
297    async fn stream(
298        &self,
299        request: CompletionRequest,
300    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
301        let span = if tracing::Span::current().is_disabled() {
302            info_span!(
303                target: "rig::completions",
304                "chat_streaming",
305                gen_ai.operation.name = "chat_streaming",
306                gen_ai.provider.name = "moonshot",
307                gen_ai.request.model = self.model,
308                gen_ai.system_instructions = tracing::field::Empty,
309                gen_ai.response.id = tracing::field::Empty,
310                gen_ai.response.model = tracing::field::Empty,
311                gen_ai.usage.output_tokens = tracing::field::Empty,
312                gen_ai.usage.input_tokens = tracing::field::Empty,
313            )
314        } else {
315            tracing::Span::current()
316        };
317
318        span.record("gen_ai.system_instructions", &request.preamble);
319        let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
320
321        let params = json_utils::merge(
322            request.additional_params.unwrap_or(serde_json::json!({})),
323            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
324        );
325
326        request.additional_params = Some(params);
327
328        if tracing::enabled!(tracing::Level::TRACE) {
329            tracing::trace!(target: "rig::completions",
330                "MoonShot streaming completion request: {}",
331                serde_json::to_string_pretty(&request)?
332            );
333        }
334
335        let body = serde_json::to_vec(&request)?;
336        let req = self
337            .client
338            .post("/chat/completions")?
339            .body(body)
340            .map_err(http_client::Error::from)?;
341
342        send_compatible_streaming_request(self.client.clone(), req)
343            .instrument(span)
344            .await
345    }
346}
347
348#[derive(Default, Debug, Deserialize, Serialize)]
349pub enum ToolChoice {
350    None,
351    #[default]
352    Auto,
353}
354
355impl TryFrom<message::ToolChoice> for ToolChoice {
356    type Error = CompletionError;
357
358    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
359        let res = match value {
360            message::ToolChoice::None => Self::None,
361            message::ToolChoice::Auto => Self::Auto,
362            choice => {
363                return Err(CompletionError::ProviderError(format!(
364                    "Unsupported tool choice type: {choice:?}"
365                )));
366            }
367        };
368
369        Ok(res)
370    }
371}
372#[cfg(test)]
373mod tests {
374    #[test]
375    fn test_client_initialization() {
376        let _client: crate::providers::moonshot::Client =
377            crate::providers::moonshot::Client::new("dummy-key").expect("Client::new() failed");
378        let _client_from_builder: crate::providers::moonshot::Client =
379            crate::providers::moonshot::Client::builder()
380                .api_key("dummy-key")
381                .build()
382                .expect("Client::builder() failed");
383    }
384}