1use crate::client::{
26 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27 ProviderClient,
28};
29use crate::http_client::HttpClientExt;
30use crate::providers::openai::send_compatible_streaming_request;
31use crate::streaming::StreamingCompletionResponse;
32use crate::{
33 completion::{self, CompletionError, CompletionRequest},
34 json_utils,
35 providers::openai,
36};
37use crate::{http_client, message};
38use serde::{Deserialize, Serialize};
39use tracing::{Instrument, info_span};
40
41const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
45
46#[derive(Debug, Default, Clone, Copy)]
47pub struct MoonshotExt;
48#[derive(Debug, Default, Clone, Copy)]
49pub struct MoonshotBuilder;
50
51type MoonshotApiKey = BearerAuth;
52
53impl Provider for MoonshotExt {
54 type Builder = MoonshotBuilder;
55
56 const VERIFY_PATH: &'static str = "/models";
57}
58
59impl DebugExt for MoonshotExt {}
60
61impl ProviderBuilder for MoonshotBuilder {
62 type Extension<H>
63 = MoonshotExt
64 where
65 H: HttpClientExt;
66 type ApiKey = MoonshotApiKey;
67
68 const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
69
70 fn build<H>(
71 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72 ) -> http_client::Result<Self::Extension<H>>
73 where
74 H: HttpClientExt,
75 {
76 Ok(MoonshotExt)
77 }
78}
79
80impl<H> Capabilities<H> for MoonshotExt {
81 type Completion = Capable<CompletionModel<H>>;
82 type Embeddings = Nothing;
83 type Transcription = Nothing;
84 type ModelListing = Nothing;
85 #[cfg(feature = "image")]
86 type ImageGeneration = Nothing;
87 #[cfg(feature = "audio")]
88 type AudioGeneration = Nothing;
89}
90
91pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
92pub type ClientBuilder<H = reqwest::Client> =
93 client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
94
95impl ProviderClient for Client {
96 type Input = String;
97
98 fn from_env() -> Self {
101 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
102 Self::new(&api_key).unwrap()
103 }
104
105 fn from_val(input: Self::Input) -> Self {
106 Self::new(&input).unwrap()
107 }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112 error: MoonshotError,
113}
114
115#[derive(Debug, Deserialize)]
116struct MoonshotError {
117 message: String,
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(untagged)]
122enum ApiResponse<T> {
123 Ok(T),
124 Err(ApiErrorResponse),
125}
126
127pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
133
134pub const KIMI_K2: &str = "kimi-k2";
136
137pub const KIMI_K2_5: &str = "kimi-k2.5";
139
140#[derive(Debug, Serialize, Deserialize)]
141pub(super) struct MoonshotCompletionRequest {
142 model: String,
143 pub messages: Vec<openai::Message>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 temperature: Option<f64>,
146 #[serde(skip_serializing_if = "Vec::is_empty")]
147 tools: Vec<openai::ToolDefinition>,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 max_tokens: Option<u64>,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
152 #[serde(flatten, skip_serializing_if = "Option::is_none")]
153 pub additional_params: Option<serde_json::Value>,
154}
155
156impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
157 type Error = CompletionError;
158
159 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
160 if req.output_schema.is_some() {
161 tracing::warn!("Structured outputs currently not supported for Moonshot");
162 }
163 let model = req.model.clone().unwrap_or_else(|| model.to_string());
164 let mut partial_history = vec![];
166 if let Some(docs) = req.normalized_documents() {
167 partial_history.push(docs);
168 }
169 partial_history.extend(req.chat_history);
170
171 let mut full_history: Vec<openai::Message> = match &req.preamble {
173 Some(preamble) => vec![openai::Message::system(preamble)],
174 None => vec![],
175 };
176
177 full_history.extend(
179 partial_history
180 .into_iter()
181 .map(message::Message::try_into)
182 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
183 .into_iter()
184 .flatten()
185 .collect::<Vec<_>>(),
186 );
187
188 let tool_choice = req
189 .tool_choice
190 .clone()
191 .map(crate::providers::openai::ToolChoice::try_from)
192 .transpose()?;
193
194 Ok(Self {
195 model: model.to_string(),
196 messages: full_history,
197 temperature: req.temperature,
198 max_tokens: req.max_tokens,
199 tools: req
200 .tools
201 .clone()
202 .into_iter()
203 .map(openai::ToolDefinition::from)
204 .collect::<Vec<_>>(),
205 tool_choice,
206 additional_params: req.additional_params,
207 })
208 }
209}
210
211#[derive(Clone)]
212pub struct CompletionModel<T = reqwest::Client> {
213 client: Client<T>,
214 pub model: String,
215}
216
217impl<T> CompletionModel<T> {
218 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
219 Self {
220 client,
221 model: model.into(),
222 }
223 }
224}
225
226impl<T> completion::CompletionModel for CompletionModel<T>
227where
228 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
229{
230 type Response = openai::CompletionResponse;
231 type StreamingResponse = openai::StreamingCompletionResponse;
232
233 type Client = Client<T>;
234
235 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
236 Self::new(client.clone(), model)
237 }
238
239 async fn completion(
240 &self,
241 completion_request: CompletionRequest,
242 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
243 let span = if tracing::Span::current().is_disabled() {
244 info_span!(
245 target: "rig::completions",
246 "chat",
247 gen_ai.operation.name = "chat",
248 gen_ai.provider.name = "moonshot",
249 gen_ai.request.model = self.model,
250 gen_ai.system_instructions = tracing::field::Empty,
251 gen_ai.response.id = tracing::field::Empty,
252 gen_ai.response.model = tracing::field::Empty,
253 gen_ai.usage.output_tokens = tracing::field::Empty,
254 gen_ai.usage.input_tokens = tracing::field::Empty,
255 )
256 } else {
257 tracing::Span::current()
258 };
259
260 span.record("gen_ai.system_instructions", &completion_request.preamble);
261
262 let request =
263 MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
264
265 if tracing::enabled!(tracing::Level::TRACE) {
266 tracing::trace!(target: "rig::completions",
267 "MoonShot completion request: {}",
268 serde_json::to_string_pretty(&request)?
269 );
270 }
271
272 let body = serde_json::to_vec(&request)?;
273 let req = self
274 .client
275 .post("/chat/completions")?
276 .body(body)
277 .map_err(http_client::Error::from)?;
278
279 let async_block = async move {
280 let response = self.client.send::<_, bytes::Bytes>(req).await?;
281
282 let status = response.status();
283 let response_body = response.into_body().into_future().await?.to_vec();
284
285 if status.is_success() {
286 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
287 &response_body,
288 )? {
289 ApiResponse::Ok(response) => {
290 let span = tracing::Span::current();
291 span.record("gen_ai.response.id", response.id.clone());
292 span.record("gen_ai.response.model_name", response.model.clone());
293 if let Some(ref usage) = response.usage {
294 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
295 span.record(
296 "gen_ai.usage.output_tokens",
297 usage.total_tokens - usage.prompt_tokens,
298 );
299 }
300 if tracing::enabled!(tracing::Level::TRACE) {
301 tracing::trace!(target: "rig::completions",
302 "MoonShot completion response: {}",
303 serde_json::to_string_pretty(&response)?
304 );
305 }
306 response.try_into()
307 }
308 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
309 }
310 } else {
311 Err(CompletionError::ProviderError(
312 String::from_utf8_lossy(&response_body).to_string(),
313 ))
314 }
315 };
316
317 async_block.instrument(span).await
318 }
319
320 async fn stream(
321 &self,
322 request: CompletionRequest,
323 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
324 let span = if tracing::Span::current().is_disabled() {
325 info_span!(
326 target: "rig::completions",
327 "chat_streaming",
328 gen_ai.operation.name = "chat_streaming",
329 gen_ai.provider.name = "moonshot",
330 gen_ai.request.model = self.model,
331 gen_ai.system_instructions = tracing::field::Empty,
332 gen_ai.response.id = tracing::field::Empty,
333 gen_ai.response.model = tracing::field::Empty,
334 gen_ai.usage.output_tokens = tracing::field::Empty,
335 gen_ai.usage.input_tokens = tracing::field::Empty,
336 )
337 } else {
338 tracing::Span::current()
339 };
340
341 span.record("gen_ai.system_instructions", &request.preamble);
342 let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
343
344 let params = json_utils::merge(
345 request.additional_params.unwrap_or(serde_json::json!({})),
346 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
347 );
348
349 request.additional_params = Some(params);
350
351 if tracing::enabled!(tracing::Level::TRACE) {
352 tracing::trace!(target: "rig::completions",
353 "MoonShot streaming completion request: {}",
354 serde_json::to_string_pretty(&request)?
355 );
356 }
357
358 let body = serde_json::to_vec(&request)?;
359 let req = self
360 .client
361 .post("/chat/completions")?
362 .body(body)
363 .map_err(http_client::Error::from)?;
364
365 send_compatible_streaming_request(self.client.clone(), req)
366 .instrument(span)
367 .await
368 }
369}
370
371#[derive(Default, Debug, Deserialize, Serialize)]
372pub enum ToolChoice {
373 None,
374 #[default]
375 Auto,
376}
377
378impl TryFrom<message::ToolChoice> for ToolChoice {
379 type Error = CompletionError;
380
381 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
382 let res = match value {
383 message::ToolChoice::None => Self::None,
384 message::ToolChoice::Auto => Self::Auto,
385 choice => {
386 return Err(CompletionError::ProviderError(format!(
387 "Unsupported tool choice type: {choice:?}"
388 )));
389 }
390 };
391
392 Ok(res)
393 }
394}
395#[cfg(test)]
396mod tests {
397 #[test]
398 fn test_client_initialization() {
399 let _client =
400 crate::providers::moonshot::Client::new("dummy-key").expect("Client::new() failed");
401 let _client_from_builder = crate::providers::moonshot::Client::builder()
402 .api_key("dummy-key")
403 .build()
404 .expect("Client::builder() failed");
405 }
406}