Skip to main content

codetether_agent/provider/
stepfun.rs

1//! StepFun provider implementation (direct API, not via OpenRouter)
2//!
3//! StepFun models: step-1-8k, step-1-32k, step-1-128k, step-1-256k, step-3.5-flash
4
5use super::{
6    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7    Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17    api_key: String,
18    client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("StepFunProvider")
24            .field("api_key", &"<REDACTED>")
25            .field("api_key_len", &self.api_key.len())
26            .field("client", &"<reqwest::Client>")
27            .finish()
28    }
29}
30
31impl StepFunProvider {
32    pub fn new(api_key: String) -> Result<Self> {
33        tracing::debug!(
34            provider = "stepfun",
35            api_key_len = api_key.len(),
36            "Creating StepFun provider"
37        );
38        Ok(Self {
39            api_key,
40            client: reqwest::Client::new(),
41        })
42    }
43
44    /// Validate that the API key is non-empty
45    fn validate_api_key(&self) -> Result<()> {
46        if self.api_key.is_empty() {
47            anyhow::bail!("StepFun API key is empty");
48        }
49        if self.api_key.len() < 10 {
50            tracing::warn!(provider = "stepfun", "API key seems unusually short");
51        }
52        Ok(())
53    }
54}
55
56// ============== Request Types ==============
57
58#[derive(Debug, Serialize)]
59struct ChatRequest {
60    model: String,
61    messages: Vec<ChatMessage>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    tools: Option<Vec<ChatTool>>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    temperature: Option<f32>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    max_tokens: Option<usize>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74    role: String,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    content: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    tool_calls: Option<Vec<ToolCall>>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85    r#type: String,
86    function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91    name: String,
92    description: String,
93    parameters: serde_json::Value,
94}
95
96// ============== Response Types ==============
97
98#[derive(Debug, Deserialize)]
99struct ChatResponse {
100    id: String,
101    choices: Vec<ChatChoice>,
102    usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107    index: usize,
108    message: ChatResponseMessage,
109    finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114    role: String,
115    #[serde(default)]
116    content: Option<String>,
117    #[serde(default)]
118    tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123    id: String,
124    r#type: String,
125    function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130    name: String,
131    arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136    prompt_tokens: usize,
137    completion_tokens: usize,
138    total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143    error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148    message: String,
149    #[serde(default)]
150    code: Option<String>,
151}
152
153// ============== Streaming Types ==============
154
155#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157    choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162    delta: StreamDelta,
163    finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168    #[serde(default)]
169    content: Option<String>,
170    #[serde(default)]
171    tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176    #[allow(dead_code)]
177    index: usize,
178    #[serde(default)]
179    id: Option<String>,
180    #[serde(default)]
181    function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186    #[serde(default)]
187    name: Option<String>,
188    #[serde(default)]
189    arguments: Option<String>,
190}
191
192impl StepFunProvider {
193    fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194        let mut result = Vec::new();
195
196        for msg in messages {
197            match msg.role {
198                Role::System => {
199                    let content = msg
200                        .content
201                        .iter()
202                        .filter_map(|p| match p {
203                            ContentPart::Text { text } => Some(text.clone()),
204                            _ => None,
205                        })
206                        .collect::<Vec<_>>()
207                        .join("\n");
208                    result.push(ChatMessage {
209                        role: "system".to_string(),
210                        content: Some(content),
211                        tool_calls: None,
212                        tool_call_id: None,
213                    });
214                }
215                Role::User => {
216                    let content = msg
217                        .content
218                        .iter()
219                        .filter_map(|p| match p {
220                            ContentPart::Text { text } => Some(text.clone()),
221                            _ => None,
222                        })
223                        .collect::<Vec<_>>()
224                        .join("\n");
225                    result.push(ChatMessage {
226                        role: "user".to_string(),
227                        content: Some(content),
228                        tool_calls: None,
229                        tool_call_id: None,
230                    });
231                }
232                Role::Assistant => {
233                    let content = msg
234                        .content
235                        .iter()
236                        .filter_map(|p| match p {
237                            ContentPart::Text { text } => Some(text.clone()),
238                            _ => None,
239                        })
240                        .collect::<Vec<_>>()
241                        .join("\n");
242
243                    let tool_calls: Vec<ToolCall> = msg
244                        .content
245                        .iter()
246                        .filter_map(|p| match p {
247                            ContentPart::ToolCall {
248                                id,
249                                name,
250                                arguments,
251                            } => Some(ToolCall {
252                                id: id.clone(),
253                                r#type: "function".to_string(),
254                                function: ToolCallFunction {
255                                    name: name.clone(),
256                                    arguments: arguments.clone(),
257                                },
258                            }),
259                            _ => None,
260                        })
261                        .collect();
262
263                    result.push(ChatMessage {
264                        role: "assistant".to_string(),
265                        // StepFun requires content field to be present (even if empty) when tool_calls exist
266                        content: if content.is_empty() && !tool_calls.is_empty() {
267                            Some(String::new())
268                        } else if content.is_empty() {
269                            None
270                        } else {
271                            Some(content)
272                        },
273                        tool_calls: if tool_calls.is_empty() {
274                            None
275                        } else {
276                            Some(tool_calls)
277                        },
278                        tool_call_id: None,
279                    });
280                }
281                Role::Tool => {
282                    for part in &msg.content {
283                        if let ContentPart::ToolResult {
284                            tool_call_id,
285                            content,
286                        } = part
287                        {
288                            result.push(ChatMessage {
289                                role: "tool".to_string(),
290                                content: Some(content.clone()),
291                                tool_calls: None,
292                                tool_call_id: Some(tool_call_id.clone()),
293                            });
294                        }
295                    }
296                }
297            }
298        }
299
300        result
301    }
302
303    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
304        tools
305            .iter()
306            .map(|t| ChatTool {
307                r#type: "function".to_string(),
308                function: ChatFunction {
309                    name: t.name.clone(),
310                    description: t.description.clone(),
311                    parameters: t.parameters.clone(),
312                },
313            })
314            .collect()
315    }
316}
317
318#[async_trait]
319impl Provider for StepFunProvider {
320    fn name(&self) -> &str {
321        "stepfun"
322    }
323
324    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
325        Ok(vec![
326            ModelInfo {
327                id: "step-3.5-flash".to_string(),
328                name: "Step 3.5 Flash".to_string(),
329                provider: "stepfun".to_string(),
330                context_window: 128_000,
331                max_output_tokens: Some(8192),
332                supports_vision: false,
333                supports_tools: true,
334                supports_streaming: true,
335                input_cost_per_million: Some(0.0), // Free tier
336                output_cost_per_million: Some(0.0),
337            },
338            ModelInfo {
339                id: "step-1-8k".to_string(),
340                name: "Step 1 8K".to_string(),
341                provider: "stepfun".to_string(),
342                context_window: 8_000,
343                max_output_tokens: Some(4096),
344                supports_vision: false,
345                supports_tools: true,
346                supports_streaming: true,
347                input_cost_per_million: Some(0.5),
348                output_cost_per_million: Some(1.5),
349            },
350            ModelInfo {
351                id: "step-1-32k".to_string(),
352                name: "Step 1 32K".to_string(),
353                provider: "stepfun".to_string(),
354                context_window: 32_000,
355                max_output_tokens: Some(8192),
356                supports_vision: false,
357                supports_tools: true,
358                supports_streaming: true,
359                input_cost_per_million: Some(1.0),
360                output_cost_per_million: Some(3.0),
361            },
362            ModelInfo {
363                id: "step-1-128k".to_string(),
364                name: "Step 1 128K".to_string(),
365                provider: "stepfun".to_string(),
366                context_window: 128_000,
367                max_output_tokens: Some(8192),
368                supports_vision: false,
369                supports_tools: true,
370                supports_streaming: true,
371                input_cost_per_million: Some(2.0),
372                output_cost_per_million: Some(6.0),
373            },
374            ModelInfo {
375                id: "step-1v-8k".to_string(),
376                name: "Step 1 Vision 8K".to_string(),
377                provider: "stepfun".to_string(),
378                context_window: 8_000,
379                max_output_tokens: Some(4096),
380                supports_vision: true,
381                supports_tools: true,
382                supports_streaming: true,
383                input_cost_per_million: Some(1.0),
384                output_cost_per_million: Some(3.0),
385            },
386        ])
387    }
388
389    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
390        tracing::debug!(
391            provider = "stepfun",
392            model = %request.model,
393            message_count = request.messages.len(),
394            tool_count = request.tools.len(),
395            "Starting completion request"
396        );
397
398        // Validate API key before making request
399        self.validate_api_key()?;
400
401        let messages = self.convert_messages(&request.messages);
402        let tools = self.convert_tools(&request.tools);
403
404        let chat_request = ChatRequest {
405            model: request.model.clone(),
406            messages,
407            tools: if tools.is_empty() { None } else { Some(tools) },
408            temperature: request.temperature,
409            max_tokens: request.max_tokens,
410            stream: Some(false),
411        };
412
413        // Debug: log the request being sent
414        if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
415            tracing::debug!("StepFun request: {}", json_str);
416        }
417
418        let response = self
419            .client
420            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
421            .header("Authorization", format!("Bearer {}", self.api_key))
422            .header("Content-Type", "application/json")
423            .json(&chat_request)
424            .send()
425            .await?;
426
427        let status = response.status();
428        let body = response.text().await?;
429
430        if !status.is_success() {
431            if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
432                // Log error code if present for debugging
433                if let Some(ref code) = err.error.code {
434                    tracing::error!(error_code = %code, "StepFun API error code");
435                }
436                anyhow::bail!("StepFun API error: {}", err.error.message);
437            }
438            anyhow::bail!("StepFun API error ({}): {}", status, body);
439        }
440
441        let chat_response: ChatResponse = serde_json::from_str(&body)
442            .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
443
444        // Log response metadata for debugging
445        tracing::debug!(
446            response_id = %chat_response.id,
447            "Received StepFun response"
448        );
449
450        let choice = chat_response
451            .choices
452            .first()
453            .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
454
455        // Log choice index and role for debugging
456        tracing::debug!(
457            choice_index = choice.index,
458            message_role = %choice.message.role,
459            "Processing StepFun choice"
460        );
461
462        // Log usage for tracing
463        tracing::info!(
464            prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
465            completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
466            finish_reason = ?choice.finish_reason,
467            "StepFun completion received"
468        );
469
470        let mut content = Vec::new();
471        let mut has_tool_calls = false;
472
473        if let Some(text) = &choice.message.content {
474            if !text.is_empty() {
475                content.push(ContentPart::Text { text: text.clone() });
476            }
477        }
478
479        if let Some(tool_calls) = &choice.message.tool_calls {
480            has_tool_calls = !tool_calls.is_empty();
481            for tc in tool_calls {
482                content.push(ContentPart::ToolCall {
483                    id: tc.id.clone(),
484                    name: tc.function.name.clone(),
485                    arguments: tc.function.arguments.clone(),
486                });
487            }
488        }
489
490        let finish_reason = if has_tool_calls {
491            FinishReason::ToolCalls
492        } else {
493            match choice.finish_reason.as_deref() {
494                Some("stop") => FinishReason::Stop,
495                Some("length") => FinishReason::Length,
496                Some("tool_calls") => FinishReason::ToolCalls,
497                _ => FinishReason::Stop,
498            }
499        };
500
501        Ok(CompletionResponse {
502            message: Message {
503                role: Role::Assistant,
504                content,
505            },
506            usage: Usage {
507                prompt_tokens: chat_response
508                    .usage
509                    .as_ref()
510                    .map(|u| u.prompt_tokens)
511                    .unwrap_or(0),
512                completion_tokens: chat_response
513                    .usage
514                    .as_ref()
515                    .map(|u| u.completion_tokens)
516                    .unwrap_or(0),
517                total_tokens: chat_response
518                    .usage
519                    .as_ref()
520                    .map(|u| u.total_tokens)
521                    .unwrap_or(0),
522                ..Default::default()
523            },
524            finish_reason,
525        })
526    }
527
528    async fn complete_stream(
529        &self,
530        request: CompletionRequest,
531    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
532        tracing::debug!(
533            provider = "stepfun",
534            model = %request.model,
535            message_count = request.messages.len(),
536            tool_count = request.tools.len(),
537            "Starting streaming completion request"
538        );
539
540        self.validate_api_key()?;
541
542        let messages = self.convert_messages(&request.messages);
543        let tools = self.convert_tools(&request.tools);
544
545        let chat_request = ChatRequest {
546            model: request.model.clone(),
547            messages,
548            tools: if tools.is_empty() { None } else { Some(tools) },
549            temperature: request.temperature,
550            max_tokens: request.max_tokens,
551            stream: Some(true),
552        };
553
554        let response = self
555            .client
556            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
557            .header("Authorization", format!("Bearer {}", self.api_key))
558            .header("Content-Type", "application/json")
559            .json(&chat_request)
560            .send()
561            .await?;
562
563        if !response.status().is_success() {
564            let status = response.status();
565            let body = response.text().await?;
566            anyhow::bail!("StepFun API error ({}): {}", status, body);
567        }
568
569        let stream = response
570            .bytes_stream()
571            .map(|result| match result {
572                Ok(bytes) => {
573                    let text = String::from_utf8_lossy(&bytes);
574                    let mut chunks = Vec::new();
575
576                    for line in text.lines() {
577                        if let Some(data) = line.strip_prefix("data: ") {
578                            if data.trim() == "[DONE]" {
579                                chunks.push(StreamChunk::Done { usage: None });
580                                continue;
581                            }
582
583                            if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
584                                if let Some(choice) = chunk.choices.first() {
585                                    if let Some(content) = &choice.delta.content {
586                                        chunks.push(StreamChunk::Text(content.clone()));
587                                    }
588
589                                    if let Some(tool_calls) = &choice.delta.tool_calls {
590                                        for tc in tool_calls {
591                                            if let Some(id) = &tc.id {
592                                                if let Some(func) = &tc.function {
593                                                    if let Some(name) = &func.name {
594                                                        chunks.push(StreamChunk::ToolCallStart {
595                                                            id: id.clone(),
596                                                            name: name.clone(),
597                                                        });
598                                                    }
599                                                }
600                                            }
601                                            if let Some(func) = &tc.function {
602                                                if let Some(args) = &func.arguments {
603                                                    if !args.is_empty() {
604                                                        chunks.push(StreamChunk::ToolCallDelta {
605                                                            id: tc.id.clone().unwrap_or_default(),
606                                                            arguments_delta: args.clone(),
607                                                        });
608                                                    }
609                                                }
610                                            }
611                                        }
612                                    }
613
614                                    if choice.finish_reason.is_some() {
615                                        chunks.push(StreamChunk::Done { usage: None });
616                                    }
617                                }
618                            }
619                        }
620                    }
621
622                    if chunks.is_empty() {
623                        StreamChunk::Text(String::new())
624                    } else if chunks.len() == 1 {
625                        chunks.pop().unwrap()
626                    } else {
627                        // Return first chunk, others are lost (simplified)
628                        chunks.remove(0)
629                    }
630                }
631                Err(e) => StreamChunk::Error(e.to_string()),
632            })
633            .boxed();
634
635        Ok(stream)
636    }
637}