Skip to main content

atlas/providers/
perplexity.rs

1//! Perplexity API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::perplexity;
6//!
7//! let client = perplexity::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_sonar_small_online = client.completion_model(perplexity::LLAMA_3_1_SONAR_SMALL_ONLINE);
10//! ```
11
12use crate::{
13    agent::AgentBuilder,
14    completion::{self, CompletionError},
15    extractor::ExtractorBuilder,
16    json_utils,
17};
18
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23// ================================================================
24// Main Cohere Client
25// ================================================================
26const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
27
28#[derive(Clone)]
29pub struct Client {
30    base_url: String,
31    http_client: reqwest::Client,
32}
33
34impl Client {
35    pub fn new(api_key: &str) -> Self {
36        Self::from_url(api_key, PERPLEXITY_API_BASE_URL)
37    }
38
39    /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable.
40    /// Panics if the environment variable is not set.
41    pub fn from_env() -> Self {
42        let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
43        Self::new(&api_key)
44    }
45
46    pub fn from_url(api_key: &str, base_url: &str) -> Self {
47        Self {
48            base_url: base_url.to_string(),
49            http_client: reqwest::Client::builder()
50                .default_headers({
51                    let mut headers = reqwest::header::HeaderMap::new();
52                    headers.insert(
53                        "Authorization",
54                        format!("Bearer {}", api_key)
55                            .parse()
56                            .expect("Bearer token should parse"),
57                    );
58                    headers
59                })
60                .build()
61                .expect("Perplexity reqwest client should build"),
62        }
63    }
64
65    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
66        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
67        self.http_client.post(url)
68    }
69
70    pub fn completion_model(&self, model: &str) -> CompletionModel {
71        CompletionModel::new(self.clone(), model)
72    }
73
74    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
75        AgentBuilder::new(self.completion_model(model))
76    }
77
78    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
79        &self,
80        model: &str,
81    ) -> ExtractorBuilder<T, CompletionModel> {
82        ExtractorBuilder::new(self.completion_model(model))
83    }
84}
85
86#[derive(Debug, Deserialize)]
87struct ApiErrorResponse {
88    message: String,
89}
90
91#[derive(Debug, Deserialize)]
92#[serde(untagged)]
93enum ApiResponse<T> {
94    Ok(T),
95    Err(ApiErrorResponse),
96}
97
98// ================================================================
99// Perplexity Completion API
100// ================================================================
101/// `llama-3.1-sonar-small-128k-online` completion model
102pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online";
103/// `llama-3.1-sonar-large-128k-online` completion model
104pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online";
105/// `llama-3.1-sonar-huge-128k-online` completion model
106pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online";
107/// `llama-3.1-sonar-small-128k-chat` completion model
108pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat";
109/// `llama-3.1-sonar-large-128k-chat` completion model
110pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat";
111/// `llama-3.1-8b-instruct` completion model
112pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct";
113/// `llama-3.1-70b-instruct` completion model
114pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct";
115
116#[derive(Debug, Deserialize)]
117pub struct CompletionResponse {
118    pub id: String,
119    pub model: String,
120    pub object: String,
121    pub created: u64,
122    #[serde(default)]
123    pub choices: Vec<Choice>,
124    pub usage: Usage,
125}
126
127#[derive(Deserialize, Debug)]
128pub struct Message {
129    pub role: String,
130    pub content: String,
131}
132
133#[derive(Deserialize, Debug)]
134pub struct Delta {
135    pub role: String,
136    pub content: String,
137}
138
139#[derive(Deserialize, Debug)]
140pub struct Choice {
141    pub index: usize,
142    pub finish_reason: String,
143    pub message: Message,
144    pub delta: Delta,
145}
146
147#[derive(Deserialize, Debug)]
148pub struct Usage {
149    pub prompt_tokens: u32,
150    pub completion_tokens: u32,
151    pub total_tokens: u32,
152}
153
154impl std::fmt::Display for Usage {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
159            self.prompt_tokens, self.completion_tokens, self.total_tokens
160        )
161    }
162}
163
164impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
165    type Error = CompletionError;
166
167    fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
168        match value.choices.as_slice() {
169            [Choice {
170                message: Message { content, .. },
171                ..
172            }, ..] => Ok(completion::CompletionResponse {
173                choice: completion::ModelChoice::Message(content.to_string()),
174                raw_response: value,
175            }),
176            _ => Err(CompletionError::ResponseError(
177                "Response did not contain a message or tool call".into(),
178            )),
179        }
180    }
181}
182
183#[derive(Clone)]
184pub struct CompletionModel {
185    client: Client,
186    pub model: String,
187}
188
189impl CompletionModel {
190    pub fn new(client: Client, model: &str) -> Self {
191        Self {
192            client,
193            model: model.to_string(),
194        }
195    }
196}
197
198impl completion::CompletionModel for CompletionModel {
199    type Response = CompletionResponse;
200
201    async fn completion(
202        &self,
203        completion_request: completion::CompletionRequest,
204    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
205        // Add preamble to messages (if available)
206        let mut messages = if let Some(preamble) = &completion_request.preamble {
207            vec![completion::Message {
208                role: "system".into(),
209                content: preamble.clone(),
210            }]
211        } else {
212            vec![]
213        };
214
215        // Add context documents to chat history
216        let prompt_with_context = completion_request.prompt_with_context();
217
218        // Add chat history to messages
219        messages.extend(completion_request.chat_history);
220
221        // Add user prompt to messages
222        messages.push(completion::Message {
223            role: "user".to_string(),
224            content: prompt_with_context,
225        });
226
227        let request = json!({
228            "model": self.model,
229            "messages": messages,
230            "temperature": completion_request.temperature,
231        });
232
233        let response = self
234            .client
235            .post("/chat/completions")
236            .json(
237                &if let Some(ref params) = completion_request.additional_params {
238                    json_utils::merge(request.clone(), params.clone())
239                } else {
240                    request.clone()
241                },
242            )
243            .send()
244            .await?;
245
246        if response.status().is_success() {
247            match response.json::<ApiResponse<CompletionResponse>>().await? {
248                ApiResponse::Ok(completion) => {
249                    tracing::info!(target: "rig",
250                        "Perplexity completion token usage: {}",
251                        completion.usage
252                    );
253                    Ok(completion.try_into()?)
254                }
255                ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
256            }
257        } else {
258            Err(CompletionError::ProviderError(response.text().await?))
259        }
260    }
261}