use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::provider::Provider;
use crate::types::{
CompletionRequest, CompletionResponse, ContentBlock, ContentDelta, Role, StopReason,
StreamChunk, StreamEventType, Usage,
};
const ALIBABA_API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";
pub struct AlibabaProvider {
api_key: String,
client: Client,
}
impl AlibabaProvider {
pub fn new(api_key: &str) -> Result<Self> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()?;
Ok(Self {
api_key: api_key.to_string(),
client,
})
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("ALIBABA_API_KEY")
.map_err(|_| Error::config("ALIBABA_API_KEY environment variable not set"))?;
Self::new(&api_key)
}
pub async fn list_models(&self) -> Result<Vec<String>> {
Ok(vec![
"qwen-turbo".to_string(),
"qwen-plus".to_string(),
"qwen-max".to_string(),
"qwen-max-longcontext".to_string(),
"qwen-vl-plus".to_string(),
"qwen-vl-max".to_string(),
"qwen-coder-turbo".to_string(),
"qwen-coder-max".to_string(),
"llama-2-7b-chat".to_string(),
"llama-2-13b-chat".to_string(),
"llama-2-70b-chat".to_string(),
"mistral-7b-instruct".to_string(),
"baichuan-2-7b-chat".to_string(),
"baichuan-2-13b-chat".to_string(),
])
}
fn handle_error_response(&self, status: reqwest::StatusCode, body: &str) -> Error {
match status.as_u16() {
401 | 403 => Error::auth(format!("Alibaba authentication failed: {}", body)),
429 => Error::rate_limited("Alibaba rate limited", None),
500..=599 => Error::server(status.as_u16(), body.to_string()),
_ => Error::other(format!("Alibaba error ({}): {}", status, body)),
}
}
pub fn get_model_info(model: &str) -> Option<AlibabaModelInfo> {
match model {
"qwen-turbo" => Some(AlibabaModelInfo {
name: "qwen-turbo".to_string(),
context_window: 8000,
supports_vision: false,
supports_function_call: true,
max_output_tokens: 2000,
}),
"qwen-plus" => Some(AlibabaModelInfo {
name: "qwen-plus".to_string(),
context_window: 32000,
supports_vision: false,
supports_function_call: true,
max_output_tokens: 4000,
}),
"qwen-max" => Some(AlibabaModelInfo {
name: "qwen-max".to_string(),
context_window: 32000,
supports_vision: false,
supports_function_call: true,
max_output_tokens: 8000,
}),
"qwen-max-longcontext" => Some(AlibabaModelInfo {
name: "qwen-max-longcontext".to_string(),
context_window: 200000,
supports_vision: false,
supports_function_call: true,
max_output_tokens: 8000,
}),
"qwen-vl-plus" => Some(AlibabaModelInfo {
name: "qwen-vl-plus".to_string(),
context_window: 16000,
supports_vision: true,
supports_function_call: false,
max_output_tokens: 1000,
}),
"qwen-vl-max" => Some(AlibabaModelInfo {
name: "qwen-vl-max".to_string(),
context_window: 32000,
supports_vision: true,
supports_function_call: false,
max_output_tokens: 2000,
}),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlibabaModelInfo {
pub name: String,
pub context_window: u32,
pub supports_vision: bool,
pub supports_function_call: bool,
pub max_output_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModelSpecialization {
#[default]
General,
Vision,
Code,
Math,
}
#[derive(Debug, Serialize)]
struct AlibabaTextGenRequest {
model: String,
input: AlibabaInput,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<AlibabaParameters>,
}
#[derive(Debug, Serialize)]
struct AlibabaInput {
messages: Vec<AlibabaMessage>,
}
#[derive(Debug, Serialize)]
struct AlibabaMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct AlibabaParameters {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct AlibabaTextGenResponse {
output: Option<AlibabaOutput>,
usage: Option<AlibabaUsage>,
code: Option<String>,
message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AlibabaOutput {
#[serde(default)]
text: String,
}
#[derive(Debug, Deserialize)]
struct AlibabaUsage {
input_tokens: Option<u32>,
output_tokens: Option<u32>,
}
#[async_trait]
impl Provider for AlibabaProvider {
fn name(&self) -> &str {
"alibaba"
}
fn default_model(&self) -> Option<&str> {
Some("qwen-max-longcontext")
}
fn supports_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
true
}
fn supports_streaming(&self) -> bool {
false
}
fn supported_models(&self) -> Option<&[&str]> {
Some(&[
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
"qwen-vl-plus",
"qwen-vl-max",
"qwen-coder-turbo",
"qwen-coder-max",
"llama-2-7b-chat",
"llama-2-13b-chat",
"llama-2-70b-chat",
"mistral-7b-instruct",
"baichuan-2-7b-chat",
"baichuan-2-13b-chat",
])
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let model = if request.model.is_empty() {
"qwen-max-longcontext".to_string()
} else {
request.model.clone()
};
let mut messages = Vec::new();
for msg in &request.messages {
let role = match msg.role {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::System => "system".to_string(),
};
for content in &msg.content {
if let ContentBlock::Text { text } = content {
messages.push(AlibabaMessage {
role: role.clone(),
content: text.clone(),
});
}
}
}
let alibaba_request = AlibabaTextGenRequest {
model: model.clone(),
input: AlibabaInput { messages },
parameters: Some(AlibabaParameters {
temperature: request.temperature,
top_p: request.top_p,
max_tokens: request.max_tokens,
}),
};
let response = self
.client
.post(ALIBABA_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&alibaba_request)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(self.handle_error_response(status, &body));
}
let alibaba_response: AlibabaTextGenResponse = serde_json::from_str(&body)
.map_err(|e| Error::other(format!("Failed to parse response: {}", e)))?;
if let Some(code) = alibaba_response.code {
if code != "200" {
return Err(Error::other(format!(
"Alibaba API error {}: {}",
code,
alibaba_response
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
}
let response_text = alibaba_response
.output
.map(|o| o.text)
.filter(|t| !t.is_empty())
.unwrap_or_else(|| "No response from Alibaba".to_string());
Ok(CompletionResponse {
id: uuid::Uuid::new_v4().to_string(),
model,
content: vec![ContentBlock::Text {
text: response_text,
}],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: alibaba_response
.usage
.as_ref()
.and_then(|u| u.input_tokens)
.unwrap_or(0),
output_tokens: alibaba_response
.usage
.as_ref()
.and_then(|u| u.output_tokens)
.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
})
}
async fn complete_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
let response = self.complete(request).await?;
let stream = async_stream::try_stream! {
yield StreamChunk {
event_type: StreamEventType::MessageStart,
index: None,
delta: None,
stop_reason: None,
usage: None,
};
for block in response.content {
if let ContentBlock::Text { text } = block {
yield StreamChunk {
event_type: StreamEventType::ContentBlockDelta,
index: Some(0),
delta: Some(ContentDelta::Text { text }),
stop_reason: None,
usage: None,
};
}
}
yield StreamChunk {
event_type: StreamEventType::MessageStop,
index: None,
delta: None,
stop_reason: Some(StopReason::EndTurn),
usage: Some(response.usage),
};
};
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alibaba_provider_creation() {
let provider = AlibabaProvider::new("test-key").unwrap();
assert_eq!(provider.api_key, "test-key");
}
#[tokio::test]
async fn test_list_models() {
let provider = AlibabaProvider::new("test-key").unwrap();
let models = provider.list_models().await.unwrap();
assert!(!models.is_empty());
assert!(models.contains(&"qwen-max".to_string()));
assert!(models.contains(&"qwen-vl-max".to_string()));
assert!(models.contains(&"llama-2-70b-chat".to_string()));
assert!(models.contains(&"mistral-7b-instruct".to_string()));
assert!(models.contains(&"baichuan-2-13b-chat".to_string()));
}
#[test]
fn test_get_model_info() {
let info = AlibabaProvider::get_model_info("qwen-max").unwrap();
assert_eq!(info.name, "qwen-max");
assert!(info.supports_function_call);
assert!(!info.supports_vision);
assert_eq!(info.context_window, 32000);
}
#[test]
fn test_vision_model_info() {
let info = AlibabaProvider::get_model_info("qwen-vl-max").unwrap();
assert!(info.supports_vision);
assert!(!info.supports_function_call);
}
#[test]
fn test_model_info_invalid() {
let info = AlibabaProvider::get_model_info("invalid-model");
assert!(info.is_none());
}
#[test]
fn test_model_specialization_default() {
assert_eq!(ModelSpecialization::default(), ModelSpecialization::General);
}
#[test]
fn test_qwen_context_windows() {
let turbo = AlibabaProvider::get_model_info("qwen-turbo").unwrap();
let plus = AlibabaProvider::get_model_info("qwen-plus").unwrap();
let max = AlibabaProvider::get_model_info("qwen-max").unwrap();
let long = AlibabaProvider::get_model_info("qwen-max-longcontext").unwrap();
assert!(turbo.context_window < plus.context_window);
assert_eq!(plus.context_window, max.context_window);
assert!(max.context_window < long.context_window);
}
}