1use crate::client::{
12 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
13 ProviderClient,
14};
15use crate::http_client::HttpClientExt;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 json_utils,
21 providers::openai,
22};
23use crate::{http_client, message};
24use serde::{Deserialize, Serialize};
25use tracing::{Instrument, info_span};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MoonshotExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MoonshotBuilder;
36
37type MoonshotApiKey = BearerAuth;
38
39impl Provider for MoonshotExt {
40 type Builder = MoonshotBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl DebugExt for MoonshotExt {}
56
57impl ProviderBuilder for MoonshotBuilder {
58 type Output = MoonshotExt;
59 type ApiKey = MoonshotApiKey;
60
61 const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
62}
63
64impl<H> Capabilities<H> for MoonshotExt {
65 type Completion = Capable<CompletionModel<H>>;
66 type Embeddings = Nothing;
67 type Transcription = Nothing;
68 #[cfg(feature = "image")]
69 type ImageGeneration = Nothing;
70 #[cfg(feature = "audio")]
71 type AudioGeneration = Nothing;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
75pub type ClientBuilder<H = reqwest::Client> =
76 client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
77
78impl ProviderClient for Client {
79 type Input = String;
80
81 fn from_env() -> Self {
84 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
85 Self::new(&api_key).unwrap()
86 }
87
88 fn from_val(input: Self::Input) -> Self {
89 Self::new(&input).unwrap()
90 }
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 error: MoonshotError,
96}
97
98#[derive(Debug, Deserialize)]
99struct MoonshotError {
100 message: String,
101}
102
103#[derive(Debug, Deserialize)]
104#[serde(untagged)]
105enum ApiResponse<T> {
106 Ok(T),
107 Err(ApiErrorResponse),
108}
109
110pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
115
116#[derive(Debug, Serialize, Deserialize)]
117pub(super) struct MoonshotCompletionRequest {
118 model: String,
119 pub messages: Vec<openai::Message>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 temperature: Option<f64>,
122 #[serde(skip_serializing_if = "Vec::is_empty")]
123 tools: Vec<openai::ToolDefinition>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 max_tokens: Option<u64>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
128 #[serde(flatten, skip_serializing_if = "Option::is_none")]
129 pub additional_params: Option<serde_json::Value>,
130}
131
132impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
133 type Error = CompletionError;
134
135 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
136 let mut partial_history = vec![];
138 if let Some(docs) = req.normalized_documents() {
139 partial_history.push(docs);
140 }
141 partial_history.extend(req.chat_history);
142
143 let mut full_history: Vec<openai::Message> = match &req.preamble {
145 Some(preamble) => vec![openai::Message::system(preamble)],
146 None => vec![],
147 };
148
149 full_history.extend(
151 partial_history
152 .into_iter()
153 .map(message::Message::try_into)
154 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
155 .into_iter()
156 .flatten()
157 .collect::<Vec<_>>(),
158 );
159
160 let tool_choice = req
161 .tool_choice
162 .clone()
163 .map(crate::providers::openai::ToolChoice::try_from)
164 .transpose()?;
165
166 Ok(Self {
167 model: model.to_string(),
168 messages: full_history,
169 temperature: req.temperature,
170 max_tokens: req.max_tokens,
171 tools: req
172 .tools
173 .clone()
174 .into_iter()
175 .map(openai::ToolDefinition::from)
176 .collect::<Vec<_>>(),
177 tool_choice,
178 additional_params: req.additional_params,
179 })
180 }
181}
182
183#[derive(Clone)]
184pub struct CompletionModel<T = reqwest::Client> {
185 client: Client<T>,
186 pub model: String,
187}
188
189impl<T> CompletionModel<T> {
190 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 }
195 }
196}
197
198impl<T> completion::CompletionModel for CompletionModel<T>
199where
200 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
201{
202 type Response = openai::CompletionResponse;
203 type StreamingResponse = openai::StreamingCompletionResponse;
204
205 type Client = Client<T>;
206
207 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
208 Self::new(client.clone(), model)
209 }
210
211 #[cfg_attr(feature = "worker", worker::send)]
212 async fn completion(
213 &self,
214 completion_request: CompletionRequest,
215 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
216 let span = if tracing::Span::current().is_disabled() {
217 info_span!(
218 target: "rig::completions",
219 "chat",
220 gen_ai.operation.name = "chat",
221 gen_ai.provider.name = "moonshot",
222 gen_ai.request.model = self.model,
223 gen_ai.system_instructions = tracing::field::Empty,
224 gen_ai.response.id = tracing::field::Empty,
225 gen_ai.response.model = tracing::field::Empty,
226 gen_ai.usage.output_tokens = tracing::field::Empty,
227 gen_ai.usage.input_tokens = tracing::field::Empty,
228 )
229 } else {
230 tracing::Span::current()
231 };
232
233 span.record("gen_ai.system_instructions", &completion_request.preamble);
234
235 let request =
236 MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
237
238 if tracing::enabled!(tracing::Level::TRACE) {
239 tracing::trace!(target: "rig::completions",
240 "MoonShot completion request: {}",
241 serde_json::to_string_pretty(&request)?
242 );
243 }
244
245 let body = serde_json::to_vec(&request)?;
246 let req = self
247 .client
248 .post("/chat/completions")?
249 .body(body)
250 .map_err(http_client::Error::from)?;
251
252 let async_block = async move {
253 let response = self.client.send::<_, bytes::Bytes>(req).await?;
254
255 let status = response.status();
256 let response_body = response.into_body().into_future().await?.to_vec();
257
258 if status.is_success() {
259 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
260 &response_body,
261 )? {
262 ApiResponse::Ok(response) => {
263 let span = tracing::Span::current();
264 span.record("gen_ai.response.id", response.id.clone());
265 span.record("gen_ai.response.model_name", response.model.clone());
266 if let Some(ref usage) = response.usage {
267 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
268 span.record(
269 "gen_ai.usage.output_tokens",
270 usage.total_tokens - usage.prompt_tokens,
271 );
272 }
273 if tracing::enabled!(tracing::Level::TRACE) {
274 tracing::trace!(target: "rig::completions",
275 "MoonShot completion response: {}",
276 serde_json::to_string_pretty(&response)?
277 );
278 }
279 response.try_into()
280 }
281 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
282 }
283 } else {
284 Err(CompletionError::ProviderError(
285 String::from_utf8_lossy(&response_body).to_string(),
286 ))
287 }
288 };
289
290 async_block.instrument(span).await
291 }
292
293 #[cfg_attr(feature = "worker", worker::send)]
294 async fn stream(
295 &self,
296 request: CompletionRequest,
297 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
298 let span = if tracing::Span::current().is_disabled() {
299 info_span!(
300 target: "rig::completions",
301 "chat_streaming",
302 gen_ai.operation.name = "chat_streaming",
303 gen_ai.provider.name = "moonshot",
304 gen_ai.request.model = self.model,
305 gen_ai.system_instructions = tracing::field::Empty,
306 gen_ai.response.id = tracing::field::Empty,
307 gen_ai.response.model = tracing::field::Empty,
308 gen_ai.usage.output_tokens = tracing::field::Empty,
309 gen_ai.usage.input_tokens = tracing::field::Empty,
310 )
311 } else {
312 tracing::Span::current()
313 };
314
315 span.record("gen_ai.system_instructions", &request.preamble);
316 let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
317
318 let params = json_utils::merge(
319 request.additional_params.unwrap_or(serde_json::json!({})),
320 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
321 );
322
323 request.additional_params = Some(params);
324
325 if tracing::enabled!(tracing::Level::TRACE) {
326 tracing::trace!(target: "rig::completions",
327 "MoonShot streaming completion request: {}",
328 serde_json::to_string_pretty(&request)?
329 );
330 }
331
332 let body = serde_json::to_vec(&request)?;
333 let req = self
334 .client
335 .post("/chat/completions")?
336 .body(body)
337 .map_err(http_client::Error::from)?;
338
339 send_compatible_streaming_request(self.client.clone(), req)
340 .instrument(span)
341 .await
342 }
343}
344
345#[derive(Default, Debug, Deserialize, Serialize)]
346pub enum ToolChoice {
347 None,
348 #[default]
349 Auto,
350}
351
352impl TryFrom<message::ToolChoice> for ToolChoice {
353 type Error = CompletionError;
354
355 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
356 let res = match value {
357 message::ToolChoice::None => Self::None,
358 message::ToolChoice::Auto => Self::Auto,
359 choice => {
360 return Err(CompletionError::ProviderError(format!(
361 "Unsupported tool choice type: {choice:?}"
362 )));
363 }
364 };
365
366 Ok(res)
367 }
368}