1use crate::client::BearerAuth;
12use crate::completion::CompletionRequest;
13use crate::providers::openai;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 OneOrMany,
18 client::{
19 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
20 },
21 completion::{self, CompletionError, MessageError, message},
22 http_client::{self, HttpClientExt},
23};
24use bytes::Bytes;
25use serde::{Deserialize, Serialize};
26use tracing::{Instrument, info_span};
27
28const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33#[derive(Debug, Default, Clone, Copy)]
34pub struct PerplexityExt;
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityBuilder;
38
39type PerplexityApiKey = BearerAuth;
40
41impl Provider for PerplexityExt {
42 type Builder = PerplexityBuilder;
43
44 const VERIFY_PATH: &'static str = "";
46
47 fn build<H>(
48 _: &crate::client::ClientBuilder<
49 Self::Builder,
50 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
51 H,
52 >,
53 ) -> http_client::Result<Self> {
54 Ok(Self)
55 }
56}
57
58impl<H> Capabilities<H> for PerplexityExt {
59 type Completion = Capable<CompletionModel<H>>;
60 type Transcription = Nothing;
61 type Embeddings = Nothing;
62 #[cfg(feature = "image")]
63 type ImageGeneration = Nothing;
64
65 #[cfg(feature = "audio")]
66 type AudioGeneration = Nothing;
67}
68
69impl DebugExt for PerplexityExt {}
70
71impl ProviderBuilder for PerplexityBuilder {
72 type Output = PerplexityExt;
73 type ApiKey = PerplexityApiKey;
74
75 const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
76}
77
78pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
79pub type ClientBuilder<H = reqwest::Client> =
80 client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
81
82impl ProviderClient for Client {
83 type Input = String;
84
85 fn from_env() -> Self {
88 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
89 Self::new(&api_key).unwrap()
90 }
91
92 fn from_val(input: Self::Input) -> Self {
93 Self::new(&input).unwrap()
94 }
95}
96
97#[derive(Debug, Deserialize)]
98struct ApiErrorResponse {
99 message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104enum ApiResponse<T> {
105 Ok(T),
106 Err(ApiErrorResponse),
107}
108
109pub const SONAR_PRO: &str = "sonar_pro";
114pub const SONAR: &str = "sonar";
115
116#[derive(Debug, Deserialize, Serialize)]
117pub struct CompletionResponse {
118 pub id: String,
119 pub model: String,
120 pub object: String,
121 pub created: u64,
122 #[serde(default)]
123 pub choices: Vec<Choice>,
124 pub usage: Usage,
125}
126
127#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
128pub struct Message {
129 pub role: Role,
130 pub content: String,
131}
132
133#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
134#[serde(rename_all = "lowercase")]
135pub enum Role {
136 System,
137 User,
138 Assistant,
139}
140
141#[derive(Deserialize, Debug, Serialize)]
142pub struct Delta {
143 pub role: Role,
144 pub content: String,
145}
146
147#[derive(Deserialize, Debug, Serialize)]
148pub struct Choice {
149 pub index: usize,
150 pub finish_reason: String,
151 pub message: Message,
152 pub delta: Delta,
153}
154
155#[derive(Deserialize, Debug, Serialize)]
156pub struct Usage {
157 pub prompt_tokens: u32,
158 pub completion_tokens: u32,
159 pub total_tokens: u32,
160}
161
162impl std::fmt::Display for Usage {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(
165 f,
166 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
167 self.prompt_tokens, self.completion_tokens, self.total_tokens
168 )
169 }
170}
171
172impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
173 type Error = CompletionError;
174
175 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
176 let choice = response.choices.first().ok_or_else(|| {
177 CompletionError::ResponseError("Response contained no choices".to_owned())
178 })?;
179
180 match &choice.message {
181 Message {
182 role: Role::Assistant,
183 content,
184 } => Ok(completion::CompletionResponse {
185 choice: OneOrMany::one(content.clone().into()),
186 usage: completion::Usage {
187 input_tokens: response.usage.prompt_tokens as u64,
188 output_tokens: response.usage.completion_tokens as u64,
189 total_tokens: response.usage.total_tokens as u64,
190 cached_input_tokens: 0,
191 },
192 raw_response: response,
193 }),
194 _ => Err(CompletionError::ResponseError(
195 "Response contained no assistant message".to_owned(),
196 )),
197 }
198 }
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub(super) struct PerplexityCompletionRequest {
203 model: String,
204 pub messages: Vec<Message>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub temperature: Option<f64>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub max_tokens: Option<u64>,
209 #[serde(flatten, skip_serializing_if = "Option::is_none")]
210 additional_params: Option<serde_json::Value>,
211 pub stream: bool,
212}
213
214impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
215 type Error = CompletionError;
216
217 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
218 let mut partial_history = vec![];
219 if let Some(docs) = req.normalized_documents() {
220 partial_history.push(docs);
221 }
222 partial_history.extend(req.chat_history);
223
224 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
226 vec![Message {
227 role: Role::System,
228 content: preamble,
229 }]
230 });
231
232 full_history.extend(
234 partial_history
235 .into_iter()
236 .map(message::Message::try_into)
237 .collect::<Result<Vec<Message>, _>>()?,
238 );
239
240 Ok(Self {
241 model: model.to_string(),
242 messages: full_history,
243 temperature: req.temperature,
244 max_tokens: req.max_tokens,
245 additional_params: req.additional_params,
246 stream: false,
247 })
248 }
249}
250
251#[derive(Clone)]
252pub struct CompletionModel<T = reqwest::Client> {
253 client: Client<T>,
254 pub model: String,
255}
256
257impl<T> CompletionModel<T> {
258 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
259 Self {
260 client,
261 model: model.into(),
262 }
263 }
264}
265
266impl TryFrom<message::Message> for Message {
267 type Error = MessageError;
268
269 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
270 Ok(match message {
271 message::Message::User { content } => {
272 let collapsed_content = content
273 .into_iter()
274 .map(|content| match content {
275 message::UserContent::Text(message::Text { text }) => Ok(text),
276 _ => Err(MessageError::ConversionError(
277 "Only text content is supported by Perplexity".to_owned(),
278 )),
279 })
280 .collect::<Result<Vec<_>, _>>()?
281 .join("\n");
282
283 Message {
284 role: Role::User,
285 content: collapsed_content,
286 }
287 }
288
289 message::Message::Assistant { content, .. } => {
290 let collapsed_content = content
291 .into_iter()
292 .map(|content| {
293 Ok(match content {
294 message::AssistantContent::Text(message::Text { text }) => text,
295 _ => return Err(MessageError::ConversionError(
296 "Only text assistant message content is supported by Perplexity"
297 .to_owned(),
298 )),
299 })
300 })
301 .collect::<Result<Vec<_>, _>>()?
302 .join("\n");
303
304 Message {
305 role: Role::Assistant,
306 content: collapsed_content,
307 }
308 }
309 })
310 }
311}
312
313impl From<Message> for message::Message {
314 fn from(message: Message) -> Self {
315 match message.role {
316 Role::User => message::Message::user(message.content),
317 Role::Assistant => message::Message::assistant(message.content),
318
319 Role::System => message::Message::user(message.content),
322 }
323 }
324}
325
326impl<T> completion::CompletionModel for CompletionModel<T>
327where
328 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
329{
330 type Response = CompletionResponse;
331 type StreamingResponse = openai::StreamingCompletionResponse;
332
333 type Client = Client<T>;
334
335 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
336 Self::new(client.clone(), model)
337 }
338
339 async fn completion(
340 &self,
341 completion_request: completion::CompletionRequest,
342 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343 let span = if tracing::Span::current().is_disabled() {
344 info_span!(
345 target: "rig::completions",
346 "chat",
347 gen_ai.operation.name = "chat",
348 gen_ai.provider.name = "perplexity",
349 gen_ai.request.model = self.model,
350 gen_ai.system_instructions = tracing::field::Empty,
351 gen_ai.response.id = tracing::field::Empty,
352 gen_ai.response.model = tracing::field::Empty,
353 gen_ai.usage.output_tokens = tracing::field::Empty,
354 gen_ai.usage.input_tokens = tracing::field::Empty,
355 )
356 } else {
357 tracing::Span::current()
358 };
359
360 span.record("gen_ai.system_instructions", &completion_request.preamble);
361
362 if completion_request.tool_choice.is_some() {
363 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
364 }
365
366 if !completion_request.tools.is_empty() {
367 tracing::warn!("WARNING: `tools` not supported on Perplexity");
368 }
369 let request =
370 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
371
372 if tracing::enabled!(tracing::Level::TRACE) {
373 tracing::trace!(target: "rig::completions",
374 "Perplexity completion request: {}",
375 serde_json::to_string_pretty(&request)?
376 );
377 }
378
379 let body = serde_json::to_vec(&request)?;
380
381 let req = self
382 .client
383 .post("/v1/chat/completions")?
384 .body(body)
385 .map_err(http_client::Error::from)?;
386
387 let async_block = async move {
388 let response = self.client.send::<_, Bytes>(req).await?;
389
390 let status = response.status();
391 let response_body = response.into_body().into_future().await?.to_vec();
392
393 if status.is_success() {
394 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
395 ApiResponse::Ok(response) => {
396 let span = tracing::Span::current();
397 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
398 span.record(
399 "gen_ai.usage.output_tokens",
400 response.usage.completion_tokens,
401 );
402 span.record("gen_ai.response.id", response.id.to_string());
403 span.record("gen_ai.response.model_name", response.model.to_string());
404 if tracing::enabled!(tracing::Level::TRACE) {
405 tracing::trace!(target: "rig::responses",
406 "Perplexity completion response: {}",
407 serde_json::to_string_pretty(&response)?
408 );
409 }
410 Ok(response.try_into()?)
411 }
412 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
413 }
414 } else {
415 Err(CompletionError::ProviderError(
416 String::from_utf8_lossy(&response_body).to_string(),
417 ))
418 }
419 };
420
421 async_block.instrument(span).await
422 }
423
424 async fn stream(
425 &self,
426 completion_request: completion::CompletionRequest,
427 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
428 let span = if tracing::Span::current().is_disabled() {
429 info_span!(
430 target: "rig::completions",
431 "chat_streaming",
432 gen_ai.operation.name = "chat_streaming",
433 gen_ai.provider.name = "perplexity",
434 gen_ai.request.model = self.model,
435 gen_ai.system_instructions = tracing::field::Empty,
436 gen_ai.response.id = tracing::field::Empty,
437 gen_ai.response.model = tracing::field::Empty,
438 gen_ai.usage.output_tokens = tracing::field::Empty,
439 gen_ai.usage.input_tokens = tracing::field::Empty,
440 )
441 } else {
442 tracing::Span::current()
443 };
444
445 span.record("gen_ai.system_instructions", &completion_request.preamble);
446
447 if completion_request.tool_choice.is_some() {
448 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
449 }
450
451 if !completion_request.tools.is_empty() {
452 tracing::warn!("WARNING: `tools` not supported on Perplexity");
453 }
454
455 let mut request =
456 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
457 request.stream = true;
458
459 if tracing::enabled!(tracing::Level::TRACE) {
460 tracing::trace!(target: "rig::completions",
461 "Perplexity streaming completion request: {}",
462 serde_json::to_string_pretty(&request)?
463 );
464 }
465
466 let body = serde_json::to_vec(&request)?;
467
468 let req = self
469 .client
470 .post("/chat/completions")?
471 .body(body)
472 .map_err(http_client::Error::from)?;
473
474 send_compatible_streaming_request(self.client.clone(), req)
475 .instrument(span)
476 .await
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_deserialize_message() {
486 let json_data = r#"
487 {
488 "role": "user",
489 "content": "Hello, how can I help you?"
490 }
491 "#;
492
493 let message: Message = serde_json::from_str(json_data).unwrap();
494 assert_eq!(message.role, Role::User);
495 assert_eq!(message.content, "Hello, how can I help you?");
496 }
497
498 #[test]
499 fn test_serialize_message() {
500 let message = Message {
501 role: Role::Assistant,
502 content: "I am here to assist you.".to_string(),
503 };
504
505 let json_data = serde_json::to_string(&message).unwrap();
506 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
507 assert_eq!(json_data, expected_json);
508 }
509
510 #[test]
511 fn test_message_to_message_conversion() {
512 let user_message = message::Message::user("User message");
513 let assistant_message = message::Message::assistant("Assistant message");
514
515 let converted_user_message: Message = user_message.clone().try_into().unwrap();
516 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
517
518 assert_eq!(converted_user_message.role, Role::User);
519 assert_eq!(converted_user_message.content, "User message");
520
521 assert_eq!(converted_assistant_message.role, Role::Assistant);
522 assert_eq!(converted_assistant_message.content, "Assistant message");
523
524 let back_to_user_message: message::Message = converted_user_message.into();
525 let back_to_assistant_message: message::Message = converted_assistant_message.into();
526
527 assert_eq!(user_message, back_to_user_message);
528 assert_eq!(assistant_message, back_to_assistant_message);
529 }
530}