1use crate::client::{
12 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
13 ProviderClient,
14};
15use crate::http_client::HttpClientExt;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 json_utils,
21 providers::openai,
22};
23use crate::{http_client, message};
24use serde::{Deserialize, Serialize};
25use tracing::{Instrument, info_span};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct MoonshotExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MoonshotBuilder;
36
37type MoonshotApiKey = BearerAuth;
38
39impl Provider for MoonshotExt {
40 type Builder = MoonshotBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl DebugExt for MoonshotExt {}
56
57impl ProviderBuilder for MoonshotBuilder {
58 type Output = MoonshotExt;
59 type ApiKey = MoonshotApiKey;
60
61 const BASE_URL: &'static str = MOONSHOT_API_BASE_URL;
62}
63
64impl<H> Capabilities<H> for MoonshotExt {
65 type Completion = Capable<CompletionModel<H>>;
66 type Embeddings = Nothing;
67 type Transcription = Nothing;
68 #[cfg(feature = "image")]
69 type ImageGeneration = Nothing;
70 #[cfg(feature = "audio")]
71 type AudioGeneration = Nothing;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<MoonshotExt, H>;
75pub type ClientBuilder<H = reqwest::Client> =
76 client::ClientBuilder<MoonshotBuilder, MoonshotApiKey, H>;
77
78impl ProviderClient for Client {
79 type Input = String;
80
81 fn from_env() -> Self {
84 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
85 Self::new(&api_key).unwrap()
86 }
87
88 fn from_val(input: Self::Input) -> Self {
89 Self::new(&input).unwrap()
90 }
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 error: MoonshotError,
96}
97
98#[derive(Debug, Deserialize)]
99struct MoonshotError {
100 message: String,
101}
102
103#[derive(Debug, Deserialize)]
104#[serde(untagged)]
105enum ApiResponse<T> {
106 Ok(T),
107 Err(ApiErrorResponse),
108}
109
110pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
115
116#[derive(Debug, Serialize, Deserialize)]
117pub(super) struct MoonshotCompletionRequest {
118 model: String,
119 pub messages: Vec<openai::Message>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 temperature: Option<f64>,
122 #[serde(skip_serializing_if = "Vec::is_empty")]
123 tools: Vec<openai::ToolDefinition>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 max_tokens: Option<u64>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
128 #[serde(flatten, skip_serializing_if = "Option::is_none")]
129 pub additional_params: Option<serde_json::Value>,
130}
131
132impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest {
133 type Error = CompletionError;
134
135 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
136 let mut partial_history = vec![];
138 if let Some(docs) = req.normalized_documents() {
139 partial_history.push(docs);
140 }
141 partial_history.extend(req.chat_history);
142
143 let mut full_history: Vec<openai::Message> = match &req.preamble {
145 Some(preamble) => vec![openai::Message::system(preamble)],
146 None => vec![],
147 };
148
149 full_history.extend(
151 partial_history
152 .into_iter()
153 .map(message::Message::try_into)
154 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
155 .into_iter()
156 .flatten()
157 .collect::<Vec<_>>(),
158 );
159
160 let tool_choice = req
161 .tool_choice
162 .clone()
163 .map(crate::providers::openai::ToolChoice::try_from)
164 .transpose()?;
165
166 Ok(Self {
167 model: model.to_string(),
168 messages: full_history,
169 temperature: req.temperature,
170 max_tokens: req.max_tokens,
171 tools: req
172 .tools
173 .clone()
174 .into_iter()
175 .map(openai::ToolDefinition::from)
176 .collect::<Vec<_>>(),
177 tool_choice,
178 additional_params: req.additional_params,
179 })
180 }
181}
182
183#[derive(Clone)]
184pub struct CompletionModel<T = reqwest::Client> {
185 client: Client<T>,
186 pub model: String,
187}
188
189impl<T> CompletionModel<T> {
190 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 }
195 }
196}
197
198impl<T> completion::CompletionModel for CompletionModel<T>
199where
200 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
201{
202 type Response = openai::CompletionResponse;
203 type StreamingResponse = openai::StreamingCompletionResponse;
204
205 type Client = Client<T>;
206
207 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
208 Self::new(client.clone(), model)
209 }
210
211 #[cfg_attr(feature = "worker", worker::send)]
212 async fn completion(
213 &self,
214 completion_request: CompletionRequest,
215 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
216 let preamble = completion_request.preamble.clone();
217 let request =
218 MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
219
220 tracing::trace!(
221 "Moonshot API input: {request}",
222 request = serde_json::to_string_pretty(&request).unwrap()
223 );
224
225 let span = if tracing::Span::current().is_disabled() {
226 info_span!(
227 target: "rig::completions",
228 "chat",
229 gen_ai.operation.name = "chat",
230 gen_ai.provider.name = "moonshot",
231 gen_ai.request.model = self.model,
232 gen_ai.system_instructions = preamble,
233 gen_ai.response.id = tracing::field::Empty,
234 gen_ai.response.model = tracing::field::Empty,
235 gen_ai.usage.output_tokens = tracing::field::Empty,
236 gen_ai.usage.input_tokens = tracing::field::Empty,
237 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
238 gen_ai.output.messages = tracing::field::Empty,
239 )
240 } else {
241 tracing::Span::current()
242 };
243
244 let body = serde_json::to_vec(&request)?;
245 let req = self
246 .client
247 .post("/chat/completions")?
248 .body(body)
249 .map_err(http_client::Error::from)?;
250
251 let async_block = async move {
252 let response = self.client.send::<_, bytes::Bytes>(req).await?;
253
254 let status = response.status();
255 let response_body = response.into_body().into_future().await?.to_vec();
256
257 if status.is_success() {
258 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
259 &response_body,
260 )? {
261 ApiResponse::Ok(response) => {
262 let span = tracing::Span::current();
263 span.record("gen_ai.response.id", response.id.clone());
264 span.record("gen_ai.response.model_name", response.model.clone());
265 span.record(
266 "gen_ai.output.messages",
267 serde_json::to_string(&response.choices).unwrap(),
268 );
269 if let Some(ref usage) = response.usage {
270 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
271 span.record(
272 "gen_ai.usage.output_tokens",
273 usage.total_tokens - usage.prompt_tokens,
274 );
275 }
276 tracing::trace!(
277 target: "rig::completions",
278 "MoonShot completion response: {}",
279 serde_json::to_string_pretty(&response)?
280 );
281 response.try_into()
282 }
283 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
284 }
285 } else {
286 Err(CompletionError::ProviderError(
287 String::from_utf8_lossy(&response_body).to_string(),
288 ))
289 }
290 };
291
292 async_block.instrument(span).await
293 }
294
295 #[cfg_attr(feature = "worker", worker::send)]
296 async fn stream(
297 &self,
298 request: CompletionRequest,
299 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
300 let preamble = request.preamble.clone();
301 let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?;
302
303 let span = if tracing::Span::current().is_disabled() {
304 info_span!(
305 target: "rig::completions",
306 "chat_streaming",
307 gen_ai.operation.name = "chat_streaming",
308 gen_ai.provider.name = "moonshot",
309 gen_ai.request.model = self.model,
310 gen_ai.system_instructions = preamble,
311 gen_ai.response.id = tracing::field::Empty,
312 gen_ai.response.model = tracing::field::Empty,
313 gen_ai.usage.output_tokens = tracing::field::Empty,
314 gen_ai.usage.input_tokens = tracing::field::Empty,
315 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
316 gen_ai.output.messages = tracing::field::Empty,
317 )
318 } else {
319 tracing::Span::current()
320 };
321
322 let params = json_utils::merge(
323 request.additional_params.unwrap_or(serde_json::json!({})),
324 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
325 );
326
327 request.additional_params = Some(params);
328
329 let body = serde_json::to_vec(&request)?;
330 let req = self
331 .client
332 .post("/chat/completions")?
333 .body(body)
334 .map_err(http_client::Error::from)?;
335
336 send_compatible_streaming_request(self.client.http_client().clone(), req)
337 .instrument(span)
338 .await
339 }
340}
341
342#[derive(Default, Debug, Deserialize, Serialize)]
343pub enum ToolChoice {
344 None,
345 #[default]
346 Auto,
347}
348
349impl TryFrom<message::ToolChoice> for ToolChoice {
350 type Error = CompletionError;
351
352 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
353 let res = match value {
354 message::ToolChoice::None => Self::None,
355 message::ToolChoice::Auto => Self::Auto,
356 choice => {
357 return Err(CompletionError::ProviderError(format!(
358 "Unsupported tool choice type: {choice:?}"
359 )));
360 }
361 };
362
363 Ok(res)
364 }
365}