rig/providers/
moonshot.rs1use crate::{
13 agent::AgentBuilder,
14 completion::{self, CompletionError, CompletionRequest},
15 extractor::ExtractorBuilder,
16 json_utils,
17 providers::openai,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";
27
28#[derive(Clone)]
29pub struct Client {
30 base_url: String,
31 http_client: reqwest::Client,
32}
33
34impl Client {
35 pub fn new(api_key: &str) -> Self {
37 Self::from_url(api_key, MOONSHOT_API_BASE_URL)
38 }
39
40 pub fn from_url(api_key: &str, base_url: &str) -> Self {
42 Self {
43 base_url: base_url.to_string(),
44 http_client: reqwest::Client::builder()
45 .default_headers({
46 let mut headers = reqwest::header::HeaderMap::new();
47 headers.insert(
48 "Authorization",
49 format!("Bearer {}", api_key)
50 .parse()
51 .expect("Bearer token should parse"),
52 );
53 headers
54 })
55 .build()
56 .expect("Moonshot reqwest client should build"),
57 }
58 }
59
60 pub fn from_env() -> Self {
63 let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
64 Self::new(&api_key)
65 }
66
67 fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url)
70 }
71
72 pub fn completion_model(&self, model: &str) -> CompletionModel {
84 CompletionModel::new(self.clone(), model)
85 }
86
87 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
89 AgentBuilder::new(self.completion_model(model))
90 }
91
92 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
94 &self,
95 model: &str,
96 ) -> ExtractorBuilder<T, CompletionModel> {
97 ExtractorBuilder::new(self.completion_model(model))
98 }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103 error: MoonshotError,
104}
105
106#[derive(Debug, Deserialize)]
107struct MoonshotError {
108 message: String,
109}
110
111#[derive(Debug, Deserialize)]
112#[serde(untagged)]
113enum ApiResponse<T> {
114 Ok(T),
115 Err(ApiErrorResponse),
116}
117
118pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";
122
123#[derive(Clone)]
124pub struct CompletionModel {
125 client: Client,
126 pub model: String,
127}
128
129impl CompletionModel {
130 pub fn new(client: Client, model: &str) -> Self {
131 Self {
132 client,
133 model: model.to_string(),
134 }
135 }
136}
137
138impl completion::CompletionModel for CompletionModel {
139 type Response = openai::CompletionResponse;
140
141 #[cfg_attr(feature = "worker", worker::send)]
142 async fn completion(
143 &self,
144 completion_request: CompletionRequest,
145 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
146 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
148 Some(preamble) => vec![openai::Message::system(preamble)],
149 None => vec![],
150 };
151
152 let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
154
155 let chat_history: Vec<openai::Message> = completion_request
157 .chat_history
158 .into_iter()
159 .map(|message| message.try_into())
160 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
161 .into_iter()
162 .flatten()
163 .collect();
164
165 full_history.extend(chat_history);
167 full_history.extend(prompt);
168
169 let request = if completion_request.tools.is_empty() {
170 json!({
171 "model": self.model,
172 "messages": full_history,
173 "temperature": completion_request.temperature,
174 })
175 } else {
176 json!({
177 "model": self.model,
178 "messages": full_history,
179 "temperature": completion_request.temperature,
180 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
181 "tool_choice": "auto",
182 })
183 };
184
185 let response = self
186 .client
187 .post("/chat/completions")
188 .json(
189 &if let Some(params) = completion_request.additional_params {
190 json_utils::merge(request, params)
191 } else {
192 request
193 },
194 )
195 .send()
196 .await?;
197
198 if response.status().is_success() {
199 let t = response.text().await?;
200 tracing::debug!(target: "rig", "MoonShot completion error: {}", t);
201
202 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
203 ApiResponse::Ok(response) => {
204 tracing::info!(target: "rig",
205 "MoonShot completion token usage: {:?}",
206 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
207 );
208 response.try_into()
209 }
210 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
211 }
212 } else {
213 Err(CompletionError::ProviderError(response.text().await?))
214 }
215 }
216}