1use super::{
6 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7 Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17 api_key: String,
18 client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("StepFunProvider")
24 .field("api_key", &"<REDACTED>")
25 .field("api_key_len", &self.api_key.len())
26 .field("client", &"<reqwest::Client>")
27 .finish()
28 }
29}
30
31impl StepFunProvider {
32 pub fn new(api_key: String) -> Result<Self> {
33 tracing::debug!(
34 provider = "stepfun",
35 api_key_len = api_key.len(),
36 "Creating StepFun provider"
37 );
38 Ok(Self {
39 api_key,
40 client: reqwest::Client::new(),
41 })
42 }
43
44 fn validate_api_key(&self) -> Result<()> {
46 if self.api_key.is_empty() {
47 anyhow::bail!("StepFun API key is empty");
48 }
49 if self.api_key.len() < 10 {
50 tracing::warn!(provider = "stepfun", "API key seems unusually short");
51 }
52 Ok(())
53 }
54}
55
56#[derive(Debug, Serialize)]
59struct ChatRequest {
60 model: String,
61 messages: Vec<ChatMessage>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 tools: Option<Vec<ChatTool>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 temperature: Option<f32>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 max_tokens: Option<usize>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74 role: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 content: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 tool_calls: Option<Vec<ToolCall>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85 r#type: String,
86 function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91 name: String,
92 description: String,
93 parameters: serde_json::Value,
94}
95
96#[derive(Debug, Deserialize)]
99struct ChatResponse {
100 id: String,
101 choices: Vec<ChatChoice>,
102 usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107 index: usize,
108 message: ChatResponseMessage,
109 finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114 role: String,
115 #[serde(default)]
116 content: Option<String>,
117 #[serde(default)]
118 tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123 id: String,
124 r#type: String,
125 function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130 name: String,
131 arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136 prompt_tokens: usize,
137 completion_tokens: usize,
138 total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143 error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148 message: String,
149 #[serde(default)]
150 code: Option<String>,
151}
152
153#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157 choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162 delta: StreamDelta,
163 finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168 #[serde(default)]
169 content: Option<String>,
170 #[serde(default)]
171 tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176 #[allow(dead_code)]
177 index: usize,
178 #[serde(default)]
179 id: Option<String>,
180 #[serde(default)]
181 function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186 #[serde(default)]
187 name: Option<String>,
188 #[serde(default)]
189 arguments: Option<String>,
190}
191
192impl StepFunProvider {
193 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194 let mut result = Vec::new();
195
196 for msg in messages {
197 match msg.role {
198 Role::System => {
199 let content = msg
200 .content
201 .iter()
202 .filter_map(|p| match p {
203 ContentPart::Text { text } => Some(text.clone()),
204 _ => None,
205 })
206 .collect::<Vec<_>>()
207 .join("\n");
208 result.push(ChatMessage {
209 role: "system".to_string(),
210 content: Some(content),
211 tool_calls: None,
212 tool_call_id: None,
213 });
214 }
215 Role::User => {
216 let content = msg
217 .content
218 .iter()
219 .filter_map(|p| match p {
220 ContentPart::Text { text } => Some(text.clone()),
221 _ => None,
222 })
223 .collect::<Vec<_>>()
224 .join("\n");
225 result.push(ChatMessage {
226 role: "user".to_string(),
227 content: Some(content),
228 tool_calls: None,
229 tool_call_id: None,
230 });
231 }
232 Role::Assistant => {
233 let content = msg
234 .content
235 .iter()
236 .filter_map(|p| match p {
237 ContentPart::Text { text } => Some(text.clone()),
238 _ => None,
239 })
240 .collect::<Vec<_>>()
241 .join("\n");
242
243 let tool_calls: Vec<ToolCall> = msg
244 .content
245 .iter()
246 .filter_map(|p| match p {
247 ContentPart::ToolCall {
248 id,
249 name,
250 arguments,
251 ..
252 } => Some(ToolCall {
253 id: id.clone(),
254 r#type: "function".to_string(),
255 function: ToolCallFunction {
256 name: name.clone(),
257 arguments: arguments.clone(),
258 },
259 }),
260 _ => None,
261 })
262 .collect();
263
264 result.push(ChatMessage {
265 role: "assistant".to_string(),
266 content: if content.is_empty() && !tool_calls.is_empty() {
268 Some(String::new())
269 } else if content.is_empty() {
270 None
271 } else {
272 Some(content)
273 },
274 tool_calls: if tool_calls.is_empty() {
275 None
276 } else {
277 Some(tool_calls)
278 },
279 tool_call_id: None,
280 });
281 }
282 Role::Tool => {
283 for part in &msg.content {
284 if let ContentPart::ToolResult {
285 tool_call_id,
286 content,
287 } = part
288 {
289 result.push(ChatMessage {
290 role: "tool".to_string(),
291 content: Some(content.clone()),
292 tool_calls: None,
293 tool_call_id: Some(tool_call_id.clone()),
294 });
295 }
296 }
297 }
298 }
299 }
300
301 result
302 }
303
304 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
305 tools
306 .iter()
307 .map(|t| ChatTool {
308 r#type: "function".to_string(),
309 function: ChatFunction {
310 name: t.name.clone(),
311 description: t.description.clone(),
312 parameters: t.parameters.clone(),
313 },
314 })
315 .collect()
316 }
317}
318
319#[async_trait]
320impl Provider for StepFunProvider {
321 fn name(&self) -> &str {
322 "stepfun"
323 }
324
325 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
326 Ok(vec![
327 ModelInfo {
328 id: "step-3.5-flash".to_string(),
329 name: "Step 3.5 Flash".to_string(),
330 provider: "stepfun".to_string(),
331 context_window: 128_000,
332 max_output_tokens: Some(8192),
333 supports_vision: false,
334 supports_tools: true,
335 supports_streaming: true,
336 input_cost_per_million: Some(0.0), output_cost_per_million: Some(0.0),
338 },
339 ModelInfo {
340 id: "step-1-8k".to_string(),
341 name: "Step 1 8K".to_string(),
342 provider: "stepfun".to_string(),
343 context_window: 8_000,
344 max_output_tokens: Some(4096),
345 supports_vision: false,
346 supports_tools: true,
347 supports_streaming: true,
348 input_cost_per_million: Some(0.5),
349 output_cost_per_million: Some(1.5),
350 },
351 ModelInfo {
352 id: "step-1-32k".to_string(),
353 name: "Step 1 32K".to_string(),
354 provider: "stepfun".to_string(),
355 context_window: 32_000,
356 max_output_tokens: Some(8192),
357 supports_vision: false,
358 supports_tools: true,
359 supports_streaming: true,
360 input_cost_per_million: Some(1.0),
361 output_cost_per_million: Some(3.0),
362 },
363 ModelInfo {
364 id: "step-1-128k".to_string(),
365 name: "Step 1 128K".to_string(),
366 provider: "stepfun".to_string(),
367 context_window: 128_000,
368 max_output_tokens: Some(8192),
369 supports_vision: false,
370 supports_tools: true,
371 supports_streaming: true,
372 input_cost_per_million: Some(2.0),
373 output_cost_per_million: Some(6.0),
374 },
375 ModelInfo {
376 id: "step-1v-8k".to_string(),
377 name: "Step 1 Vision 8K".to_string(),
378 provider: "stepfun".to_string(),
379 context_window: 8_000,
380 max_output_tokens: Some(4096),
381 supports_vision: true,
382 supports_tools: true,
383 supports_streaming: true,
384 input_cost_per_million: Some(1.0),
385 output_cost_per_million: Some(3.0),
386 },
387 ])
388 }
389
390 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
391 tracing::debug!(
392 provider = "stepfun",
393 model = %request.model,
394 message_count = request.messages.len(),
395 tool_count = request.tools.len(),
396 "Starting completion request"
397 );
398
399 self.validate_api_key()?;
401
402 let messages = self.convert_messages(&request.messages);
403 let tools = self.convert_tools(&request.tools);
404
405 let chat_request = ChatRequest {
406 model: request.model.clone(),
407 messages,
408 tools: if tools.is_empty() { None } else { Some(tools) },
409 temperature: request.temperature,
410 max_tokens: request.max_tokens,
411 stream: Some(false),
412 };
413
414 if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
416 tracing::debug!("StepFun request: {}", json_str);
417 }
418
419 let response = self
420 .client
421 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
422 .header("Authorization", format!("Bearer {}", self.api_key))
423 .header("Content-Type", "application/json")
424 .json(&chat_request)
425 .send()
426 .await?;
427
428 let status = response.status();
429 let body = response.text().await?;
430
431 if !status.is_success() {
432 if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
433 if let Some(ref code) = err.error.code {
435 tracing::error!(error_code = %code, "StepFun API error code");
436 }
437 anyhow::bail!("StepFun API error: {}", err.error.message);
438 }
439 anyhow::bail!("StepFun API error ({}): {}", status, body);
440 }
441
442 let chat_response: ChatResponse = serde_json::from_str(&body)
443 .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
444
445 tracing::debug!(
447 response_id = %chat_response.id,
448 "Received StepFun response"
449 );
450
451 let choice = chat_response
452 .choices
453 .first()
454 .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
455
456 tracing::debug!(
458 choice_index = choice.index,
459 message_role = %choice.message.role,
460 "Processing StepFun choice"
461 );
462
463 tracing::info!(
465 prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
466 completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
467 finish_reason = ?choice.finish_reason,
468 "StepFun completion received"
469 );
470
471 let mut content = Vec::new();
472 let mut has_tool_calls = false;
473
474 if let Some(text) = &choice.message.content {
475 if !text.is_empty() {
476 content.push(ContentPart::Text { text: text.clone() });
477 }
478 }
479
480 if let Some(tool_calls) = &choice.message.tool_calls {
481 has_tool_calls = !tool_calls.is_empty();
482 for tc in tool_calls {
483 content.push(ContentPart::ToolCall {
484 id: tc.id.clone(),
485 name: tc.function.name.clone(),
486 arguments: tc.function.arguments.clone(),
487 thought_signature: None,
488 });
489 }
490 }
491
492 let finish_reason = if has_tool_calls {
493 FinishReason::ToolCalls
494 } else {
495 match choice.finish_reason.as_deref() {
496 Some("stop") => FinishReason::Stop,
497 Some("length") => FinishReason::Length,
498 Some("tool_calls") => FinishReason::ToolCalls,
499 _ => FinishReason::Stop,
500 }
501 };
502
503 Ok(CompletionResponse {
504 message: Message {
505 role: Role::Assistant,
506 content,
507 },
508 usage: Usage {
509 prompt_tokens: chat_response
510 .usage
511 .as_ref()
512 .map(|u| u.prompt_tokens)
513 .unwrap_or(0),
514 completion_tokens: chat_response
515 .usage
516 .as_ref()
517 .map(|u| u.completion_tokens)
518 .unwrap_or(0),
519 total_tokens: chat_response
520 .usage
521 .as_ref()
522 .map(|u| u.total_tokens)
523 .unwrap_or(0),
524 ..Default::default()
525 },
526 finish_reason,
527 })
528 }
529
530 async fn complete_stream(
531 &self,
532 request: CompletionRequest,
533 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
534 tracing::debug!(
535 provider = "stepfun",
536 model = %request.model,
537 message_count = request.messages.len(),
538 tool_count = request.tools.len(),
539 "Starting streaming completion request"
540 );
541
542 self.validate_api_key()?;
543
544 let messages = self.convert_messages(&request.messages);
545 let tools = self.convert_tools(&request.tools);
546
547 let chat_request = ChatRequest {
548 model: request.model.clone(),
549 messages,
550 tools: if tools.is_empty() { None } else { Some(tools) },
551 temperature: request.temperature,
552 max_tokens: request.max_tokens,
553 stream: Some(true),
554 };
555
556 let response = self
557 .client
558 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
559 .header("Authorization", format!("Bearer {}", self.api_key))
560 .header("Content-Type", "application/json")
561 .json(&chat_request)
562 .send()
563 .await?;
564
565 if !response.status().is_success() {
566 let status = response.status();
567 let body = response.text().await?;
568 anyhow::bail!("StepFun API error ({}): {}", status, body);
569 }
570
571 let stream = response
572 .bytes_stream()
573 .map(|result| match result {
574 Ok(bytes) => {
575 let text = String::from_utf8_lossy(&bytes);
576 let mut chunks = Vec::new();
577
578 for line in text.lines() {
579 if let Some(data) = line.strip_prefix("data: ") {
580 if data.trim() == "[DONE]" {
581 chunks.push(StreamChunk::Done { usage: None });
582 continue;
583 }
584
585 if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
586 if let Some(choice) = chunk.choices.first() {
587 if let Some(content) = &choice.delta.content {
588 chunks.push(StreamChunk::Text(content.clone()));
589 }
590
591 if let Some(tool_calls) = &choice.delta.tool_calls {
592 for tc in tool_calls {
593 if let Some(id) = &tc.id {
594 if let Some(func) = &tc.function {
595 if let Some(name) = &func.name {
596 chunks.push(StreamChunk::ToolCallStart {
597 id: id.clone(),
598 name: name.clone(),
599 });
600 }
601 }
602 }
603 if let Some(func) = &tc.function {
604 if let Some(args) = &func.arguments {
605 if !args.is_empty() {
606 chunks.push(StreamChunk::ToolCallDelta {
607 id: tc.id.clone().unwrap_or_default(),
608 arguments_delta: args.clone(),
609 });
610 }
611 }
612 }
613 }
614 }
615
616 if choice.finish_reason.is_some() {
617 chunks.push(StreamChunk::Done { usage: None });
618 }
619 }
620 }
621 }
622 }
623
624 if chunks.is_empty() {
625 StreamChunk::Text(String::new())
626 } else if chunks.len() == 1 {
627 chunks.pop().unwrap()
628 } else {
629 chunks.remove(0)
631 }
632 }
633 Err(e) => StreamChunk::Error(e.to_string()),
634 })
635 .boxed();
636
637 Ok(stream)
638 }
639}