use crate::provider_api::{
FinishReason, LlmError, LlmProvider, LlmRequest, LlmResponse, TokenUsage,
};
use crate::secret::{EnvSecretProvider, SecretProvider, SecretString};
use serde::{Deserialize, Serialize};
pub struct QwenProvider {
api_key: SecretString,
model: String,
client: reqwest::blocking::Client,
base_url: String,
}
impl QwenProvider {
#[must_use]
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: SecretString::new(api_key),
model: model.into(),
client: reqwest::blocking::Client::new(),
base_url: "https://dashscope.aliyuncs.com".into(),
}
}
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
Self::from_secret_provider(&EnvSecretProvider, model)
}
pub fn from_secret_provider(
secrets: &dyn SecretProvider,
model: impl Into<String>,
) -> Result<Self, LlmError> {
let api_key = secrets
.get_secret("QWEN_API_KEY")
.map_err(|e| LlmError::auth(format!("QWEN_API_KEY: {e}")))?;
Ok(Self {
api_key,
model: model.into(),
client: reqwest::blocking::Client::new(),
base_url: "https://dashscope.aliyuncs.com".into(),
})
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
}
#[derive(Serialize)]
struct QwenRequest {
model: String,
input: QwenInput,
parameters: QwenParameters,
}
#[derive(Serialize)]
struct QwenInput {
messages: Vec<QwenMessage>,
}
#[derive(Serialize)]
struct QwenMessage {
role: String,
content: String,
}
#[derive(Serialize)]
struct QwenParameters {
max_tokens: u32,
temperature: f64,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop: Vec<String>,
}
#[derive(Deserialize)]
struct QwenResponse {
output: QwenOutput,
usage: QwenUsage,
#[allow(dead_code)]
request_id: String,
}
#[derive(Deserialize)]
struct QwenOutput {
choices: Vec<QwenChoice>,
}
#[derive(Deserialize)]
struct QwenChoice {
message: QwenChoiceMessage,
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct QwenChoiceMessage {
content: String,
}
#[derive(Deserialize)]
#[allow(clippy::struct_field_names)] struct QwenUsage {
input_tokens: u32,
output_tokens: u32,
total_tokens: u32,
}
#[derive(Deserialize)]
struct QwenError {
code: Option<String>,
message: String,
#[allow(dead_code)]
request_id: Option<String>,
}
impl LlmProvider for QwenProvider {
fn name(&self) -> &'static str {
"qwen"
}
fn model(&self) -> &str {
&self.model
}
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
let url = format!(
"{}/api/v1/services/aigc/text-generation/generation",
self.base_url
);
let mut messages = Vec::new();
if let Some(ref system) = request.system {
messages.push(QwenMessage {
role: "system".to_string(),
content: system.clone(),
});
}
messages.push(QwenMessage {
role: "user".to_string(),
content: request.prompt.clone(),
});
let body = QwenRequest {
model: self.model.clone(),
input: QwenInput { messages },
parameters: QwenParameters {
max_tokens: request.max_tokens,
temperature: request.temperature,
stop: request.stop_sequences.clone(),
},
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key.expose()))
.header("Content-Type", "application/json")
.json(&body)
.send()
.map_err(|e| LlmError::network(format!("Request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let error_body: QwenError = response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse error: {e}")))?;
let code = error_body.code.as_deref().unwrap_or("unknown");
return match code {
"InvalidApiKey" | "InvalidParameter" => Err(LlmError::auth(error_body.message)),
"Throttling" => Err(LlmError::rate_limit(error_body.message)),
_ => Err(LlmError::provider(error_body.message)),
};
}
let api_response: QwenResponse = response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse response: {e}")))?;
let content = api_response
.output
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
let finish_reason = match api_response
.output
.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
{
Some("length") => FinishReason::MaxTokens,
_ => FinishReason::Stop, };
Ok(LlmResponse {
content,
model: self.model.clone(),
usage: TokenUsage {
prompt_tokens: api_response.usage.input_tokens,
completion_tokens: api_response.usage.output_tokens,
total_tokens: api_response.usage.total_tokens,
},
finish_reason,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provider_has_correct_name() {
let provider = QwenProvider::new("test-key", "qwen-turbo");
assert_eq!(provider.name(), "qwen");
assert_eq!(provider.model(), "qwen-turbo");
}
}