Skip to main content

codetether_agent/provider/
stepfun.rs

1//! StepFun provider implementation (direct API, not via OpenRouter)
2//!
3//! StepFun models: step-1-8k, step-1-32k, step-1-128k, step-1-256k, step-3.5-flash
4
5use super::{
6    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7    Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17    api_key: String,
18    client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("StepFunProvider")
24            .field("api_key", &"<REDACTED>")
25            .field("api_key_len", &self.api_key.len())
26            .field("client", &"<reqwest::Client>")
27            .finish()
28    }
29}
30
31impl StepFunProvider {
32    pub fn new(api_key: String) -> Result<Self> {
33        tracing::debug!(
34            provider = "stepfun",
35            api_key_len = api_key.len(),
36            "Creating StepFun provider"
37        );
38        Ok(Self {
39            api_key,
40            client: reqwest::Client::new(),
41        })
42    }
43
44    /// Validate that the API key is non-empty
45    fn validate_api_key(&self) -> Result<()> {
46        if self.api_key.is_empty() {
47            anyhow::bail!("StepFun API key is empty");
48        }
49        if self.api_key.len() < 10 {
50            tracing::warn!(provider = "stepfun", "API key seems unusually short");
51        }
52        Ok(())
53    }
54}
55
56// ============== Request Types ==============
57
58#[derive(Debug, Serialize)]
59struct ChatRequest {
60    model: String,
61    messages: Vec<ChatMessage>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    tools: Option<Vec<ChatTool>>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    temperature: Option<f32>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    max_tokens: Option<usize>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74    role: String,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    content: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    tool_calls: Option<Vec<ToolCall>>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85    r#type: String,
86    function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91    name: String,
92    description: String,
93    parameters: serde_json::Value,
94}
95
96// ============== Response Types ==============
97
98#[derive(Debug, Deserialize)]
99struct ChatResponse {
100    id: String,
101    choices: Vec<ChatChoice>,
102    usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107    index: usize,
108    message: ChatResponseMessage,
109    finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114    role: String,
115    #[serde(default)]
116    content: Option<String>,
117    #[serde(default)]
118    tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123    id: String,
124    r#type: String,
125    function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130    name: String,
131    arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136    prompt_tokens: usize,
137    completion_tokens: usize,
138    total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143    error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148    message: String,
149    #[serde(default)]
150    code: Option<String>,
151}
152
153// ============== Streaming Types ==============
154
155#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157    choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162    delta: StreamDelta,
163    finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168    #[serde(default)]
169    content: Option<String>,
170    #[serde(default)]
171    tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176    #[allow(dead_code)]
177    index: usize,
178    #[serde(default)]
179    id: Option<String>,
180    #[serde(default)]
181    function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186    #[serde(default)]
187    name: Option<String>,
188    #[serde(default)]
189    arguments: Option<String>,
190}
191
192impl StepFunProvider {
193    fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194        let mut result = Vec::new();
195
196        for msg in messages {
197            match msg.role {
198                Role::System => {
199                    let content = msg
200                        .content
201                        .iter()
202                        .filter_map(|p| match p {
203                            ContentPart::Text { text } => Some(text.clone()),
204                            _ => None,
205                        })
206                        .collect::<Vec<_>>()
207                        .join("\n");
208                    result.push(ChatMessage {
209                        role: "system".to_string(),
210                        content: Some(content),
211                        tool_calls: None,
212                        tool_call_id: None,
213                    });
214                }
215                Role::User => {
216                    let content = msg
217                        .content
218                        .iter()
219                        .filter_map(|p| match p {
220                            ContentPart::Text { text } => Some(text.clone()),
221                            _ => None,
222                        })
223                        .collect::<Vec<_>>()
224                        .join("\n");
225                    result.push(ChatMessage {
226                        role: "user".to_string(),
227                        content: Some(content),
228                        tool_calls: None,
229                        tool_call_id: None,
230                    });
231                }
232                Role::Assistant => {
233                    let content = msg
234                        .content
235                        .iter()
236                        .filter_map(|p| match p {
237                            ContentPart::Text { text } => Some(text.clone()),
238                            _ => None,
239                        })
240                        .collect::<Vec<_>>()
241                        .join("\n");
242
243                    let tool_calls: Vec<ToolCall> = msg
244                        .content
245                        .iter()
246                        .filter_map(|p| match p {
247                            ContentPart::ToolCall {
248                                id,
249                                name,
250                                arguments,
251                                ..
252                            } => Some(ToolCall {
253                                id: id.clone(),
254                                r#type: "function".to_string(),
255                                function: ToolCallFunction {
256                                    name: name.clone(),
257                                    arguments: arguments.clone(),
258                                },
259                            }),
260                            _ => None,
261                        })
262                        .collect();
263
264                    result.push(ChatMessage {
265                        role: "assistant".to_string(),
266                        // StepFun requires content field to be present (even if empty) when tool_calls exist
267                        content: if content.is_empty() && !tool_calls.is_empty() {
268                            Some(String::new())
269                        } else if content.is_empty() {
270                            None
271                        } else {
272                            Some(content)
273                        },
274                        tool_calls: if tool_calls.is_empty() {
275                            None
276                        } else {
277                            Some(tool_calls)
278                        },
279                        tool_call_id: None,
280                    });
281                }
282                Role::Tool => {
283                    for part in &msg.content {
284                        if let ContentPart::ToolResult {
285                            tool_call_id,
286                            content,
287                        } = part
288                        {
289                            result.push(ChatMessage {
290                                role: "tool".to_string(),
291                                content: Some(content.clone()),
292                                tool_calls: None,
293                                tool_call_id: Some(tool_call_id.clone()),
294                            });
295                        }
296                    }
297                }
298            }
299        }
300
301        result
302    }
303
304    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
305        tools
306            .iter()
307            .map(|t| ChatTool {
308                r#type: "function".to_string(),
309                function: ChatFunction {
310                    name: t.name.clone(),
311                    description: t.description.clone(),
312                    parameters: t.parameters.clone(),
313                },
314            })
315            .collect()
316    }
317}
318
319#[async_trait]
320impl Provider for StepFunProvider {
321    fn name(&self) -> &str {
322        "stepfun"
323    }
324
325    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
326        Ok(vec![
327            ModelInfo {
328                id: "step-3.5-flash".to_string(),
329                name: "Step 3.5 Flash".to_string(),
330                provider: "stepfun".to_string(),
331                context_window: 128_000,
332                max_output_tokens: Some(8192),
333                supports_vision: false,
334                supports_tools: true,
335                supports_streaming: true,
336                input_cost_per_million: Some(0.0), // Free tier
337                output_cost_per_million: Some(0.0),
338            },
339            ModelInfo {
340                id: "step-1-8k".to_string(),
341                name: "Step 1 8K".to_string(),
342                provider: "stepfun".to_string(),
343                context_window: 8_000,
344                max_output_tokens: Some(4096),
345                supports_vision: false,
346                supports_tools: true,
347                supports_streaming: true,
348                input_cost_per_million: Some(0.5),
349                output_cost_per_million: Some(1.5),
350            },
351            ModelInfo {
352                id: "step-1-32k".to_string(),
353                name: "Step 1 32K".to_string(),
354                provider: "stepfun".to_string(),
355                context_window: 32_000,
356                max_output_tokens: Some(8192),
357                supports_vision: false,
358                supports_tools: true,
359                supports_streaming: true,
360                input_cost_per_million: Some(1.0),
361                output_cost_per_million: Some(3.0),
362            },
363            ModelInfo {
364                id: "step-1-128k".to_string(),
365                name: "Step 1 128K".to_string(),
366                provider: "stepfun".to_string(),
367                context_window: 128_000,
368                max_output_tokens: Some(8192),
369                supports_vision: false,
370                supports_tools: true,
371                supports_streaming: true,
372                input_cost_per_million: Some(2.0),
373                output_cost_per_million: Some(6.0),
374            },
375            ModelInfo {
376                id: "step-1v-8k".to_string(),
377                name: "Step 1 Vision 8K".to_string(),
378                provider: "stepfun".to_string(),
379                context_window: 8_000,
380                max_output_tokens: Some(4096),
381                supports_vision: true,
382                supports_tools: true,
383                supports_streaming: true,
384                input_cost_per_million: Some(1.0),
385                output_cost_per_million: Some(3.0),
386            },
387        ])
388    }
389
390    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
391        tracing::debug!(
392            provider = "stepfun",
393            model = %request.model,
394            message_count = request.messages.len(),
395            tool_count = request.tools.len(),
396            "Starting completion request"
397        );
398
399        // Validate API key before making request
400        self.validate_api_key()?;
401
402        let messages = self.convert_messages(&request.messages);
403        let tools = self.convert_tools(&request.tools);
404
405        let chat_request = ChatRequest {
406            model: request.model.clone(),
407            messages,
408            tools: if tools.is_empty() { None } else { Some(tools) },
409            temperature: request.temperature,
410            max_tokens: request.max_tokens,
411            stream: Some(false),
412        };
413
414        // Debug: log the request being sent
415        if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
416            tracing::debug!("StepFun request: {}", json_str);
417        }
418
419        let response = self
420            .client
421            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
422            .header("Authorization", format!("Bearer {}", self.api_key))
423            .header("Content-Type", "application/json")
424            .json(&chat_request)
425            .send()
426            .await?;
427
428        let status = response.status();
429        let body = response.text().await?;
430
431        if !status.is_success() {
432            if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
433                // Log error code if present for debugging
434                if let Some(ref code) = err.error.code {
435                    tracing::error!(error_code = %code, "StepFun API error code");
436                }
437                anyhow::bail!("StepFun API error: {}", err.error.message);
438            }
439            anyhow::bail!("StepFun API error ({}): {}", status, body);
440        }
441
442        let chat_response: ChatResponse = serde_json::from_str(&body)
443            .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
444
445        // Log response metadata for debugging
446        tracing::debug!(
447            response_id = %chat_response.id,
448            "Received StepFun response"
449        );
450
451        let choice = chat_response
452            .choices
453            .first()
454            .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
455
456        // Log choice index and role for debugging
457        tracing::debug!(
458            choice_index = choice.index,
459            message_role = %choice.message.role,
460            "Processing StepFun choice"
461        );
462
463        // Log usage for tracing
464        tracing::info!(
465            prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
466            completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
467            finish_reason = ?choice.finish_reason,
468            "StepFun completion received"
469        );
470
471        let mut content = Vec::new();
472        let mut has_tool_calls = false;
473
474        if let Some(text) = &choice.message.content {
475            if !text.is_empty() {
476                content.push(ContentPart::Text { text: text.clone() });
477            }
478        }
479
480        if let Some(tool_calls) = &choice.message.tool_calls {
481            has_tool_calls = !tool_calls.is_empty();
482            for tc in tool_calls {
483                content.push(ContentPart::ToolCall {
484                    id: tc.id.clone(),
485                    name: tc.function.name.clone(),
486                    arguments: tc.function.arguments.clone(),
487                    thought_signature: None,
488                });
489            }
490        }
491
492        let finish_reason = if has_tool_calls {
493            FinishReason::ToolCalls
494        } else {
495            match choice.finish_reason.as_deref() {
496                Some("stop") => FinishReason::Stop,
497                Some("length") => FinishReason::Length,
498                Some("tool_calls") => FinishReason::ToolCalls,
499                _ => FinishReason::Stop,
500            }
501        };
502
503        Ok(CompletionResponse {
504            message: Message {
505                role: Role::Assistant,
506                content,
507            },
508            usage: Usage {
509                prompt_tokens: chat_response
510                    .usage
511                    .as_ref()
512                    .map(|u| u.prompt_tokens)
513                    .unwrap_or(0),
514                completion_tokens: chat_response
515                    .usage
516                    .as_ref()
517                    .map(|u| u.completion_tokens)
518                    .unwrap_or(0),
519                total_tokens: chat_response
520                    .usage
521                    .as_ref()
522                    .map(|u| u.total_tokens)
523                    .unwrap_or(0),
524                ..Default::default()
525            },
526            finish_reason,
527        })
528    }
529
530    async fn complete_stream(
531        &self,
532        request: CompletionRequest,
533    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
534        tracing::debug!(
535            provider = "stepfun",
536            model = %request.model,
537            message_count = request.messages.len(),
538            tool_count = request.tools.len(),
539            "Starting streaming completion request"
540        );
541
542        self.validate_api_key()?;
543
544        let messages = self.convert_messages(&request.messages);
545        let tools = self.convert_tools(&request.tools);
546
547        let chat_request = ChatRequest {
548            model: request.model.clone(),
549            messages,
550            tools: if tools.is_empty() { None } else { Some(tools) },
551            temperature: request.temperature,
552            max_tokens: request.max_tokens,
553            stream: Some(true),
554        };
555
556        let response = self
557            .client
558            .post(format!("{}/chat/completions", STEPFUN_API_BASE))
559            .header("Authorization", format!("Bearer {}", self.api_key))
560            .header("Content-Type", "application/json")
561            .json(&chat_request)
562            .send()
563            .await?;
564
565        if !response.status().is_success() {
566            let status = response.status();
567            let body = response.text().await?;
568            anyhow::bail!("StepFun API error ({}): {}", status, body);
569        }
570
571        let stream = response
572            .bytes_stream()
573            .map(|result| match result {
574                Ok(bytes) => {
575                    let text = String::from_utf8_lossy(&bytes);
576                    let mut chunks = Vec::new();
577
578                    for line in text.lines() {
579                        if let Some(data) = line.strip_prefix("data: ") {
580                            if data.trim() == "[DONE]" {
581                                chunks.push(StreamChunk::Done { usage: None });
582                                continue;
583                            }
584
585                            if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
586                                if let Some(choice) = chunk.choices.first() {
587                                    if let Some(content) = &choice.delta.content {
588                                        chunks.push(StreamChunk::Text(content.clone()));
589                                    }
590
591                                    if let Some(tool_calls) = &choice.delta.tool_calls {
592                                        for tc in tool_calls {
593                                            if let Some(id) = &tc.id {
594                                                if let Some(func) = &tc.function {
595                                                    if let Some(name) = &func.name {
596                                                        chunks.push(StreamChunk::ToolCallStart {
597                                                            id: id.clone(),
598                                                            name: name.clone(),
599                                                        });
600                                                    }
601                                                }
602                                            }
603                                            if let Some(func) = &tc.function {
604                                                if let Some(args) = &func.arguments {
605                                                    if !args.is_empty() {
606                                                        chunks.push(StreamChunk::ToolCallDelta {
607                                                            id: tc.id.clone().unwrap_or_default(),
608                                                            arguments_delta: args.clone(),
609                                                        });
610                                                    }
611                                                }
612                                            }
613                                        }
614                                    }
615
616                                    if choice.finish_reason.is_some() {
617                                        chunks.push(StreamChunk::Done { usage: None });
618                                    }
619                                }
620                            }
621                        }
622                    }
623
624                    if chunks.is_empty() {
625                        StreamChunk::Text(String::new())
626                    } else if chunks.len() == 1 {
627                        chunks.pop().unwrap()
628                    } else {
629                        // Return first chunk, others are lost (simplified)
630                        chunks.remove(0)
631                    }
632                }
633                Err(e) => StreamChunk::Error(e.to_string()),
634            })
635            .boxed();
636
637        Ok(stream)
638    }
639}