rig/providers/
moonshot.rs

1//! Moonshot API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::moonshot;
6//!
7//! let client = moonshot::Client::new("YOUR_API_KEY");
8//!
9//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
10//! ```
11
12use crate::{
13    agent::AgentBuilder,
14    completion::{self, CompletionError, CompletionRequest},
15    extractor::ExtractorBuilder,
16    json_utils,
17    providers::openai,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23// ================================================================
24// Main Moonshot Client
25// ================================================================
26const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
27
28#[derive(Clone)]
29pub struct Client {
30    base_url: String,
31    http_client: reqwest::Client,
32}
33
34impl Client {
35    /// Create a new Moonshot client with the given API key.
36    pub fn new(api_key: &str) -> Self {
37        Self::from_url(api_key, MOONSHOT_API_BASE_URL)
38    }
39
40    /// Create a new Moonshot client with the given API key and base API URL.
41    pub fn from_url(api_key: &str, base_url: &str) -> Self {
42        Self {
43            base_url: base_url.to_string(),
44            http_client: reqwest::Client::builder()
45                .default_headers({
46                    let mut headers = reqwest::header::HeaderMap::new();
47                    headers.insert(
48                        "Authorization",
49                        format!("Bearer {}", api_key)
50                            .parse()
51                            .expect("Bearer token should parse"),
52                    );
53                    headers
54                })
55                .build()
56                .expect("Moonshot reqwest client should build"),
57        }
58    }
59
60    /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
61    /// Panics if the environment variable is not set.
62    pub fn from_env() -> Self {
63        let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
64        Self::new(&api_key)
65    }
66
67    fn post(&self, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69        self.http_client.post(url)
70    }
71
72    /// Create a completion model with the given name.
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::moonshot::{Client, self};
77    ///
78    /// // Initialize the Moonshot client
79    /// let moonshot = Client::new("your-moonshot-api-key");
80    ///
81    /// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT);
82    /// ```
83    pub fn completion_model(&self, model: &str) -> CompletionModel {
84        CompletionModel::new(self.clone(), model)
85    }
86
87    /// Create an agent builder with the given completion model.
88    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
89        AgentBuilder::new(self.completion_model(model))
90    }
91
92    /// Create an extractor builder with the given completion model.
93    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
94        &self,
95        model: &str,
96    ) -> ExtractorBuilder<T, CompletionModel> {
97        ExtractorBuilder::new(self.completion_model(model))
98    }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103    error: MoonshotError,
104}
105
106#[derive(Debug, Deserialize)]
107struct MoonshotError {
108    message: String,
109}
110
111#[derive(Debug, Deserialize)]
112#[serde(untagged)]
113enum ApiResponse<T> {
114    Ok(T),
115    Err(ApiErrorResponse),
116}
117
118// ================================================================
119// Moonshot Completion API
120// ================================================================
121pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
122
123#[derive(Clone)]
124pub struct CompletionModel {
125    client: Client,
126    pub model: String,
127}
128
129impl CompletionModel {
130    pub fn new(client: Client, model: &str) -> Self {
131        Self {
132            client,
133            model: model.to_string(),
134        }
135    }
136}
137
138impl completion::CompletionModel for CompletionModel {
139    type Response = openai::CompletionResponse;
140
141    #[cfg_attr(feature = "worker", worker::send)]
142    async fn completion(
143        &self,
144        completion_request: CompletionRequest,
145    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
146        // Add preamble to chat history (if available)
147        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
148            Some(preamble) => vec![openai::Message::system(preamble)],
149            None => vec![],
150        };
151
152        // Convert prompt to user message
153        let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
154
155        // Convert existing chat history
156        let chat_history: Vec<openai::Message> = completion_request
157            .chat_history
158            .into_iter()
159            .map(|message| message.try_into())
160            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
161            .into_iter()
162            .flatten()
163            .collect();
164
165        // Combine all messages into a single history
166        full_history.extend(chat_history);
167        full_history.extend(prompt);
168
169        let request = if completion_request.tools.is_empty() {
170            json!({
171                "model": self.model,
172                "messages": full_history,
173                "temperature": completion_request.temperature,
174            })
175        } else {
176            json!({
177                "model": self.model,
178                "messages": full_history,
179                "temperature": completion_request.temperature,
180                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
181                "tool_choice": "auto",
182            })
183        };
184
185        let response = self
186            .client
187            .post("/chat/completions")
188            .json(
189                &if let Some(params) = completion_request.additional_params {
190                    json_utils::merge(request, params)
191                } else {
192                    request
193                },
194            )
195            .send()
196            .await?;
197
198        if response.status().is_success() {
199            let t = response.text().await?;
200            tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
201
202            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
203                ApiResponse::Ok(response) => {
204                    tracing::info!(target: "rig",
205                        "MoonShot completion token usage: {:?}",
206                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
207                    );
208                    response.try_into()
209                }
210                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
211            }
212        } else {
213            Err(CompletionError::ProviderError(response.text().await?))
214        }
215    }
216}