Skip to main content

rig/providers/
perplexity.rs

1//! Perplexity API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::perplexity;
6//!
7//! let client = perplexity::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_sonar_small_online = client.completion_model(perplexity::LLAMA_3_1_SONAR_SMALL_ONLINE);
10//! ```
11use crate::client::BearerAuth;
12use crate::completion::CompletionRequest;
13use crate::providers::openai;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    OneOrMany,
18    client::{
19        self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
20    },
21    completion::{self, CompletionError, MessageError, message},
22    http_client::{self, HttpClientExt},
23};
24use bytes::Bytes;
25use serde::{Deserialize, Serialize};
26use tracing::{Instrument, info_span};
27
28// ================================================================
29// Main Cohere Client
30// ================================================================
31const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33#[derive(Debug, Default, Clone, Copy)]
34pub struct PerplexityExt;
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityBuilder;
38
39type PerplexityApiKey = BearerAuth;
40
41impl Provider for PerplexityExt {
42    type Builder = PerplexityBuilder;
43
44    // There is currently no way to verify a perplexity api key without consuming tokens
45    const VERIFY_PATH: &'static str = "";
46}
47
48impl<H> Capabilities<H> for PerplexityExt {
49    type Completion = Capable<CompletionModel<H>>;
50    type Transcription = Nothing;
51    type Embeddings = Nothing;
52    type ModelListing = Nothing;
53    #[cfg(feature = "image")]
54    type ImageGeneration = Nothing;
55
56    #[cfg(feature = "audio")]
57    type AudioGeneration = Nothing;
58}
59
60impl DebugExt for PerplexityExt {}
61
62impl ProviderBuilder for PerplexityBuilder {
63    type Extension<H>
64        = PerplexityExt
65    where
66        H: HttpClientExt;
67    type ApiKey = PerplexityApiKey;
68
69    const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
70
71    fn build<H>(
72        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
73    ) -> http_client::Result<Self::Extension<H>>
74    where
75        H: HttpClientExt,
76    {
77        Ok(PerplexityExt)
78    }
79}
80
81pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
82pub type ClientBuilder<H = reqwest::Client> =
83    client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
84
85impl ProviderClient for Client {
86    type Input = String;
87
88    /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable.
89    /// Panics if the environment variable is not set.
90    fn from_env() -> Self {
91        let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
92        Self::new(&api_key).unwrap()
93    }
94
95    fn from_val(input: Self::Input) -> Self {
96        Self::new(&input).unwrap()
97    }
98}
99
100#[derive(Debug, Deserialize)]
101struct ApiErrorResponse {
102    message: String,
103}
104
105#[derive(Debug, Deserialize)]
106#[serde(untagged)]
107enum ApiResponse<T> {
108    Ok(T),
109    Err(ApiErrorResponse),
110}
111
112// ================================================================
113// Perplexity Completion API
114// ================================================================
115
116pub const SONAR_PRO: &str = "sonar_pro";
117pub const SONAR: &str = "sonar";
118
119#[derive(Debug, Deserialize, Serialize)]
120pub struct CompletionResponse {
121    pub id: String,
122    pub model: String,
123    pub object: String,
124    pub created: u64,
125    #[serde(default)]
126    pub choices: Vec<Choice>,
127    pub usage: Usage,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131pub struct Message {
132    pub role: Role,
133    pub content: String,
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139    System,
140    User,
141    Assistant,
142}
143
144#[derive(Deserialize, Debug, Serialize)]
145pub struct Delta {
146    pub role: Role,
147    pub content: String,
148}
149
150#[derive(Deserialize, Debug, Serialize)]
151pub struct Choice {
152    pub index: usize,
153    pub finish_reason: String,
154    pub message: Message,
155    pub delta: Delta,
156}
157
158#[derive(Deserialize, Debug, Serialize)]
159pub struct Usage {
160    pub prompt_tokens: u32,
161    pub completion_tokens: u32,
162    pub total_tokens: u32,
163}
164
165impl std::fmt::Display for Usage {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(
168            f,
169            "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
170            self.prompt_tokens, self.completion_tokens, self.total_tokens
171        )
172    }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176    type Error = CompletionError;
177
178    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179        let choice = response.choices.first().ok_or_else(|| {
180            CompletionError::ResponseError("Response contained no choices".to_owned())
181        })?;
182
183        match &choice.message {
184            Message {
185                role: Role::Assistant,
186                content,
187            } => Ok(completion::CompletionResponse {
188                choice: OneOrMany::one(content.clone().into()),
189                usage: completion::Usage {
190                    input_tokens: response.usage.prompt_tokens as u64,
191                    output_tokens: response.usage.completion_tokens as u64,
192                    total_tokens: response.usage.total_tokens as u64,
193                    cached_input_tokens: 0,
194                    cache_creation_input_tokens: 0,
195                },
196                raw_response: response,
197                message_id: None,
198            }),
199            _ => Err(CompletionError::ResponseError(
200                "Response contained no assistant message".to_owned(),
201            )),
202        }
203    }
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207pub(super) struct PerplexityCompletionRequest {
208    model: String,
209    pub messages: Vec<Message>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub temperature: Option<f64>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub max_tokens: Option<u64>,
214    #[serde(flatten, skip_serializing_if = "Option::is_none")]
215    additional_params: Option<serde_json::Value>,
216    pub stream: bool,
217}
218
219impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
220    type Error = CompletionError;
221
222    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
223        if req.output_schema.is_some() {
224            tracing::warn!("Structured outputs currently not supported for Perplexity");
225        }
226        let model = req.model.clone().unwrap_or_else(|| model.to_string());
227        let mut partial_history = vec![];
228        if let Some(docs) = req.normalized_documents() {
229            partial_history.push(docs);
230        }
231        partial_history.extend(req.chat_history);
232
233        // Initialize full history with preamble (or empty if non-existent)
234        let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
235            vec![Message {
236                role: Role::System,
237                content: preamble,
238            }]
239        });
240
241        // Convert and extend the rest of the history
242        full_history.extend(
243            partial_history
244                .into_iter()
245                .map(message::Message::try_into)
246                .collect::<Result<Vec<Message>, _>>()?,
247        );
248
249        Ok(Self {
250            model: model.to_string(),
251            messages: full_history,
252            temperature: req.temperature,
253            max_tokens: req.max_tokens,
254            additional_params: req.additional_params,
255            stream: false,
256        })
257    }
258}
259
260#[derive(Clone)]
261pub struct CompletionModel<T = reqwest::Client> {
262    client: Client<T>,
263    pub model: String,
264}
265
266impl<T> CompletionModel<T> {
267    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
268        Self {
269            client,
270            model: model.into(),
271        }
272    }
273}
274
275impl TryFrom<message::Message> for Message {
276    type Error = MessageError;
277
278    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
279        Ok(match message {
280            message::Message::System { content } => Message {
281                role: Role::System,
282                content,
283            },
284            message::Message::User { content } => {
285                let collapsed_content = content
286                    .into_iter()
287                    .map(|content| match content {
288                        message::UserContent::Text(message::Text { text }) => Ok(text),
289                        _ => Err(MessageError::ConversionError(
290                            "Only text content is supported by Perplexity".to_owned(),
291                        )),
292                    })
293                    .collect::<Result<Vec<_>, _>>()?
294                    .join("\n");
295
296                Message {
297                    role: Role::User,
298                    content: collapsed_content,
299                }
300            }
301
302            message::Message::Assistant { content, .. } => {
303                let collapsed_content = content
304                    .into_iter()
305                    .map(|content| {
306                        Ok(match content {
307                            message::AssistantContent::Text(message::Text { text }) => text,
308                            _ => return Err(MessageError::ConversionError(
309                                "Only text assistant message content is supported by Perplexity"
310                                    .to_owned(),
311                            )),
312                        })
313                    })
314                    .collect::<Result<Vec<_>, _>>()?
315                    .join("\n");
316
317                Message {
318                    role: Role::Assistant,
319                    content: collapsed_content,
320                }
321            }
322        })
323    }
324}
325
326impl From<Message> for message::Message {
327    fn from(message: Message) -> Self {
328        match message.role {
329            Role::User => message::Message::user(message.content),
330            Role::Assistant => message::Message::assistant(message.content),
331
332            // System messages get coerced into user messages for ease of error handling.
333            // They should be handled on the outside of `Message` conversions via the preamble.
334            Role::System => message::Message::user(message.content),
335        }
336    }
337}
338
339impl<T> completion::CompletionModel for CompletionModel<T>
340where
341    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
342{
343    type Response = CompletionResponse;
344    type StreamingResponse = openai::StreamingCompletionResponse;
345
346    type Client = Client<T>;
347
348    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
349        Self::new(client.clone(), model)
350    }
351
352    async fn completion(
353        &self,
354        completion_request: completion::CompletionRequest,
355    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
356        let span = if tracing::Span::current().is_disabled() {
357            info_span!(
358                target: "rig::completions",
359                "chat",
360                gen_ai.operation.name = "chat",
361                gen_ai.provider.name = "perplexity",
362                gen_ai.request.model = self.model,
363                gen_ai.system_instructions = tracing::field::Empty,
364                gen_ai.response.id = tracing::field::Empty,
365                gen_ai.response.model = tracing::field::Empty,
366                gen_ai.usage.output_tokens = tracing::field::Empty,
367                gen_ai.usage.input_tokens = tracing::field::Empty,
368                gen_ai.usage.cached_tokens = tracing::field::Empty,
369            )
370        } else {
371            tracing::Span::current()
372        };
373
374        span.record("gen_ai.system_instructions", &completion_request.preamble);
375
376        if completion_request.tool_choice.is_some() {
377            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
378        }
379
380        if !completion_request.tools.is_empty() {
381            tracing::warn!("WARNING: `tools` not supported on Perplexity");
382        }
383        let request =
384            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
385
386        if tracing::enabled!(tracing::Level::TRACE) {
387            tracing::trace!(target: "rig::completions",
388                "Perplexity completion request: {}",
389                serde_json::to_string_pretty(&request)?
390            );
391        }
392
393        let body = serde_json::to_vec(&request)?;
394
395        let req = self
396            .client
397            .post("/v1/chat/completions")?
398            .body(body)
399            .map_err(http_client::Error::from)?;
400
401        let async_block = async move {
402            let response = self.client.send::<_, Bytes>(req).await?;
403
404            let status = response.status();
405            let response_body = response.into_body().into_future().await?.to_vec();
406
407            if status.is_success() {
408                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
409                    ApiResponse::Ok(response) => {
410                        let span = tracing::Span::current();
411                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
412                        span.record(
413                            "gen_ai.usage.output_tokens",
414                            response.usage.completion_tokens,
415                        );
416                        span.record("gen_ai.response.id", response.id.to_string());
417                        span.record("gen_ai.response.model_name", response.model.to_string());
418                        if tracing::enabled!(tracing::Level::TRACE) {
419                            tracing::trace!(target: "rig::responses",
420                                "Perplexity completion response: {}",
421                                serde_json::to_string_pretty(&response)?
422                            );
423                        }
424                        Ok(response.try_into()?)
425                    }
426                    ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
427                }
428            } else {
429                Err(CompletionError::ProviderError(
430                    String::from_utf8_lossy(&response_body).to_string(),
431                ))
432            }
433        };
434
435        async_block.instrument(span).await
436    }
437
438    async fn stream(
439        &self,
440        completion_request: completion::CompletionRequest,
441    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
442        let span = if tracing::Span::current().is_disabled() {
443            info_span!(
444                target: "rig::completions",
445                "chat_streaming",
446                gen_ai.operation.name = "chat_streaming",
447                gen_ai.provider.name = "perplexity",
448                gen_ai.request.model = self.model,
449                gen_ai.system_instructions = tracing::field::Empty,
450                gen_ai.response.id = tracing::field::Empty,
451                gen_ai.response.model = tracing::field::Empty,
452                gen_ai.usage.output_tokens = tracing::field::Empty,
453                gen_ai.usage.input_tokens = tracing::field::Empty,
454                gen_ai.usage.cached_tokens = tracing::field::Empty,
455            )
456        } else {
457            tracing::Span::current()
458        };
459
460        span.record("gen_ai.system_instructions", &completion_request.preamble);
461
462        if completion_request.tool_choice.is_some() {
463            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
464        }
465
466        if !completion_request.tools.is_empty() {
467            tracing::warn!("WARNING: `tools` not supported on Perplexity");
468        }
469
470        let mut request =
471            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
472        request.stream = true;
473
474        if tracing::enabled!(tracing::Level::TRACE) {
475            tracing::trace!(target: "rig::completions",
476                "Perplexity streaming completion request: {}",
477                serde_json::to_string_pretty(&request)?
478            );
479        }
480
481        let body = serde_json::to_vec(&request)?;
482
483        let req = self
484            .client
485            .post("/chat/completions")?
486            .body(body)
487            .map_err(http_client::Error::from)?;
488
489        send_compatible_streaming_request(self.client.clone(), req)
490            .instrument(span)
491            .await
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_deserialize_message() {
501        let json_data = r#"
502        {
503            "role": "user",
504            "content": "Hello, how can I help you?"
505        }
506        "#;
507
508        let message: Message = serde_json::from_str(json_data).unwrap();
509        assert_eq!(message.role, Role::User);
510        assert_eq!(message.content, "Hello, how can I help you?");
511    }
512
513    #[test]
514    fn test_serialize_message() {
515        let message = Message {
516            role: Role::Assistant,
517            content: "I am here to assist you.".to_string(),
518        };
519
520        let json_data = serde_json::to_string(&message).unwrap();
521        let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
522        assert_eq!(json_data, expected_json);
523    }
524
525    #[test]
526    fn test_message_to_message_conversion() {
527        let user_message = message::Message::user("User message");
528        let assistant_message = message::Message::assistant("Assistant message");
529
530        let converted_user_message: Message = user_message.clone().try_into().unwrap();
531        let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
532
533        assert_eq!(converted_user_message.role, Role::User);
534        assert_eq!(converted_user_message.content, "User message");
535
536        assert_eq!(converted_assistant_message.role, Role::Assistant);
537        assert_eq!(converted_assistant_message.content, "Assistant message");
538
539        let back_to_user_message: message::Message = converted_user_message.into();
540        let back_to_assistant_message: message::Message = converted_assistant_message.into();
541
542        assert_eq!(user_message, back_to_user_message);
543        assert_eq!(assistant_message, back_to_assistant_message);
544    }
545    #[test]
546    fn test_client_initialization() {
547        let _client =
548            crate::providers::perplexity::Client::new("dummy-key").expect("Client::new() failed");
549        let _client_from_builder = crate::providers::perplexity::Client::builder()
550            .api_key("dummy-key")
551            .build()
552            .expect("Client::builder() failed");
553    }
554}