Skip to main content

codetether_agent/provider/
moonshot.rs

1//! Moonshot AI provider implementation (direct API)
2//!
3//! For Kimi K2.5 and other Moonshot models via api.moonshot.ai
4
5use super::{
6    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7    Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::Deserialize;
13use serde_json::{json, Value};
14
15pub struct MoonshotProvider {
16    client: Client,
17    api_key: String,
18    base_url: String,
19}
20
21impl MoonshotProvider {
22    pub fn new(api_key: String) -> Result<Self> {
23        Ok(Self {
24            client: Client::new(),
25            api_key,
26            base_url: "https://api.moonshot.ai/v1".to_string(),
27        })
28    }
29
30    fn convert_messages(messages: &[Message]) -> Vec<Value> {
31        messages
32            .iter()
33            .map(|msg| {
34                let role = match msg.role {
35                    Role::System => "system",
36                    Role::User => "user",
37                    Role::Assistant => "assistant",
38                    Role::Tool => "tool",
39                };
40
41                match msg.role {
42                    Role::Tool => {
43                        if let Some(ContentPart::ToolResult { tool_call_id, content }) = msg.content.first() {
44                            json!({
45                                "role": "tool",
46                                "tool_call_id": tool_call_id,
47                                "content": content
48                            })
49                        } else {
50                            json!({"role": role, "content": ""})
51                        }
52                    }
53                    Role::Assistant => {
54                        let text: String = msg.content.iter()
55                            .filter_map(|p| match p {
56                                ContentPart::Text { text } => Some(text.clone()),
57                                _ => None,
58                            })
59                            .collect::<Vec<_>>()
60                            .join("");
61
62                        let tool_calls: Vec<Value> = msg.content.iter()
63                            .filter_map(|p| match p {
64                                ContentPart::ToolCall { id, name, arguments } => Some(json!({
65                                    "id": id,
66                                    "type": "function",
67                                    "function": {
68                                        "name": name,
69                                        "arguments": arguments
70                                    }
71                                })),
72                                _ => None,
73                            })
74                            .collect();
75
76                        if tool_calls.is_empty() {
77                            json!({"role": "assistant", "content": text})
78                        } else {
79                            // Moonshot requires reasoning_content for K2.5 thinking models
80                            // Include empty string when we don't have the original
81                            json!({
82                                "role": "assistant",
83                                "content": if text.is_empty() { "".to_string() } else { text },
84                                "reasoning_content": "",
85                                "tool_calls": tool_calls
86                            })
87                        }
88                    }
89                    _ => {
90                        let text: String = msg.content.iter()
91                            .filter_map(|p| match p {
92                                ContentPart::Text { text } => Some(text.clone()),
93                                _ => None,
94                            })
95                            .collect::<Vec<_>>()
96                            .join("\n");
97
98                        json!({"role": role, "content": text})
99                    }
100                }
101            })
102            .collect()
103    }
104
105    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
106        tools
107            .iter()
108            .map(|t| {
109                json!({
110                    "type": "function",
111                    "function": {
112                        "name": t.name,
113                        "description": t.description,
114                        "parameters": t.parameters
115                    }
116                })
117            })
118            .collect()
119    }
120}
121
122#[derive(Debug, Deserialize)]
123struct MoonshotResponse {
124    id: String,
125    model: String,
126    choices: Vec<MoonshotChoice>,
127    #[serde(default)]
128    usage: Option<MoonshotUsage>,
129}
130
131#[derive(Debug, Deserialize)]
132struct MoonshotChoice {
133    message: MoonshotMessage,
134    #[serde(default)]
135    finish_reason: Option<String>,
136}
137
138#[derive(Debug, Deserialize)]
139struct MoonshotMessage {
140    #[allow(dead_code)]
141    role: String,
142    #[serde(default)]
143    content: Option<String>,
144    #[serde(default)]
145    tool_calls: Option<Vec<MoonshotToolCall>>,
146    // Kimi K2.5 reasoning
147    #[serde(default)]
148    reasoning_content: Option<String>,
149}
150
151#[derive(Debug, Deserialize)]
152struct MoonshotToolCall {
153    id: String,
154    #[serde(rename = "type")]
155    call_type: String,
156    function: MoonshotFunction,
157}
158
159#[derive(Debug, Deserialize)]
160struct MoonshotFunction {
161    name: String,
162    arguments: String,
163}
164
165#[derive(Debug, Deserialize)]
166struct MoonshotUsage {
167    #[serde(default)]
168    prompt_tokens: usize,
169    #[serde(default)]
170    completion_tokens: usize,
171    #[serde(default)]
172    total_tokens: usize,
173}
174
175#[derive(Debug, Deserialize)]
176struct MoonshotError {
177    #[allow(dead_code)]
178    error: MoonshotErrorDetail,
179}
180
181#[derive(Debug, Deserialize)]
182struct MoonshotErrorDetail {
183    message: String,
184    #[serde(default, rename = "type")]
185    error_type: Option<String>,
186}
187
188#[async_trait]
189impl Provider for MoonshotProvider {
190    fn name(&self) -> &str {
191        "moonshotai"
192    }
193
194    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
195        Ok(vec![
196            ModelInfo {
197                id: "kimi-k2.5".to_string(),
198                name: "Kimi K2.5".to_string(),
199                provider: "moonshotai".to_string(),
200                context_window: 256_000,
201                max_output_tokens: Some(64_000),
202                supports_vision: true,
203                supports_tools: true,
204                supports_streaming: true,
205                input_cost_per_million: Some(0.56),  // ¥4/M tokens
206                output_cost_per_million: Some(2.8),  // ¥20/M tokens
207            },
208            ModelInfo {
209                id: "kimi-k2-thinking".to_string(),
210                name: "Kimi K2 Thinking".to_string(),
211                provider: "moonshotai".to_string(),
212                context_window: 128_000,
213                max_output_tokens: Some(64_000),
214                supports_vision: false,
215                supports_tools: true,
216                supports_streaming: true,
217                input_cost_per_million: Some(0.56),
218                output_cost_per_million: Some(2.8),
219            },
220            ModelInfo {
221                id: "kimi-latest".to_string(),
222                name: "Kimi Latest".to_string(),
223                provider: "moonshotai".to_string(),
224                context_window: 128_000,
225                max_output_tokens: Some(64_000),
226                supports_vision: false,
227                supports_tools: true,
228                supports_streaming: true,
229                input_cost_per_million: Some(0.42),  // Cheaper
230                output_cost_per_million: Some(1.68),
231            },
232        ])
233    }
234
235    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
236        let messages = Self::convert_messages(&request.messages);
237        let tools = Self::convert_tools(&request.tools);
238
239        // Kimi K2.5 requires specific temperatures:
240        // - temperature = 1.0 when thinking is enabled  
241        // - temperature = 0.6 when thinking is disabled
242        let temperature = if request.model.contains("k2") {
243            0.6  // We disable thinking for tool calling workflows
244        } else {
245            request.temperature.unwrap_or(0.7)
246        };
247
248        let mut body = json!({
249            "model": request.model,
250            "messages": messages,
251            "temperature": temperature,
252        });
253
254        // Disable thinking mode to avoid needing to track reasoning_content
255        // across message roundtrips (required for K2.5)
256        if request.model.contains("k2") {
257            body["thinking"] = json!({"type": "disabled"});
258        }
259
260        if !tools.is_empty() {
261            body["tools"] = json!(tools);
262        }
263        if let Some(max) = request.max_tokens {
264            body["max_tokens"] = json!(max);
265        }
266
267        tracing::debug!("Moonshot request to model {}", request.model);
268
269        let response = self
270            .client
271            .post(format!("{}/chat/completions", self.base_url))
272            .header("Authorization", format!("Bearer {}", self.api_key))
273            .header("Content-Type", "application/json")
274            .json(&body)
275            .send()
276            .await
277            .context("Failed to send request to Moonshot")?;
278
279        let status = response.status();
280        let text = response.text().await.context("Failed to read response")?;
281
282        if !status.is_success() {
283            if let Ok(err) = serde_json::from_str::<MoonshotError>(&text) {
284                anyhow::bail!("Moonshot API error: {} ({:?})", err.error.message, err.error.error_type);
285            }
286            anyhow::bail!("Moonshot API error: {} {}", status, text);
287        }
288
289        let response: MoonshotResponse = serde_json::from_str(&text)
290            .context(format!("Failed to parse Moonshot response: {}", &text[..text.len().min(200)]))?;
291
292        // Log response metadata for debugging
293        tracing::debug!(
294            response_id = %response.id,
295            model = %response.model,
296            "Received Moonshot response"
297        );
298
299        let choice = response.choices.first().ok_or_else(|| anyhow::anyhow!("No choices"))?;
300
301        // Log reasoning/thinking content if present (Kimi K2 models)
302        if let Some(ref reasoning) = choice.message.reasoning_content {
303            if !reasoning.is_empty() {
304                tracing::info!(
305                    reasoning_len = reasoning.len(),
306                    "Model reasoning/thinking content received"
307                );
308                tracing::debug!(
309                    reasoning = %reasoning,
310                    "Full model reasoning"
311                );
312            }
313        }
314
315        let mut content = Vec::new();
316        let mut has_tool_calls = false;
317
318        if let Some(text) = &choice.message.content {
319            if !text.is_empty() {
320                content.push(ContentPart::Text { text: text.clone() });
321            }
322        }
323
324        if let Some(tool_calls) = &choice.message.tool_calls {
325            has_tool_calls = !tool_calls.is_empty();
326            for tc in tool_calls {
327                // Log tool call details for debugging (uses role and call_type fields)
328                tracing::debug!(
329                    tool_call_id = %tc.id,
330                    call_type = %tc.call_type,
331                    function_name = %tc.function.name,
332                    "Processing tool call"
333                );
334                content.push(ContentPart::ToolCall {
335                    id: tc.id.clone(),
336                    name: tc.function.name.clone(),
337                    arguments: tc.function.arguments.clone(),
338                });
339            }
340        }
341
342        let finish_reason = if has_tool_calls {
343            FinishReason::ToolCalls
344        } else {
345            match choice.finish_reason.as_deref() {
346                Some("stop") => FinishReason::Stop,
347                Some("length") => FinishReason::Length,
348                Some("tool_calls") => FinishReason::ToolCalls,
349                _ => FinishReason::Stop,
350            }
351        };
352
353        Ok(CompletionResponse {
354            message: Message {
355                role: Role::Assistant,
356                content,
357            },
358            usage: Usage {
359                prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
360                completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
361                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
362                ..Default::default()
363            },
364            finish_reason,
365        })
366    }
367
368    async fn complete_stream(
369        &self,
370        request: CompletionRequest,
371    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
372        tracing::debug!(
373            provider = "moonshotai",
374            model = %request.model,
375            message_count = request.messages.len(),
376            "Starting streaming completion request (falling back to non-streaming)"
377        );
378        
379        // Fall back to non-streaming for now
380        let response = self.complete(request).await?;
381        let text = response.message.content.iter()
382            .filter_map(|p| match p {
383                ContentPart::Text { text } => Some(text.clone()),
384                _ => None,
385            })
386            .collect::<Vec<_>>()
387            .join("");
388        
389        Ok(Box::pin(futures::stream::once(async move {
390            StreamChunk::Text(text)
391        })))
392    }
393}