Skip to main content

rig/providers/
moonshot.rs

1//! Moonshot AI (Kimi) API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig::providers::moonshot;
6//! use rig::client::CompletionClient;
7//!
8//! let client = moonshot::Client::new("YOUR_API_KEY").expect("Failed to build client");
9//!
10//! let kimi_model = client.completion_model(moonshot::KIMI_K2_5);
11//! ```
12//!
13//! # Custom base URL
14//! The default base URL is `https://api.moonshot.cn/v1`. For global access,
15//! use `https://api.moonshot.ai/v1`:
16//! ```no_run
17//! use rig::providers::moonshot;
18//!
19//! let client = moonshot::Client::builder()
20//!     .api_key("YOUR_API_KEY")
21//!     .base_url("https://api.moonshot.ai/v1")
22//!     .build()
23//!     .expect("Failed to build Moonshot client");
24//! ```
25use crate::client::{
26    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27    ProviderClient,
28};
29use crate::http_client::HttpClientExt;
30use crate::providers::openai::send_compatible_streaming_request;
31use crate::streaming::StreamingCompletionResponse;
32use crate::{
33    completion::{self, CompletionError, CompletionRequest},
34    json_utils,
35    providers::openai,
36};
37use crate::{http_client, message};
38use serde::{Deserialize, Serialize};
39use tracing::{Instrument, info_span};
40
41// ================================================================
42// Main Moonshot Client
43// ================================================================
44const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct MoonshotExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct MoonshotBuilder;
50
51type MoonshotApiKey = BearerAuth;
52
53impl Provider for MoonshotExt {
54    type Builder = MoonshotBuilder;
55
56    const VERIFY_PATH: &'static str = "/models";
57}
58
59impl DebugExt for MoonshotExt {}
60
61impl ProviderBuilder for MoonshotBuilder {
62    type Extension<H>
63        = MoonshotExt
64    where
65        H: HttpClientExt;
66    type ApiKey = MoonshotApiKey;
67
68    const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
69
70    fn build<H>(
71        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72    ) -> http_client::Result<Self::Extension<H>>
73    where
74        H: HttpClientExt,
75    {
76        Ok(MoonshotExt)
77    }
78}
79
80impl<H> Capabilities<H> for MoonshotExt {
81    type Completion = Capable<CompletionModel<H>>;
82    type Embeddings = Nothing;
83    type Transcription = Nothing;
84    type ModelListing = Nothing;
85    #[cfg(feature = "image")]
86    type ImageGeneration = Nothing;
87    #[cfg(feature = "audio")]
88    type AudioGeneration = Nothing;
89}
90
91pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
92pub type ClientBuilder<H = reqwest::Client> =
93    client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
94
95impl ProviderClient for Client {
96    type Input = String;
97
98    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
99    /// Panics if the environment variable is not set.
100    fn from_env() -> Self {
101        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
102        Self::new(&api_key).unwrap()
103    }
104
105    fn from_val(input: Self::Input) -> Self {
106        Self::new(&input).unwrap()
107    }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112    error: MoonshotError,
113}
114
115#[derive(Debug, Deserialize)]
116struct MoonshotError {
117    message: String,
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(untagged)]
122enum ApiResponse<T> {
123    Ok(T),
124    Err(ApiErrorResponse),
125}
126
127// ================================================================
128// Moonshot Completion API
129// ================================================================
130
131/// Moonshot v1 128K context model (legacy)
132pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
133
134/// Kimi K2 — Mixture-of-Experts model (1T total params, 32B active)
135pub const KIMI_K2: &str = "kimi-k2";
136
137/// Kimi K2.5 — Native multimodal agentic model with 256K context
138pub const KIMI_K2_5: &str = "kimi-k2.5";
139
140#[derive(Debug, Serialize, Deserialize)]
141pub(super) struct MoonshotCompletionRequest {
142    model: String,
143    pub messages: Vec<openai::Message>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    temperature: Option<f64>,
146    #[serde(skip_serializing_if = "Vec::is_empty")]
147    tools: Vec<openai::ToolDefinition>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    max_tokens: Option<u64>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
152    #[serde(flatten, skip_serializing_if = "Option::is_none")]
153    pub additional_params: Option<serde_json::Value>,
154}
155
156impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
157    type Error = CompletionError;
158
159    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
160        if req.output_schema.is_some() {
161            tracing::warn!("Structured outputs currently not supported for Moonshot");
162        }
163        let model = req.model.clone().unwrap_or_else(|| model.to_string());
164        // Build up the order of messages (context, chat_history, prompt)
165        let mut partial_history = vec![];
166        if let Some(docs) = req.normalized_documents() {
167            partial_history.push(docs);
168        }
169        partial_history.extend(req.chat_history);
170
171        // Add preamble to chat history (if available)
172        let mut full_history: Vec<openai::Message> = match &req.preamble {
173            Some(preamble) => vec![openai::Message::system(preamble)],
174            None => vec![],
175        };
176
177        // Convert and extend the rest of the history
178        full_history.extend(
179            partial_history
180                .into_iter()
181                .map(message::Message::try_into)
182                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
183                .into_iter()
184                .flatten()
185                .collect::<Vec<_>>(),
186        );
187
188        let tool_choice = req
189            .tool_choice
190            .clone()
191            .map(crate::providers::openai::ToolChoice::try_from)
192            .transpose()?;
193
194        Ok(Self {
195            model: model.to_string(),
196            messages: full_history,
197            temperature: req.temperature,
198            max_tokens: req.max_tokens,
199            tools: req
200                .tools
201                .clone()
202                .into_iter()
203                .map(openai::ToolDefinition::from)
204                .collect::<Vec<_>>(),
205            tool_choice,
206            additional_params: req.additional_params,
207        })
208    }
209}
210
211#[derive(Clone)]
212pub struct CompletionModel<T = reqwest::Client> {
213    client: Client<T>,
214    pub model: String,
215}
216
217impl<T> CompletionModel<T> {
218    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
219        Self {
220            client,
221            model: model.into(),
222        }
223    }
224}
225
226impl<T> completion::CompletionModel for CompletionModel<T>
227where
228    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
229{
230    type Response = openai::CompletionResponse;
231    type StreamingResponse = openai::StreamingCompletionResponse;
232
233    type Client = Client<T>;
234
235    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
236        Self::new(client.clone(), model)
237    }
238
239    async fn completion(
240        &self,
241        completion_request: CompletionRequest,
242    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
243        let span = if tracing::Span::current().is_disabled() {
244            info_span!(
245                target: "rig::completions",
246                "chat",
247                gen_ai.operation.name = "chat",
248                gen_ai.provider.name = "moonshot",
249                gen_ai.request.model = self.model,
250                gen_ai.system_instructions = tracing::field::Empty,
251                gen_ai.response.id = tracing::field::Empty,
252                gen_ai.response.model = tracing::field::Empty,
253                gen_ai.usage.output_tokens = tracing::field::Empty,
254                gen_ai.usage.input_tokens = tracing::field::Empty,
255            )
256        } else {
257            tracing::Span::current()
258        };
259
260        span.record("gen_ai.system_instructions", &completion_request.preamble);
261
262        let request =
263            MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
264
265        if tracing::enabled!(tracing::Level::TRACE) {
266            tracing::trace!(target: "rig::completions",
267                "MoonShot completion request: {}",
268                serde_json::to_string_pretty(&request)?
269            );
270        }
271
272        let body = serde_json::to_vec(&request)?;
273        let req = self
274            .client
275            .post("/chat/completions")?
276            .body(body)
277            .map_err(http_client::Error::from)?;
278
279        let async_block = async move {
280            let response = self.client.send::<_, bytes::Bytes>(req).await?;
281
282            let status = response.status();
283            let response_body = response.into_body().into_future().await?.to_vec();
284
285            if status.is_success() {
286                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
287                    &response_body,
288                )? {
289                    ApiResponse::Ok(response) => {
290                        let span = tracing::Span::current();
291                        span.record("gen_ai.response.id", response.id.clone());
292                        span.record("gen_ai.response.model_name", response.model.clone());
293                        if let Some(ref usage) = response.usage {
294                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
295                            span.record(
296                                "gen_ai.usage.output_tokens",
297                                usage.total_tokens - usage.prompt_tokens,
298                            );
299                        }
300                        if tracing::enabled!(tracing::Level::TRACE) {
301                            tracing::trace!(target: "rig::completions",
302                                "MoonShot completion response: {}",
303                                serde_json::to_string_pretty(&response)?
304                            );
305                        }
306                        response.try_into()
307                    }
308                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
309                }
310            } else {
311                Err(CompletionError::ProviderError(
312                    String::from_utf8_lossy(&response_body).to_string(),
313                ))
314            }
315        };
316
317        async_block.instrument(span).await
318    }
319
320    async fn stream(
321        &self,
322        request: CompletionRequest,
323    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
324        let span = if tracing::Span::current().is_disabled() {
325            info_span!(
326                target: "rig::completions",
327                "chat_streaming",
328                gen_ai.operation.name = "chat_streaming",
329                gen_ai.provider.name = "moonshot",
330                gen_ai.request.model = self.model,
331                gen_ai.system_instructions = tracing::field::Empty,
332                gen_ai.response.id = tracing::field::Empty,
333                gen_ai.response.model = tracing::field::Empty,
334                gen_ai.usage.output_tokens = tracing::field::Empty,
335                gen_ai.usage.input_tokens = tracing::field::Empty,
336            )
337        } else {
338            tracing::Span::current()
339        };
340
341        span.record("gen_ai.system_instructions", &request.preamble);
342        let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
343
344        let params = json_utils::merge(
345            request.additional_params.unwrap_or(serde_json::json!({})),
346            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
347        );
348
349        request.additional_params = Some(params);
350
351        if tracing::enabled!(tracing::Level::TRACE) {
352            tracing::trace!(target: "rig::completions",
353                "MoonShot streaming completion request: {}",
354                serde_json::to_string_pretty(&request)?
355            );
356        }
357
358        let body = serde_json::to_vec(&request)?;
359        let req = self
360            .client
361            .post("/chat/completions")?
362            .body(body)
363            .map_err(http_client::Error::from)?;
364
365        send_compatible_streaming_request(self.client.clone(), req)
366            .instrument(span)
367            .await
368    }
369}
370
371#[derive(Default, Debug, Deserialize, Serialize)]
372pub enum ToolChoice {
373    None,
374    #[default]
375    Auto,
376}
377
378impl TryFrom<message::ToolChoice> for ToolChoice {
379    type Error = CompletionError;
380
381    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
382        let res = match value {
383            message::ToolChoice::None => Self::None,
384            message::ToolChoice::Auto => Self::Auto,
385            choice => {
386                return Err(CompletionError::ProviderError(format!(
387                    "Unsupported tool choice type: {choice:?}"
388                )));
389            }
390        };
391
392        Ok(res)
393    }
394}
395#[cfg(test)]
396mod tests {
397    #[test]
398    fn test_client_initialization() {
399        let _client =
400            crate::providers::moonshot::Client::new("dummy-key").expect("Client::new() failed");
401        let _client_from_builder = crate::providers::moonshot::Client::builder()
402            .api_key("dummy-key")
403            .build()
404            .expect("Client::builder() failed");
405    }
406}