rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11use crate::client::{
12    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
13    ProviderClient,
14};
15use crate::http_client::HttpClientExt;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    completion::{self, CompletionError, CompletionRequest},
20    json_utils,
21    providers::openai,
22};
23use crate::{http_client, message};
24use serde::{Deserialize, Serialize};
25use tracing::{Instrument, info_span};
26
27// ================================================================
28// Main Moonshot Client
29// ================================================================
30const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MoonshotExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MoonshotBuilder;
36
37type MoonshotApiKey = BearerAuth;
38
39impl Provider for MoonshotExt {
40    type Builder = MoonshotBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl DebugExt for MoonshotExt {}
56
57impl ProviderBuilder for MoonshotBuilder {
58    type Output = MoonshotExt;
59    type ApiKey = MoonshotApiKey;
60
61    const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
62}
63
64impl<H> Capabilities<H> for MoonshotExt {
65    type Completion = Capable<CompletionModel<H>>;
66    type Embeddings = Nothing;
67    type Transcription = Nothing;
68    #[cfg(feature = "image")]
69    type ImageGeneration = Nothing;
70    #[cfg(feature = "audio")]
71    type AudioGeneration = Nothing;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
75pub type ClientBuilder<H = reqwest::Client> =
76    client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
77
78impl ProviderClient for Client {
79    type Input = String;
80
81    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
82    /// Panics if the environment variable is not set.
83    fn from_env() -> Self {
84        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
85        Self::new(&api_key).unwrap()
86    }
87
88    fn from_val(input: Self::Input) -> Self {
89        Self::new(&input).unwrap()
90    }
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95    error: MoonshotError,
96}
97
98#[derive(Debug, Deserialize)]
99struct MoonshotError {
100    message: String,
101}
102
103#[derive(Debug, Deserialize)]
104#[serde(untagged)]
105enum ApiResponse<T> {
106    Ok(T),
107    Err(ApiErrorResponse),
108}
109
110// ================================================================
111// Moonshot Completion API
112// ================================================================
113
114pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
115
116#[derive(Debug, Serialize, Deserialize)]
117pub(super) struct MoonshotCompletionRequest {
118    model: String,
119    pub messages: Vec<openai::Message>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    temperature: Option<f64>,
122    #[serde(skip_serializing_if = "Vec::is_empty")]
123    tools: Vec<openai::ToolDefinition>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    max_tokens: Option<u64>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
128    #[serde(flatten, skip_serializing_if = "Option::is_none")]
129    pub additional_params: Option<serde_json::Value>,
130}
131
132impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
133    type Error = CompletionError;
134
135    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
136        // Build up the order of messages (context, chat_history, prompt)
137        let mut partial_history = vec![];
138        if let Some(docs) = req.normalized_documents() {
139            partial_history.push(docs);
140        }
141        partial_history.extend(req.chat_history);
142
143        // Add preamble to chat history (if available)
144        let mut full_history: Vec<openai::Message> = match &req.preamble {
145            Some(preamble) => vec![openai::Message::system(preamble)],
146            None => vec![],
147        };
148
149        // Convert and extend the rest of the history
150        full_history.extend(
151            partial_history
152                .into_iter()
153                .map(message::Message::try_into)
154                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
155                .into_iter()
156                .flatten()
157                .collect::<Vec<_>>(),
158        );
159
160        let tool_choice = req
161            .tool_choice
162            .clone()
163            .map(crate::providers::openai::ToolChoice::try_from)
164            .transpose()?;
165
166        Ok(Self {
167            model: model.to_string(),
168            messages: full_history,
169            temperature: req.temperature,
170            max_tokens: req.max_tokens,
171            tools: req
172                .tools
173                .clone()
174                .into_iter()
175                .map(openai::ToolDefinition::from)
176                .collect::<Vec<_>>(),
177            tool_choice,
178            additional_params: req.additional_params,
179        })
180    }
181}
182
183#[derive(Clone)]
184pub struct CompletionModel<T = reqwest::Client> {
185    client: Client<T>,
186    pub model: String,
187}
188
189impl<T> CompletionModel<T> {
190    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
191        Self {
192            client,
193            model: model.into(),
194        }
195    }
196}
197
198impl<T> completion::CompletionModel for CompletionModel<T>
199where
200    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
201{
202    type Response = openai::CompletionResponse;
203    type StreamingResponse = openai::StreamingCompletionResponse;
204
205    type Client = Client<T>;
206
207    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
208        Self::new(client.clone(), model)
209    }
210
211    #[cfg_attr(feature = "worker", worker::send)]
212    async fn completion(
213        &self,
214        completion_request: CompletionRequest,
215    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
216        let span = if tracing::Span::current().is_disabled() {
217            info_span!(
218                target: "rig::completions",
219                "chat",
220                gen_ai.operation.name = "chat",
221                gen_ai.provider.name = "moonshot",
222                gen_ai.request.model = self.model,
223                gen_ai.system_instructions = tracing::field::Empty,
224                gen_ai.response.id = tracing::field::Empty,
225                gen_ai.response.model = tracing::field::Empty,
226                gen_ai.usage.output_tokens = tracing::field::Empty,
227                gen_ai.usage.input_tokens = tracing::field::Empty,
228            )
229        } else {
230            tracing::Span::current()
231        };
232
233        span.record("gen_ai.system_instructions", &completion_request.preamble);
234
235        let request =
236            MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
237
238        if tracing::enabled!(tracing::Level::TRACE) {
239            tracing::trace!(target: "rig::completions",
240                "MoonShot completion request: {}",
241                serde_json::to_string_pretty(&request)?
242            );
243        }
244
245        let body = serde_json::to_vec(&request)?;
246        let req = self
247            .client
248            .post("/chat/completions")?
249            .body(body)
250            .map_err(http_client::Error::from)?;
251
252        let async_block = async move {
253            let response = self.client.send::<_, bytes::Bytes>(req).await?;
254
255            let status = response.status();
256            let response_body = response.into_body().into_future().await?.to_vec();
257
258            if status.is_success() {
259                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
260                    &response_body,
261                )? {
262                    ApiResponse::Ok(response) => {
263                        let span = tracing::Span::current();
264                        span.record("gen_ai.response.id", response.id.clone());
265                        span.record("gen_ai.response.model_name", response.model.clone());
266                        if let Some(ref usage) = response.usage {
267                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
268                            span.record(
269                                "gen_ai.usage.output_tokens",
270                                usage.total_tokens - usage.prompt_tokens,
271                            );
272                        }
273                        if tracing::enabled!(tracing::Level::TRACE) {
274                            tracing::trace!(target: "rig::completions",
275                                "MoonShot completion response: {}",
276                                serde_json::to_string_pretty(&response)?
277                            );
278                        }
279                        response.try_into()
280                    }
281                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
282                }
283            } else {
284                Err(CompletionError::ProviderError(
285                    String::from_utf8_lossy(&response_body).to_string(),
286                ))
287            }
288        };
289
290        async_block.instrument(span).await
291    }
292
293    #[cfg_attr(feature = "worker", worker::send)]
294    async fn stream(
295        &self,
296        request: CompletionRequest,
297    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
298        let span = if tracing::Span::current().is_disabled() {
299            info_span!(
300                target: "rig::completions",
301                "chat_streaming",
302                gen_ai.operation.name = "chat_streaming",
303                gen_ai.provider.name = "moonshot",
304                gen_ai.request.model = self.model,
305                gen_ai.system_instructions = tracing::field::Empty,
306                gen_ai.response.id = tracing::field::Empty,
307                gen_ai.response.model = tracing::field::Empty,
308                gen_ai.usage.output_tokens = tracing::field::Empty,
309                gen_ai.usage.input_tokens = tracing::field::Empty,
310            )
311        } else {
312            tracing::Span::current()
313        };
314
315        span.record("gen_ai.system_instructions", &request.preamble);
316        let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
317
318        let params = json_utils::merge(
319            request.additional_params.unwrap_or(serde_json::json!({})),
320            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
321        );
322
323        request.additional_params = Some(params);
324
325        if tracing::enabled!(tracing::Level::TRACE) {
326            tracing::trace!(target: "rig::completions",
327                "MoonShot streaming completion request: {}",
328                serde_json::to_string_pretty(&request)?
329            );
330        }
331
332        let body = serde_json::to_vec(&request)?;
333        let req = self
334            .client
335            .post("/chat/completions")?
336            .body(body)
337            .map_err(http_client::Error::from)?;
338
339        send_compatible_streaming_request(self.client.clone(), req)
340            .instrument(span)
341            .await
342    }
343}
344
345#[derive(Default, Debug, Deserialize, Serialize)]
346pub enum ToolChoice {
347    None,
348    #[default]
349    Auto,
350}
351
352impl TryFrom<message::ToolChoice> for ToolChoice {
353    type Error = CompletionError;
354
355    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
356        let res = match value {
357            message::ToolChoice::None => Self::None,
358            message::ToolChoice::Auto => Self::Auto,
359            choice => {
360                return Err(CompletionError::ProviderError(format!(
361                    "Unsupported tool choice type: {choice:?}"
362                )));
363            }
364        };
365
366        Ok(res)
367    }
368}