1use crate::{
12 OneOrMany,
13 agent::AgentBuilder,
14 completion::{self, CompletionError, MessageError, message},
15 extractor::ExtractorBuilder,
16 impl_conversion_traits, json_utils,
17};
18
19use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
20use crate::completion::CompletionRequest;
21use crate::json_utils::merge;
22use crate::providers::openai;
23use crate::providers::openai::send_compatible_streaming_request;
24use crate::streaming::StreamingCompletionResponse;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28
29const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
33
34pub struct ClientBuilder<'a> {
35 api_key: &'a str,
36 base_url: &'a str,
37 http_client: Option<reqwest::Client>,
38}
39
40impl<'a> ClientBuilder<'a> {
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: PERPLEXITY_API_BASE_URL,
45 http_client: None,
46 }
47 }
48
49 pub fn base_url(mut self, base_url: &'a str) -> Self {
50 self.base_url = base_url;
51 self
52 }
53
54 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
55 self.http_client = Some(client);
56 self
57 }
58
59 pub fn build(self) -> Result<Client, ClientBuilderError> {
60 let http_client = if let Some(http_client) = self.http_client {
61 http_client
62 } else {
63 reqwest::Client::builder().build()?
64 };
65
66 Ok(Client {
67 base_url: self.base_url.to_string(),
68 api_key: self.api_key.to_string(),
69 http_client,
70 })
71 }
72}
73
74#[derive(Clone)]
75pub struct Client {
76 base_url: String,
77 api_key: String,
78 http_client: reqwest::Client,
79}
80
81impl std::fmt::Debug for Client {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Client")
84 .field("base_url", &self.base_url)
85 .field("http_client", &self.http_client)
86 .field("api_key", &"<REDACTED>")
87 .finish()
88 }
89}
90
91impl Client {
92 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
103 ClientBuilder::new(api_key)
104 }
105
106 pub fn new(api_key: &str) -> Self {
111 Self::builder(api_key)
112 .build()
113 .expect("Perplexity client should build")
114 }
115
116 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
117 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
118 self.http_client.post(url).bearer_auth(&self.api_key)
119 }
120
121 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
122 AgentBuilder::new(self.completion_model(model))
123 }
124
125 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
126 &self,
127 model: &str,
128 ) -> ExtractorBuilder<T, CompletionModel> {
129 ExtractorBuilder::new(self.completion_model(model))
130 }
131}
132
133impl ProviderClient for Client {
134 fn from_env() -> Self {
137 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
138 Self::new(&api_key)
139 }
140
141 fn from_val(input: crate::client::ProviderValue) -> Self {
142 let crate::client::ProviderValue::Simple(api_key) = input else {
143 panic!("Incorrect provider value type")
144 };
145 Self::new(&api_key)
146 }
147}
148
149impl CompletionClient for Client {
150 type CompletionModel = CompletionModel;
151
152 fn completion_model(&self, model: &str) -> CompletionModel {
153 CompletionModel::new(self.clone(), model)
154 }
155}
156
157impl_conversion_traits!(
158 AsTranscription,
159 AsEmbeddings,
160 AsImageGeneration,
161 AsAudioGeneration for Client
162);
163
164#[derive(Debug, Deserialize)]
165struct ApiErrorResponse {
166 message: String,
167}
168
169#[derive(Debug, Deserialize)]
170#[serde(untagged)]
171enum ApiResponse<T> {
172 Ok(T),
173 Err(ApiErrorResponse),
174}
175
176pub const SONAR_PRO: &str = "sonar-pro";
181pub const SONAR: &str = "sonar";
183
184#[derive(Debug, Deserialize, Serialize)]
185pub struct CompletionResponse {
186 pub id: String,
187 pub model: String,
188 pub object: String,
189 pub created: u64,
190 #[serde(default)]
191 pub choices: Vec<Choice>,
192 pub usage: Usage,
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
196pub struct Message {
197 pub role: Role,
198 pub content: String,
199}
200
201#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
202#[serde(rename_all = "lowercase")]
203pub enum Role {
204 System,
205 User,
206 Assistant,
207}
208
209#[derive(Deserialize, Debug, Serialize)]
210pub struct Delta {
211 pub role: Role,
212 pub content: String,
213}
214
215#[derive(Deserialize, Debug, Serialize)]
216pub struct Choice {
217 pub index: usize,
218 pub finish_reason: String,
219 pub message: Message,
220 pub delta: Delta,
221}
222
223#[derive(Deserialize, Debug, Serialize)]
224pub struct Usage {
225 pub prompt_tokens: u32,
226 pub completion_tokens: u32,
227 pub total_tokens: u32,
228}
229
230impl std::fmt::Display for Usage {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 write!(
233 f,
234 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
235 self.prompt_tokens, self.completion_tokens, self.total_tokens
236 )
237 }
238}
239
240impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
241 type Error = CompletionError;
242
243 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
244 let choice = response.choices.first().ok_or_else(|| {
245 CompletionError::ResponseError("Response contained no choices".to_owned())
246 })?;
247
248 match &choice.message {
249 Message {
250 role: Role::Assistant,
251 content,
252 } => Ok(completion::CompletionResponse {
253 choice: OneOrMany::one(content.clone().into()),
254 usage: completion::Usage {
255 input_tokens: response.usage.prompt_tokens as u64,
256 output_tokens: response.usage.completion_tokens as u64,
257 total_tokens: response.usage.total_tokens as u64,
258 },
259 raw_response: response,
260 }),
261 _ => Err(CompletionError::ResponseError(
262 "Response contained no assistant message".to_owned(),
263 )),
264 }
265 }
266}
267
268#[derive(Clone)]
269pub struct CompletionModel {
270 client: Client,
271 pub model: String,
272}
273
274impl CompletionModel {
275 pub fn new(client: Client, model: &str) -> Self {
276 Self {
277 client,
278 model: model.to_string(),
279 }
280 }
281
282 fn create_completion_request(
283 &self,
284 completion_request: CompletionRequest,
285 ) -> Result<Value, CompletionError> {
286 let mut partial_history = vec![];
288 if let Some(docs) = completion_request.normalized_documents() {
289 partial_history.push(docs);
290 }
291 partial_history.extend(completion_request.chat_history);
292
293 let mut full_history: Vec<Message> =
295 completion_request
296 .preamble
297 .map_or_else(Vec::new, |preamble| {
298 vec![Message {
299 role: Role::System,
300 content: preamble,
301 }]
302 });
303
304 full_history.extend(
306 partial_history
307 .into_iter()
308 .map(message::Message::try_into)
309 .collect::<Result<Vec<Message>, _>>()?,
310 );
311
312 let request = json!({
314 "model": self.model,
315 "messages": full_history,
316 "temperature": completion_request.temperature,
317 });
318
319 let request = if let Some(ref params) = completion_request.additional_params {
320 json_utils::merge(request, params.clone())
321 } else {
322 request
323 };
324
325 Ok(request)
326 }
327}
328
329impl TryFrom<message::Message> for Message {
330 type Error = MessageError;
331
332 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
333 Ok(match message {
334 message::Message::User { content } => {
335 let collapsed_content = content
336 .into_iter()
337 .map(|content| match content {
338 message::UserContent::Text(message::Text { text }) => Ok(text),
339 _ => Err(MessageError::ConversionError(
340 "Only text content is supported by Perplexity".to_owned(),
341 )),
342 })
343 .collect::<Result<Vec<_>, _>>()?
344 .join("\n");
345
346 Message {
347 role: Role::User,
348 content: collapsed_content,
349 }
350 }
351
352 message::Message::Assistant { content, .. } => {
353 let collapsed_content = content
354 .into_iter()
355 .map(|content| {
356 Ok(match content {
357 message::AssistantContent::Text(message::Text { text }) => text,
358 _ => return Err(MessageError::ConversionError(
359 "Only text assistant message content is supported by Perplexity"
360 .to_owned(),
361 )),
362 })
363 })
364 .collect::<Result<Vec<_>, _>>()?
365 .join("\n");
366
367 Message {
368 role: Role::Assistant,
369 content: collapsed_content,
370 }
371 }
372 })
373 }
374}
375
376impl From<Message> for message::Message {
377 fn from(message: Message) -> Self {
378 match message.role {
379 Role::User => message::Message::user(message.content),
380 Role::Assistant => message::Message::assistant(message.content),
381
382 Role::System => message::Message::user(message.content),
385 }
386 }
387}
388
389impl completion::CompletionModel for CompletionModel {
390 type Response = CompletionResponse;
391 type StreamingResponse = openai::StreamingCompletionResponse;
392
393 #[cfg_attr(feature = "worker", worker::send)]
394 async fn completion(
395 &self,
396 completion_request: completion::CompletionRequest,
397 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
398 let request = self.create_completion_request(completion_request)?;
399
400 let response = self
401 .client
402 .post("/chat/completions")
403 .json(&request)
404 .send()
405 .await?;
406
407 if response.status().is_success() {
408 match response.json::<ApiResponse<CompletionResponse>>().await? {
409 ApiResponse::Ok(completion) => {
410 tracing::info!(target: "rig",
411 "Perplexity completion token usage: {}",
412 completion.usage
413 );
414 Ok(completion.try_into()?)
415 }
416 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
417 }
418 } else {
419 Err(CompletionError::ProviderError(response.text().await?))
420 }
421 }
422
423 #[cfg_attr(feature = "worker", worker::send)]
424 async fn stream(
425 &self,
426 completion_request: completion::CompletionRequest,
427 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
428 let mut request = self.create_completion_request(completion_request)?;
429
430 request = merge(request, json!({"stream": true}));
431
432 let builder = self.client.post("/chat/completions").json(&request);
433
434 send_compatible_streaming_request(builder).await
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_deserialize_message() {
444 let json_data = r#"
445 {
446 "role": "user",
447 "content": "Hello, how can I help you?"
448 }
449 "#;
450
451 let message: Message = serde_json::from_str(json_data).unwrap();
452 assert_eq!(message.role, Role::User);
453 assert_eq!(message.content, "Hello, how can I help you?");
454 }
455
456 #[test]
457 fn test_serialize_message() {
458 let message = Message {
459 role: Role::Assistant,
460 content: "I am here to assist you.".to_string(),
461 };
462
463 let json_data = serde_json::to_string(&message).unwrap();
464 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
465 assert_eq!(json_data, expected_json);
466 }
467
468 #[test]
469 fn test_message_to_message_conversion() {
470 let user_message = message::Message::user("User message");
471 let assistant_message = message::Message::assistant("Assistant message");
472
473 let converted_user_message: Message = user_message.clone().try_into().unwrap();
474 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
475
476 assert_eq!(converted_user_message.role, Role::User);
477 assert_eq!(converted_user_message.content, "User message");
478
479 assert_eq!(converted_assistant_message.role, Role::Assistant);
480 assert_eq!(converted_assistant_message.content, "Assistant message");
481
482 let back_to_user_message: message::Message = converted_user_message.into();
483 let back_to_assistant_message: message::Message = converted_assistant_message.into();
484
485 assert_eq!(user_message, back_to_user_message);
486 assert_eq!(assistant_message, back_to_assistant_message);
487 }
488}