1use super::{
6 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7 Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17 api_key: String,
18 client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("StepFunProvider")
24 .field("api_key", &"<REDACTED>")
25 .field("api_key_len", &self.api_key.len())
26 .field("client", &"<reqwest::Client>")
27 .finish()
28 }
29}
30
31impl StepFunProvider {
32 pub fn new(api_key: String) -> Result<Self> {
33 tracing::debug!(
34 provider = "stepfun",
35 api_key_len = api_key.len(),
36 "Creating StepFun provider"
37 );
38 Ok(Self {
39 api_key,
40 client: reqwest::Client::new(),
41 })
42 }
43
44 fn validate_api_key(&self) -> Result<()> {
46 if self.api_key.is_empty() {
47 anyhow::bail!("StepFun API key is empty");
48 }
49 if self.api_key.len() < 10 {
50 tracing::warn!(provider = "stepfun", "API key seems unusually short");
51 }
52 Ok(())
53 }
54}
55
56#[derive(Debug, Serialize)]
59struct ChatRequest {
60 model: String,
61 messages: Vec<ChatMessage>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 tools: Option<Vec<ChatTool>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 temperature: Option<f32>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 max_tokens: Option<usize>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74 role: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 content: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 tool_calls: Option<Vec<ToolCall>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85 r#type: String,
86 function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91 name: String,
92 description: String,
93 parameters: serde_json::Value,
94}
95
96#[derive(Debug, Deserialize)]
99struct ChatResponse {
100 id: String,
101 choices: Vec<ChatChoice>,
102 usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107 index: usize,
108 message: ChatResponseMessage,
109 finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114 role: String,
115 #[serde(default)]
116 content: Option<String>,
117 #[serde(default)]
118 tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123 id: String,
124 r#type: String,
125 function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130 name: String,
131 arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136 prompt_tokens: usize,
137 completion_tokens: usize,
138 total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143 error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148 message: String,
149 #[serde(default)]
150 code: Option<String>,
151}
152
153#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157 choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162 delta: StreamDelta,
163 finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168 #[serde(default)]
169 content: Option<String>,
170 #[serde(default)]
171 tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176 #[allow(dead_code)]
177 index: usize,
178 #[serde(default)]
179 id: Option<String>,
180 #[serde(default)]
181 function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186 #[serde(default)]
187 name: Option<String>,
188 #[serde(default)]
189 arguments: Option<String>,
190}
191
192impl StepFunProvider {
193 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194 let mut result = Vec::new();
195
196 for msg in messages {
197 match msg.role {
198 Role::System => {
199 let content = msg
200 .content
201 .iter()
202 .filter_map(|p| match p {
203 ContentPart::Text { text } => Some(text.clone()),
204 _ => None,
205 })
206 .collect::<Vec<_>>()
207 .join("\n");
208 result.push(ChatMessage {
209 role: "system".to_string(),
210 content: Some(content),
211 tool_calls: None,
212 tool_call_id: None,
213 });
214 }
215 Role::User => {
216 let content = msg
217 .content
218 .iter()
219 .filter_map(|p| match p {
220 ContentPart::Text { text } => Some(text.clone()),
221 _ => None,
222 })
223 .collect::<Vec<_>>()
224 .join("\n");
225 result.push(ChatMessage {
226 role: "user".to_string(),
227 content: Some(content),
228 tool_calls: None,
229 tool_call_id: None,
230 });
231 }
232 Role::Assistant => {
233 let content = msg
234 .content
235 .iter()
236 .filter_map(|p| match p {
237 ContentPart::Text { text } => Some(text.clone()),
238 _ => None,
239 })
240 .collect::<Vec<_>>()
241 .join("\n");
242
243 let tool_calls: Vec<ToolCall> = msg
244 .content
245 .iter()
246 .filter_map(|p| match p {
247 ContentPart::ToolCall {
248 id,
249 name,
250 arguments,
251 } => Some(ToolCall {
252 id: id.clone(),
253 r#type: "function".to_string(),
254 function: ToolCallFunction {
255 name: name.clone(),
256 arguments: arguments.clone(),
257 },
258 }),
259 _ => None,
260 })
261 .collect();
262
263 result.push(ChatMessage {
264 role: "assistant".to_string(),
265 content: if content.is_empty() && !tool_calls.is_empty() {
267 Some(String::new())
268 } else if content.is_empty() {
269 None
270 } else {
271 Some(content)
272 },
273 tool_calls: if tool_calls.is_empty() {
274 None
275 } else {
276 Some(tool_calls)
277 },
278 tool_call_id: None,
279 });
280 }
281 Role::Tool => {
282 for part in &msg.content {
283 if let ContentPart::ToolResult {
284 tool_call_id,
285 content,
286 } = part
287 {
288 result.push(ChatMessage {
289 role: "tool".to_string(),
290 content: Some(content.clone()),
291 tool_calls: None,
292 tool_call_id: Some(tool_call_id.clone()),
293 });
294 }
295 }
296 }
297 }
298 }
299
300 result
301 }
302
303 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
304 tools
305 .iter()
306 .map(|t| ChatTool {
307 r#type: "function".to_string(),
308 function: ChatFunction {
309 name: t.name.clone(),
310 description: t.description.clone(),
311 parameters: t.parameters.clone(),
312 },
313 })
314 .collect()
315 }
316}
317
318#[async_trait]
319impl Provider for StepFunProvider {
320 fn name(&self) -> &str {
321 "stepfun"
322 }
323
324 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
325 Ok(vec![
326 ModelInfo {
327 id: "step-3.5-flash".to_string(),
328 name: "Step 3.5 Flash".to_string(),
329 provider: "stepfun".to_string(),
330 context_window: 128_000,
331 max_output_tokens: Some(8192),
332 supports_vision: false,
333 supports_tools: true,
334 supports_streaming: true,
335 input_cost_per_million: Some(0.0), output_cost_per_million: Some(0.0),
337 },
338 ModelInfo {
339 id: "step-1-8k".to_string(),
340 name: "Step 1 8K".to_string(),
341 provider: "stepfun".to_string(),
342 context_window: 8_000,
343 max_output_tokens: Some(4096),
344 supports_vision: false,
345 supports_tools: true,
346 supports_streaming: true,
347 input_cost_per_million: Some(0.5),
348 output_cost_per_million: Some(1.5),
349 },
350 ModelInfo {
351 id: "step-1-32k".to_string(),
352 name: "Step 1 32K".to_string(),
353 provider: "stepfun".to_string(),
354 context_window: 32_000,
355 max_output_tokens: Some(8192),
356 supports_vision: false,
357 supports_tools: true,
358 supports_streaming: true,
359 input_cost_per_million: Some(1.0),
360 output_cost_per_million: Some(3.0),
361 },
362 ModelInfo {
363 id: "step-1-128k".to_string(),
364 name: "Step 1 128K".to_string(),
365 provider: "stepfun".to_string(),
366 context_window: 128_000,
367 max_output_tokens: Some(8192),
368 supports_vision: false,
369 supports_tools: true,
370 supports_streaming: true,
371 input_cost_per_million: Some(2.0),
372 output_cost_per_million: Some(6.0),
373 },
374 ModelInfo {
375 id: "step-1v-8k".to_string(),
376 name: "Step 1 Vision 8K".to_string(),
377 provider: "stepfun".to_string(),
378 context_window: 8_000,
379 max_output_tokens: Some(4096),
380 supports_vision: true,
381 supports_tools: true,
382 supports_streaming: true,
383 input_cost_per_million: Some(1.0),
384 output_cost_per_million: Some(3.0),
385 },
386 ])
387 }
388
389 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
390 tracing::debug!(
391 provider = "stepfun",
392 model = %request.model,
393 message_count = request.messages.len(),
394 tool_count = request.tools.len(),
395 "Starting completion request"
396 );
397
398 self.validate_api_key()?;
400
401 let messages = self.convert_messages(&request.messages);
402 let tools = self.convert_tools(&request.tools);
403
404 let chat_request = ChatRequest {
405 model: request.model.clone(),
406 messages,
407 tools: if tools.is_empty() { None } else { Some(tools) },
408 temperature: request.temperature,
409 max_tokens: request.max_tokens,
410 stream: Some(false),
411 };
412
413 if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
415 tracing::debug!("StepFun request: {}", json_str);
416 }
417
418 let response = self
419 .client
420 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
421 .header("Authorization", format!("Bearer {}", self.api_key))
422 .header("Content-Type", "application/json")
423 .json(&chat_request)
424 .send()
425 .await?;
426
427 let status = response.status();
428 let body = response.text().await?;
429
430 if !status.is_success() {
431 if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
432 if let Some(ref code) = err.error.code {
434 tracing::error!(error_code = %code, "StepFun API error code");
435 }
436 anyhow::bail!("StepFun API error: {}", err.error.message);
437 }
438 anyhow::bail!("StepFun API error ({}): {}", status, body);
439 }
440
441 let chat_response: ChatResponse = serde_json::from_str(&body)
442 .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
443
444 tracing::debug!(
446 response_id = %chat_response.id,
447 "Received StepFun response"
448 );
449
450 let choice = chat_response
451 .choices
452 .first()
453 .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
454
455 tracing::debug!(
457 choice_index = choice.index,
458 message_role = %choice.message.role,
459 "Processing StepFun choice"
460 );
461
462 tracing::info!(
464 prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
465 completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
466 finish_reason = ?choice.finish_reason,
467 "StepFun completion received"
468 );
469
470 let mut content = Vec::new();
471 let mut has_tool_calls = false;
472
473 if let Some(text) = &choice.message.content {
474 if !text.is_empty() {
475 content.push(ContentPart::Text { text: text.clone() });
476 }
477 }
478
479 if let Some(tool_calls) = &choice.message.tool_calls {
480 has_tool_calls = !tool_calls.is_empty();
481 for tc in tool_calls {
482 content.push(ContentPart::ToolCall {
483 id: tc.id.clone(),
484 name: tc.function.name.clone(),
485 arguments: tc.function.arguments.clone(),
486 });
487 }
488 }
489
490 let finish_reason = if has_tool_calls {
491 FinishReason::ToolCalls
492 } else {
493 match choice.finish_reason.as_deref() {
494 Some("stop") => FinishReason::Stop,
495 Some("length") => FinishReason::Length,
496 Some("tool_calls") => FinishReason::ToolCalls,
497 _ => FinishReason::Stop,
498 }
499 };
500
501 Ok(CompletionResponse {
502 message: Message {
503 role: Role::Assistant,
504 content,
505 },
506 usage: Usage {
507 prompt_tokens: chat_response
508 .usage
509 .as_ref()
510 .map(|u| u.prompt_tokens)
511 .unwrap_or(0),
512 completion_tokens: chat_response
513 .usage
514 .as_ref()
515 .map(|u| u.completion_tokens)
516 .unwrap_or(0),
517 total_tokens: chat_response
518 .usage
519 .as_ref()
520 .map(|u| u.total_tokens)
521 .unwrap_or(0),
522 ..Default::default()
523 },
524 finish_reason,
525 })
526 }
527
528 async fn complete_stream(
529 &self,
530 request: CompletionRequest,
531 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
532 tracing::debug!(
533 provider = "stepfun",
534 model = %request.model,
535 message_count = request.messages.len(),
536 tool_count = request.tools.len(),
537 "Starting streaming completion request"
538 );
539
540 self.validate_api_key()?;
541
542 let messages = self.convert_messages(&request.messages);
543 let tools = self.convert_tools(&request.tools);
544
545 let chat_request = ChatRequest {
546 model: request.model.clone(),
547 messages,
548 tools: if tools.is_empty() { None } else { Some(tools) },
549 temperature: request.temperature,
550 max_tokens: request.max_tokens,
551 stream: Some(true),
552 };
553
554 let response = self
555 .client
556 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
557 .header("Authorization", format!("Bearer {}", self.api_key))
558 .header("Content-Type", "application/json")
559 .json(&chat_request)
560 .send()
561 .await?;
562
563 if !response.status().is_success() {
564 let status = response.status();
565 let body = response.text().await?;
566 anyhow::bail!("StepFun API error ({}): {}", status, body);
567 }
568
569 let stream = response
570 .bytes_stream()
571 .map(|result| match result {
572 Ok(bytes) => {
573 let text = String::from_utf8_lossy(&bytes);
574 let mut chunks = Vec::new();
575
576 for line in text.lines() {
577 if let Some(data) = line.strip_prefix("data: ") {
578 if data.trim() == "[DONE]" {
579 chunks.push(StreamChunk::Done { usage: None });
580 continue;
581 }
582
583 if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
584 if let Some(choice) = chunk.choices.first() {
585 if let Some(content) = &choice.delta.content {
586 chunks.push(StreamChunk::Text(content.clone()));
587 }
588
589 if let Some(tool_calls) = &choice.delta.tool_calls {
590 for tc in tool_calls {
591 if let Some(id) = &tc.id {
592 if let Some(func) = &tc.function {
593 if let Some(name) = &func.name {
594 chunks.push(StreamChunk::ToolCallStart {
595 id: id.clone(),
596 name: name.clone(),
597 });
598 }
599 }
600 }
601 if let Some(func) = &tc.function {
602 if let Some(args) = &func.arguments {
603 if !args.is_empty() {
604 chunks.push(StreamChunk::ToolCallDelta {
605 id: tc.id.clone().unwrap_or_default(),
606 arguments_delta: args.clone(),
607 });
608 }
609 }
610 }
611 }
612 }
613
614 if choice.finish_reason.is_some() {
615 chunks.push(StreamChunk::Done { usage: None });
616 }
617 }
618 }
619 }
620 }
621
622 if chunks.is_empty() {
623 StreamChunk::Text(String::new())
624 } else if chunks.len() == 1 {
625 chunks.pop().unwrap()
626 } else {
627 chunks.remove(0)
629 }
630 }
631 Err(e) => StreamChunk::Error(e.to_string()),
632 })
633 .boxed();
634
635 Ok(stream)
636 }
637}