rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11
12use crate::json_utils::merge;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16    agent::AgentBuilder,
17    completion::{self, CompletionError, CompletionRequest},
18    extractor::ExtractorBuilder,
19    json_utils,
20    providers::openai,
21};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::{json, Value};
25
26// ================================================================
27// Main Moonshot Client
28// ================================================================
29const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
30
31#[derive(Clone)]
32pub struct Client {
33    base_url: String,
34    http_client: reqwest::Client,
35}
36
37impl Client {
38    /// Create a new Moonshot client with the given API key.
39    pub fn new(api_key: &str) -> Self {
40        Self::from_url(api_key, MOONSHOT_API_BASE_URL)
41    }
42
43    /// Create a new Moonshot client with the given API key and base API URL.
44    pub fn from_url(api_key: &str, base_url: &str) -> Self {
45        Self {
46            base_url: base_url.to_string(),
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        "Authorization",
52                        format!("Bearer {}", api_key)
53                            .parse()
54                            .expect("Bearer token should parse"),
55                    );
56                    headers
57                })
58                .build()
59                .expect("Moonshot reqwest client should build"),
60        }
61    }
62
63    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
64    /// Panics if the environment variable is not set.
65    pub fn from_env() -> Self {
66        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
67        Self::new(&api_key)
68    }
69
70    fn post(&self, path: &str) -> reqwest::RequestBuilder {
71        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
72        self.http_client.post(url)
73    }
74
75    /// Create a completion model with the given name.
76    ///
77    /// # Example
78    /// ```
79    /// use rig::providers::moonshot::{Client, self};
80    ///
81    /// // Initialize the Moonshot client
82    /// let moonshot = Client::new("your-moonshot-api-key");
83    ///
84    /// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT);
85    /// ```
86    pub fn completion_model(&self, model: &str) -> CompletionModel {
87        CompletionModel::new(self.clone(), model)
88    }
89
90    /// Create an agent builder with the given completion model.
91    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
92        AgentBuilder::new(self.completion_model(model))
93    }
94
95    /// Create an extractor builder with the given completion model.
96    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
97        &self,
98        model: &str,
99    ) -> ExtractorBuilder<T, CompletionModel> {
100        ExtractorBuilder::new(self.completion_model(model))
101    }
102}
103
104#[derive(Debug, Deserialize)]
105struct ApiErrorResponse {
106    error: MoonshotError,
107}
108
109#[derive(Debug, Deserialize)]
110struct MoonshotError {
111    message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117    Ok(T),
118    Err(ApiErrorResponse),
119}
120
121// ================================================================
122// Moonshot Completion API
123// ================================================================
124pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
125
126#[derive(Clone)]
127pub struct CompletionModel {
128    client: Client,
129    pub model: String,
130}
131
132impl CompletionModel {
133    pub fn new(client: Client, model: &str) -> Self {
134        Self {
135            client,
136            model: model.to_string(),
137        }
138    }
139
140    fn create_completion_request(
141        &self,
142        completion_request: CompletionRequest,
143    ) -> Result<Value, CompletionError> {
144        // Add preamble to chat history (if available)
145        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
146            Some(preamble) => vec![openai::Message::system(preamble)],
147            None => vec![],
148        };
149
150        // Convert prompt to user message
151        let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
152
153        // Convert existing chat history
154        let chat_history: Vec<openai::Message> = completion_request
155            .chat_history
156            .into_iter()
157            .map(|message| message.try_into())
158            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
159            .into_iter()
160            .flatten()
161            .collect();
162
163        // Combine all messages into a single history
164        full_history.extend(chat_history);
165        full_history.extend(prompt);
166
167        let request = if completion_request.tools.is_empty() {
168            json!({
169                "model": self.model,
170                "messages": full_history,
171                "temperature": completion_request.temperature,
172            })
173        } else {
174            json!({
175                "model": self.model,
176                "messages": full_history,
177                "temperature": completion_request.temperature,
178                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
179                "tool_choice": "auto",
180            })
181        };
182
183        let request = if let Some(params) = completion_request.additional_params {
184            json_utils::merge(request, params)
185        } else {
186            request
187        };
188
189        Ok(request)
190    }
191}
192
193impl completion::CompletionModel for CompletionModel {
194    type Response = openai::CompletionResponse;
195
196    #[cfg_attr(feature = "worker", worker::send)]
197    async fn completion(
198        &self,
199        completion_request: CompletionRequest,
200    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
201        let request = self.create_completion_request(completion_request)?;
202
203        let response = self
204            .client
205            .post("/chat/completions")
206            .json(&request)
207            .send()
208            .await?;
209
210        if response.status().is_success() {
211            let t = response.text().await?;
212            tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
213
214            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
215                ApiResponse::Ok(response) => {
216                    tracing::info!(target: "rig",
217                        "MoonShot completion token usage: {:?}",
218                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
219                    );
220                    response.try_into()
221                }
222                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
223            }
224        } else {
225            Err(CompletionError::ProviderError(response.text().await?))
226        }
227    }
228}
229
230impl StreamingCompletionModel for CompletionModel {
231    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
232        let mut request = self.create_completion_request(request)?;
233
234        request = merge(request, json!({"stream": true}));
235
236        let builder = self.client.post("/chat/completions").json(&request);
237
238        send_compatible_streaming_request(builder).await
239    }
240}