1use crate::client::BearerAuth;
12use crate::completion::CompletionRequest;
13use crate::providers::openai;
14use crate::providers::openai::send_compatible_streaming_request;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 OneOrMany,
18 client::{
19 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
20 },
21 completion::{self, CompletionError, MessageError, message},
22 http_client::{self, HttpClientExt},
23};
24use bytes::Bytes;
25use serde::{Deserialize, Serialize};
26use tracing::{Instrument, info_span};
27
28const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33#[derive(Debug, Default, Clone, Copy)]
34pub struct PerplexityExt;
35
36#[derive(Debug, Default, Clone, Copy)]
37pub struct PerplexityBuilder;
38
39type PerplexityApiKey = BearerAuth;
40
41impl Provider for PerplexityExt {
42 type Builder = PerplexityBuilder;
43
44 const VERIFY_PATH: &'static str = "";
46
47 fn build<H>(
48 _: &crate::client::ClientBuilder<
49 Self::Builder,
50 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
51 H,
52 >,
53 ) -> http_client::Result<Self> {
54 Ok(Self)
55 }
56}
57
58impl<H> Capabilities<H> for PerplexityExt {
59 type Completion = Capable<CompletionModel<H>>;
60 type Transcription = Nothing;
61 type Embeddings = Nothing;
62 type ModelListing = Nothing;
63 #[cfg(feature = "image")]
64 type ImageGeneration = Nothing;
65
66 #[cfg(feature = "audio")]
67 type AudioGeneration = Nothing;
68}
69
70impl DebugExt for PerplexityExt {}
71
72impl ProviderBuilder for PerplexityBuilder {
73 type Output = PerplexityExt;
74 type ApiKey = PerplexityApiKey;
75
76 const BASE_URL: &'static str = PERPLEXITY_API_BASE_URL;
77}
78
79pub type Client<H = reqwest::Client> = client::Client<PerplexityExt, H>;
80pub type ClientBuilder<H = reqwest::Client> =
81 client::ClientBuilder<PerplexityBuilder, PerplexityApiKey, H>;
82
83impl ProviderClient for Client {
84 type Input = String;
85
86 fn from_env() -> Self {
89 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
90 Self::new(&api_key).unwrap()
91 }
92
93 fn from_val(input: Self::Input) -> Self {
94 Self::new(&input).unwrap()
95 }
96}
97
98#[derive(Debug, Deserialize)]
99struct ApiErrorResponse {
100 message: String,
101}
102
103#[derive(Debug, Deserialize)]
104#[serde(untagged)]
105enum ApiResponse<T> {
106 Ok(T),
107 Err(ApiErrorResponse),
108}
109
110pub const SONAR_PRO: &str = "sonar_pro";
115pub const SONAR: &str = "sonar";
116
117#[derive(Debug, Deserialize, Serialize)]
118pub struct CompletionResponse {
119 pub id: String,
120 pub model: String,
121 pub object: String,
122 pub created: u64,
123 #[serde(default)]
124 pub choices: Vec<Choice>,
125 pub usage: Usage,
126}
127
128#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
129pub struct Message {
130 pub role: Role,
131 pub content: String,
132}
133
134#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
135#[serde(rename_all = "lowercase")]
136pub enum Role {
137 System,
138 User,
139 Assistant,
140}
141
142#[derive(Deserialize, Debug, Serialize)]
143pub struct Delta {
144 pub role: Role,
145 pub content: String,
146}
147
148#[derive(Deserialize, Debug, Serialize)]
149pub struct Choice {
150 pub index: usize,
151 pub finish_reason: String,
152 pub message: Message,
153 pub delta: Delta,
154}
155
156#[derive(Deserialize, Debug, Serialize)]
157pub struct Usage {
158 pub prompt_tokens: u32,
159 pub completion_tokens: u32,
160 pub total_tokens: u32,
161}
162
163impl std::fmt::Display for Usage {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 write!(
166 f,
167 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
168 self.prompt_tokens, self.completion_tokens, self.total_tokens
169 )
170 }
171}
172
173impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
174 type Error = CompletionError;
175
176 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
177 let choice = response.choices.first().ok_or_else(|| {
178 CompletionError::ResponseError("Response contained no choices".to_owned())
179 })?;
180
181 match &choice.message {
182 Message {
183 role: Role::Assistant,
184 content,
185 } => Ok(completion::CompletionResponse {
186 choice: OneOrMany::one(content.clone().into()),
187 usage: completion::Usage {
188 input_tokens: response.usage.prompt_tokens as u64,
189 output_tokens: response.usage.completion_tokens as u64,
190 total_tokens: response.usage.total_tokens as u64,
191 cached_input_tokens: 0,
192 },
193 raw_response: response,
194 message_id: None,
195 }),
196 _ => Err(CompletionError::ResponseError(
197 "Response contained no assistant message".to_owned(),
198 )),
199 }
200 }
201}
202
203#[derive(Debug, Serialize, Deserialize)]
204pub(super) struct PerplexityCompletionRequest {
205 model: String,
206 pub messages: Vec<Message>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub temperature: Option<f64>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub max_tokens: Option<u64>,
211 #[serde(flatten, skip_serializing_if = "Option::is_none")]
212 additional_params: Option<serde_json::Value>,
213 pub stream: bool,
214}
215
216impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest {
217 type Error = CompletionError;
218
219 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
220 if req.output_schema.is_some() {
221 tracing::warn!("Structured outputs currently not supported for Perplexity");
222 }
223 let model = req.model.clone().unwrap_or_else(|| model.to_string());
224 let mut partial_history = vec![];
225 if let Some(docs) = req.normalized_documents() {
226 partial_history.push(docs);
227 }
228 partial_history.extend(req.chat_history);
229
230 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
232 vec![Message {
233 role: Role::System,
234 content: preamble,
235 }]
236 });
237
238 full_history.extend(
240 partial_history
241 .into_iter()
242 .map(message::Message::try_into)
243 .collect::<Result<Vec<Message>, _>>()?,
244 );
245
246 Ok(Self {
247 model: model.to_string(),
248 messages: full_history,
249 temperature: req.temperature,
250 max_tokens: req.max_tokens,
251 additional_params: req.additional_params,
252 stream: false,
253 })
254 }
255}
256
257#[derive(Clone)]
258pub struct CompletionModel<T = reqwest::Client> {
259 client: Client<T>,
260 pub model: String,
261}
262
263impl<T> CompletionModel<T> {
264 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
265 Self {
266 client,
267 model: model.into(),
268 }
269 }
270}
271
272impl TryFrom<message::Message> for Message {
273 type Error = MessageError;
274
275 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
276 Ok(match message {
277 message::Message::User { content } => {
278 let collapsed_content = content
279 .into_iter()
280 .map(|content| match content {
281 message::UserContent::Text(message::Text { text }) => Ok(text),
282 _ => Err(MessageError::ConversionError(
283 "Only text content is supported by Perplexity".to_owned(),
284 )),
285 })
286 .collect::<Result<Vec<_>, _>>()?
287 .join("\n");
288
289 Message {
290 role: Role::User,
291 content: collapsed_content,
292 }
293 }
294
295 message::Message::Assistant { content, .. } => {
296 let collapsed_content = content
297 .into_iter()
298 .map(|content| {
299 Ok(match content {
300 message::AssistantContent::Text(message::Text { text }) => text,
301 _ => return Err(MessageError::ConversionError(
302 "Only text assistant message content is supported by Perplexity"
303 .to_owned(),
304 )),
305 })
306 })
307 .collect::<Result<Vec<_>, _>>()?
308 .join("\n");
309
310 Message {
311 role: Role::Assistant,
312 content: collapsed_content,
313 }
314 }
315 })
316 }
317}
318
319impl From<Message> for message::Message {
320 fn from(message: Message) -> Self {
321 match message.role {
322 Role::User => message::Message::user(message.content),
323 Role::Assistant => message::Message::assistant(message.content),
324
325 Role::System => message::Message::user(message.content),
328 }
329 }
330}
331
332impl<T> completion::CompletionModel for CompletionModel<T>
333where
334 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
335{
336 type Response = CompletionResponse;
337 type StreamingResponse = openai::StreamingCompletionResponse;
338
339 type Client = Client<T>;
340
341 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
342 Self::new(client.clone(), model)
343 }
344
345 async fn completion(
346 &self,
347 completion_request: completion::CompletionRequest,
348 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
349 let span = if tracing::Span::current().is_disabled() {
350 info_span!(
351 target: "rig::completions",
352 "chat",
353 gen_ai.operation.name = "chat",
354 gen_ai.provider.name = "perplexity",
355 gen_ai.request.model = self.model,
356 gen_ai.system_instructions = tracing::field::Empty,
357 gen_ai.response.id = tracing::field::Empty,
358 gen_ai.response.model = tracing::field::Empty,
359 gen_ai.usage.output_tokens = tracing::field::Empty,
360 gen_ai.usage.input_tokens = tracing::field::Empty,
361 )
362 } else {
363 tracing::Span::current()
364 };
365
366 span.record("gen_ai.system_instructions", &completion_request.preamble);
367
368 if completion_request.tool_choice.is_some() {
369 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
370 }
371
372 if !completion_request.tools.is_empty() {
373 tracing::warn!("WARNING: `tools` not supported on Perplexity");
374 }
375 let request =
376 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
377
378 if tracing::enabled!(tracing::Level::TRACE) {
379 tracing::trace!(target: "rig::completions",
380 "Perplexity completion request: {}",
381 serde_json::to_string_pretty(&request)?
382 );
383 }
384
385 let body = serde_json::to_vec(&request)?;
386
387 let req = self
388 .client
389 .post("/v1/chat/completions")?
390 .body(body)
391 .map_err(http_client::Error::from)?;
392
393 let async_block = async move {
394 let response = self.client.send::<_, Bytes>(req).await?;
395
396 let status = response.status();
397 let response_body = response.into_body().into_future().await?.to_vec();
398
399 if status.is_success() {
400 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
401 ApiResponse::Ok(response) => {
402 let span = tracing::Span::current();
403 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
404 span.record(
405 "gen_ai.usage.output_tokens",
406 response.usage.completion_tokens,
407 );
408 span.record("gen_ai.response.id", response.id.to_string());
409 span.record("gen_ai.response.model_name", response.model.to_string());
410 if tracing::enabled!(tracing::Level::TRACE) {
411 tracing::trace!(target: "rig::responses",
412 "Perplexity completion response: {}",
413 serde_json::to_string_pretty(&response)?
414 );
415 }
416 Ok(response.try_into()?)
417 }
418 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
419 }
420 } else {
421 Err(CompletionError::ProviderError(
422 String::from_utf8_lossy(&response_body).to_string(),
423 ))
424 }
425 };
426
427 async_block.instrument(span).await
428 }
429
430 async fn stream(
431 &self,
432 completion_request: completion::CompletionRequest,
433 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
434 let span = if tracing::Span::current().is_disabled() {
435 info_span!(
436 target: "rig::completions",
437 "chat_streaming",
438 gen_ai.operation.name = "chat_streaming",
439 gen_ai.provider.name = "perplexity",
440 gen_ai.request.model = self.model,
441 gen_ai.system_instructions = tracing::field::Empty,
442 gen_ai.response.id = tracing::field::Empty,
443 gen_ai.response.model = tracing::field::Empty,
444 gen_ai.usage.output_tokens = tracing::field::Empty,
445 gen_ai.usage.input_tokens = tracing::field::Empty,
446 )
447 } else {
448 tracing::Span::current()
449 };
450
451 span.record("gen_ai.system_instructions", &completion_request.preamble);
452
453 if completion_request.tool_choice.is_some() {
454 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
455 }
456
457 if !completion_request.tools.is_empty() {
458 tracing::warn!("WARNING: `tools` not supported on Perplexity");
459 }
460
461 let mut request =
462 PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
463 request.stream = true;
464
465 if tracing::enabled!(tracing::Level::TRACE) {
466 tracing::trace!(target: "rig::completions",
467 "Perplexity streaming completion request: {}",
468 serde_json::to_string_pretty(&request)?
469 );
470 }
471
472 let body = serde_json::to_vec(&request)?;
473
474 let req = self
475 .client
476 .post("/chat/completions")?
477 .body(body)
478 .map_err(http_client::Error::from)?;
479
480 send_compatible_streaming_request(self.client.clone(), req)
481 .instrument(span)
482 .await
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[test]
491 fn test_deserialize_message() {
492 let json_data = r#"
493 {
494 "role": "user",
495 "content": "Hello, how can I help you?"
496 }
497 "#;
498
499 let message: Message = serde_json::from_str(json_data).unwrap();
500 assert_eq!(message.role, Role::User);
501 assert_eq!(message.content, "Hello, how can I help you?");
502 }
503
504 #[test]
505 fn test_serialize_message() {
506 let message = Message {
507 role: Role::Assistant,
508 content: "I am here to assist you.".to_string(),
509 };
510
511 let json_data = serde_json::to_string(&message).unwrap();
512 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
513 assert_eq!(json_data, expected_json);
514 }
515
516 #[test]
517 fn test_message_to_message_conversion() {
518 let user_message = message::Message::user("User message");
519 let assistant_message = message::Message::assistant("Assistant message");
520
521 let converted_user_message: Message = user_message.clone().try_into().unwrap();
522 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
523
524 assert_eq!(converted_user_message.role, Role::User);
525 assert_eq!(converted_user_message.content, "User message");
526
527 assert_eq!(converted_assistant_message.role, Role::Assistant);
528 assert_eq!(converted_assistant_message.content, "Assistant message");
529
530 let back_to_user_message: message::Message = converted_user_message.into();
531 let back_to_assistant_message: message::Message = converted_assistant_message.into();
532
533 assert_eq!(user_message, back_to_user_message);
534 assert_eq!(assistant_message, back_to_assistant_message);
535 }
536 #[test]
537 fn test_client_initialization() {
538 let _client: crate::providers::perplexity::Client =
539 crate::providers::perplexity::Client::new("dummy-key").expect("Client::new() failed");
540 let _client_from_builder: crate::providers::perplexity::Client =
541 crate::providers::perplexity::Client::builder()
542 .api_key("dummy-key")
543 .build()
544 .expect("Client::builder() failed");
545 }
546}