rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11
12use crate::json_utils::merge;
13use crate::message;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::{StreamingCompletionModel, StreamingResult};
16use crate::{
17    agent::AgentBuilder,
18    completion::{self, CompletionError, CompletionRequest},
19    extractor::ExtractorBuilder,
20    json_utils,
21    providers::openai,
22};
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::{json, Value};
26
27// ================================================================
28// Main Moonshot Client
29// ================================================================
30const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Clone)]
33pub struct Client {
34    base_url: String,
35    http_client: reqwest::Client,
36}
37
38impl Client {
39    /// Create a new Moonshot client with the given API key.
40    pub fn new(api_key: &str) -> Self {
41        Self::from_url(api_key, MOONSHOT_API_BASE_URL)
42    }
43
44    /// Create a new Moonshot client with the given API key and base API URL.
45    pub fn from_url(api_key: &str, base_url: &str) -> Self {
46        Self {
47            base_url: base_url.to_string(),
48            http_client: reqwest::Client::builder()
49                .default_headers({
50                    let mut headers = reqwest::header::HeaderMap::new();
51                    headers.insert(
52                        "Authorization",
53                        format!("Bearer {}", api_key)
54                            .parse()
55                            .expect("Bearer token should parse"),
56                    );
57                    headers
58                })
59                .build()
60                .expect("Moonshot reqwest client should build"),
61        }
62    }
63
64    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
65    /// Panics if the environment variable is not set.
66    pub fn from_env() -> Self {
67        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
68        Self::new(&api_key)
69    }
70
71    fn post(&self, path: &str) -> reqwest::RequestBuilder {
72        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
73        self.http_client.post(url)
74    }
75
76    /// Create a completion model with the given name.
77    ///
78    /// # Example
79    /// ```
80    /// use rig::providers::moonshot::{Client, self};
81    ///
82    /// // Initialize the Moonshot client
83    /// let moonshot = Client::new("your-moonshot-api-key");
84    ///
85    /// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT);
86    /// ```
87    pub fn completion_model(&self, model: &str) -> CompletionModel {
88        CompletionModel::new(self.clone(), model)
89    }
90
91    /// Create an agent builder with the given completion model.
92    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
93        AgentBuilder::new(self.completion_model(model))
94    }
95
96    /// Create an extractor builder with the given completion model.
97    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
98        &self,
99        model: &str,
100    ) -> ExtractorBuilder<T, CompletionModel> {
101        ExtractorBuilder::new(self.completion_model(model))
102    }
103}
104
105#[derive(Debug, Deserialize)]
106struct ApiErrorResponse {
107    error: MoonshotError,
108}
109
110#[derive(Debug, Deserialize)]
111struct MoonshotError {
112    message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118    Ok(T),
119    Err(ApiErrorResponse),
120}
121
122// ================================================================
123// Moonshot Completion API
124// ================================================================
125pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
126
127#[derive(Clone)]
128pub struct CompletionModel {
129    client: Client,
130    pub model: String,
131}
132
133impl CompletionModel {
134    pub fn new(client: Client, model: &str) -> Self {
135        Self {
136            client,
137            model: model.to_string(),
138        }
139    }
140
141    fn create_completion_request(
142        &self,
143        completion_request: CompletionRequest,
144    ) -> Result<Value, CompletionError> {
145        // Build up the order of messages (context, chat_history)
146        let mut partial_history = vec![];
147        if let Some(docs) = completion_request.normalized_documents() {
148            partial_history.push(docs);
149        }
150        partial_history.extend(completion_request.chat_history);
151
152        // Initialize full history with preamble (or empty if non-existent)
153        let mut full_history: Vec<openai::Message> = completion_request
154            .preamble
155            .map_or_else(Vec::new, |preamble| {
156                vec![openai::Message::system(&preamble)]
157            });
158
159        // Convert and extend the rest of the history
160        full_history.extend(
161            partial_history
162                .into_iter()
163                .map(message::Message::try_into)
164                .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
165                .into_iter()
166                .flatten()
167                .collect::<Vec<_>>(),
168        );
169
170        let request = if completion_request.tools.is_empty() {
171            json!({
172                "model": self.model,
173                "messages": full_history,
174                "temperature": completion_request.temperature,
175            })
176        } else {
177            json!({
178                "model": self.model,
179                "messages": full_history,
180                "temperature": completion_request.temperature,
181                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
182                "tool_choice": "auto",
183            })
184        };
185
186        let request = if let Some(params) = completion_request.additional_params {
187            json_utils::merge(request, params)
188        } else {
189            request
190        };
191
192        Ok(request)
193    }
194}
195
196impl completion::CompletionModel for CompletionModel {
197    type Response = openai::CompletionResponse;
198
199    #[cfg_attr(feature = "worker", worker::send)]
200    async fn completion(
201        &self,
202        completion_request: CompletionRequest,
203    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
204        let request = self.create_completion_request(completion_request)?;
205
206        let response = self
207            .client
208            .post("/chat/completions")
209            .json(&request)
210            .send()
211            .await?;
212
213        if response.status().is_success() {
214            let t = response.text().await?;
215            tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
216
217            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
218                ApiResponse::Ok(response) => {
219                    tracing::info!(target: "rig",
220                        "MoonShot completion token usage: {:?}",
221                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
222                    );
223                    response.try_into()
224                }
225                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
226            }
227        } else {
228            Err(CompletionError::ProviderError(response.text().await?))
229        }
230    }
231}
232
233impl StreamingCompletionModel for CompletionModel {
234    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
235        let mut request = self.create_completion_request(request)?;
236
237        request = merge(request, json!({"stream": true}));
238
239        let builder = self.client.post("/chat/completions").json(&request);
240
241        send_compatible_streaming_request(builder).await
242    }
243}