use super::{ChatStreamChunk, LlmProvider, Message, MessagePart, TokenUsage, ToolCallChunk};
use anyhow::Result;
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
use std::pin::Pin;
use std::time::Duration;
pub struct QwenProvider {
client: Client,
api_key: String,
}
#[derive(Serialize)]
struct QwenChatRequest {
model: String,
input: QwenChatInput,
parameters: QwenChatParameters,
}
#[derive(Serialize)]
struct QwenChatInput {
messages: Vec<QwenMessage>,
}
#[derive(Serialize, Deserialize)]
struct QwenMessage {
role: String,
content: serde_json::Value, }
#[derive(Serialize)]
struct QwenChatParameters {
result_format: String, incremental_output: bool,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
}
#[derive(Deserialize)]
struct QwenResponseChunk {
output: QwenOutput,
usage: Option<QwenUsage>,
}
#[derive(Deserialize)]
struct QwenUsage {
input_tokens: Option<i64>,
output_tokens: Option<i64>,
total_tokens: Option<i64>,
}
#[derive(Deserialize)]
struct QwenOutput {
choices: Vec<QwenChoice>,
}
#[derive(Deserialize)]
struct QwenChoice {
message: QwenMessage,
}
impl QwenProvider {
pub fn new() -> Self {
let client = Client::builder()
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(300))
.build()
.unwrap_or_default();
Self {
client,
api_key: env::var("QWEN_API_KEY").unwrap_or_default(),
}
}
fn build_qwen_content(&self, parts_json: &serde_json::Value) -> (serde_json::Value, bool) {
let mut has_image = false;
let mut content_items: Vec<serde_json::Value> = Vec::new();
if let Ok(parts) = serde_json::from_value::<Vec<MessagePart>>(parts_json.clone()) {
for part in &parts {
match part.part_type.as_str() {
"image" => {
if let Some(url_val) = &part.data {
if let Some(url_str) = url_val.as_str() {
has_image = true;
content_items.push(serde_json::json!({"image": url_str}));
}
}
}
"text" => {
if let Some(text) = &part.content {
if !text.is_empty() {
content_items.push(serde_json::json!({"text": text}));
}
}
}
"kline" | "position" | "balance" | "news" | "order" | "symbolInfo" => {
if let Some(data) = &part.data {
let label = match part.part_type.as_str() {
"kline" => "Market Chart Data",
"position" => "User Positions",
"balance" => "Wallet Balance",
"news" => "News Article",
_ => "Attached Data",
};
let text_block =
format!("\n--- {} ---\n{}\n--- End {} ---\n", label, data, label);
content_items.push(serde_json::json!({"text": text_block}));
}
}
_ => {
if let Some(text) = &part.content {
if !text.is_empty() {
content_items.push(serde_json::json!({"text": text}));
}
}
}
}
}
}
if has_image {
(serde_json::Value::Array(content_items), true)
} else {
let text = content_items
.iter()
.filter_map(|item| item.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join("\n");
(serde_json::Value::String(text), false)
}
}
fn extract_text(&self, parts_json: &serde_json::Value) -> String {
if let Ok(parts) = serde_json::from_value::<Vec<MessagePart>>(parts_json.clone()) {
parts
.iter()
.filter_map(|p| p.content.clone())
.collect::<Vec<_>>()
.join("\n")
} else {
String::new()
}
}
}
#[async_trait]
impl LlmProvider for QwenProvider {
async fn stream_chat(
&self,
model: &str,
system_prompt: Option<String>,
history: Vec<Message>,
_tools: Option<Vec<serde_json::Value>>, ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatStreamChunk>> + Send>>> {
let mut qwen_messages = Vec::new();
let mut use_multimodal = false;
if let Some(sp) = system_prompt {
qwen_messages.push(QwenMessage {
role: "system".to_string(),
content: serde_json::Value::String(sp),
});
}
for msg in history {
if msg.role == "user" {
let (content, has_image) = self.build_qwen_content(&msg.parts);
if has_image {
use_multimodal = true;
}
qwen_messages.push(QwenMessage {
role: msg.role,
content,
});
} else {
let text = self.extract_text(&msg.parts);
qwen_messages.push(QwenMessage {
role: msg.role,
content: serde_json::Value::String(text),
});
}
}
let request = QwenChatRequest {
model: model.to_string(),
input: QwenChatInput {
messages: qwen_messages,
},
parameters: QwenChatParameters {
result_format: "message".to_string(),
incremental_output: true,
temperature: Some(0.7),
},
};
let endpoint = if use_multimodal {
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
} else {
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
};
let response = self
.client
.post(endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("X-DashScope-SSE", "enable") .json(&request)
.send()
.await?;
if !response.status().is_success() {
let txt = response.text().await?;
return Err(anyhow::anyhow!("DashScope Chat Error: {}", txt));
}
let stream = response.bytes_stream().eventsource();
let mapped = stream.map(|event_res| {
let event = event_res.map_err(|e| anyhow::anyhow!("SSE Error: {}", e))?;
let chunk: QwenResponseChunk = serde_json::from_str(&event.data)?;
let text = chunk
.output
.choices
.first()
.and_then(|c| match &c.message.content {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Array(arr) => {
let texts: Vec<&str> = arr
.iter()
.filter_map(|item| item.get("text").and_then(|v| v.as_str()))
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(""))
}
}
_ => None,
});
let usage = chunk.usage.map(|u| TokenUsage {
input_tokens: u.input_tokens.unwrap_or(0),
output_tokens: u.output_tokens.unwrap_or(0),
total_tokens: u.total_tokens.unwrap_or(0),
});
Ok(ChatStreamChunk {
content: text,
tool_calls: vec![],
usage,
finish_reason: None,
})
});
Ok(Box::pin(mapped))
}
}