1use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
12use crate::http_client::HttpClientExt;
13use crate::json_utils::merge;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    completion::{self, CompletionError, CompletionRequest},
18    json_utils,
19    providers::openai,
20};
21use crate::{http_client, impl_conversion_traits, message};
22use http::Method;
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use tracing::{Instrument, info_span};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32pub struct ClientBuilder<'a, T = reqwest::Client> {
33    api_key: &'a str,
34    base_url: &'a str,
35    http_client: T,
36}
37
38impl<'a, T> ClientBuilder<'a, T>
39where
40    T: Default,
41{
42    pub fn new(api_key: &'a str) -> Self {
43        Self {
44            api_key,
45            base_url: MOONSHOT_API_BASE_URL,
46            http_client: Default::default(),
47        }
48    }
49}
50
51impl<'a, T> ClientBuilder<'a, T> {
52    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
53        Self {
54            api_key,
55            base_url: MOONSHOT_API_BASE_URL,
56            http_client,
57        }
58    }
59
60    pub fn base_url(mut self, base_url: &'a str) -> Self {
61        self.base_url = base_url;
62        self
63    }
64
65    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
66        ClientBuilder {
67            api_key: self.api_key,
68            base_url: self.base_url,
69            http_client,
70        }
71    }
72
73    pub fn build(self) -> Client<T> {
74        Client {
75            base_url: self.base_url.to_string(),
76            api_key: self.api_key.to_string(),
77            http_client: self.http_client,
78        }
79    }
80}
81
82#[derive(Clone)]
83pub struct Client<T = reqwest::Client> {
84    base_url: String,
85    api_key: String,
86    http_client: T,
87}
88
89impl<T> std::fmt::Debug for Client<T>
90where
91    T: std::fmt::Debug,
92{
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("Client")
95            .field("base_url", &self.base_url)
96            .field("http_client", &self.http_client)
97            .field("api_key", &"<REDACTED>")
98            .finish()
99    }
100}
101
102impl<T> Client<T>
103where
104    T: HttpClientExt,
105{
106    fn req(
107        &self,
108        method: http_client::Method,
109        path: &str,
110    ) -> http_client::Result<http_client::Builder> {
111        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113        http_client::with_bearer_auth(
114            http_client::Builder::new().method(method).uri(url),
115            &self.api_key,
116        )
117    }
118
119    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
120        self.req(http_client::Method::GET, path)
121    }
122}
123
124impl Client<reqwest::Client> {
125    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
126        ClientBuilder::new(api_key)
127    }
128
129    pub fn new(api_key: &str) -> Self {
130        Self::builder(api_key).build()
131    }
132
133    pub fn from_env() -> Self {
134        <Self as ProviderClient>::from_env()
135    }
136}
137
138impl<T> ProviderClient for Client<T>
139where
140    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
141{
142    fn from_env() -> Self {
145        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
146        ClientBuilder::<T>::new(&api_key).build()
147    }
148
149    fn from_val(input: crate::client::ProviderValue) -> Self {
150        let crate::client::ProviderValue::Simple(api_key) = input else {
151            panic!("Incorrect provider value type")
152        };
153        ClientBuilder::<T>::new(&api_key).build()
154    }
155}
156
157impl<T> CompletionClient for Client<T>
158where
159    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
160{
161    type CompletionModel = CompletionModel<T>;
162
163    fn completion_model(&self, model: &str) -> Self::CompletionModel {
175        CompletionModel::new(self.clone(), model)
176    }
177}
178
179impl<T> VerifyClient for Client<T>
180where
181    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
182{
183    #[cfg_attr(feature = "worker", worker::send)]
184    async fn verify(&self) -> Result<(), VerifyError> {
185        let req = self
186            .get("/models")?
187            .body(http_client::NoBody)
188            .map_err(http_client::Error::from)?;
189
190        let response = HttpClientExt::send(&self.http_client, req).await?;
191
192        match response.status() {
193            reqwest::StatusCode::OK => Ok(()),
194            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
195            reqwest::StatusCode::INTERNAL_SERVER_ERROR
196            | reqwest::StatusCode::SERVICE_UNAVAILABLE
197            | reqwest::StatusCode::BAD_GATEWAY => {
198                let text = http_client::text(response).await?;
199                Err(VerifyError::ProviderError(text))
200            }
201            _ => Ok(()),
202        }
203    }
204}
205
206impl_conversion_traits!(
207    AsEmbeddings,
208    AsTranscription,
209    AsImageGeneration,
210    AsAudioGeneration for Client<T>
211);
212
213#[derive(Debug, Deserialize)]
214struct ApiErrorResponse {
215    error: MoonshotError,
216}
217
218#[derive(Debug, Deserialize)]
219struct MoonshotError {
220    message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225enum ApiResponse<T> {
226    Ok(T),
227    Err(ApiErrorResponse),
228}
229
230pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
234
235#[derive(Clone)]
236pub struct CompletionModel<T = reqwest::Client> {
237    client: Client<T>,
238    pub model: String,
239}
240
241impl<T> CompletionModel<T> {
242    pub fn new(client: Client<T>, model: &str) -> Self {
243        Self {
244            client,
245            model: model.to_string(),
246        }
247    }
248
249    fn create_completion_request(
250        &self,
251        completion_request: CompletionRequest,
252    ) -> Result<Value, CompletionError> {
253        let mut partial_history = vec![];
255        if let Some(docs) = completion_request.normalized_documents() {
256            partial_history.push(docs);
257        }
258        partial_history.extend(completion_request.chat_history);
259
260        let mut full_history: Vec<openai::Message> = completion_request
262            .preamble
263            .map_or_else(Vec::new, |preamble| {
264                vec![openai::Message::system(&preamble)]
265            });
266
267        full_history.extend(
269            partial_history
270                .into_iter()
271                .map(message::Message::try_into)
272                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
273                .into_iter()
274                .flatten()
275                .collect::<Vec<_>>(),
276        );
277
278        let tool_choice = completion_request
279            .tool_choice
280            .map(ToolChoice::try_from)
281            .transpose()?;
282
283        let request = if completion_request.tools.is_empty() {
284            json!({
285                "model": self.model,
286                "messages": full_history,
287                "temperature": completion_request.temperature,
288                "max_tokens": completion_request.max_tokens,
289            })
290        } else {
291            json!({
292                "model": self.model,
293                "messages": full_history,
294                "temperature": completion_request.temperature,
295                "max_tokens": completion_request.max_tokens,
296                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
297                "tool_choice": tool_choice,
298            })
299        };
300
301        let request = if let Some(params) = completion_request.additional_params {
302            json_utils::merge(request, params)
303        } else {
304            request
305        };
306
307        Ok(request)
308    }
309}
310
311impl<T> completion::CompletionModel for CompletionModel<T>
312where
313    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
314{
315    type Response = openai::CompletionResponse;
316    type StreamingResponse = openai::StreamingCompletionResponse;
317
318    #[cfg_attr(feature = "worker", worker::send)]
319    async fn completion(
320        &self,
321        completion_request: CompletionRequest,
322    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
323        let preamble = completion_request.preamble.clone();
324        let request = self.create_completion_request(completion_request)?;
325
326        println!(
327            "Moonshot API input: {request}",
328            request = serde_json::to_string_pretty(&request).unwrap()
329        );
330
331        let span = if tracing::Span::current().is_disabled() {
332            info_span!(
333                target: "rig::completions",
334                "chat",
335                gen_ai.operation.name = "chat",
336                gen_ai.provider.name = "moonshot",
337                gen_ai.request.model = self.model,
338                gen_ai.system_instructions = preamble,
339                gen_ai.response.id = tracing::field::Empty,
340                gen_ai.response.model = tracing::field::Empty,
341                gen_ai.usage.output_tokens = tracing::field::Empty,
342                gen_ai.usage.input_tokens = tracing::field::Empty,
343                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
344                gen_ai.output.messages = tracing::field::Empty,
345            )
346        } else {
347            tracing::Span::current()
348        };
349
350        let body = serde_json::to_vec(&request)?;
351        let req = self
352            .client
353            .req(Method::POST, "/chat/completions")?
354            .header("Content-Type", "application/json")
355            .body(body)
356            .map_err(http_client::Error::from)?;
357
358        let async_block = async move {
359            let response = self.client.http_client.send::<_, bytes::Bytes>(req).await?;
360
361            let status = response.status();
362            let response_body = response.into_body().into_future().await?.to_vec();
363
364            if status.is_success() {
365                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
366                    &response_body,
367                )? {
368                    ApiResponse::Ok(response) => {
369                        tracing::debug!(target: "rig::completions", "MoonShot completion response: {t}", t = serde_json::to_string_pretty(&response)?);
370                        let span = tracing::Span::current();
371                        span.record("gen_ai.response.id", response.id.clone());
372                        span.record("gen_ai.response.model_name", response.model.clone());
373                        span.record(
374                            "gen_ai.output.messages",
375                            serde_json::to_string(&response.choices).unwrap(),
376                        );
377                        if let Some(ref usage) = response.usage {
378                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
379                            span.record(
380                                "gen_ai.usage.output_tokens",
381                                usage.total_tokens - usage.prompt_tokens,
382                            );
383                        }
384                        response.try_into()
385                    }
386                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
387                }
388            } else {
389                Err(CompletionError::ProviderError(
390                    String::from_utf8_lossy(&response_body).to_string(),
391                ))
392            }
393        };
394
395        async_block.instrument(span).await
396    }
397
398    #[cfg_attr(feature = "worker", worker::send)]
399    async fn stream(
400        &self,
401        request: CompletionRequest,
402    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
403        let preamble = request.preamble.clone();
404        let mut request = self.create_completion_request(request)?;
405
406        let span = if tracing::Span::current().is_disabled() {
407            info_span!(
408                target: "rig::completions",
409                "chat_streaming",
410                gen_ai.operation.name = "chat_streaming",
411                gen_ai.provider.name = "moonshot",
412                gen_ai.request.model = self.model,
413                gen_ai.system_instructions = preamble,
414                gen_ai.response.id = tracing::field::Empty,
415                gen_ai.response.model = tracing::field::Empty,
416                gen_ai.usage.output_tokens = tracing::field::Empty,
417                gen_ai.usage.input_tokens = tracing::field::Empty,
418                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
419                gen_ai.output.messages = tracing::field::Empty,
420            )
421        } else {
422            tracing::Span::current()
423        };
424
425        request = merge(
426            request,
427            json!({"stream": true, "stream_options": {"include_usage": true}}),
428        );
429
430        let body = serde_json::to_vec(&request)?;
431        let req = self
432            .client
433            .req(Method::POST, "/chat/completions")?
434            .header("Content-Type", "application/json")
435            .body(body)
436            .map_err(http_client::Error::from)?;
437
438        send_compatible_streaming_request(self.client.http_client.clone(), req)
439            .instrument(span)
440            .await
441    }
442}
443
444#[derive(Default, Debug, Deserialize, Serialize)]
445pub enum ToolChoice {
446    None,
447    #[default]
448    Auto,
449}
450
451impl TryFrom<message::ToolChoice> for ToolChoice {
452    type Error = CompletionError;
453
454    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
455        let res = match value {
456            message::ToolChoice::None => Self::None,
457            message::ToolChoice::Auto => Self::Auto,
458            choice => {
459                return Err(CompletionError::ProviderError(format!(
460                    "Unsupported tool choice type: {choice:?}"
461                )));
462            }
463        };
464
465        Ok(res)
466    }
467}