Skip to main content

rig/providers/
perplexity.rs

1//! Perplexity API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::perplexity;
6//!
7//! let client = perplexity::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_sonar_small_online = client.completion_model(perplexity::LLAMA_3_1_SONAR_SMALL_ONLINE);
10//! ```
11use crate::client::BearerAuth;
12use crate::completion::CompletionRequest;
13use crate::providers::openai;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    OneOrMany,
18    client::{
19        self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
20    },
21    completion::{self, CompletionError, MessageError, message},
22    http_client::{self, HttpClientExt},
23};
24use bytes::Bytes;
25use serde::{Deserialize, Serialize};
26use tracing::{Instrument, info_span};
27
28// ================================================================
29// Main Cohere Client
30// ================================================================
31const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33#[derive(Debug, Default, Clone, Copy)]
34pub struct PerplexityExt;
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityBuilder;
38
39type PerplexityApiKey = BearerAuth;
40
41impl Provider for PerplexityExt {
42    type Builder = PerplexityBuilder;
43
44    // There is currently no way to verify a perplexity api key without consuming tokens
45    const VERIFY_PATH: &'static str = "";
46
47    fn build<H>(
48        _: &crate::client::ClientBuilder<
49            Self::Builder,
50            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
51            H,
52        >,
53    ) -> http_client::Result<Self> {
54        Ok(Self)
55    }
56}
57
58impl<H> Capabilities<H> for PerplexityExt {
59    type Completion = Capable<CompletionModel<H>>;
60    type Transcription = Nothing;
61    type Embeddings = Nothing;
62    #[cfg(feature = "image")]
63    type ImageGeneration = Nothing;
64
65    #[cfg(feature = "audio")]
66    type AudioGeneration = Nothing;
67}
68
69impl DebugExt for PerplexityExt {}
70
71impl ProviderBuilder for PerplexityBuilder {
72    type Output = PerplexityExt;
73    type ApiKey = PerplexityApiKey;
74
75    const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
76}
77
78pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
79pub type ClientBuilder<H = reqwest::Client> =
80    client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
81
82impl ProviderClient for Client {
83    type Input = String;
84
85    /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable.
86    /// Panics if the environment variable is not set.
87    fn from_env() -> Self {
88        let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
89        Self::new(&api_key).unwrap()
90    }
91
92    fn from_val(input: Self::Input) -> Self {
93        Self::new(&input).unwrap()
94    }
95}
96
97#[derive(Debug, Deserialize)]
98struct ApiErrorResponse {
99    message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104enum ApiResponse<T> {
105    Ok(T),
106    Err(ApiErrorResponse),
107}
108
109// ================================================================
110// Perplexity Completion API
111// ================================================================
112
113pub const SONAR_PRO: &str = "sonar_pro";
114pub const SONAR: &str = "sonar";
115
116#[derive(Debug, Deserialize, Serialize)]
117pub struct CompletionResponse {
118    pub id: String,
119    pub model: String,
120    pub object: String,
121    pub created: u64,
122    #[serde(default)]
123    pub choices: Vec<Choice>,
124    pub usage: Usage,
125}
126
127#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
128pub struct Message {
129    pub role: Role,
130    pub content: String,
131}
132
133#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
134#[serde(rename_all = "lowercase")]
135pub enum Role {
136    System,
137    User,
138    Assistant,
139}
140
141#[derive(Deserialize, Debug, Serialize)]
142pub struct Delta {
143    pub role: Role,
144    pub content: String,
145}
146
147#[derive(Deserialize, Debug, Serialize)]
148pub struct Choice {
149    pub index: usize,
150    pub finish_reason: String,
151    pub message: Message,
152    pub delta: Delta,
153}
154
155#[derive(Deserialize, Debug, Serialize)]
156pub struct Usage {
157    pub prompt_tokens: u32,
158    pub completion_tokens: u32,
159    pub total_tokens: u32,
160}
161
162impl std::fmt::Display for Usage {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(
165            f,
166            "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
167            self.prompt_tokens, self.completion_tokens, self.total_tokens
168        )
169    }
170}
171
172impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
173    type Error = CompletionError;
174
175    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
176        let choice = response.choices.first().ok_or_else(|| {
177            CompletionError::ResponseError("Response contained no choices".to_owned())
178        })?;
179
180        match &choice.message {
181            Message {
182                role: Role::Assistant,
183                content,
184            } => Ok(completion::CompletionResponse {
185                choice: OneOrMany::one(content.clone().into()),
186                usage: completion::Usage {
187                    input_tokens: response.usage.prompt_tokens as u64,
188                    output_tokens: response.usage.completion_tokens as u64,
189                    total_tokens: response.usage.total_tokens as u64,
190                    cached_input_tokens: 0,
191                },
192                raw_response: response,
193            }),
194            _ => Err(CompletionError::ResponseError(
195                "Response contained no assistant message".to_owned(),
196            )),
197        }
198    }
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub(super) struct PerplexityCompletionRequest {
203    model: String,
204    pub messages: Vec<Message>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub temperature: Option<f64>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub max_tokens: Option<u64>,
209    #[serde(flatten, skip_serializing_if = "Option::is_none")]
210    additional_params: Option<serde_json::Value>,
211    pub stream: bool,
212}
213
214impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
215    type Error = CompletionError;
216
217    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
218        let mut partial_history = vec![];
219        if let Some(docs) = req.normalized_documents() {
220            partial_history.push(docs);
221        }
222        partial_history.extend(req.chat_history);
223
224        // Initialize full history with preamble (or empty if non-existent)
225        let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
226            vec![Message {
227                role: Role::System,
228                content: preamble,
229            }]
230        });
231
232        // Convert and extend the rest of the history
233        full_history.extend(
234            partial_history
235                .into_iter()
236                .map(message::Message::try_into)
237                .collect::<Result<Vec<Message>, _>>()?,
238        );
239
240        Ok(Self {
241            model: model.to_string(),
242            messages: full_history,
243            temperature: req.temperature,
244            max_tokens: req.max_tokens,
245            additional_params: req.additional_params,
246            stream: false,
247        })
248    }
249}
250
251#[derive(Clone)]
252pub struct CompletionModel<T = reqwest::Client> {
253    client: Client<T>,
254    pub model: String,
255}
256
257impl<T> CompletionModel<T> {
258    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
259        Self {
260            client,
261            model: model.into(),
262        }
263    }
264}
265
266impl TryFrom<message::Message> for Message {
267    type Error = MessageError;
268
269    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
270        Ok(match message {
271            message::Message::User { content } => {
272                let collapsed_content = content
273                    .into_iter()
274                    .map(|content| match content {
275                        message::UserContent::Text(message::Text { text }) => Ok(text),
276                        _ => Err(MessageError::ConversionError(
277                            "Only text content is supported by Perplexity".to_owned(),
278                        )),
279                    })
280                    .collect::<Result<Vec<_>, _>>()?
281                    .join("\n");
282
283                Message {
284                    role: Role::User,
285                    content: collapsed_content,
286                }
287            }
288
289            message::Message::Assistant { content, .. } => {
290                let collapsed_content = content
291                    .into_iter()
292                    .map(|content| {
293                        Ok(match content {
294                            message::AssistantContent::Text(message::Text { text }) => text,
295                            _ => return Err(MessageError::ConversionError(
296                                "Only text assistant message content is supported by Perplexity"
297                                    .to_owned(),
298                            )),
299                        })
300                    })
301                    .collect::<Result<Vec<_>, _>>()?
302                    .join("\n");
303
304                Message {
305                    role: Role::Assistant,
306                    content: collapsed_content,
307                }
308            }
309        })
310    }
311}
312
313impl From<Message> for message::Message {
314    fn from(message: Message) -> Self {
315        match message.role {
316            Role::User => message::Message::user(message.content),
317            Role::Assistant => message::Message::assistant(message.content),
318
319            // System messages get coerced into user messages for ease of error handling.
320            // They should be handled on the outside of `Message` conversions via the preamble.
321            Role::System => message::Message::user(message.content),
322        }
323    }
324}
325
326impl<T> completion::CompletionModel for CompletionModel<T>
327where
328    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
329{
330    type Response = CompletionResponse;
331    type StreamingResponse = openai::StreamingCompletionResponse;
332
333    type Client = Client<T>;
334
335    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
336        Self::new(client.clone(), model)
337    }
338
339    async fn completion(
340        &self,
341        completion_request: completion::CompletionRequest,
342    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343        let span = if tracing::Span::current().is_disabled() {
344            info_span!(
345                target: "rig::completions",
346                "chat",
347                gen_ai.operation.name = "chat",
348                gen_ai.provider.name = "perplexity",
349                gen_ai.request.model = self.model,
350                gen_ai.system_instructions = tracing::field::Empty,
351                gen_ai.response.id = tracing::field::Empty,
352                gen_ai.response.model = tracing::field::Empty,
353                gen_ai.usage.output_tokens = tracing::field::Empty,
354                gen_ai.usage.input_tokens = tracing::field::Empty,
355            )
356        } else {
357            tracing::Span::current()
358        };
359
360        span.record("gen_ai.system_instructions", &completion_request.preamble);
361
362        if completion_request.tool_choice.is_some() {
363            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
364        }
365
366        if !completion_request.tools.is_empty() {
367            tracing::warn!("WARNING: `tools` not supported on Perplexity");
368        }
369        let request =
370            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
371
372        if tracing::enabled!(tracing::Level::TRACE) {
373            tracing::trace!(target: "rig::completions",
374                "Perplexity completion request: {}",
375                serde_json::to_string_pretty(&request)?
376            );
377        }
378
379        let body = serde_json::to_vec(&request)?;
380
381        let req = self
382            .client
383            .post("/v1/chat/completions")?
384            .body(body)
385            .map_err(http_client::Error::from)?;
386
387        let async_block = async move {
388            let response = self.client.send::<_, Bytes>(req).await?;
389
390            let status = response.status();
391            let response_body = response.into_body().into_future().await?.to_vec();
392
393            if status.is_success() {
394                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
395                    ApiResponse::Ok(response) => {
396                        let span = tracing::Span::current();
397                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
398                        span.record(
399                            "gen_ai.usage.output_tokens",
400                            response.usage.completion_tokens,
401                        );
402                        span.record("gen_ai.response.id", response.id.to_string());
403                        span.record("gen_ai.response.model_name", response.model.to_string());
404                        if tracing::enabled!(tracing::Level::TRACE) {
405                            tracing::trace!(target: "rig::responses",
406                                "Perplexity completion response: {}",
407                                serde_json::to_string_pretty(&response)?
408                            );
409                        }
410                        Ok(response.try_into()?)
411                    }
412                    ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
413                }
414            } else {
415                Err(CompletionError::ProviderError(
416                    String::from_utf8_lossy(&response_body).to_string(),
417                ))
418            }
419        };
420
421        async_block.instrument(span).await
422    }
423
424    async fn stream(
425        &self,
426        completion_request: completion::CompletionRequest,
427    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
428        let span = if tracing::Span::current().is_disabled() {
429            info_span!(
430                target: "rig::completions",
431                "chat_streaming",
432                gen_ai.operation.name = "chat_streaming",
433                gen_ai.provider.name = "perplexity",
434                gen_ai.request.model = self.model,
435                gen_ai.system_instructions = tracing::field::Empty,
436                gen_ai.response.id = tracing::field::Empty,
437                gen_ai.response.model = tracing::field::Empty,
438                gen_ai.usage.output_tokens = tracing::field::Empty,
439                gen_ai.usage.input_tokens = tracing::field::Empty,
440            )
441        } else {
442            tracing::Span::current()
443        };
444
445        span.record("gen_ai.system_instructions", &completion_request.preamble);
446
447        if completion_request.tool_choice.is_some() {
448            tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
449        }
450
451        if !completion_request.tools.is_empty() {
452            tracing::warn!("WARNING: `tools` not supported on Perplexity");
453        }
454
455        let mut request =
456            PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
457        request.stream = true;
458
459        if tracing::enabled!(tracing::Level::TRACE) {
460            tracing::trace!(target: "rig::completions",
461                "Perplexity streaming completion request: {}",
462                serde_json::to_string_pretty(&request)?
463            );
464        }
465
466        let body = serde_json::to_vec(&request)?;
467
468        let req = self
469            .client
470            .post("/chat/completions")?
471            .body(body)
472            .map_err(http_client::Error::from)?;
473
474        send_compatible_streaming_request(self.client.clone(), req)
475            .instrument(span)
476            .await
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_deserialize_message() {
486        let json_data = r#"
487        {
488            "role": "user",
489            "content": "Hello, how can I help you?"
490        }
491        "#;
492
493        let message: Message = serde_json::from_str(json_data).unwrap();
494        assert_eq!(message.role, Role::User);
495        assert_eq!(message.content, "Hello, how can I help you?");
496    }
497
498    #[test]
499    fn test_serialize_message() {
500        let message = Message {
501            role: Role::Assistant,
502            content: "I am here to assist you.".to_string(),
503        };
504
505        let json_data = serde_json::to_string(&message).unwrap();
506        let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
507        assert_eq!(json_data, expected_json);
508    }
509
510    #[test]
511    fn test_message_to_message_conversion() {
512        let user_message = message::Message::user("User message");
513        let assistant_message = message::Message::assistant("Assistant message");
514
515        let converted_user_message: Message = user_message.clone().try_into().unwrap();
516        let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
517
518        assert_eq!(converted_user_message.role, Role::User);
519        assert_eq!(converted_user_message.content, "User message");
520
521        assert_eq!(converted_assistant_message.role, Role::Assistant);
522        assert_eq!(converted_assistant_message.content, "Assistant message");
523
524        let back_to_user_message: message::Message = converted_user_message.into();
525        let back_to_assistant_message: message::Message = converted_assistant_message.into();
526
527        assert_eq!(user_message, back_to_user_message);
528        assert_eq!(assistant_message, back_to_assistant_message);
529    }
530}