rig/providers/
moonshot.rs1use crate::json_utils::merge;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16 agent::AgentBuilder,
17 completion::{self, CompletionError, CompletionRequest},
18 extractor::ExtractorBuilder,
19 json_utils,
20 providers::openai,
21};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::{json, Value};
25
26const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
30
31#[derive(Clone)]
32pub struct Client {
33 base_url: String,
34 http_client: reqwest::Client,
35}
36
37impl Client {
38 pub fn new(api_key: &str) -> Self {
40 Self::from_url(api_key, MOONSHOT_API_BASE_URL)
41 }
42
43 pub fn from_url(api_key: &str, base_url: &str) -> Self {
45 Self {
46 base_url: base_url.to_string(),
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 "Authorization",
52 format!("Bearer {}", api_key)
53 .parse()
54 .expect("Bearer token should parse"),
55 );
56 headers
57 })
58 .build()
59 .expect("Moonshot reqwest client should build"),
60 }
61 }
62
63 pub fn from_env() -> Self {
66 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
67 Self::new(&api_key)
68 }
69
70 fn post(&self, path: &str) -> reqwest::RequestBuilder {
71 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
72 self.http_client.post(url)
73 }
74
75 pub fn completion_model(&self, model: &str) -> CompletionModel {
87 CompletionModel::new(self.clone(), model)
88 }
89
90 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
92 AgentBuilder::new(self.completion_model(model))
93 }
94
95 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
97 &self,
98 model: &str,
99 ) -> ExtractorBuilder<T, CompletionModel> {
100 ExtractorBuilder::new(self.completion_model(model))
101 }
102}
103
104#[derive(Debug, Deserialize)]
105struct ApiErrorResponse {
106 error: MoonshotError,
107}
108
109#[derive(Debug, Deserialize)]
110struct MoonshotError {
111 message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117 Ok(T),
118 Err(ApiErrorResponse),
119}
120
121pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
125
126#[derive(Clone)]
127pub struct CompletionModel {
128 client: Client,
129 pub model: String,
130}
131
132impl CompletionModel {
133 pub fn new(client: Client, model: &str) -> Self {
134 Self {
135 client,
136 model: model.to_string(),
137 }
138 }
139
140 fn create_completion_request(
141 &self,
142 completion_request: CompletionRequest,
143 ) -> Result<Value, CompletionError> {
144 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
146 Some(preamble) => vec![openai::Message::system(preamble)],
147 None => vec![],
148 };
149
150 let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
152
153 let chat_history: Vec<openai::Message> = completion_request
155 .chat_history
156 .into_iter()
157 .map(|message| message.try_into())
158 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
159 .into_iter()
160 .flatten()
161 .collect();
162
163 full_history.extend(chat_history);
165 full_history.extend(prompt);
166
167 let request = if completion_request.tools.is_empty() {
168 json!({
169 "model": self.model,
170 "messages": full_history,
171 "temperature": completion_request.temperature,
172 })
173 } else {
174 json!({
175 "model": self.model,
176 "messages": full_history,
177 "temperature": completion_request.temperature,
178 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
179 "tool_choice": "auto",
180 })
181 };
182
183 let request = if let Some(params) = completion_request.additional_params {
184 json_utils::merge(request, params)
185 } else {
186 request
187 };
188
189 Ok(request)
190 }
191}
192
193impl completion::CompletionModel for CompletionModel {
194 type Response = openai::CompletionResponse;
195
196 #[cfg_attr(feature = "worker", worker::send)]
197 async fn completion(
198 &self,
199 completion_request: CompletionRequest,
200 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
201 let request = self.create_completion_request(completion_request)?;
202
203 let response = self
204 .client
205 .post("/chat/completions")
206 .json(&request)
207 .send()
208 .await?;
209
210 if response.status().is_success() {
211 let t = response.text().await?;
212 tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
213
214 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
215 ApiResponse::Ok(response) => {
216 tracing::info!(target: "rig",
217 "MoonShot completion token usage: {:?}",
218 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
219 );
220 response.try_into()
221 }
222 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
223 }
224 } else {
225 Err(CompletionError::ProviderError(response.text().await?))
226 }
227 }
228}
229
230impl StreamingCompletionModel for CompletionModel {
231 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
232 let mut request = self.create_completion_request(request)?;
233
234 request = merge(request, json!({"stream": true}));
235
236 let builder = self.client.post("/chat/completions").json(&request);
237
238 send_compatible_streaming_request(builder).await
239 }
240}