1use crate::client::{
12 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
13 ProviderClient,
14};
15use crate::http_client::HttpClientExt;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 json_utils,
21 providers::openai,
22};
23use crate::{http_client, message};
24use serde::{Deserialize, Serialize};
25use tracing::{Instrument, info_span};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MoonshotExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MoonshotBuilder;
36
37type MoonshotApiKey = BearerAuth;
38
39impl Provider for MoonshotExt {
40 type Builder = MoonshotBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl DebugExt for MoonshotExt {}
56
57impl ProviderBuilder for MoonshotBuilder {
58 type Output = MoonshotExt;
59 type ApiKey = MoonshotApiKey;
60
61 const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
62}
63
64impl<H> Capabilities<H> for MoonshotExt {
65 type Completion = Capable<CompletionModel<H>>;
66 type Embeddings = Nothing;
67 type Transcription = Nothing;
68 type ModelListing = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
76pub type ClientBuilder<H = reqwest::Client> =
77 client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
78
79impl ProviderClient for Client {
80 type Input = String;
81
82 fn from_env() -> Self {
85 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
86 Self::new(&api_key).unwrap()
87 }
88
89 fn from_val(input: Self::Input) -> Self {
90 Self::new(&input).unwrap()
91 }
92}
93
94#[derive(Debug, Deserialize)]
95struct ApiErrorResponse {
96 error: MoonshotError,
97}
98
99#[derive(Debug, Deserialize)]
100struct MoonshotError {
101 message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107 Ok(T),
108 Err(ApiErrorResponse),
109}
110
111pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
116
117#[derive(Debug, Serialize, Deserialize)]
118pub(super) struct MoonshotCompletionRequest {
119 model: String,
120 pub messages: Vec<openai::Message>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 temperature: Option<f64>,
123 #[serde(skip_serializing_if = "Vec::is_empty")]
124 tools: Vec<openai::ToolDefinition>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 max_tokens: Option<u64>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
129 #[serde(flatten, skip_serializing_if = "Option::is_none")]
130 pub additional_params: Option<serde_json::Value>,
131}
132
133impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
134 type Error = CompletionError;
135
136 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
137 if req.output_schema.is_some() {
138 tracing::warn!("Structured outputs currently not supported for Moonshot");
139 }
140 let model = req.model.clone().unwrap_or_else(|| model.to_string());
141 let mut partial_history = vec![];
143 if let Some(docs) = req.normalized_documents() {
144 partial_history.push(docs);
145 }
146 partial_history.extend(req.chat_history);
147
148 let mut full_history: Vec<openai::Message> = match &req.preamble {
150 Some(preamble) => vec![openai::Message::system(preamble)],
151 None => vec![],
152 };
153
154 full_history.extend(
156 partial_history
157 .into_iter()
158 .map(message::Message::try_into)
159 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
160 .into_iter()
161 .flatten()
162 .collect::<Vec<_>>(),
163 );
164
165 let tool_choice = req
166 .tool_choice
167 .clone()
168 .map(crate::providers::openai::ToolChoice::try_from)
169 .transpose()?;
170
171 Ok(Self {
172 model: model.to_string(),
173 messages: full_history,
174 temperature: req.temperature,
175 max_tokens: req.max_tokens,
176 tools: req
177 .tools
178 .clone()
179 .into_iter()
180 .map(openai::ToolDefinition::from)
181 .collect::<Vec<_>>(),
182 tool_choice,
183 additional_params: req.additional_params,
184 })
185 }
186}
187
188#[derive(Clone)]
189pub struct CompletionModel<T = reqwest::Client> {
190 client: Client<T>,
191 pub model: String,
192}
193
194impl<T> CompletionModel<T> {
195 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
196 Self {
197 client,
198 model: model.into(),
199 }
200 }
201}
202
203impl<T> completion::CompletionModel for CompletionModel<T>
204where
205 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
206{
207 type Response = openai::CompletionResponse;
208 type StreamingResponse = openai::StreamingCompletionResponse;
209
210 type Client = Client<T>;
211
212 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
213 Self::new(client.clone(), model)
214 }
215
216 async fn completion(
217 &self,
218 completion_request: CompletionRequest,
219 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
220 let span = if tracing::Span::current().is_disabled() {
221 info_span!(
222 target: "rig::completions",
223 "chat",
224 gen_ai.operation.name = "chat",
225 gen_ai.provider.name = "moonshot",
226 gen_ai.request.model = self.model,
227 gen_ai.system_instructions = tracing::field::Empty,
228 gen_ai.response.id = tracing::field::Empty,
229 gen_ai.response.model = tracing::field::Empty,
230 gen_ai.usage.output_tokens = tracing::field::Empty,
231 gen_ai.usage.input_tokens = tracing::field::Empty,
232 )
233 } else {
234 tracing::Span::current()
235 };
236
237 span.record("gen_ai.system_instructions", &completion_request.preamble);
238
239 let request =
240 MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
241
242 if tracing::enabled!(tracing::Level::TRACE) {
243 tracing::trace!(target: "rig::completions",
244 "MoonShot completion request: {}",
245 serde_json::to_string_pretty(&request)?
246 );
247 }
248
249 let body = serde_json::to_vec(&request)?;
250 let req = self
251 .client
252 .post("/chat/completions")?
253 .body(body)
254 .map_err(http_client::Error::from)?;
255
256 let async_block = async move {
257 let response = self.client.send::<_, bytes::Bytes>(req).await?;
258
259 let status = response.status();
260 let response_body = response.into_body().into_future().await?.to_vec();
261
262 if status.is_success() {
263 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
264 &response_body,
265 )? {
266 ApiResponse::Ok(response) => {
267 let span = tracing::Span::current();
268 span.record("gen_ai.response.id", response.id.clone());
269 span.record("gen_ai.response.model_name", response.model.clone());
270 if let Some(ref usage) = response.usage {
271 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
272 span.record(
273 "gen_ai.usage.output_tokens",
274 usage.total_tokens - usage.prompt_tokens,
275 );
276 }
277 if tracing::enabled!(tracing::Level::TRACE) {
278 tracing::trace!(target: "rig::completions",
279 "MoonShot completion response: {}",
280 serde_json::to_string_pretty(&response)?
281 );
282 }
283 response.try_into()
284 }
285 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
286 }
287 } else {
288 Err(CompletionError::ProviderError(
289 String::from_utf8_lossy(&response_body).to_string(),
290 ))
291 }
292 };
293
294 async_block.instrument(span).await
295 }
296
297 async fn stream(
298 &self,
299 request: CompletionRequest,
300 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
301 let span = if tracing::Span::current().is_disabled() {
302 info_span!(
303 target: "rig::completions",
304 "chat_streaming",
305 gen_ai.operation.name = "chat_streaming",
306 gen_ai.provider.name = "moonshot",
307 gen_ai.request.model = self.model,
308 gen_ai.system_instructions = tracing::field::Empty,
309 gen_ai.response.id = tracing::field::Empty,
310 gen_ai.response.model = tracing::field::Empty,
311 gen_ai.usage.output_tokens = tracing::field::Empty,
312 gen_ai.usage.input_tokens = tracing::field::Empty,
313 )
314 } else {
315 tracing::Span::current()
316 };
317
318 span.record("gen_ai.system_instructions", &request.preamble);
319 let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
320
321 let params = json_utils::merge(
322 request.additional_params.unwrap_or(serde_json::json!({})),
323 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
324 );
325
326 request.additional_params = Some(params);
327
328 if tracing::enabled!(tracing::Level::TRACE) {
329 tracing::trace!(target: "rig::completions",
330 "MoonShot streaming completion request: {}",
331 serde_json::to_string_pretty(&request)?
332 );
333 }
334
335 let body = serde_json::to_vec(&request)?;
336 let req = self
337 .client
338 .post("/chat/completions")?
339 .body(body)
340 .map_err(http_client::Error::from)?;
341
342 send_compatible_streaming_request(self.client.clone(), req)
343 .instrument(span)
344 .await
345 }
346}
347
348#[derive(Default, Debug, Deserialize, Serialize)]
349pub enum ToolChoice {
350 None,
351 #[default]
352 Auto,
353}
354
355impl TryFrom<message::ToolChoice> for ToolChoice {
356 type Error = CompletionError;
357
358 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
359 let res = match value {
360 message::ToolChoice::None => Self::None,
361 message::ToolChoice::Auto => Self::Auto,
362 choice => {
363 return Err(CompletionError::ProviderError(format!(
364 "Unsupported tool choice type: {choice:?}"
365 )));
366 }
367 };
368
369 Ok(res)
370 }
371}
372#[cfg(test)]
373mod tests {
374 #[test]
375 fn test_client_initialization() {
376 let _client: crate::providers::moonshot::Client =
377 crate::providers::moonshot::Client::new("dummy-key").expect("Client::new() failed");
378 let _client_from_builder: crate::providers::moonshot::Client =
379 crate::providers::moonshot::Client::builder()
380 .api_key("dummy-key")
381 .build()
382 .expect("Client::builder() failed");
383 }
384}