rig/providers/
moonshot.rs1use crate::json_utils::merge;
13use crate::message;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::{StreamingCompletionModel, StreamingResult};
16use crate::{
17 agent::AgentBuilder,
18 completion::{self, CompletionError, CompletionRequest},
19 extractor::ExtractorBuilder,
20 json_utils,
21 providers::openai,
22};
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::{json, Value};
26
27const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
31
32#[derive(Clone)]
33pub struct Client {
34 base_url: String,
35 http_client: reqwest::Client,
36}
37
38impl Client {
39 pub fn new(api_key: &str) -> Self {
41 Self::from_url(api_key, MOONSHOT_API_BASE_URL)
42 }
43
44 pub fn from_url(api_key: &str, base_url: &str) -> Self {
46 Self {
47 base_url: base_url.to_string(),
48 http_client: reqwest::Client::builder()
49 .default_headers({
50 let mut headers = reqwest::header::HeaderMap::new();
51 headers.insert(
52 "Authorization",
53 format!("Bearer {}", api_key)
54 .parse()
55 .expect("Bearer token should parse"),
56 );
57 headers
58 })
59 .build()
60 .expect("Moonshot reqwest client should build"),
61 }
62 }
63
64 pub fn from_env() -> Self {
67 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
68 Self::new(&api_key)
69 }
70
71 fn post(&self, path: &str) -> reqwest::RequestBuilder {
72 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
73 self.http_client.post(url)
74 }
75
76 pub fn completion_model(&self, model: &str) -> CompletionModel {
88 CompletionModel::new(self.clone(), model)
89 }
90
91 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
93 AgentBuilder::new(self.completion_model(model))
94 }
95
96 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
98 &self,
99 model: &str,
100 ) -> ExtractorBuilder<T, CompletionModel> {
101 ExtractorBuilder::new(self.completion_model(model))
102 }
103}
104
105#[derive(Debug, Deserialize)]
106struct ApiErrorResponse {
107 error: MoonshotError,
108}
109
110#[derive(Debug, Deserialize)]
111struct MoonshotError {
112 message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118 Ok(T),
119 Err(ApiErrorResponse),
120}
121
122pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
126
127#[derive(Clone)]
128pub struct CompletionModel {
129 client: Client,
130 pub model: String,
131}
132
133impl CompletionModel {
134 pub fn new(client: Client, model: &str) -> Self {
135 Self {
136 client,
137 model: model.to_string(),
138 }
139 }
140
141 fn create_completion_request(
142 &self,
143 completion_request: CompletionRequest,
144 ) -> Result<Value, CompletionError> {
145 let mut partial_history = vec![];
147 if let Some(docs) = completion_request.normalized_documents() {
148 partial_history.push(docs);
149 }
150 partial_history.extend(completion_request.chat_history);
151
152 let mut full_history: Vec<openai::Message> = completion_request
154 .preamble
155 .map_or_else(Vec::new, |preamble| {
156 vec![openai::Message::system(&preamble)]
157 });
158
159 full_history.extend(
161 partial_history
162 .into_iter()
163 .map(message::Message::try_into)
164 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
165 .into_iter()
166 .flatten()
167 .collect::<Vec<_>>(),
168 );
169
170 let request = if completion_request.tools.is_empty() {
171 json!({
172 "model": self.model,
173 "messages": full_history,
174 "temperature": completion_request.temperature,
175 })
176 } else {
177 json!({
178 "model": self.model,
179 "messages": full_history,
180 "temperature": completion_request.temperature,
181 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
182 "tool_choice": "auto",
183 })
184 };
185
186 let request = if let Some(params) = completion_request.additional_params {
187 json_utils::merge(request, params)
188 } else {
189 request
190 };
191
192 Ok(request)
193 }
194}
195
196impl completion::CompletionModel for CompletionModel {
197 type Response = openai::CompletionResponse;
198
199 #[cfg_attr(feature = "worker", worker::send)]
200 async fn completion(
201 &self,
202 completion_request: CompletionRequest,
203 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
204 let request = self.create_completion_request(completion_request)?;
205
206 let response = self
207 .client
208 .post("/chat/completions")
209 .json(&request)
210 .send()
211 .await?;
212
213 if response.status().is_success() {
214 let t = response.text().await?;
215 tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
216
217 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
218 ApiResponse::Ok(response) => {
219 tracing::info!(target: "rig",
220 "MoonShot completion token usage: {:?}",
221 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
222 );
223 response.try_into()
224 }
225 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
226 }
227 } else {
228 Err(CompletionError::ProviderError(response.text().await?))
229 }
230 }
231}
232
233impl StreamingCompletionModel for CompletionModel {
234 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
235 let mut request = self.create_completion_request(request)?;
236
237 request = merge(request, json!({"stream": true}));
238
239 let builder = self.client.post("/chat/completions").json(&request);
240
241 send_compatible_streaming_request(builder).await
242 }
243}