1use crate::client::BearerAuth;
12use crate::completion::CompletionRequest;
13use crate::providers::openai;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 OneOrMany,
18 client::{
19 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
20 },
21 completion::{self, CompletionError, MessageError, message},
22 http_client::{self, HttpClientExt},
23};
24use bytes::Bytes;
25use serde::{Deserialize, Serialize};
26use tracing::{Instrument, info_span};
27
28const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33#[derive(Debug, Default, Clone, Copy)]
34pub struct PerplexityExt;
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityBuilder;
38
39type PerplexityApiKey = BearerAuth;
40
41impl Provider for PerplexityExt {
42 type Builder = PerplexityBuilder;
43
44 const VERIFY_PATH: &'static str = "";
46}
47
48impl<H> Capabilities<H> for PerplexityExt {
49 type Completion = Capable<CompletionModel<H>>;
50 type Transcription = Nothing;
51 type Embeddings = Nothing;
52 type ModelListing = Nothing;
53 #[cfg(feature = "image")]
54 type ImageGeneration = Nothing;
55
56 #[cfg(feature = "audio")]
57 type AudioGeneration = Nothing;
58}
59
60impl DebugExt for PerplexityExt {}
61
62impl ProviderBuilder for PerplexityBuilder {
63 type Extension<H>
64 = PerplexityExt
65 where
66 H: HttpClientExt;
67 type ApiKey = PerplexityApiKey;
68
69 const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
70
71 fn build<H>(
72 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
73 ) -> http_client::Result<Self::Extension<H>>
74 where
75 H: HttpClientExt,
76 {
77 Ok(PerplexityExt)
78 }
79}
80
81pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
82pub type ClientBuilder<H = reqwest::Client> =
83 client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
84
85impl ProviderClient for Client {
86 type Input = String;
87
88 fn from_env() -> Self {
91 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
92 Self::new(&api_key).unwrap()
93 }
94
95 fn from_val(input: Self::Input) -> Self {
96 Self::new(&input).unwrap()
97 }
98}
99
100#[derive(Debug, Deserialize)]
101struct ApiErrorResponse {
102 message: String,
103}
104
105#[derive(Debug, Deserialize)]
106#[serde(untagged)]
107enum ApiResponse<T> {
108 Ok(T),
109 Err(ApiErrorResponse),
110}
111
112pub const SONAR_PRO: &str = "sonar_pro";
117pub const SONAR: &str = "sonar";
118
119#[derive(Debug, Deserialize, Serialize)]
120pub struct CompletionResponse {
121 pub id: String,
122 pub model: String,
123 pub object: String,
124 pub created: u64,
125 #[serde(default)]
126 pub choices: Vec<Choice>,
127 pub usage: Usage,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131pub struct Message {
132 pub role: Role,
133 pub content: String,
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139 System,
140 User,
141 Assistant,
142}
143
144#[derive(Deserialize, Debug, Serialize)]
145pub struct Delta {
146 pub role: Role,
147 pub content: String,
148}
149
150#[derive(Deserialize, Debug, Serialize)]
151pub struct Choice {
152 pub index: usize,
153 pub finish_reason: String,
154 pub message: Message,
155 pub delta: Delta,
156}
157
158#[derive(Deserialize, Debug, Serialize)]
159pub struct Usage {
160 pub prompt_tokens: u32,
161 pub completion_tokens: u32,
162 pub total_tokens: u32,
163}
164
165impl std::fmt::Display for Usage {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 write!(
168 f,
169 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
170 self.prompt_tokens, self.completion_tokens, self.total_tokens
171 )
172 }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176 type Error = CompletionError;
177
178 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179 let choice = response.choices.first().ok_or_else(|| {
180 CompletionError::ResponseError("Response contained no choices".to_owned())
181 })?;
182
183 match &choice.message {
184 Message {
185 role: Role::Assistant,
186 content,
187 } => Ok(completion::CompletionResponse {
188 choice: OneOrMany::one(content.clone().into()),
189 usage: completion::Usage {
190 input_tokens: response.usage.prompt_tokens as u64,
191 output_tokens: response.usage.completion_tokens as u64,
192 total_tokens: response.usage.total_tokens as u64,
193 cached_input_tokens: 0,
194 cache_creation_input_tokens: 0,
195 },
196 raw_response: response,
197 message_id: None,
198 }),
199 _ => Err(CompletionError::ResponseError(
200 "Response contained no assistant message".to_owned(),
201 )),
202 }
203 }
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207pub(super) struct PerplexityCompletionRequest {
208 model: String,
209 pub messages: Vec<Message>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 pub temperature: Option<f64>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub max_tokens: Option<u64>,
214 #[serde(flatten, skip_serializing_if = "Option::is_none")]
215 additional_params: Option<serde_json::Value>,
216 pub stream: bool,
217}
218
219impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
220 type Error = CompletionError;
221
222 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
223 if req.output_schema.is_some() {
224 tracing::warn!("Structured outputs currently not supported for Perplexity");
225 }
226 let model = req.model.clone().unwrap_or_else(|| model.to_string());
227 let mut partial_history = vec![];
228 if let Some(docs) = req.normalized_documents() {
229 partial_history.push(docs);
230 }
231 partial_history.extend(req.chat_history);
232
233 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
235 vec![Message {
236 role: Role::System,
237 content: preamble,
238 }]
239 });
240
241 full_history.extend(
243 partial_history
244 .into_iter()
245 .map(message::Message::try_into)
246 .collect::<Result<Vec<Message>, _>>()?,
247 );
248
249 Ok(Self {
250 model: model.to_string(),
251 messages: full_history,
252 temperature: req.temperature,
253 max_tokens: req.max_tokens,
254 additional_params: req.additional_params,
255 stream: false,
256 })
257 }
258}
259
260#[derive(Clone)]
261pub struct CompletionModel<T = reqwest::Client> {
262 client: Client<T>,
263 pub model: String,
264}
265
266impl<T> CompletionModel<T> {
267 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
268 Self {
269 client,
270 model: model.into(),
271 }
272 }
273}
274
275impl TryFrom<message::Message> for Message {
276 type Error = MessageError;
277
278 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
279 Ok(match message {
280 message::Message::System { content } => Message {
281 role: Role::System,
282 content,
283 },
284 message::Message::User { content } => {
285 let collapsed_content = content
286 .into_iter()
287 .map(|content| match content {
288 message::UserContent::Text(message::Text { text }) => Ok(text),
289 _ => Err(MessageError::ConversionError(
290 "Only text content is supported by Perplexity".to_owned(),
291 )),
292 })
293 .collect::<Result<Vec<_>, _>>()?
294 .join("\n");
295
296 Message {
297 role: Role::User,
298 content: collapsed_content,
299 }
300 }
301
302 message::Message::Assistant { content, .. } => {
303 let collapsed_content = content
304 .into_iter()
305 .map(|content| {
306 Ok(match content {
307 message::AssistantContent::Text(message::Text { text }) => text,
308 _ => return Err(MessageError::ConversionError(
309 "Only text assistant message content is supported by Perplexity"
310 .to_owned(),
311 )),
312 })
313 })
314 .collect::<Result<Vec<_>, _>>()?
315 .join("\n");
316
317 Message {
318 role: Role::Assistant,
319 content: collapsed_content,
320 }
321 }
322 })
323 }
324}
325
326impl From<Message> for message::Message {
327 fn from(message: Message) -> Self {
328 match message.role {
329 Role::User => message::Message::user(message.content),
330 Role::Assistant => message::Message::assistant(message.content),
331
332 Role::System => message::Message::user(message.content),
335 }
336 }
337}
338
339impl<T> completion::CompletionModel for CompletionModel<T>
340where
341 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
342{
343 type Response = CompletionResponse;
344 type StreamingResponse = openai::StreamingCompletionResponse;
345
346 type Client = Client<T>;
347
348 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
349 Self::new(client.clone(), model)
350 }
351
352 async fn completion(
353 &self,
354 completion_request: completion::CompletionRequest,
355 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
356 let span = if tracing::Span::current().is_disabled() {
357 info_span!(
358 target: "rig::completions",
359 "chat",
360 gen_ai.operation.name = "chat",
361 gen_ai.provider.name = "perplexity",
362 gen_ai.request.model = self.model,
363 gen_ai.system_instructions = tracing::field::Empty,
364 gen_ai.response.id = tracing::field::Empty,
365 gen_ai.response.model = tracing::field::Empty,
366 gen_ai.usage.output_tokens = tracing::field::Empty,
367 gen_ai.usage.input_tokens = tracing::field::Empty,
368 gen_ai.usage.cached_tokens = tracing::field::Empty,
369 )
370 } else {
371 tracing::Span::current()
372 };
373
374 span.record("gen_ai.system_instructions", &completion_request.preamble);
375
376 if completion_request.tool_choice.is_some() {
377 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
378 }
379
380 if !completion_request.tools.is_empty() {
381 tracing::warn!("WARNING: `tools` not supported on Perplexity");
382 }
383 let request =
384 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
385
386 if tracing::enabled!(tracing::Level::TRACE) {
387 tracing::trace!(target: "rig::completions",
388 "Perplexity completion request: {}",
389 serde_json::to_string_pretty(&request)?
390 );
391 }
392
393 let body = serde_json::to_vec(&request)?;
394
395 let req = self
396 .client
397 .post("/v1/chat/completions")?
398 .body(body)
399 .map_err(http_client::Error::from)?;
400
401 let async_block = async move {
402 let response = self.client.send::<_, Bytes>(req).await?;
403
404 let status = response.status();
405 let response_body = response.into_body().into_future().await?.to_vec();
406
407 if status.is_success() {
408 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
409 ApiResponse::Ok(response) => {
410 let span = tracing::Span::current();
411 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
412 span.record(
413 "gen_ai.usage.output_tokens",
414 response.usage.completion_tokens,
415 );
416 span.record("gen_ai.response.id", response.id.to_string());
417 span.record("gen_ai.response.model_name", response.model.to_string());
418 if tracing::enabled!(tracing::Level::TRACE) {
419 tracing::trace!(target: "rig::responses",
420 "Perplexity completion response: {}",
421 serde_json::to_string_pretty(&response)?
422 );
423 }
424 Ok(response.try_into()?)
425 }
426 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
427 }
428 } else {
429 Err(CompletionError::ProviderError(
430 String::from_utf8_lossy(&response_body).to_string(),
431 ))
432 }
433 };
434
435 async_block.instrument(span).await
436 }
437
438 async fn stream(
439 &self,
440 completion_request: completion::CompletionRequest,
441 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
442 let span = if tracing::Span::current().is_disabled() {
443 info_span!(
444 target: "rig::completions",
445 "chat_streaming",
446 gen_ai.operation.name = "chat_streaming",
447 gen_ai.provider.name = "perplexity",
448 gen_ai.request.model = self.model,
449 gen_ai.system_instructions = tracing::field::Empty,
450 gen_ai.response.id = tracing::field::Empty,
451 gen_ai.response.model = tracing::field::Empty,
452 gen_ai.usage.output_tokens = tracing::field::Empty,
453 gen_ai.usage.input_tokens = tracing::field::Empty,
454 gen_ai.usage.cached_tokens = tracing::field::Empty,
455 )
456 } else {
457 tracing::Span::current()
458 };
459
460 span.record("gen_ai.system_instructions", &completion_request.preamble);
461
462 if completion_request.tool_choice.is_some() {
463 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
464 }
465
466 if !completion_request.tools.is_empty() {
467 tracing::warn!("WARNING: `tools` not supported on Perplexity");
468 }
469
470 let mut request =
471 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
472 request.stream = true;
473
474 if tracing::enabled!(tracing::Level::TRACE) {
475 tracing::trace!(target: "rig::completions",
476 "Perplexity streaming completion request: {}",
477 serde_json::to_string_pretty(&request)?
478 );
479 }
480
481 let body = serde_json::to_vec(&request)?;
482
483 let req = self
484 .client
485 .post("/chat/completions")?
486 .body(body)
487 .map_err(http_client::Error::from)?;
488
489 send_compatible_streaming_request(self.client.clone(), req)
490 .instrument(span)
491 .await
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_deserialize_message() {
501 let json_data = r#"
502 {
503 "role": "user",
504 "content": "Hello, how can I help you?"
505 }
506 "#;
507
508 let message: Message = serde_json::from_str(json_data).unwrap();
509 assert_eq!(message.role, Role::User);
510 assert_eq!(message.content, "Hello, how can I help you?");
511 }
512
513 #[test]
514 fn test_serialize_message() {
515 let message = Message {
516 role: Role::Assistant,
517 content: "I am here to assist you.".to_string(),
518 };
519
520 let json_data = serde_json::to_string(&message).unwrap();
521 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
522 assert_eq!(json_data, expected_json);
523 }
524
525 #[test]
526 fn test_message_to_message_conversion() {
527 let user_message = message::Message::user("User message");
528 let assistant_message = message::Message::assistant("Assistant message");
529
530 let converted_user_message: Message = user_message.clone().try_into().unwrap();
531 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
532
533 assert_eq!(converted_user_message.role, Role::User);
534 assert_eq!(converted_user_message.content, "User message");
535
536 assert_eq!(converted_assistant_message.role, Role::Assistant);
537 assert_eq!(converted_assistant_message.content, "Assistant message");
538
539 let back_to_user_message: message::Message = converted_user_message.into();
540 let back_to_assistant_message: message::Message = converted_assistant_message.into();
541
542 assert_eq!(user_message, back_to_user_message);
543 assert_eq!(assistant_message, back_to_assistant_message);
544 }
545 #[test]
546 fn test_client_initialization() {
547 let _client =
548 crate::providers::perplexity::Client::new("dummy-key").expect("Client::new() failed");
549 let _client_from_builder = crate::providers::perplexity::Client::builder()
550 .api_key("dummy-key")
551 .build()
552 .expect("Client::builder() failed");
553 }
554}