use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct PerplexityProvider {
base: BaseProvider,
pricing_cache: HashMap<String, ModelPricing>,
}
impl PerplexityProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://api.perplexity.ai".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
pricing_cache: Self::initialize_pricing_cache(),
};
info!(
"Perplexity AI provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"llama-3-sonar-small-32k-chat".to_string(),
ModelPricing {
model: "llama-3-sonar-small-32k-chat".to_string(),
input_cost_per_1k: 0.0002,
output_cost_per_1k: 0.0002,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"llama-3-sonar-small-32k-online".to_string(),
ModelPricing {
model: "llama-3-sonar-small-32k-online".to_string(),
input_cost_per_1k: 0.0002,
output_cost_per_1k: 0.0002,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"llama-3-sonar-large-32k-chat".to_string(),
ModelPricing {
model: "llama-3-sonar-large-32k-chat".to_string(),
input_cost_per_1k: 0.001,
output_cost_per_1k: 0.001,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"llama-3-sonar-large-32k-online".to_string(),
ModelPricing {
model: "llama-3-sonar-large-32k-online".to_string(),
input_cost_per_1k: 0.001,
output_cost_per_1k: 0.001,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"llama-3-8b-instruct".to_string(),
ModelPricing {
model: "llama-3-8b-instruct".to_string(),
input_cost_per_1k: 0.0002,
output_cost_per_1k: 0.0002,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"llama-3-70b-instruct".to_string(),
ModelPricing {
model: "llama-3-70b-instruct".to_string(),
input_cost_per_1k: 0.001,
output_cost_per_1k: 0.001,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"mixtral-8x7b-instruct".to_string(),
ModelPricing {
model: "mixtral-8x7b-instruct".to_string(),
input_cost_per_1k: 0.0006,
output_cost_per_1k: 0.0006,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn convert_messages_to_perplexity(&self, messages: &[ChatMessage]) -> Vec<serde_json::Value> {
messages
.iter()
.map(|message| {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "assistant",
MessageRole::Function => "function", };
let content = match &message.content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => {
parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" ")
}
None => String::new(),
};
json!({
"role": role,
"content": content
})
})
.collect()
}
fn convert_perplexity_response_to_openai(
&self,
perplexity_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let choices = perplexity_response
.get("choices")
.and_then(|c| c.as_array())
.ok_or_else(|| ProviderError::Parsing("No choices in response".to_string()))?;
let openai_choices: Result<Vec<ChatCompletionChoice>> = choices
.iter()
.enumerate()
.map(|(index, choice)| {
let message = choice
.get("message")
.ok_or_else(|| ProviderError::Parsing("No message in choice".to_string()))?;
let role = message
.get("role")
.and_then(|r| r.as_str())
.map(|r| match r {
"assistant" => MessageRole::Assistant,
"user" => MessageRole::User,
"system" => MessageRole::System,
_ => MessageRole::Assistant,
})
.unwrap_or(MessageRole::Assistant);
let content = message
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let finish_reason = choice
.get("finish_reason")
.and_then(|fr| fr.as_str())
.map(|fr| fr.to_string());
Ok(ChatCompletionChoice {
index: index as u32,
message: ChatMessage {
role,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None, tool_call_id: None,
audio: None,
},
finish_reason,
logprobs: None,
})
})
.collect();
let usage = perplexity_response.get("usage").map(|u| Usage {
prompt_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
completion_tokens: u
.get("completion_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
total_tokens: u.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
prompt_tokens_details: None,
completion_tokens_details: None,
});
Ok(ChatCompletionResponse {
id: perplexity_response
.get("id")
.and_then(|id| id.as_str())
.unwrap_or(&format!("chatcmpl-perplexity-{}", uuid::Uuid::new_v4()))
.to_string(),
object: "chat.completion".to_string(),
created: perplexity_response
.get("created")
.and_then(|c| c.as_u64())
.unwrap_or_else(|| chrono::Utc::now().timestamp() as u64),
model: model.to_string(),
choices: openai_choices?
.into_iter()
.map(|choice| ChatChoice {
index: choice.index,
message: choice.message,
logprobs: choice.logprobs.map(|_| Logprobs { content: None }),
finish_reason: choice.finish_reason,
})
.collect(),
usage,
system_fingerprint: None,
})
}
}
#[async_trait]
impl Provider for PerplexityProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("perplexity".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model)
|| model.contains("sonar")
|| model.contains("llama-3")
|| model.contains("mixtral")
}
async fn supports_images(&self) -> bool {
false }
async fn supports_embeddings(&self) -> bool {
false }
async fn supports_streaming(&self) -> bool {
true }
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec![
"llama-3-sonar-small-32k-chat",
"llama-3-sonar-small-32k-online",
"llama-3-sonar-large-32k-chat",
"llama-3-sonar-large-32k-online",
"llama-3-8b-instruct",
"llama-3-70b-instruct",
"mixtral-8x7b-instruct",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "perplexity".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Perplexity AI health check");
let test_messages = vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}];
let body = json!({
"model": "llama-3-8b-instruct",
"messages": self.convert_messages_to_perplexity(&test_messages),
"max_tokens": 1
});
let url = format!("{}/chat/completions", self.base.base_url);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.base.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if response.status().is_success() || response.status().as_u16() == 400 {
Ok(())
} else {
Err(
ProviderError::Unknown(format!("Health check failed: {}", response.status()))
.into(),
)
}
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Perplexity AI chat completion for model: {}", request.model);
let messages = self.convert_messages_to_perplexity(&request.messages);
let mut body = json!({
"model": request.model,
"messages": messages
});
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
body["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
body["top_p"] = json!(top_p);
}
if let Some(stream) = request.stream {
body["stream"] = json!(stream);
}
if let Some(presence_penalty) = request.presence_penalty {
body["presence_penalty"] = json!(presence_penalty);
}
if let Some(frequency_penalty) = request.frequency_penalty {
body["frequency_penalty"] = json!(frequency_penalty);
}
let url = format!("{}/chat/completions", self.base.base_url);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.base.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let perplexity_response: serde_json::Value =
self.base.parse_json_response(response).await?;
self.convert_perplexity_response_to_openai(perplexity_response, &request.model)
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
let chat_request = ChatCompletionRequest {
model: request.model.clone(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text(request.prompt)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}],
max_tokens: request.max_tokens,
max_completion_tokens: None,
temperature: request.temperature.map(|t| t as f32),
top_p: request.top_p.map(|t| t as f32),
n: request.n,
stream: request.stream,
stream_options: None,
stop: request.stop,
presence_penalty: request.presence_penalty.map(|p| p as f32),
frequency_penalty: request.frequency_penalty.map(|f| f as f32),
logit_bias: request
.logit_bias
.map(|lb| lb.into_iter().map(|(k, v)| (k, v as f32)).collect()),
user: request.user,
functions: None,
function_call: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
logprobs: None,
top_logprobs: None,
modalities: None,
audio: None,
};
let chat_response = self.chat_completion(chat_request, _context).await?;
let text = match &chat_response.choices.first().unwrap().message.content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" "),
None => String::new(),
};
Ok(CompletionResponse {
id: chat_response.id.replace("chatcmpl", "cmpl"),
object: "text_completion".to_string(),
created: chat_response.created,
model: request.model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: chat_response.choices.first().unwrap().finish_reason.clone(),
}],
usage: chat_response.usage,
})
}
async fn embedding(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
Err(
ProviderError::InvalidRequest("Embeddings not supported by Perplexity AI".to_string())
.into(),
)
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Image generation not supported by Perplexity AI".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0005,
output_cost_per_1k: 0.0005,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}