1use crate::{
13 agent::AgentBuilder,
14 completion::{self, message, CompletionError, MessageError},
15 extractor::ExtractorBuilder,
16 impl_conversion_traits, json_utils, OneOrMany,
17};
18
19use crate::client::{CompletionClient, ProviderClient};
20use crate::completion::CompletionRequest;
21use crate::json_utils::merge;
22use crate::providers::openai;
23use crate::providers::openai::send_compatible_streaming_request;
24use crate::streaming::StreamingCompletionResponse;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::{json, Value};
28
29const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
33
34#[derive(Clone, Debug)]
35pub struct Client {
36 base_url: String,
37 http_client: reqwest::Client,
38}
39
40impl Client {
41 pub fn new(api_key: &str) -> Self {
42 Self::from_url(api_key, PERPLEXITY_API_BASE_URL)
43 }
44
45 pub fn from_url(api_key: &str, base_url: &str) -> Self {
46 Self {
47 base_url: base_url.to_string(),
48 http_client: reqwest::Client::builder()
49 .default_headers({
50 let mut headers = reqwest::header::HeaderMap::new();
51 headers.insert(
52 "Authorization",
53 format!("Bearer {api_key}")
54 .parse()
55 .expect("Bearer token should parse"),
56 );
57 headers
58 })
59 .build()
60 .expect("Perplexity reqwest client should build"),
61 }
62 }
63
64 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
65 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
66 self.http_client.post(url)
67 }
68
69 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
70 AgentBuilder::new(self.completion_model(model))
71 }
72
73 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
74 &self,
75 model: &str,
76 ) -> ExtractorBuilder<T, CompletionModel> {
77 ExtractorBuilder::new(self.completion_model(model))
78 }
79}
80
81impl ProviderClient for Client {
82 fn from_env() -> Self {
85 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
86 Self::new(&api_key)
87 }
88}
89
90impl CompletionClient for Client {
91 type CompletionModel = CompletionModel;
92
93 fn completion_model(&self, model: &str) -> CompletionModel {
94 CompletionModel::new(self.clone(), model)
95 }
96}
97
98impl_conversion_traits!(
99 AsTranscription,
100 AsEmbeddings,
101 AsImageGeneration,
102 AsAudioGeneration for Client
103);
104
105#[derive(Debug, Deserialize)]
106struct ApiErrorResponse {
107 message: String,
108}
109
110#[derive(Debug, Deserialize)]
111#[serde(untagged)]
112enum ApiResponse<T> {
113 Ok(T),
114 Err(ApiErrorResponse),
115}
116
117pub const SONAR_PRO: &str = "sonar-pro";
122pub const SONAR: &str = "sonar";
124
125#[derive(Debug, Deserialize)]
126pub struct CompletionResponse {
127 pub id: String,
128 pub model: String,
129 pub object: String,
130 pub created: u64,
131 #[serde(default)]
132 pub choices: Vec<Choice>,
133 pub usage: Usage,
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
137pub struct Message {
138 pub role: Role,
139 pub content: String,
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
143#[serde(rename_all = "lowercase")]
144pub enum Role {
145 System,
146 User,
147 Assistant,
148}
149
150#[derive(Deserialize, Debug)]
151pub struct Delta {
152 pub role: Role,
153 pub content: String,
154}
155
156#[derive(Deserialize, Debug)]
157pub struct Choice {
158 pub index: usize,
159 pub finish_reason: String,
160 pub message: Message,
161 pub delta: Delta,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct Usage {
166 pub prompt_tokens: u32,
167 pub completion_tokens: u32,
168 pub total_tokens: u32,
169}
170
171impl std::fmt::Display for Usage {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(
174 f,
175 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
176 self.prompt_tokens, self.completion_tokens, self.total_tokens
177 )
178 }
179}
180
181impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
182 type Error = CompletionError;
183
184 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
185 let choice = response.choices.first().ok_or_else(|| {
186 CompletionError::ResponseError("Response contained no choices".to_owned())
187 })?;
188
189 match &choice.message {
190 Message {
191 role: Role::Assistant,
192 content,
193 } => Ok(completion::CompletionResponse {
194 choice: OneOrMany::one(content.clone().into()),
195 raw_response: response,
196 }),
197 _ => Err(CompletionError::ResponseError(
198 "Response contained no assistant message".to_owned(),
199 )),
200 }
201 }
202}
203
204#[derive(Clone)]
205pub struct CompletionModel {
206 client: Client,
207 pub model: String,
208}
209
210impl CompletionModel {
211 pub fn new(client: Client, model: &str) -> Self {
212 Self {
213 client,
214 model: model.to_string(),
215 }
216 }
217
218 fn create_completion_request(
219 &self,
220 completion_request: CompletionRequest,
221 ) -> Result<Value, CompletionError> {
222 let mut partial_history = vec![];
224 if let Some(docs) = completion_request.normalized_documents() {
225 partial_history.push(docs);
226 }
227 partial_history.extend(completion_request.chat_history);
228
229 let mut full_history: Vec<Message> =
231 completion_request
232 .preamble
233 .map_or_else(Vec::new, |preamble| {
234 vec![Message {
235 role: Role::System,
236 content: preamble,
237 }]
238 });
239
240 full_history.extend(
242 partial_history
243 .into_iter()
244 .map(message::Message::try_into)
245 .collect::<Result<Vec<Message>, _>>()?,
246 );
247
248 let request = json!({
250 "model": self.model,
251 "messages": full_history,
252 "temperature": completion_request.temperature,
253 });
254
255 let request = if let Some(ref params) = completion_request.additional_params {
256 json_utils::merge(request, params.clone())
257 } else {
258 request
259 };
260
261 Ok(request)
262 }
263}
264
265impl TryFrom<message::Message> for Message {
266 type Error = MessageError;
267
268 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
269 Ok(match message {
270 message::Message::User { content } => {
271 let collapsed_content = content
272 .into_iter()
273 .map(|content| match content {
274 message::UserContent::Text(message::Text { text }) => Ok(text),
275 _ => Err(MessageError::ConversionError(
276 "Only text content is supported by Perplexity".to_owned(),
277 )),
278 })
279 .collect::<Result<Vec<_>, _>>()?
280 .join("\n");
281
282 Message {
283 role: Role::User,
284 content: collapsed_content,
285 }
286 }
287
288 message::Message::Assistant { content } => {
289 let collapsed_content = content
290 .into_iter()
291 .map(|content| {
292 Ok(match content {
293 message::AssistantContent::Text(message::Text { text }) => text,
294 _ => return Err(MessageError::ConversionError(
295 "Only text assistant message content is supported by Perplexity"
296 .to_owned(),
297 )),
298 })
299 })
300 .collect::<Result<Vec<_>, _>>()?
301 .join("\n");
302
303 Message {
304 role: Role::Assistant,
305 content: collapsed_content,
306 }
307 }
308 })
309 }
310}
311
312impl From<Message> for message::Message {
313 fn from(message: Message) -> Self {
314 match message.role {
315 Role::User => message::Message::user(message.content),
316 Role::Assistant => message::Message::assistant(message.content),
317
318 Role::System => message::Message::user(message.content),
321 }
322 }
323}
324
325impl completion::CompletionModel for CompletionModel {
326 type Response = CompletionResponse;
327 type StreamingResponse = openai::StreamingCompletionResponse;
328
329 #[cfg_attr(feature = "worker", worker::send)]
330 async fn completion(
331 &self,
332 completion_request: completion::CompletionRequest,
333 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
334 let request = self.create_completion_request(completion_request)?;
335
336 let response = self
337 .client
338 .post("/chat/completions")
339 .json(&request)
340 .send()
341 .await?;
342
343 if response.status().is_success() {
344 match response.json::<ApiResponse<CompletionResponse>>().await? {
345 ApiResponse::Ok(completion) => {
346 tracing::info!(target: "rig",
347 "Perplexity completion token usage: {}",
348 completion.usage
349 );
350 Ok(completion.try_into()?)
351 }
352 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
353 }
354 } else {
355 Err(CompletionError::ProviderError(response.text().await?))
356 }
357 }
358
359 #[cfg_attr(feature = "worker", worker::send)]
360 async fn stream(
361 &self,
362 completion_request: completion::CompletionRequest,
363 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
364 let mut request = self.create_completion_request(completion_request)?;
365
366 request = merge(request, json!({"stream": true}));
367
368 let builder = self.client.post("/chat/completions").json(&request);
369
370 send_compatible_streaming_request(builder).await
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_deserialize_message() {
380 let json_data = r#"
381 {
382 "role": "user",
383 "content": "Hello, how can I help you?"
384 }
385 "#;
386
387 let message: Message = serde_json::from_str(json_data).unwrap();
388 assert_eq!(message.role, Role::User);
389 assert_eq!(message.content, "Hello, how can I help you?");
390 }
391
392 #[test]
393 fn test_serialize_message() {
394 let message = Message {
395 role: Role::Assistant,
396 content: "I am here to assist you.".to_string(),
397 };
398
399 let json_data = serde_json::to_string(&message).unwrap();
400 let expected_json = r#"{"role":"assistant","content":"I am here to assist you."}"#;
401 assert_eq!(json_data, expected_json);
402 }
403
404 #[test]
405 fn test_message_to_message_conversion() {
406 let user_message = message::Message::user("User message");
407 let assistant_message = message::Message::assistant("Assistant message");
408
409 let converted_user_message: Message = user_message.clone().try_into().unwrap();
410 let converted_assistant_message: Message = assistant_message.clone().try_into().unwrap();
411
412 assert_eq!(converted_user_message.role, Role::User);
413 assert_eq!(converted_user_message.content, "User message");
414
415 assert_eq!(converted_assistant_message.role, Role::Assistant);
416 assert_eq!(converted_assistant_message.content, "Assistant message");
417
418 let back_to_user_message: message::Message = converted_user_message.into();
419 let back_to_assistant_message: message::Message = converted_assistant_message.into();
420
421 assert_eq!(user_message, back_to_user_message);
422 assert_eq!(assistant_message, back_to_assistant_message);
423 }
424}