1use crate::client::BearerAuth;
15use crate::completion::CompletionRequest;
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 client::{
22 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
23 },
24 completion::{self, CompletionError, MessageError, message},
25 http_client::{self, HttpClientExt},
26};
27use bytes::Bytes;
28use serde::{Deserialize, Serialize};
29use tracing::{Instrument, info_span};
30
31const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityExt;
38
39#[derive(Debug, Default, Clone, Copy)]
40pub struct PerplexityBuilder;
41
42type PerplexityApiKey = BearerAuth;
43
44impl Provider for PerplexityExt {
45 type Builder = PerplexityBuilder;
46
47 const VERIFY_PATH: &'static str = "";
49}
50
51impl<H> Capabilities<H> for PerplexityExt {
52 type Completion = Capable<CompletionModel<H>>;
53 type Transcription = Nothing;
54 type Embeddings = Nothing;
55 type ModelListing = Nothing;
56 #[cfg(feature = "image")]
57 type ImageGeneration = Nothing;
58
59 #[cfg(feature = "audio")]
60 type AudioGeneration = Nothing;
61}
62
63impl DebugExt for PerplexityExt {}
64
65impl ProviderBuilder for PerplexityBuilder {
66 type Extension<H>
67 = PerplexityExt
68 where
69 H: HttpClientExt;
70 type ApiKey = PerplexityApiKey;
71
72 const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
73
74 fn build<H>(
75 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
76 ) -> http_client::Result<Self::Extension<H>>
77 where
78 H: HttpClientExt,
79 {
80 Ok(PerplexityExt)
81 }
82}
83
84pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
85pub type ClientBuilder<H = crate::markers::Missing> =
86 client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
87
88impl ProviderClient for Client {
89 type Input = String;
90 type Error = crate::client::ProviderClientError;
91
92 fn from_env() -> Result<Self, Self::Error> {
94 let api_key = crate::client::required_env_var("PERPLEXITY_API_KEY")?;
95 Self::new(&api_key).map_err(Into::into)
96 }
97
98 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
99 Self::new(&input).map_err(Into::into)
100 }
101}
102
103#[derive(Debug, Deserialize)]
104struct ApiErrorResponse {
105 message: String,
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(untagged)]
110enum ApiResponse<T> {
111 Ok(T),
112 Err(ApiErrorResponse),
113}
114
115pub const SONAR_PRO: &str = "sonar_pro";
120pub const SONAR: &str = "sonar";
121
122#[derive(Debug, Deserialize, Serialize)]
123pub struct CompletionResponse {
124 pub id: String,
125 pub model: String,
126 pub object: String,
127 pub created: u64,
128 #[serde(default)]
129 pub choices: Vec<Choice>,
130 pub usage: Usage,
131}
132
133#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
134pub struct Message {
135 pub role: Role,
136 pub content: String,
137}
138
139#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
140#[serde(rename_all = "lowercase")]
141pub enum Role {
142 System,
143 User,
144 Assistant,
145}
146
147#[derive(Deserialize, Debug, Serialize)]
148pub struct Delta {
149 pub role: Role,
150 pub content: String,
151}
152
153#[derive(Deserialize, Debug, Serialize)]
154pub struct Choice {
155 pub index: usize,
156 pub finish_reason: String,
157 pub message: Message,
158 pub delta: Delta,
159}
160
161#[derive(Deserialize, Debug, Serialize)]
162pub struct Usage {
163 pub prompt_tokens: u32,
164 pub completion_tokens: u32,
165 pub total_tokens: u32,
166}
167
168impl std::fmt::Display for Usage {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 write!(
171 f,
172 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
173 self.prompt_tokens, self.completion_tokens, self.total_tokens
174 )
175 }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179 type Error = CompletionError;
180
181 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182 let choice = response.choices.first().ok_or_else(|| {
183 CompletionError::ResponseError("Response contained no choices".to_owned())
184 })?;
185
186 match &choice.message {
187 Message {
188 role: Role::Assistant,
189 content,
190 } => Ok(completion::CompletionResponse {
191 choice: OneOrMany::one(content.clone().into()),
192 usage: completion::Usage {
193 input_tokens: response.usage.prompt_tokens as u64,
194 output_tokens: response.usage.completion_tokens as u64,
195 total_tokens: response.usage.total_tokens as u64,
196 cached_input_tokens: 0,
197 cache_creation_input_tokens: 0,
198 tool_use_prompt_tokens: 0,
199 reasoning_tokens: 0,
200 },
201 raw_response: response,
202 message_id: None,
203 }),
204 _ => Err(CompletionError::ResponseError(
205 "Response contained no assistant message".to_owned(),
206 )),
207 }
208 }
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212pub(super) struct PerplexityCompletionRequest {
213 model: String,
214 pub messages: Vec<Message>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub temperature: Option<f64>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 pub max_tokens: Option<u64>,
219 #[serde(flatten, skip_serializing_if = "Option::is_none")]
220 additional_params: Option<serde_json::Value>,
221 pub stream: bool,
222}
223
224impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
225 type Error = CompletionError;
226
227 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
228 if req.output_schema.is_some() {
229 tracing::warn!("Structured outputs currently not supported for Perplexity");
230 }
231 let model = req.model.clone().unwrap_or_else(|| model.to_string());
232 let mut partial_history = vec![];
233 if let Some(docs) = req.normalized_documents() {
234 partial_history.push(docs);
235 }
236 partial_history.extend(req.chat_history);
237
238 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
240 vec![Message {
241 role: Role::System,
242 content: preamble,
243 }]
244 });
245
246 full_history.extend(
248 partial_history
249 .into_iter()
250 .map(message::Message::try_into)
251 .collect::<Result<Vec<Message>, _>>()?,
252 );
253
254 Ok(Self {
255 model: model.to_string(),
256 messages: full_history,
257 temperature: req.temperature,
258 max_tokens: req.max_tokens,
259 additional_params: req.additional_params,
260 stream: false,
261 })
262 }
263}
264
265#[derive(Clone)]
266pub struct CompletionModel<T = reqwest::Client> {
267 client: Client<T>,
268 pub model: String,
269}
270
271impl<T> CompletionModel<T> {
272 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
273 Self {
274 client,
275 model: model.into(),
276 }
277 }
278}
279
280impl TryFrom<message::Message> for Message {
281 type Error = MessageError;
282
283 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
284 Ok(match message {
285 message::Message::System { content } => Message {
286 role: Role::System,
287 content,
288 },
289 message::Message::User { content } => {
290 let collapsed_content = content
291 .into_iter()
292 .map(|content| match content {
293 message::UserContent::Text(message::Text { text, .. }) => Ok(text),
294 _ => Err(MessageError::ConversionError(
295 "Only text content is supported by Perplexity".to_owned(),
296 )),
297 })
298 .collect::<Result<Vec<_>, _>>()?
299 .join("\n");
300
301 Message {
302 role: Role::User,
303 content: collapsed_content,
304 }
305 }
306
307 message::Message::Assistant { content, .. } => {
308 let collapsed_content = content
309 .into_iter()
310 .map(|content| {
311 Ok(match content {
312 message::AssistantContent::Text(message::Text { text, .. }) => text,
313 _ => return Err(MessageError::ConversionError(
314 "Only text assistant message content is supported by Perplexity"
315 .to_owned(),
316 )),
317 })
318 })
319 .collect::<Result<Vec<_>, _>>()?
320 .join("\n");
321
322 Message {
323 role: Role::Assistant,
324 content: collapsed_content,
325 }
326 }
327 })
328 }
329}
330
331impl From<Message> for message::Message {
332 fn from(message: Message) -> Self {
333 match message.role {
334 Role::User => message::Message::user(message.content),
335 Role::Assistant => message::Message::assistant(message.content),
336
337 Role::System => message::Message::user(message.content),
340 }
341 }
342}
343
344impl<T> completion::CompletionModel for CompletionModel<T>
345where
346 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
347{
348 type Response = CompletionResponse;
349 type StreamingResponse = openai::StreamingCompletionResponse;
350
351 type Client = Client<T>;
352
353 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
354 Self::new(client.clone(), model)
355 }
356
357 async fn completion(
358 &self,
359 completion_request: completion::CompletionRequest,
360 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
361 let span = if tracing::Span::current().is_disabled() {
362 info_span!(
363 target: "rig::completions",
364 "chat",
365 gen_ai.operation.name = "chat",
366 gen_ai.provider.name = "perplexity",
367 gen_ai.request.model = self.model,
368 gen_ai.system_instructions = tracing::field::Empty,
369 gen_ai.response.id = tracing::field::Empty,
370 gen_ai.response.model = tracing::field::Empty,
371 gen_ai.usage.output_tokens = tracing::field::Empty,
372 gen_ai.usage.input_tokens = tracing::field::Empty,
373 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
374 )
375 } else {
376 tracing::Span::current()
377 };
378
379 span.record("gen_ai.system_instructions", &completion_request.preamble);
380
381 if completion_request.tool_choice.is_some() {
382 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
383 }
384
385 if !completion_request.tools.is_empty() {
386 tracing::warn!("WARNING: `tools` not supported on Perplexity");
387 }
388 let request =
389 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
390
391 if tracing::enabled!(tracing::Level::TRACE) {
392 tracing::trace!(target: "rig::completions",
393 "Perplexity completion request: {}",
394 serde_json::to_string_pretty(&request)?
395 );
396 }
397
398 let body = serde_json::to_vec(&request)?;
399
400 let req = self
401 .client
402 .post("/v1/chat/completions")?
403 .body(body)
404 .map_err(http_client::Error::from)?;
405
406 let async_block = async move {
407 let response = self.client.send::<_, Bytes>(req).await?;
408
409 let status = response.status();
410 let response_body = response.into_body().into_future().await?.to_vec();
411
412 if status.is_success() {
413 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
414 ApiResponse::Ok(response) => {
415 let span = tracing::Span::current();
416 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
417 span.record(
418 "gen_ai.usage.output_tokens",
419 response.usage.completion_tokens,
420 );
421 span.record("gen_ai.response.id", response.id.to_string());
422 span.record("gen_ai.response.model", response.model.to_string());
423 if tracing::enabled!(tracing::Level::TRACE) {
424 tracing::trace!(target: "rig::responses",
425 "Perplexity completion response: {}",
426 serde_json::to_string_pretty(&response)?
427 );
428 }
429 Ok(response.try_into()?)
430 }
431 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
432 }
433 } else {
434 Err(CompletionError::ProviderError(
435 String::from_utf8_lossy(&response_body).to_string(),
436 ))
437 }
438 };
439
440 async_block.instrument(span).await
441 }
442
443 async fn stream(
444 &self,
445 completion_request: completion::CompletionRequest,
446 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
447 let span = if tracing::Span::current().is_disabled() {
448 info_span!(
449 target: "rig::completions",
450 "chat_streaming",
451 gen_ai.operation.name = "chat_streaming",
452 gen_ai.provider.name = "perplexity",
453 gen_ai.request.model = self.model,
454 gen_ai.system_instructions = tracing::field::Empty,
455 gen_ai.response.id = tracing::field::Empty,
456 gen_ai.response.model = tracing::field::Empty,
457 gen_ai.usage.output_tokens = tracing::field::Empty,
458 gen_ai.usage.input_tokens = tracing::field::Empty,
459 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
460 )
461 } else {
462 tracing::Span::current()
463 };
464
465 span.record("gen_ai.system_instructions", &completion_request.preamble);
466
467 if completion_request.tool_choice.is_some() {
468 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
469 }
470
471 if !completion_request.tools.is_empty() {
472 tracing::warn!("WARNING: `tools` not supported on Perplexity");
473 }
474
475 let mut request =
476 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
477 request.stream = true;
478
479 if tracing::enabled!(tracing::Level::TRACE) {
480 tracing::trace!(target: "rig::completions",
481 "Perplexity streaming completion request: {}",
482 serde_json::to_string_pretty(&request)?
483 );
484 }
485
486 let body = serde_json::to_vec(&request)?;
487
488 let req = self
489 .client
490 .post("/chat/completions")?
491 .body(body)
492 .map_err(http_client::Error::from)?;
493
494 send_compatible_streaming_request(self.client.clone(), req)
495 .instrument(span)
496 .await
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_deserialize_message() {
506 let json_data = r#"
507 {
508 "role": "user",
509 "content": "Hello, how can I help you?"
510 }
511 "#;
512
513 let message: Message = serde_json::from_str(json_data).unwrap();
514 assert_eq!(message.role, Role::User);
515 assert_eq!(message.content, "Hello, how can I help you?");
516 }
517
518 #[test]
519 fn test_serialize_message() {
520 let message = Message {
521 role: Role::Assistant,
522 content: "I am here to assist you.".to_string(),
523 };
524
525 let json_data = serde_json::to_string(&message).unwrap();
526 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
527 assert_eq!(json_data, expected_json);
528 }
529
530 #[test]
531 fn test_message_to_message_conversion() {
532 let user_message = message::Message::user("User message");
533 let assistant_message = message::Message::assistant("Assistant message");
534
535 let converted_user_message: Message = user_message.clone().try_into().unwrap();
536 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
537
538 assert_eq!(converted_user_message.role, Role::User);
539 assert_eq!(converted_user_message.content, "User message");
540
541 assert_eq!(converted_assistant_message.role, Role::Assistant);
542 assert_eq!(converted_assistant_message.content, "Assistant message");
543
544 let back_to_user_message: message::Message = converted_user_message.into();
545 let back_to_assistant_message: message::Message = converted_assistant_message.into();
546
547 assert_eq!(user_message, back_to_user_message);
548 assert_eq!(assistant_message, back_to_assistant_message);
549 }
550 #[test]
551 fn test_client_initialization() {
552 let _client =
553 crate::providers::perplexity::Client::new("dummy-key").expect("Client::new() failed");
554 let _client_from_builder = crate::providers::perplexity::Client::builder()
555 .api_key("dummy-key")
556 .build()
557 .expect("Client::builder() failed");
558 }
559}