Skip to main content

rig_core/providers/
perplexity.rs

1//! Perplexity API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::perplexity};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = perplexity::Client::new("YOUR_API_KEY")?;
9//!
10//! let sonar = client.completion_model(perplexity::SONAR);
11//! # Ok(())
12//! # }
13//! ```
14use crate::client::BearerAuth;
15use crate::completion::CompletionRequest;
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20    OneOrMany,
21    client::{
22        self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
23    },
24    completion::{self, CompletionError, MessageError, message},
25    http_client::{self, HttpClientExt},
26};
27use bytes::Bytes;
28use serde::{Deserialize, Serialize};
29use tracing::{Instrument, info_span};
30
31// ================================================================
32// Main Cohere Client
33// ================================================================
34const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityExt;
38
39#[derive(Debug, Default, Clone, Copy)]
40pub struct PerplexityBuilder;
41
42type PerplexityApiKey = BearerAuth;
43
44impl Provider for PerplexityExt {
45    type Builder = PerplexityBuilder;
46
47    // There is currently no way to verify a perplexity api key without consuming tokens
48    const VERIFY_PATH: &'static str = "";
49}
50
51impl<H> Capabilities<H> for PerplexityExt {
52    type Completion = Capable<CompletionModel<H>>;
53    type Transcription = Nothing;
54    type Embeddings = Nothing;
55    type ModelListing = Nothing;
56    #[cfg(feature = "image")]
57    type ImageGeneration = Nothing;
58
59    #[cfg(feature = "audio")]
60    type AudioGeneration = Nothing;
61}
62
63impl DebugExt for PerplexityExt {}
64
65impl ProviderBuilder for PerplexityBuilder {
66    type Extension<H>
67        = PerplexityExt
68    where
69        H: HttpClientExt;
70    type ApiKey = PerplexityApiKey;
71
72    const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
73
74    fn build<H>(
75        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
76    ) -> http_client::Result<Self::Extension<H>>
77    where
78        H: HttpClientExt,
79    {
80        Ok(PerplexityExt)
81    }
82}
83
84pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
85pub type ClientBuilder<H = crate::markers::Missing> =
86    client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
87
88impl ProviderClient for Client {
89    type Input = String;
90    type Error = crate::client::ProviderClientError;
91
92    /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable.
93    fn from_env() -> Result<Self, Self::Error> {
94        let api_key = crate::client::required_env_var("PERPLEXITY_API_KEY")?;
95        Self::new(&api_key).map_err(Into::into)
96    }
97
98    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
99        Self::new(&input).map_err(Into::into)
100    }
101}
102
103#[derive(Debug, Deserialize)]
104struct ApiErrorResponse {
105    message: String,
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(untagged)]
110enum ApiResponse<T> {
111    Ok(T),
112    Err(ApiErrorResponse),
113}
114
115// ================================================================
116// Perplexity Completion API
117// ================================================================
118
119pub const SONAR_PRO: &str = "sonar_pro";
120pub const SONAR: &str = "sonar";
121
122#[derive(Debug, Deserialize, Serialize)]
123pub struct CompletionResponse {
124    pub id: String,
125    pub model: String,
126    pub object: String,
127    pub created: u64,
128    #[serde(default)]
129    pub choices: Vec<Choice>,
130    pub usage: Usage,
131}
132
133#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
134pub struct Message {
135    pub role: Role,
136    pub content: String,
137}
138
139#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
140#[serde(rename_all = "lowercase")]
141pub enum Role {
142    System,
143    User,
144    Assistant,
145}
146
147#[derive(Deserialize, Debug, Serialize)]
148pub struct Delta {
149    pub role: Role,
150    pub content: String,
151}
152
153#[derive(Deserialize, Debug, Serialize)]
154pub struct Choice {
155    pub index: usize,
156    pub finish_reason: String,
157    pub message: Message,
158    pub delta: Delta,
159}
160
161#[derive(Deserialize, Debug, Serialize)]
162pub struct Usage {
163    pub prompt_tokens: u32,
164    pub completion_tokens: u32,
165    pub total_tokens: u32,
166}
167
168impl std::fmt::Display for Usage {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        write!(
171            f,
172            "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
173            self.prompt_tokens, self.completion_tokens, self.total_tokens
174        )
175    }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179    type Error = CompletionError;
180
181    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182        let choice = response.choices.first().ok_or_else(|| {
183            CompletionError::ResponseError("Response contained no choices".to_owned())
184        })?;
185
186        match &choice.message {
187            Message {
188                role: Role::Assistant,
189                content,
190            } => Ok(completion::CompletionResponse {
191                choice: OneOrMany::one(content.clone().into()),
192                usage: completion::Usage {
193                    input_tokens: response.usage.prompt_tokens as u64,
194                    output_tokens: response.usage.completion_tokens as u64,
195                    total_tokens: response.usage.total_tokens as u64,
196                    cached_input_tokens: 0,
197                    cache_creation_input_tokens: 0,
198                    tool_use_prompt_tokens: 0,
199                    reasoning_tokens: 0,
200                },
201                raw_response: response,
202                message_id: None,
203            }),
204            _ => Err(CompletionError::ResponseError(
205                "Response contained no assistant message".to_owned(),
206            )),
207        }
208    }
209}
210
211#[derive(Debug, Serialize, Deserialize)]
212pub(super) struct PerplexityCompletionRequest {
213    model: String,
214    pub messages: Vec<Message>,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub temperature: Option<f64>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub max_tokens: Option<u64>,
219    #[serde(flatten, skip_serializing_if = "Option::is_none")]
220    additional_params: Option<serde_json::Value>,
221    pub stream: bool,
222}
223
224impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
225    type Error = CompletionError;
226
227    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
228        if req.output_schema.is_some() {
229            tracing::warn!("Structured outputs currently not supported for Perplexity");
230        }
231        let model = req.model.clone().unwrap_or_else(|| model.to_string());
232        let mut partial_history = vec![];
233        if let Some(docs) = req.normalized_documents() {
234            partial_history.push(docs);
235        }
236        partial_history.extend(req.chat_history);
237
238        // Initialize full history with preamble (or empty if non-existent)
239        let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
240            vec![Message {
241                role: Role::System,
242                content: preamble,
243            }]
244        });
245
246        // Convert and extend the rest of the history
247        full_history.extend(
248            partial_history
249                .into_iter()
250                .map(message::Message::try_into)
251                .collect::<Result<Vec<Message>, _>>()?,
252        );
253
254        Ok(Self {
255            model: model.to_string(),
256            messages: full_history,
257            temperature: req.temperature,
258            max_tokens: req.max_tokens,
259            additional_params: req.additional_params,
260            stream: false,
261        })
262    }
263}
264
265#[derive(Clone)]
266pub struct CompletionModel<T = reqwest::Client> {
267    client: Client<T>,
268    pub model: String,
269}
270
271impl<T> CompletionModel<T> {
272    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
273        Self {
274            client,
275            model: model.into(),
276        }
277    }
278}
279
280impl TryFrom<message::Message> for Message {
281    type Error = MessageError;
282
283    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
284        Ok(match message {
285            message::Message::System { content } => Message {
286                role: Role::System,
287                content,
288            },
289            message::Message::User { content } => {
290                let collapsed_content = content
291                    .into_iter()
292                    .map(|content| match content {
293                        message::UserContent::Text(message::Text { text, .. }) => Ok(text),
294                        _ => Err(MessageError::ConversionError(
295                            "Only text content is supported by Perplexity".to_owned(),
296                        )),
297                    })
298                    .collect::<Result<Vec<_>, _>>()?
299                    .join("\n");
300
301                Message {
302                    role: Role::User,
303                    content: collapsed_content,
304                }
305            }
306
307            message::Message::Assistant { content, .. } => {
308                let collapsed_content = content
309                    .into_iter()
310                    .map(|content| {
311                        Ok(match content {
312                            message::AssistantContent::Text(message::Text { text, .. }) => text,
313                            _ => return Err(MessageError::ConversionError(
314                                "Only text assistant message content is supported by Perplexity"
315                                    .to_owned(),
316                            )),
317                        })
318                    })
319                    .collect::<Result<Vec<_>, _>>()?
320                    .join("\n");
321
322                Message {
323                    role: Role::Assistant,
324                    content: collapsed_content,
325                }
326            }
327        })
328    }
329}
330
331impl From<Message> for message::Message {
332    fn from(message: Message) -> Self {
333        match message.role {
334            Role::User => message::Message::user(message.content),
335            Role::Assistant => message::Message::assistant(message.content),
336
337            // System messages get coerced into user messages for ease of error handling.
338            // They should be handled on the outside of `Message` conversions via the preamble.
339            Role::System => message::Message::user(message.content),
340        }
341    }
342}
343
344impl<T> completion::CompletionModel for CompletionModel<T>
345where
346    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
347{
348    type Response = CompletionResponse;
349    type StreamingResponse = openai::StreamingCompletionResponse;
350
351    type Client = Client<T>;
352
353    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
354        Self::new(client.clone(), model)
355    }
356
357    async fn completion(
358        &self,
359        completion_request: completion::CompletionRequest,
360    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
361        let span = if tracing::Span::current().is_disabled() {
362            info_span!(
363                target: "rig::completions",
364                "chat",
365                gen_ai.operation.name = "chat",
366                gen_ai.provider.name = "perplexity",
367                gen_ai.request.model = self.model,
368                gen_ai.system_instructions = tracing::field::Empty,
369                gen_ai.response.id = tracing::field::Empty,
370                gen_ai.response.model = tracing::field::Empty,
371                gen_ai.usage.output_tokens = tracing::field::Empty,
372                gen_ai.usage.input_tokens = tracing::field::Empty,
373                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
374            )
375        } else {
376            tracing::Span::current()
377        };
378
379        span.record("gen_ai.system_instructions", &completion_request.preamble);
380
381        if completion_request.tool_choice.is_some() {
382            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
383        }
384
385        if !completion_request.tools.is_empty() {
386            tracing::warn!("WARNING: `tools` not supported on Perplexity");
387        }
388        let request =
389            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
390
391        if tracing::enabled!(tracing::Level::TRACE) {
392            tracing::trace!(target: "rig::completions",
393                "Perplexity completion request: {}",
394                serde_json::to_string_pretty(&request)?
395            );
396        }
397
398        let body = serde_json::to_vec(&request)?;
399
400        let req = self
401            .client
402            .post("/v1/chat/completions")?
403            .body(body)
404            .map_err(http_client::Error::from)?;
405
406        let async_block = async move {
407            let response = self.client.send::<_, Bytes>(req).await?;
408
409            let status = response.status();
410            let response_body = response.into_body().into_future().await?.to_vec();
411
412            if status.is_success() {
413                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
414                    ApiResponse::Ok(response) => {
415                        let span = tracing::Span::current();
416                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
417                        span.record(
418                            "gen_ai.usage.output_tokens",
419                            response.usage.completion_tokens,
420                        );
421                        span.record("gen_ai.response.id", response.id.to_string());
422                        span.record("gen_ai.response.model", response.model.to_string());
423                        if tracing::enabled!(tracing::Level::TRACE) {
424                            tracing::trace!(target: "rig::responses",
425                                "Perplexity completion response: {}",
426                                serde_json::to_string_pretty(&response)?
427                            );
428                        }
429                        Ok(response.try_into()?)
430                    }
431                    ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
432                }
433            } else {
434                Err(CompletionError::ProviderError(
435                    String::from_utf8_lossy(&response_body).to_string(),
436                ))
437            }
438        };
439
440        async_block.instrument(span).await
441    }
442
443    async fn stream(
444        &self,
445        completion_request: completion::CompletionRequest,
446    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
447        let span = if tracing::Span::current().is_disabled() {
448            info_span!(
449                target: "rig::completions",
450                "chat_streaming",
451                gen_ai.operation.name = "chat_streaming",
452                gen_ai.provider.name = "perplexity",
453                gen_ai.request.model = self.model,
454                gen_ai.system_instructions = tracing::field::Empty,
455                gen_ai.response.id = tracing::field::Empty,
456                gen_ai.response.model = tracing::field::Empty,
457                gen_ai.usage.output_tokens = tracing::field::Empty,
458                gen_ai.usage.input_tokens = tracing::field::Empty,
459                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
460            )
461        } else {
462            tracing::Span::current()
463        };
464
465        span.record("gen_ai.system_instructions", &completion_request.preamble);
466
467        if completion_request.tool_choice.is_some() {
468            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
469        }
470
471        if !completion_request.tools.is_empty() {
472            tracing::warn!("WARNING: `tools` not supported on Perplexity");
473        }
474
475        let mut request =
476            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
477        request.stream = true;
478
479        if tracing::enabled!(tracing::Level::TRACE) {
480            tracing::trace!(target: "rig::completions",
481                "Perplexity streaming completion request: {}",
482                serde_json::to_string_pretty(&request)?
483            );
484        }
485
486        let body = serde_json::to_vec(&request)?;
487
488        let req = self
489            .client
490            .post("/chat/completions")?
491            .body(body)
492            .map_err(http_client::Error::from)?;
493
494        send_compatible_streaming_request(self.client.clone(), req)
495            .instrument(span)
496            .await
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_deserialize_message() {
506        let json_data = r#"
507        {
508            "role": "user",
509            "content": "Hello, how can I help you?"
510        }
511        "#;
512
513        let message: Message = serde_json::from_str(json_data).unwrap();
514        assert_eq!(message.role, Role::User);
515        assert_eq!(message.content, "Hello, how can I help you?");
516    }
517
518    #[test]
519    fn test_serialize_message() {
520        let message = Message {
521            role: Role::Assistant,
522            content: "I am here to assist you.".to_string(),
523        };
524
525        let json_data = serde_json::to_string(&message).unwrap();
526        let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
527        assert_eq!(json_data, expected_json);
528    }
529
530    #[test]
531    fn test_message_to_message_conversion() {
532        let user_message = message::Message::user("User message");
533        let assistant_message = message::Message::assistant("Assistant message");
534
535        let converted_user_message: Message = user_message.clone().try_into().unwrap();
536        let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
537
538        assert_eq!(converted_user_message.role, Role::User);
539        assert_eq!(converted_user_message.content, "User message");
540
541        assert_eq!(converted_assistant_message.role, Role::Assistant);
542        assert_eq!(converted_assistant_message.content, "Assistant message");
543
544        let back_to_user_message: message::Message = converted_user_message.into();
545        let back_to_assistant_message: message::Message = converted_assistant_message.into();
546
547        assert_eq!(user_message, back_to_user_message);
548        assert_eq!(assistant_message, back_to_assistant_message);
549    }
550    #[test]
551    fn test_client_initialization() {
552        let _client =
553            crate::providers::perplexity::Client::new("dummy-key").expect("Client::new() failed");
554        let _client_from_builder = crate::providers::perplexity::Client::builder()
555            .api_key("dummy-key")
556            .build()
557            .expect("Client::builder() failed");
558    }
559}