1use super::{
6 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
7 Role, StreamChunk, ToolDefinition, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14const STEPFUN_API_BASE: &str = "https://api.stepfun.ai/v1";
15
16pub struct StepFunProvider {
17 api_key: String,
18 client: reqwest::Client,
19}
20
21impl std::fmt::Debug for StepFunProvider {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("StepFunProvider")
24 .field("api_key", &"<REDACTED>")
25 .field("api_key_len", &self.api_key.len())
26 .field("client", &"<reqwest::Client>")
27 .finish()
28 }
29}
30
31impl StepFunProvider {
32 pub fn new(api_key: String) -> Result<Self> {
33 tracing::debug!(
34 provider = "stepfun",
35 api_key_len = api_key.len(),
36 "Creating StepFun provider"
37 );
38 Ok(Self {
39 api_key,
40 client: reqwest::Client::new(),
41 })
42 }
43
44 fn validate_api_key(&self) -> Result<()> {
46 if self.api_key.is_empty() {
47 anyhow::bail!("StepFun API key is empty");
48 }
49 if self.api_key.len() < 10 {
50 tracing::warn!(provider = "stepfun", "API key seems unusually short");
51 }
52 Ok(())
53 }
54}
55
56#[derive(Debug, Serialize)]
59struct ChatRequest {
60 model: String,
61 messages: Vec<ChatMessage>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 tools: Option<Vec<ChatTool>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 temperature: Option<f32>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 max_tokens: Option<usize>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 stream: Option<bool>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73struct ChatMessage {
74 role: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 content: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 tool_calls: Option<Vec<ToolCall>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 tool_call_id: Option<String>,
81}
82
83#[derive(Debug, Serialize)]
84struct ChatTool {
85 r#type: String,
86 function: ChatFunction,
87}
88
89#[derive(Debug, Serialize)]
90struct ChatFunction {
91 name: String,
92 description: String,
93 parameters: serde_json::Value,
94}
95
96#[derive(Debug, Deserialize)]
99struct ChatResponse {
100 id: String,
101 choices: Vec<ChatChoice>,
102 usage: Option<ChatUsage>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ChatChoice {
107 index: usize,
108 message: ChatResponseMessage,
109 finish_reason: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ChatResponseMessage {
114 role: String,
115 #[serde(default)]
116 content: Option<String>,
117 #[serde(default)]
118 tool_calls: Option<Vec<ToolCall>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122struct ToolCall {
123 id: String,
124 r#type: String,
125 function: ToolCallFunction,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129struct ToolCallFunction {
130 name: String,
131 arguments: String,
132}
133
134#[derive(Debug, Deserialize)]
135struct ChatUsage {
136 prompt_tokens: usize,
137 completion_tokens: usize,
138 total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142struct ErrorResponse {
143 error: ErrorDetail,
144}
145
146#[derive(Debug, Deserialize)]
147struct ErrorDetail {
148 message: String,
149 #[serde(default)]
150 code: Option<String>,
151}
152
153#[derive(Debug, Deserialize)]
156struct StreamChunkResponse {
157 choices: Vec<StreamChoice>,
158}
159
160#[derive(Debug, Deserialize)]
161struct StreamChoice {
162 delta: StreamDelta,
163 finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct StreamDelta {
168 #[serde(default)]
169 content: Option<String>,
170 #[serde(default)]
171 tool_calls: Option<Vec<StreamToolCall>>,
172}
173
174#[derive(Debug, Deserialize)]
175struct StreamToolCall {
176 #[allow(dead_code)]
177 index: usize,
178 #[serde(default)]
179 id: Option<String>,
180 #[serde(default)]
181 function: Option<StreamToolFunction>,
182}
183
184#[derive(Debug, Deserialize)]
185struct StreamToolFunction {
186 #[serde(default)]
187 name: Option<String>,
188 #[serde(default)]
189 arguments: Option<String>,
190}
191
192impl StepFunProvider {
193 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
194 let mut result = Vec::new();
195
196 for msg in messages {
197 match msg.role {
198 Role::System => {
199 let content = msg
200 .content
201 .iter()
202 .filter_map(|p| match p {
203 ContentPart::Text { text } => Some(text.clone()),
204 _ => None,
205 })
206 .collect::<Vec<_>>()
207 .join("\n");
208 result.push(ChatMessage {
209 role: "system".to_string(),
210 content: Some(content),
211 tool_calls: None,
212 tool_call_id: None,
213 });
214 }
215 Role::User => {
216 let content = msg
217 .content
218 .iter()
219 .filter_map(|p| match p {
220 ContentPart::Text { text } => Some(text.clone()),
221 _ => None,
222 })
223 .collect::<Vec<_>>()
224 .join("\n");
225 result.push(ChatMessage {
226 role: "user".to_string(),
227 content: Some(content),
228 tool_calls: None,
229 tool_call_id: None,
230 });
231 }
232 Role::Assistant => {
233 let content = msg
234 .content
235 .iter()
236 .filter_map(|p| match p {
237 ContentPart::Text { text } => Some(text.clone()),
238 _ => None,
239 })
240 .collect::<Vec<_>>()
241 .join("\n");
242
243 let tool_calls: Vec<ToolCall> = msg
244 .content
245 .iter()
246 .filter_map(|p| match p {
247 ContentPart::ToolCall { id, name, arguments } => Some(ToolCall {
248 id: id.clone(),
249 r#type: "function".to_string(),
250 function: ToolCallFunction {
251 name: name.clone(),
252 arguments: arguments.clone(),
253 },
254 }),
255 _ => None,
256 })
257 .collect();
258
259 result.push(ChatMessage {
260 role: "assistant".to_string(),
261 content: if content.is_empty() && !tool_calls.is_empty() {
263 Some(String::new())
264 } else if content.is_empty() {
265 None
266 } else {
267 Some(content)
268 },
269 tool_calls: if tool_calls.is_empty() {
270 None
271 } else {
272 Some(tool_calls)
273 },
274 tool_call_id: None,
275 });
276 }
277 Role::Tool => {
278 for part in &msg.content {
279 if let ContentPart::ToolResult {
280 tool_call_id,
281 content,
282 } = part
283 {
284 result.push(ChatMessage {
285 role: "tool".to_string(),
286 content: Some(content.clone()),
287 tool_calls: None,
288 tool_call_id: Some(tool_call_id.clone()),
289 });
290 }
291 }
292 }
293 }
294 }
295
296 result
297 }
298
299 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<ChatTool> {
300 tools
301 .iter()
302 .map(|t| ChatTool {
303 r#type: "function".to_string(),
304 function: ChatFunction {
305 name: t.name.clone(),
306 description: t.description.clone(),
307 parameters: t.parameters.clone(),
308 },
309 })
310 .collect()
311 }
312}
313
314#[async_trait]
315impl Provider for StepFunProvider {
316 fn name(&self) -> &str {
317 "stepfun"
318 }
319
320 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
321 Ok(vec![
322 ModelInfo {
323 id: "step-3.5-flash".to_string(),
324 name: "Step 3.5 Flash".to_string(),
325 provider: "stepfun".to_string(),
326 context_window: 128_000,
327 max_output_tokens: Some(8192),
328 supports_vision: false,
329 supports_tools: true,
330 supports_streaming: true,
331 input_cost_per_million: Some(0.0), output_cost_per_million: Some(0.0),
333 },
334 ModelInfo {
335 id: "step-1-8k".to_string(),
336 name: "Step 1 8K".to_string(),
337 provider: "stepfun".to_string(),
338 context_window: 8_000,
339 max_output_tokens: Some(4096),
340 supports_vision: false,
341 supports_tools: true,
342 supports_streaming: true,
343 input_cost_per_million: Some(0.5),
344 output_cost_per_million: Some(1.5),
345 },
346 ModelInfo {
347 id: "step-1-32k".to_string(),
348 name: "Step 1 32K".to_string(),
349 provider: "stepfun".to_string(),
350 context_window: 32_000,
351 max_output_tokens: Some(8192),
352 supports_vision: false,
353 supports_tools: true,
354 supports_streaming: true,
355 input_cost_per_million: Some(1.0),
356 output_cost_per_million: Some(3.0),
357 },
358 ModelInfo {
359 id: "step-1-128k".to_string(),
360 name: "Step 1 128K".to_string(),
361 provider: "stepfun".to_string(),
362 context_window: 128_000,
363 max_output_tokens: Some(8192),
364 supports_vision: false,
365 supports_tools: true,
366 supports_streaming: true,
367 input_cost_per_million: Some(2.0),
368 output_cost_per_million: Some(6.0),
369 },
370 ModelInfo {
371 id: "step-1v-8k".to_string(),
372 name: "Step 1 Vision 8K".to_string(),
373 provider: "stepfun".to_string(),
374 context_window: 8_000,
375 max_output_tokens: Some(4096),
376 supports_vision: true,
377 supports_tools: true,
378 supports_streaming: true,
379 input_cost_per_million: Some(1.0),
380 output_cost_per_million: Some(3.0),
381 },
382 ])
383 }
384
385 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
386 tracing::debug!(
387 provider = "stepfun",
388 model = %request.model,
389 message_count = request.messages.len(),
390 tool_count = request.tools.len(),
391 "Starting completion request"
392 );
393
394 self.validate_api_key()?;
396
397 let messages = self.convert_messages(&request.messages);
398 let tools = self.convert_tools(&request.tools);
399
400 let chat_request = ChatRequest {
401 model: request.model.clone(),
402 messages,
403 tools: if tools.is_empty() { None } else { Some(tools) },
404 temperature: request.temperature,
405 max_tokens: request.max_tokens,
406 stream: Some(false),
407 };
408
409 if let Ok(json_str) = serde_json::to_string_pretty(&chat_request) {
411 tracing::debug!("StepFun request: {}", json_str);
412 }
413
414 let response = self
415 .client
416 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
417 .header("Authorization", format!("Bearer {}", self.api_key))
418 .header("Content-Type", "application/json")
419 .json(&chat_request)
420 .send()
421 .await?;
422
423 let status = response.status();
424 let body = response.text().await?;
425
426 if !status.is_success() {
427 if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
428 if let Some(ref code) = err.error.code {
430 tracing::error!(error_code = %code, "StepFun API error code");
431 }
432 anyhow::bail!("StepFun API error: {}", err.error.message);
433 }
434 anyhow::bail!("StepFun API error ({}): {}", status, body);
435 }
436
437 let chat_response: ChatResponse = serde_json::from_str(&body)
438 .map_err(|e| anyhow::anyhow!("Failed to parse response: {} - Body: {}", e, body))?;
439
440 tracing::debug!(
442 response_id = %chat_response.id,
443 "Received StepFun response"
444 );
445
446 let choice = chat_response
447 .choices
448 .first()
449 .ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
450
451 tracing::debug!(
453 choice_index = choice.index,
454 message_role = %choice.message.role,
455 "Processing StepFun choice"
456 );
457
458 tracing::info!(
460 prompt_tokens = chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
461 completion_tokens = chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
462 finish_reason = ?choice.finish_reason,
463 "StepFun completion received"
464 );
465
466 let mut content = Vec::new();
467 let mut has_tool_calls = false;
468
469 if let Some(text) = &choice.message.content {
470 if !text.is_empty() {
471 content.push(ContentPart::Text { text: text.clone() });
472 }
473 }
474
475 if let Some(tool_calls) = &choice.message.tool_calls {
476 has_tool_calls = !tool_calls.is_empty();
477 for tc in tool_calls {
478 content.push(ContentPart::ToolCall {
479 id: tc.id.clone(),
480 name: tc.function.name.clone(),
481 arguments: tc.function.arguments.clone(),
482 });
483 }
484 }
485
486 let finish_reason = if has_tool_calls {
487 FinishReason::ToolCalls
488 } else {
489 match choice.finish_reason.as_deref() {
490 Some("stop") => FinishReason::Stop,
491 Some("length") => FinishReason::Length,
492 Some("tool_calls") => FinishReason::ToolCalls,
493 _ => FinishReason::Stop,
494 }
495 };
496
497 Ok(CompletionResponse {
498 message: Message {
499 role: Role::Assistant,
500 content,
501 },
502 usage: Usage {
503 prompt_tokens: chat_response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
504 completion_tokens: chat_response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
505 total_tokens: chat_response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
506 ..Default::default()
507 },
508 finish_reason,
509 })
510 }
511
512 async fn complete_stream(
513 &self,
514 request: CompletionRequest,
515 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
516 tracing::debug!(
517 provider = "stepfun",
518 model = %request.model,
519 message_count = request.messages.len(),
520 tool_count = request.tools.len(),
521 "Starting streaming completion request"
522 );
523
524 self.validate_api_key()?;
525
526 let messages = self.convert_messages(&request.messages);
527 let tools = self.convert_tools(&request.tools);
528
529 let chat_request = ChatRequest {
530 model: request.model.clone(),
531 messages,
532 tools: if tools.is_empty() { None } else { Some(tools) },
533 temperature: request.temperature,
534 max_tokens: request.max_tokens,
535 stream: Some(true),
536 };
537
538 let response = self
539 .client
540 .post(format!("{}/chat/completions", STEPFUN_API_BASE))
541 .header("Authorization", format!("Bearer {}", self.api_key))
542 .header("Content-Type", "application/json")
543 .json(&chat_request)
544 .send()
545 .await?;
546
547 if !response.status().is_success() {
548 let status = response.status();
549 let body = response.text().await?;
550 anyhow::bail!("StepFun API error ({}): {}", status, body);
551 }
552
553 let stream = response
554 .bytes_stream()
555 .map(|result| match result {
556 Ok(bytes) => {
557 let text = String::from_utf8_lossy(&bytes);
558 let mut chunks = Vec::new();
559
560 for line in text.lines() {
561 if let Some(data) = line.strip_prefix("data: ") {
562 if data.trim() == "[DONE]" {
563 chunks.push(StreamChunk::Done { usage: None });
564 continue;
565 }
566
567 if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
568 if let Some(choice) = chunk.choices.first() {
569 if let Some(content) = &choice.delta.content {
570 chunks.push(StreamChunk::Text(content.clone()));
571 }
572
573 if let Some(tool_calls) = &choice.delta.tool_calls {
574 for tc in tool_calls {
575 if let Some(id) = &tc.id {
576 if let Some(func) = &tc.function {
577 if let Some(name) = &func.name {
578 chunks.push(StreamChunk::ToolCallStart {
579 id: id.clone(),
580 name: name.clone(),
581 });
582 }
583 }
584 }
585 if let Some(func) = &tc.function {
586 if let Some(args) = &func.arguments {
587 if !args.is_empty() {
588 chunks.push(StreamChunk::ToolCallDelta {
589 id: tc.id.clone().unwrap_or_default(),
590 arguments_delta: args.clone(),
591 });
592 }
593 }
594 }
595 }
596 }
597
598 if choice.finish_reason.is_some() {
599 chunks.push(StreamChunk::Done { usage: None });
600 }
601 }
602 }
603 }
604 }
605
606 if chunks.is_empty() {
607 StreamChunk::Text(String::new())
608 } else if chunks.len() == 1 {
609 chunks.pop().unwrap()
610 } else {
611 chunks.remove(0)
613 }
614 }
615 Err(e) => StreamChunk::Error(e.to_string()),
616 })
617 .boxed();
618
619 Ok(stream)
620 }
621}