1use crate::{
12 OneOrMany,
13 client::{VerifyClient, VerifyError},
14 completion::{self, CompletionError, MessageError, message},
15 http_client, impl_conversion_traits, json_utils,
16};
17
18use crate::client::{CompletionClient, ProviderClient};
19use crate::completion::CompletionRequest;
20use crate::json_utils::merge;
21use crate::providers::openai;
22use crate::providers::openai::send_compatible_streaming_request;
23use crate::streaming::StreamingCompletionResponse;
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use tracing::{Instrument, info_span};
27
28const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
32
33pub struct ClientBuilder<'a, T = reqwest::Client> {
34 api_key: &'a str,
35 base_url: &'a str,
36 http_client: T,
37}
38
39impl<'a, T> ClientBuilder<'a, T>
40where
41 T: Default,
42{
43 pub fn new(api_key: &'a str) -> Self {
44 Self {
45 api_key,
46 base_url: PERPLEXITY_API_BASE_URL,
47 http_client: Default::default(),
48 }
49 }
50}
51
52impl<'a, T> ClientBuilder<'a, T> {
53 pub fn base_url(mut self, base_url: &'a str) -> Self {
54 self.base_url = base_url;
55 self
56 }
57
58 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
59 ClientBuilder {
60 api_key: self.api_key,
61 base_url: self.base_url,
62 http_client,
63 }
64 }
65
66 pub fn build(self) -> Client<T> {
67 Client {
68 base_url: self.base_url.to_string(),
69 api_key: self.api_key.to_string(),
70 http_client: self.http_client,
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct Client<T = reqwest::Client> {
77 base_url: String,
78 api_key: String,
79 http_client: T,
80}
81
82impl<T> std::fmt::Debug for Client<T>
83where
84 T: std::fmt::Debug,
85{
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("Client")
88 .field("base_url", &self.base_url)
89 .field("http_client", &self.http_client)
90 .field("api_key", &"<REDACTED>")
91 .finish()
92 }
93}
94
95impl<T> Client<T>
96where
97 T: Default,
98{
99 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
110 ClientBuilder::new(api_key)
111 }
112
113 pub fn new(api_key: &str) -> Self {
118 Self::builder(api_key).build()
119 }
120}
121
122impl Client<reqwest::Client> {
123 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
124 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
125 self.http_client.post(url).bearer_auth(&self.api_key)
126 }
127}
128
129impl ProviderClient for Client<reqwest::Client> {
130 fn from_env() -> Self {
133 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
134 Self::new(&api_key)
135 }
136
137 fn from_val(input: crate::client::ProviderValue) -> Self {
138 let crate::client::ProviderValue::Simple(api_key) = input else {
139 panic!("Incorrect provider value type")
140 };
141 Self::new(&api_key)
142 }
143}
144
145impl CompletionClient for Client<reqwest::Client> {
146 type CompletionModel = CompletionModel<reqwest::Client>;
147
148 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
149 CompletionModel::new(self.clone(), model)
150 }
151}
152
153impl VerifyClient for Client<reqwest::Client> {
154 #[cfg_attr(feature = "worker", worker::send)]
155 async fn verify(&self) -> Result<(), VerifyError> {
156 Ok(())
158 }
159}
160
161impl_conversion_traits!(
162 AsTranscription,
163 AsEmbeddings,
164 AsImageGeneration,
165 AsAudioGeneration for Client<T>
166);
167
168#[derive(Debug, Deserialize)]
169struct ApiErrorResponse {
170 message: String,
171}
172
173#[derive(Debug, Deserialize)]
174#[serde(untagged)]
175enum ApiResponse<T> {
176 Ok(T),
177 Err(ApiErrorResponse),
178}
179
180pub const SONAR_PRO: &str = "sonar-pro";
185pub const SONAR: &str = "sonar";
187
188#[derive(Debug, Deserialize, Serialize)]
189pub struct CompletionResponse {
190 pub id: String,
191 pub model: String,
192 pub object: String,
193 pub created: u64,
194 #[serde(default)]
195 pub choices: Vec<Choice>,
196 pub usage: Usage,
197}
198
199#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
200pub struct Message {
201 pub role: Role,
202 pub content: String,
203}
204
205#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
206#[serde(rename_all = "lowercase")]
207pub enum Role {
208 System,
209 User,
210 Assistant,
211}
212
213#[derive(Deserialize, Debug, Serialize)]
214pub struct Delta {
215 pub role: Role,
216 pub content: String,
217}
218
219#[derive(Deserialize, Debug, Serialize)]
220pub struct Choice {
221 pub index: usize,
222 pub finish_reason: String,
223 pub message: Message,
224 pub delta: Delta,
225}
226
227#[derive(Deserialize, Debug, Serialize)]
228pub struct Usage {
229 pub prompt_tokens: u32,
230 pub completion_tokens: u32,
231 pub total_tokens: u32,
232}
233
234impl std::fmt::Display for Usage {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(
237 f,
238 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
239 self.prompt_tokens, self.completion_tokens, self.total_tokens
240 )
241 }
242}
243
244impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
245 type Error = CompletionError;
246
247 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
248 let choice = response.choices.first().ok_or_else(|| {
249 CompletionError::ResponseError("Response contained no choices".to_owned())
250 })?;
251
252 match &choice.message {
253 Message {
254 role: Role::Assistant,
255 content,
256 } => Ok(completion::CompletionResponse {
257 choice: OneOrMany::one(content.clone().into()),
258 usage: completion::Usage {
259 input_tokens: response.usage.prompt_tokens as u64,
260 output_tokens: response.usage.completion_tokens as u64,
261 total_tokens: response.usage.total_tokens as u64,
262 },
263 raw_response: response,
264 }),
265 _ => Err(CompletionError::ResponseError(
266 "Response contained no assistant message".to_owned(),
267 )),
268 }
269 }
270}
271
272#[derive(Clone)]
273pub struct CompletionModel<T> {
274 client: Client<T>,
275 pub model: String,
276}
277
278impl<T> CompletionModel<T> {
279 pub fn new(client: Client<T>, model: &str) -> Self {
280 Self {
281 client,
282 model: model.to_string(),
283 }
284 }
285
286 fn create_completion_request(
287 &self,
288 completion_request: CompletionRequest,
289 ) -> Result<Value, CompletionError> {
290 if completion_request.tool_choice.is_some() {
291 tracing::warn!("WARNING: `tool_choice` not supported on Perplexity");
292 }
293
294 let mut partial_history = vec![];
296 if let Some(docs) = completion_request.normalized_documents() {
297 partial_history.push(docs);
298 }
299 partial_history.extend(completion_request.chat_history);
300
301 let mut full_history: Vec<Message> =
303 completion_request
304 .preamble
305 .map_or_else(Vec::new, |preamble| {
306 vec![Message {
307 role: Role::System,
308 content: preamble,
309 }]
310 });
311
312 full_history.extend(
314 partial_history
315 .into_iter()
316 .map(message::Message::try_into)
317 .collect::<Result<Vec<Message>, _>>()?,
318 );
319
320 let request = json!({
322 "model": self.model,
323 "messages": full_history,
324 "temperature": completion_request.temperature,
325 });
326
327 let request = if let Some(ref params) = completion_request.additional_params {
328 json_utils::merge(request, params.clone())
329 } else {
330 request
331 };
332
333 Ok(request)
334 }
335}
336
337impl TryFrom<message::Message> for Message {
338 type Error = MessageError;
339
340 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
341 Ok(match message {
342 message::Message::User { content } => {
343 let collapsed_content = content
344 .into_iter()
345 .map(|content| match content {
346 message::UserContent::Text(message::Text { text }) => Ok(text),
347 _ => Err(MessageError::ConversionError(
348 "Only text content is supported by Perplexity".to_owned(),
349 )),
350 })
351 .collect::<Result<Vec<_>, _>>()?
352 .join("\n");
353
354 Message {
355 role: Role::User,
356 content: collapsed_content,
357 }
358 }
359
360 message::Message::Assistant { content, .. } => {
361 let collapsed_content = content
362 .into_iter()
363 .map(|content| {
364 Ok(match content {
365 message::AssistantContent::Text(message::Text { text }) => text,
366 _ => return Err(MessageError::ConversionError(
367 "Only text assistant message content is supported by Perplexity"
368 .to_owned(),
369 )),
370 })
371 })
372 .collect::<Result<Vec<_>, _>>()?
373 .join("\n");
374
375 Message {
376 role: Role::Assistant,
377 content: collapsed_content,
378 }
379 }
380 })
381 }
382}
383
384impl From<Message> for message::Message {
385 fn from(message: Message) -> Self {
386 match message.role {
387 Role::User => message::Message::user(message.content),
388 Role::Assistant => message::Message::assistant(message.content),
389
390 Role::System => message::Message::user(message.content),
393 }
394 }
395}
396
397impl completion::CompletionModel for CompletionModel<reqwest::Client> {
398 type Response = CompletionResponse;
399 type StreamingResponse = openai::StreamingCompletionResponse;
400
401 #[cfg_attr(feature = "worker", worker::send)]
402 async fn completion(
403 &self,
404 completion_request: completion::CompletionRequest,
405 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
406 let preamble = completion_request.preamble.clone();
407 let request = self.create_completion_request(completion_request)?;
408
409 let span = if tracing::Span::current().is_disabled() {
410 info_span!(
411 target: "rig::completions",
412 "chat",
413 gen_ai.operation.name = "chat",
414 gen_ai.provider.name = "perplexity",
415 gen_ai.request.model = self.model,
416 gen_ai.system_instructions = preamble,
417 gen_ai.response.id = tracing::field::Empty,
418 gen_ai.response.model = tracing::field::Empty,
419 gen_ai.usage.output_tokens = tracing::field::Empty,
420 gen_ai.usage.input_tokens = tracing::field::Empty,
421 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
422 gen_ai.output.messages = tracing::field::Empty,
423 )
424 } else {
425 tracing::Span::current()
426 };
427
428 let async_block = async move {
429 let response = self
430 .client
431 .reqwest_post("/chat/completions")
432 .json(&request)
433 .send()
434 .await
435 .map_err(|e| http_client::Error::Instance(e.into()))?;
436
437 if response.status().is_success() {
438 match response
439 .json::<ApiResponse<CompletionResponse>>()
440 .await
441 .map_err(|e| http_client::Error::Instance(e.into()))?
442 {
443 ApiResponse::Ok(completion) => {
444 let span = tracing::Span::current();
445 span.record("gen_ai.usage.input_tokens", completion.usage.prompt_tokens);
446 span.record(
447 "gen_ai.usage.output_tokens",
448 completion.usage.completion_tokens,
449 );
450 span.record(
451 "gen_ai.output.messages",
452 serde_json::to_string(&completion.choices).unwrap(),
453 );
454 span.record("gen_ai.response.id", completion.id.to_string());
455 span.record("gen_ai.response.model_name", completion.model.to_string());
456 Ok(completion.try_into()?)
457 }
458 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
459 }
460 } else {
461 Err(CompletionError::ProviderError(
462 response
463 .text()
464 .await
465 .map_err(|e| http_client::Error::Instance(e.into()))?,
466 ))
467 }
468 };
469
470 async_block.instrument(span).await
471 }
472
473 #[cfg_attr(feature = "worker", worker::send)]
474 async fn stream(
475 &self,
476 completion_request: completion::CompletionRequest,
477 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
478 let preamble = completion_request.preamble.clone();
479 let mut request = self.create_completion_request(completion_request)?;
480
481 request = merge(request, json!({"stream": true}));
482
483 let builder = self.client.reqwest_post("/chat/completions").json(&request);
484
485 let span = if tracing::Span::current().is_disabled() {
486 info_span!(
487 target: "rig::completions",
488 "chat_streaming",
489 gen_ai.operation.name = "chat_streaming",
490 gen_ai.provider.name = "perplexity",
491 gen_ai.request.model = self.model,
492 gen_ai.system_instructions = preamble,
493 gen_ai.response.id = tracing::field::Empty,
494 gen_ai.response.model = tracing::field::Empty,
495 gen_ai.usage.output_tokens = tracing::field::Empty,
496 gen_ai.usage.input_tokens = tracing::field::Empty,
497 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
498 gen_ai.output.messages = tracing::field::Empty,
499 )
500 } else {
501 tracing::Span::current()
502 };
503 send_compatible_streaming_request(builder)
504 .instrument(span)
505 .await
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_deserialize_message() {
515 let json_data = r#"
516 {
517 "role": "user",
518 "content": "Hello, how can I help you?"
519 }
520 "#;
521
522 let message: Message = serde_json::from_str(json_data).unwrap();
523 assert_eq!(message.role, Role::User);
524 assert_eq!(message.content, "Hello, how can I help you?");
525 }
526
527 #[test]
528 fn test_serialize_message() {
529 let message = Message {
530 role: Role::Assistant,
531 content: "I am here to assist you.".to_string(),
532 };
533
534 let json_data = serde_json::to_string(&message).unwrap();
535 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
536 assert_eq!(json_data, expected_json);
537 }
538
539 #[test]
540 fn test_message_to_message_conversion() {
541 let user_message = message::Message::user("User message");
542 let assistant_message = message::Message::assistant("Assistant message");
543
544 let converted_user_message: Message = user_message.clone().try_into().unwrap();
545 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
546
547 assert_eq!(converted_user_message.role, Role::User);
548 assert_eq!(converted_user_message.content, "User message");
549
550 assert_eq!(converted_assistant_message.role, Role::Assistant);
551 assert_eq!(converted_assistant_message.content, "Assistant message");
552
553 let back_to_user_message: message::Message = converted_user_message.into();
554 let back_to_assistant_message: message::Message = converted_assistant_message.into();
555
556 assert_eq!(user_message, back_to_user_message);
557 assert_eq!(assistant_message, back_to_assistant_message);
558 }
559}