Skip to main content

codetether_agent/provider/
stepfun.rs

1//! StepFun provider implementation (direct API, not via OpenRouter)
2//!
3//! StepFun models: step-1-8k, step-1-32k, step-1-128k, step-1-256k, step-3.5-flash
4
5use super::{
6    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7    Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17    api_key: String,
18    client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("StepFunProvider")
24            .field("api_key", &"<REDACTED>")
25            .field("api_key_len", &self.api_key.len())
26            .field("client", &"<reqwest::Client>")
27            .finish()
28    }
29}
30
31impl StepFunProvider {
32    pub fn new(api_key: String) -> Result<Self> {
33        tracing::debug!(
34            provider = "stepfun",
35            api_key_len = api_key.len(),
36            "Creating StepFun provider"
37        );
38        Ok(Self {
39            api_key,
40            client: reqwest::Client::new(),
41        })
42    }
43    
44    /// Validate that the API key is non-empty
45    fn validate_api_key(&self) -> Result<()> {
46        if self.api_key.is_empty() {
47            anyhow::bail!("StepFun API key is empty");
48        }
49        if self.api_key.len() < 10 {
50            tracing::warn!(provider = "stepfun", "API key seems unusually short");
51        }
52        Ok(())
53    }
54}
55
56// ============== Request Types ==============
57
58#[derive(Debug, Serialize)]
59struct ChatRequest {
60    model: String,
61    messages: Vec<ChatMessage>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    tools: Option<Vec<ChatTool>>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    temperature: Option<f32>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    max_tokens: Option<usize>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74    role: String,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    content: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    tool_calls: Option<Vec<ToolCall>>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85    r#type: String,
86    function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91    name: String,
92    description: String,
93    parameters: serde_json::Value,
94}
95
96// ============== Response Types ==============
97
98#[derive(Debug, Deserialize)]
99struct ChatResponse {
100    id: String,
101    choices: Vec<ChatChoice>,
102    usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107    index: usize,
108    message: ChatResponseMessage,
109    finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114    role: String,
115    #[serde(default)]
116    content: Option<String>,
117    #[serde(default)]
118    tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123    id: String,
124    r#type: String,
125    function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130    name: String,
131    arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136    prompt_tokens: usize,
137    completion_tokens: usize,
138    total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143    error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148    message: String,
149    #[serde(default)]
150    code: Option<String>,
151}
152
153// ============== Streaming Types ==============
154
155#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157    choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162    delta: StreamDelta,
163    finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168    #[serde(default)]
169    content: Option<String>,
170    #[serde(default)]
171    tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176    #[allow(dead_code)]
177    index: usize,
178    #[serde(default)]
179    id: Option<String>,
180    #[serde(default)]
181    function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186    #[serde(default)]
187    name: Option<String>,
188    #[serde(default)]
189    arguments: Option<String>,
190}
191
192impl StepFunProvider {
193    fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194        let mut result = Vec::new();
195
196        for msg in messages {
197            match msg.role {
198                Role::System => {
199                    let content = msg
200                        .content
201                        .iter()
202                        .filter_map(|p| match p {
203                            ContentPart::Text { text } => Some(text.clone()),
204                            _ => None,
205                        })
206                        .collect::<Vec<_>>()
207                        .join("\n");
208                    result.push(ChatMessage {
209                        role: "system".to_string(),
210                        content: Some(content),
211                        tool_calls: None,
212                        tool_call_id: None,
213                    });
214                }
215                Role::User => {
216                    let content = msg
217                        .content
218                        .iter()
219                        .filter_map(|p| match p {
220                            ContentPart::Text { text } => Some(text.clone()),
221                            _ => None,
222                        })
223                        .collect::<Vec<_>>()
224                        .join("\n");
225                    result.push(ChatMessage {
226                        role: "user".to_string(),
227                        content: Some(content),
228                        tool_calls: None,
229                        tool_call_id: None,
230                    });
231                }
232                Role::Assistant => {
233                    let content = msg
234                        .content
235                        .iter()
236                        .filter_map(|p| match p {
237                            ContentPart::Text { text } => Some(text.clone()),
238                            _ => None,
239                        })
240                        .collect::<Vec<_>>()
241                        .join("\n");
242
243                    let tool_calls: Vec<ToolCall> = msg
244                        .content
245                        .iter()
246                        .filter_map(|p| match p {
247                            ContentPart::ToolCall { id, name, arguments } => Some(ToolCall {
248                                id: id.clone(),
249                                r#type: "function".to_string(),
250                                function: ToolCallFunction {
251                                    name: name.clone(),
252                                    arguments: arguments.clone(),
253                                },
254                            }),
255                            _ => None,
256                        })
257                        .collect();
258
259                    result.push(ChatMessage {
260                        role: "assistant".to_string(),
261                        // StepFun requires content field to be present (even if empty) when tool_calls exist
262                        content: if content.is_empty() && !tool_calls.is_empty() { 
263                            Some(String::new()) 
264                        } else if content.is_empty() { 
265                            None 
266                        } else { 
267                            Some(content) 
268                        },
269                        tool_calls: if tool_calls.is_empty() {
270                            None
271                        } else {
272                            Some(tool_calls)
273                        },
274                        tool_call_id: None,
275                    });
276                }
277                Role::Tool => {
278                    for part in &msg.content {
279                        if let ContentPart::ToolResult {
280                            tool_call_id,
281                            content,
282                        } = part
283                        {
284                            result.push(ChatMessage {
285                                role: "tool".to_string(),
286                                content: Some(content.clone()),
287                                tool_calls: None,
288                                tool_call_id: Some(tool_call_id.clone()),
289                            });
290                        }
291                    }
292                }
293            }
294        }
295
296        result
297    }
298
299    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
300        tools
301            .iter()
302            .map(|t| ChatTool {
303                r#type: "function".to_string(),
304                function: ChatFunction {
305                    name: t.name.clone(),
306                    description: t.description.clone(),
307                    parameters: t.parameters.clone(),
308                },
309            })
310            .collect()
311    }
312}
313
314#[async_trait]
315impl Provider for StepFunProvider {
316    fn name(&self) -> &str {
317        "stepfun"
318    }
319
320    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
321        Ok(vec![
322            ModelInfo {
323                id: "step-3.5-flash".to_string(),
324                name: "Step 3.5 Flash".to_string(),
325                provider: "stepfun".to_string(),
326                context_window: 128_000,
327                max_output_tokens: Some(8192),
328                supports_vision: false,
329                supports_tools: true,
330                supports_streaming: true,
331                input_cost_per_million: Some(0.0), // Free tier
332                output_cost_per_million: Some(0.0),
333            },
334            ModelInfo {
335                id: "step-1-8k".to_string(),
336                name: "Step 1 8K".to_string(),
337                provider: "stepfun".to_string(),
338                context_window: 8_000,
339                max_output_tokens: Some(4096),
340                supports_vision: false,
341                supports_tools: true,
342                supports_streaming: true,
343                input_cost_per_million: Some(0.5),
344                output_cost_per_million: Some(1.5),
345            },
346            ModelInfo {
347                id: "step-1-32k".to_string(),
348                name: "Step 1 32K".to_string(),
349                provider: "stepfun".to_string(),
350                context_window: 32_000,
351                max_output_tokens: Some(8192),
352                supports_vision: false,
353                supports_tools: true,
354                supports_streaming: true,
355                input_cost_per_million: Some(1.0),
356                output_cost_per_million: Some(3.0),
357            },
358            ModelInfo {
359                id: "step-1-128k".to_string(),
360                name: "Step 1 128K".to_string(),
361                provider: "stepfun".to_string(),
362                context_window: 128_000,
363                max_output_tokens: Some(8192),
364                supports_vision: false,
365                supports_tools: true,
366                supports_streaming: true,
367                input_cost_per_million: Some(2.0),
368                output_cost_per_million: Some(6.0),
369            },
370            ModelInfo {
371                id: "step-1v-8k".to_string(),
372                name: "Step 1 Vision 8K".to_string(),
373                provider: "stepfun".to_string(),
374                context_window: 8_000,
375                max_output_tokens: Some(4096),
376                supports_vision: true,
377                supports_tools: true,
378                supports_streaming: true,
379                input_cost_per_million: Some(1.0),
380                output_cost_per_million: Some(3.0),
381            },
382        ])
383    }
384
385    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
386        tracing::debug!(
387            provider = "stepfun",
388            model = %request.model,
389            message_count = request.messages.len(),
390            tool_count = request.tools.len(),
391            "Starting completion request"
392        );
393        
394        // Validate API key before making request
395        self.validate_api_key()?;
396        
397        let messages = self.convert_messages(&request.messages);
398        let tools = self.convert_tools(&request.tools);
399
400        let chat_request = ChatRequest {
401            model: request.model.clone(),
402            messages,
403            tools: if tools.is_empty() { None } else { Some(tools) },
404            temperature: request.temperature,
405            max_tokens: request.max_tokens,
406            stream: Some(false),
407        };
408
409        // Debug: log the request being sent
410        if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
411            tracing::debug!("StepFun request: {}", json_str);
412        }
413
414        let response = self
415            .client
416            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
417            .header("Authorization", format!("Bearer {}", self.api_key))
418            .header("Content-Type", "application/json")
419            .json(&chat_request)
420            .send()
421            .await?;
422
423        let status = response.status();
424        let body = response.text().await?;
425
426        if !status.is_success() {
427            if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
428                // Log error code if present for debugging
429                if let Some(ref code) = err.error.code {
430                    tracing::error!(error_code = %code, "StepFun API error code");
431                }
432                anyhow::bail!("StepFun API error: {}", err.error.message);
433            }
434            anyhow::bail!("StepFun API error ({}): {}", status, body);
435        }
436
437        let chat_response: ChatResponse = serde_json::from_str(&body)
438            .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
439
440        // Log response metadata for debugging
441        tracing::debug!(
442            response_id = %chat_response.id,
443            "Received StepFun response"
444        );
445
446        let choice = chat_response
447            .choices
448            .first()
449            .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
450        
451        // Log choice index and role for debugging
452        tracing::debug!(
453            choice_index = choice.index,
454            message_role = %choice.message.role,
455            "Processing StepFun choice"
456        );
457
458        // Log usage for tracing
459        tracing::info!(
460            prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
461            completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
462            finish_reason = ?choice.finish_reason,
463            "StepFun completion received"
464        );
465
466        let mut content = Vec::new();
467        let mut has_tool_calls = false;
468
469        if let Some(text) = &choice.message.content {
470            if !text.is_empty() {
471                content.push(ContentPart::Text { text: text.clone() });
472            }
473        }
474
475        if let Some(tool_calls) = &choice.message.tool_calls {
476            has_tool_calls = !tool_calls.is_empty();
477            for tc in tool_calls {
478                content.push(ContentPart::ToolCall {
479                    id: tc.id.clone(),
480                    name: tc.function.name.clone(),
481                    arguments: tc.function.arguments.clone(),
482                });
483            }
484        }
485
486        let finish_reason = if has_tool_calls {
487            FinishReason::ToolCalls
488        } else {
489            match choice.finish_reason.as_deref() {
490                Some("stop") => FinishReason::Stop,
491                Some("length") => FinishReason::Length,
492                Some("tool_calls") => FinishReason::ToolCalls,
493                _ => FinishReason::Stop,
494            }
495        };
496
497        Ok(CompletionResponse {
498            message: Message {
499                role: Role::Assistant,
500                content,
501            },
502            usage: Usage {
503                prompt_tokens: chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
504                completion_tokens: chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
505                total_tokens: chat_response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
506                ..Default::default()
507            },
508            finish_reason,
509        })
510    }
511
512    async fn complete_stream(
513        &self,
514        request: CompletionRequest,
515    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
516        tracing::debug!(
517            provider = "stepfun",
518            model = %request.model,
519            message_count = request.messages.len(),
520            tool_count = request.tools.len(),
521            "Starting streaming completion request"
522        );
523        
524        self.validate_api_key()?;
525        
526        let messages = self.convert_messages(&request.messages);
527        let tools = self.convert_tools(&request.tools);
528
529        let chat_request = ChatRequest {
530            model: request.model.clone(),
531            messages,
532            tools: if tools.is_empty() { None } else { Some(tools) },
533            temperature: request.temperature,
534            max_tokens: request.max_tokens,
535            stream: Some(true),
536        };
537
538        let response = self
539            .client
540            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
541            .header("Authorization", format!("Bearer {}", self.api_key))
542            .header("Content-Type", "application/json")
543            .json(&chat_request)
544            .send()
545            .await?;
546
547        if !response.status().is_success() {
548            let status = response.status();
549            let body = response.text().await?;
550            anyhow::bail!("StepFun API error ({}): {}", status, body);
551        }
552
553        let stream = response
554            .bytes_stream()
555            .map(|result| match result {
556                Ok(bytes) => {
557                    let text = String::from_utf8_lossy(&bytes);
558                    let mut chunks = Vec::new();
559
560                    for line in text.lines() {
561                        if let Some(data) = line.strip_prefix("data: ") {
562                            if data.trim() == "[DONE]" {
563                                chunks.push(StreamChunk::Done { usage: None });
564                                continue;
565                            }
566
567                            if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
568                                if let Some(choice) = chunk.choices.first() {
569                                    if let Some(content) = &choice.delta.content {
570                                        chunks.push(StreamChunk::Text(content.clone()));
571                                    }
572
573                                    if let Some(tool_calls) = &choice.delta.tool_calls {
574                                        for tc in tool_calls {
575                                            if let Some(id) = &tc.id {
576                                                if let Some(func) = &tc.function {
577                                                    if let Some(name) = &func.name {
578                                                        chunks.push(StreamChunk::ToolCallStart {
579                                                            id: id.clone(),
580                                                            name: name.clone(),
581                                                        });
582                                                    }
583                                                }
584                                            }
585                                            if let Some(func) = &tc.function {
586                                                if let Some(args) = &func.arguments {
587                                                    if !args.is_empty() {
588                                                        chunks.push(StreamChunk::ToolCallDelta {
589                                                            id: tc.id.clone().unwrap_or_default(),
590                                                            arguments_delta: args.clone(),
591                                                        });
592                                                    }
593                                                }
594                                            }
595                                        }
596                                    }
597
598                                    if choice.finish_reason.is_some() {
599                                        chunks.push(StreamChunk::Done { usage: None });
600                                    }
601                                }
602                            }
603                        }
604                    }
605
606                    if chunks.is_empty() {
607                        StreamChunk::Text(String::new())
608                    } else if chunks.len() == 1 {
609                        chunks.pop().unwrap()
610                    } else {
611                        // Return first chunk, others are lost (simplified)
612                        chunks.remove(0)
613                    }
614                }
615                Err(e) => StreamChunk::Error(e.to_string()),
616            })
617            .boxed();
618
619        Ok(stream)
620    }
621}